pytorch笔记:14)从零开始玩转目标检测模型CenterNet

机器·深度学习 专栏收录该内容
37 篇文章 2 订阅

前言

真正的东西,是朴素的,也是优雅的”,这句话用来形容CenterNet绝不为过。笔者参考论文和官方源码,抽取目标检测的精华部分,致力朴素易懂,使用pytorch重新构建了一遍,并添加了注释。
本文以代码作为切入点,需要了解CenterNet原理,推荐参考扔掉anchor!真正的CenterNet——Objects as Points论文解读

本内容的github地址,官方源码可见文末的参考文献。

模型搭建

模型结构:resnet18+上采样+3个header输出 (图来自原论文)
resnet_backbone
在原resnet.py__init__()函数中添加了如下6行代码,self.layer5至self.layer7是上采样操作,self.hm,self.wh,self.reg为模型的3个输出Header,分别为类别关键点的heatmap图,长宽的回归,缩放坐标偏移

self.layer4 = self._make_layer(block, 512, layers[3], stride=2)  # /32
#上方为torchvision自带源码
self.layer5 = self._make_deconv_layer(512, 256)  # /16
self.layer6 = self._make_deconv_layer(256, 128)  # /8
self.layer7 = self._make_deconv_layer(128, 64)  # /4

self.hm = self._make_header_layer(64, num_classes)  # heatmap
self.wh = self._make_header_layer(64, 2)  # width and height
self.reg = self._make_header_layer(64, 2)  # regress offset

上采样函数如下,官方源码在上采样时卷积使用的是DeformableConvolutionalNetworks,笔者为了运行方便,就直接使用了传统卷积

# add for upsample
def _make_deconv_layer(self, in_ch, out_ch):
	deconv = nn.Sequential(
		nn.Conv2d(in_ch, out_ch, 3, 1, 1),
		nn.BatchNorm2d(out_ch),
		nn.ReLU(inplace=True),
		nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1),
		nn.BatchNorm2d(out_ch),
		nn.ReLU(inplace=True)
	)
	return deconv

用于生成header的函数如下

# add for three headers
def _make_header_layer(self, in_ch, out_ch):
	header = nn.Sequential(
		nn.Conv2d(in_ch, in_ch, 3, 1, 1),
		nn.ReLU(inplace=True),
		nn.Conv2d(in_ch, out_ch, 1, 1)
	)
	return header

数据集构建

SeaShips数据集:数据集共有7000张图片,图片分辨率均为1920x1080,分为六类船只(数据地址)。先放一张模型训练72个epoch的测试图。

test_image

数据初始化

Dataset是标准的torch格式,在__getitem__函数中,list_bbox_cls为[(bbox1,cls1),(bbox2,cls2),],real_w, real_h为原图片的宽和高,down_ratio为下采样倍数(默认4);heatmap_size为模型最终输出heatmap图大小(默认128);hm,wh和reg上文已经介绍过;max_objs为一张图片内可能包含最大的目标数(轮船在一张图中比较少,默认32),ind为目标关键点在二维heatmap中对应的一维索引,reg_mask为目标mask数组,是否包含目标0/1;

class CTDataset(Dataset):
    def __init__(self, opt, data, transform=None):
        '''
        数据集构建
        :param opt: 配置参数
        :param data: [(img_path,[(bbox1,cls1),(bbox2,cls2),])..] bbox(左上右下)
        :param transform:
        '''

        self.images = data
        self.opt = opt
        self.transform = transform

    def __getitem__(self, index):
        img_path, list_bbox_cls = self.images[index]
        img = Image.open(img_path)
        real_w, real_h = img.size
        if self.transform: img = self.transform(img)
        heatmap_size = self.opt.input_size // self.opt.down_ratio
        # heatmap
        hm = np.zeros((self.opt.num_classes, heatmap_size, heatmap_size), dtype=np.float32)
        # withd and hight
        wh = np.zeros((self.opt.max_objs, 2), dtype=np.float32)
        # regression
        reg = np.zeros((self.opt.max_objs, 2), dtype=np.float32)
        # index in 1D heatmap
        ind = np.zeros((self.opt.max_objs), dtype=np.int)
        # 1=there is a target in the list 0=there is not
        reg_mask = np.zeros((self.opt.max_objs), dtype=np.uint8)

        # get the absolute ratio
        w_ratio = self.opt.input_size / real_w / self.opt.down_ratio
        h_ratio = self.opt.input_size / real_h / self.opt.down_ratio

        for i, (bbox, cls) in enumerate(list_bbox_cls):
            # original bbox size -> heatmap bbox size
            bbox = bbox[0] * w_ratio, bbox[1] * h_ratio, bbox[2] * w_ratio, bbox[3] * h_ratio
            width, height = bbox[2] - bbox[0], bbox[3] - bbox[1]
            # center point(x,y)
            center = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
            center_int = center.astype(np.int)
            reg[i] = center - center_int
            wh[i] = 1. * width, 1. * height
            reg_mask[i] = 1
            ind[i] = center_int[1] * heatmap_size + center_int[0]
            radius = utils.gaussian_radius((height, width))
            #半径保证为整数
            radius = max(0, int(radius))
            utils.draw_umich_gaussian(hm[cls], center_int, radius)
        return (img, hm, wh, reg, ind, reg_mask)

首先把原图缩放至固定大小(512x512),并计算出heatmap(128x128)相对的原图的缩放比(w_ratio,h_ratio ),然后在for循环中对每个bbox的坐标按比例缩放,并获取bbox在heatmap中的width, height。同时从代码中可以看到reg坐标偏移就是缩放取整时丢失的小数部分。

生成高斯HeatMap

首先根据bbox的长和宽,计算出对应的半径radius,半径计算可参考下面注释,原理就是解一元二次方程

def gaussian_radius(det_size, min_overlap=0.7):
    '''
    求高斯半径
    方法来自于CornerNet:https://arxiv.org/pdf/1808.01244.pdf
    原理就是对应三种情况(1内扩1外扩,2内扩,2外扩)解一元二次方程:https://github.com/princeton-vl/CornerNet/issues/110
    :param det_size: bbox在特征图的大小(h,w)
    :param min_overlap: 最小的IOU
    :return: 最小的半径,其保证iou>=min_overlap
    '''
    height, width = det_size

    a1 = 1
    b1 = (height + width)
    c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
    sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
    r1 = (b1 + sq1) / 2

    a2 = 4
    b2 = 2 * (height + width)
    c2 = (1 - min_overlap) * width * height
    sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
    r2 = (b2 + sq2) / 2

    a3 = 4 * min_overlap
    b3 = -2 * min_overlap * (height + width)
    c3 = (min_overlap - 1) * width * height
    sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
    r3 = (b3 + sq3) / 2
    return min(r1, r2, r3)

根据半径,生成高斯分布的函数,需注意sigma不能为0

def gaussian2D(shape, sigma=1):
    '''
    对输入shape(radius,radius)生产一个高斯核
    :param shape: (diameter,diameter)
    :param sigma:
    :return: (radius*2+1,radius*2+1)
    '''
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m + 1, -n:n + 1]
    h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
    h[h < np.finfo(h.dtype).eps * h.max()] = 0
    return h

这里细说下如何根据半径生成高斯分布,假设radius=3则其直径为7(包括关键点),np.ogrid用来构建x和y的相对坐标,而x和y的组合则是一个以中心点(关键点)为原点(0,0)坐标网格,比如左上角可到达的点为(-3,-3),每坐标点经过高斯核后则会生产如下高斯分布(多个高斯在同个位置叠加时,取较大值):
高斯map
由于radius参与了切片操作,须保证radius为整数,但考虑到radius取0会导致高斯函数中除数sigma为0,因此把半径转化为直径处理(radius=0,sigma=1/6)

def draw_umich_gaussian(heatmap, center, radius, k=1):
    '''
    在heatmap上对中心点center半径为radius画高斯分布
    :param heatmap: 特征图(128*128)
    :param center: (x,y)
    :param radius: 半径
    :param k:
    :return:
    '''
    # 从中心开始扩展,长度为2 * radius + 1,文章说是sigma=radius/3,
    # 这里也解决了radius=0,导致sigma除数为0的问题
    diameter = 2 * radius + 1
    gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)

    x, y = center[0], center[1]

    height, width = heatmap.shape[0:2]

    # 越界处理
    left, right = min(x, radius), min(width - x, radius + 1)
    top, bottom = min(y, radius), min(height - y, radius + 1)

    # 对齐处理
    masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
    masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
    if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:  # TODO debug
        np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
    return heatmap

模型训练

模型训练过程比较简单,重点关注2个loss函数

关键点FocalLoss

L k = − 1 N ∑ x y c = { ( 1 − Y ^ x y c ) α l o g ( Y ^ x y c ) , i f   Y x y c = 1 ( 1 − Y x y c ) β ( Y ^ x y c ) α l o g ( 1 − Y ^ x y c ) o t h e r w i s e L_{k}=\frac{-1}{N}\sum_{xyc}= \begin{cases} (1-\hat{Y}_{xyc})^{\alpha}log(\hat{Y}_{xyc}), &if\ Y_{xyc}=1 \\ (1-Y_{xyc})^{\beta}(\hat{Y}_{xyc})^{\alpha}log(1-\hat{Y}_{xyc})&otherwise \end{cases} Lk=N1xyc={1Y^xyc)αlog(Y^xyc),(1Yxyc)β(Y^xyc)αlog(1Y^xyc)if Yxyc=1otherwise
照葫芦画瓢,需注意output要经过sigmoid函数。

class FocalLoss(nn.Module):

    def __init__(self, alpha=2, beta=4):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta

    def forward(self, output, target, sigmoid=True):
        if sigmoid:  # clamp is important
            output = torch.clamp(output.sigmoid(), min=1e-4, max=1 - 1e-4)

        pos_index = target.eq(1).float()
        neg_index = target.lt(1).float()

        pos_loss = torch.pow(1 - output, self.alpha) * torch.log(output) * pos_index
        neg_loss = torch.pow(1 - target, self.beta) * torch.pow(output, self.alpha) * torch.log(1 - output) * neg_index

        pos_loss = pos_loss.sum()
        neg_loss = neg_loss.sum()
        pos_num = pos_index.sum()
        loss = 0
        loss = loss - (pos_loss + neg_loss) / pos_num if pos_num > 0 else loss - neg_loss

        return loss

初看CenterNet的时候,笔者就疑惑类别损失去哪啦? 其实关键点中已包含了类别信息了,比如在hm(batch_size,nb_classes,h,w)中不同类别(通道)对应同一个位置的heatmap,哪个得分最高,关键点就属于哪一类别。

RegressionLoss

L s i z e = 1 N ∑ r = 1 N ∣ S ^ p k − s k ∣ L_{size}=\frac{1}{N}\sum_{r=1}^N|\hat{S}_{_{pk}}-s_{k}| Lsize=N1r=1NS^pksk
L 1 L1 L1损失函数通用于坐标偏差reg和长宽wh, N N N为目标的个数。另外全篇用得较多的一个API是torch.gather(根据索引取数据)

class RegL1Loss(nn.Module):

    def __init__(self):
        super(RegL1Loss, self).__init__()
        # 只统计包含目标的元素
        self.l1_loss = nn.L1Loss(reduction='sum')
        self.eps = 1e-4

    def forward(self, output: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, index: torch.Tensor):
        '''
        l1_loss from wh_loss or off_loss
        :param output: the output of model (batch,2,128,128)
        :param target: (batch,max_objs,2)   max_objs=128
        :param mask: (batch,max_objs)
        :param index: (batch,max_objs)
        :return:
        '''
        # index为1维索引,把output的特征图也转成1维
        batch = output.size(0)
        output = output.view(batch, 2, -1).transpose(1, 2).contiguous()  # (batch,128*128,2)
        # torch.gather函数需要index的维度和output保持一致(提取那一维除外)
        index = index.unsqueeze(2).expand(batch, -1, 2)  # (batch,max_objs,2)
        # get the target number
        pos_num = mask.sum()
        # 提取target对应feature值
        output = torch.gather(output, 1, index)  # (batch,max_objs,2)
        # 使用mask对无目标的output进行mask,先扩展维度
        mask = mask.unsqueeze(2).expand_as(output).float()  # (batch,max_objs,2)

        loss = self.l1_loss(output * mask, target * mask)
        loss = loss / (pos_num + self.eps)
        return loss

附上官方源码,这里由于对mask进行了广播扩维,mask.sum()理论上为2N,也就是说loss=loss*0.5(可能是作者笔误,该BUG已提交)

def forward(self, output, mask, ind, target):
    pred = _transpose_and_gather_feat(output, ind)
    mask = mask.unsqueeze(2).expand_as(pred).float()
    # loss = F.l1_loss(pred * mask, target * mask, reduction='elementwise_mean')
    loss = F.l1_loss(pred * mask, target * mask, size_average=False)
    loss = loss / (mask.sum() + 1e-4)
    return loss

模型测试

模型输出三个值out_hm, out_wh, out_reg,首先从out_hm中定位出最佳关峰点的位置。何为峰点? 在该点的值比其周围8个点的值都大,该功能听着有点耳熟,没错它其实就是max_pool操作!
max_pool

def hm_topk(hm, k):
    # 使用max_pool获取峰点
    batch, cls, h, w = hm.size()
    out = F.max_pool2d(hm, 3, 1, 1)
    keep_max = (out == hm).float()
    hm = keep_max * hm
    # 在heatmap中取每个类别的topk(hm经sigmoid后) topk_indexs的值在0~h*w
    topk_scores, topk_indexs = hm.view(batch, cls, -1).topk(k)  # (batch,cls,k)
    # 所有类别取得分最高的topk
    topk_scores, topk_ind = topk_scores.view(batch, -1).topk(k)  # (batch,k)
    # 获取得分最高的类别topk,topk_scores每个类别有得分最高的k个,topk_ind除k取下整即为class
    topk_cls = topk_ind // k
    # 若topk_indexs.size=(batch,cls*k),topk_indexs[topk_ind]即为最终的index
    topk_indexs = topk_indexs.view(batch, -1).gather(1, topk_ind)
    # 获取所有类别中最高得分topk_indexs对应的横纵坐标,即一维转二维
    topk_ys, topk_xs = topk_indexs // w, topk_indexs % w
    return topk_scores, topk_indexs, topk_cls, topk_xs, topk_ys

heatmap_bbox函数对模型的输出,返回得分topk的bbox, cls, scores

def heatmap_bbox(hm, wh, reg, k=100):
    scores, indexs, cls, xs, ys = hm_topk(hm.sigmoid_(), k)
    print(scores, indexs, cls, xs, ys)
    batch = reg.size(0)
    # 先转置便于取关键点对应的2个偏移量
    reg = reg.view(batch, 2, -1).transpose(2, 1).contiguous()  # (batch,w*h,2)
    reg_indexs = indexs.unsqueeze(2).expand(batch, -1, 2)  # (batch,k,2)
    reg = reg.gather(1, reg_indexs)  # (batch,k,2)
    xs = xs.float() + reg[:, :, 0]
    ys = ys.float() + reg[:, :, 1]
    # wh via reg_indexs
    wh = wh.view(batch, 2, -1).transpose(2, 1).contiguous().gather(1, reg_indexs)  # ((batch,k,2)
    # bbox via xs and wh
    bbox = xs - wh[:, :, 0] / 2, ys - wh[:, :, 1] / 2, xs + wh[:, :, 0] / 2, ys + wh[:, :, 1] / 2
    bbox = torch.stack(bbox, -1)  # (batch,k,4)
    return bbox, cls, scores

若对输出不进行后续处理,可能会出现多个box重叠情况,虽然max_pool中起到了nms效果,但其针对同一类别的,若模型没有训练充分,不同类别的关键点可能聚集在同一个位置。解决该问题有如下2种方法:
1> 设定一个score阈值,对小于该阈值的box都舍弃,但该方法不能解决聚集的关键点score都大于阈值的情况
2>对所有的box使用soft_nms算法,官方源码也提供了该选项(推荐使用)
box重叠

参考文献

在SeaShips数据集上训练CenterNet网络
CenterNet官方源码地址
扔掉anchor!真正的CenterNet——Objects as Points论文解读

  • 3
    点赞
  • 5
    评论
  • 31
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

相关推荐
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值