yoloseg原理:

YOLOv5-seg 是一个基于 YOLOv5 模型的变种,它用于进行图像分割。
他的主干网络与yolov5一致,检测头 继承了Detect头并增加了分割掩码的功能:

head

# yolo.py
class Segment(Detect):
    # YOLOv5 Segment head for segmentation models
    def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), inplace=True):
        super().__init__(nc, anchors, ch, inplace)
        self.nm = nm  # number of masks
        self.npr = npr  # number of protos
        self.no = 5 + nc + self.nm  # number of outputs per anchor
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
        self.proto = Proto(ch[0], self.npr, self.nm)  # protos
        self.detect = Detect.forward

    def forward(self, x):
        p = self.proto(x[0])
        x = self.detect(self, x)
        return (x, p) if self.training else (x[0], p) if self.export else (x[0], p, x[1])

模型输出

visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred, proto = model(im, augment=augment, visualize=visualize)[:2]
  • 原型掩码(protos)产生用于分割模型的原型掩码。原型掩码是指用于表示每个类别的分割掩码的一组基本形状或模式。这些掩码可以被视为类别的模板,它们定义了该类别内对象的典型形状或轮廓。在训练过程中,模型学习为每个类别生成一组原型掩码。这些原型掩码捕获了该类别对象的形状变化和特征。
  • yolov5seg中最终的实例掩码是通过将原型掩码与预测的变换参数相乘得到的。这种方法允许模型在保持原型掩码的形状和结构特征的同时,对每个实例进行精确的掩码预测。

ProtoNet它类似于用于语义分割的全连接网络(FCN),对图像上的各个像素进行分类,通过将最后得到的输出进行上采样(双线性插值)还原到原图的分辨率。

  • 预测(preds):指的是模型对于输入图像的最终输出,包括边界框、对象置信度、类别概率和实例掩码。

每个锚点输出的特征数为边界框坐标+对象置信度+类别数+掩码数。例如,如果有1个类别和32个掩码,每个锚点的输出特征数为38。如果有三个锚点,总特征数为114。

如果有1个类别和32个掩码:

  • pred:[1,16380,38]
  • proto:[1,32,104,160]
  • model()[2]: List of Tensors: (1, 3, 52, 80, 38), (1, 3, 26, 40, 38), (1, 3, 13, 20, 38))

我们可以观察到16380即为3×(52×38+26×40+13×20),即为三个尺度上的检测到的特征点的和

掩码的工作流程大致如下:

  1. 原型生成:在训练过程中,模型学习为每个类别生成一组原型掩码。这些原型掩码捕获了该类别对象的形状变化和特征。
  2. 掩码预测:在推理时,模型会根据输入图像中检测到的对象实例,预测一组掩码。这些掩码是基于原型掩码的,并通过学习到的类别特定特征进行调整。
  3. 掩码细化:预测的掩码可能会通过后处理步骤进行细化,以更准确地贴合对象的实际轮廓。这可能包括使用图像分割技术来优化掩码的边缘。
  4. 实例分割输出:最终,每个检测到的对象实例都会有一个对应的掩码,这些掩码精确地表示了对象在图像中的位置和形状。

后处理

输出的pred经过后处理 得到非极大值抑制后的结果,如果有1个类别和32个掩码:后处理后的输出tensor(1, 38)

pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det, nm=32)
def non_max_suppression(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=300,
        nm=0,  # number of masks
):
    if isinstance(prediction, (list, tuple)):  # YOLOv5 model in validation model, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output

    device = prediction.device
    mps = 'mps' in device.type  # Apple MPS
    if mps:  # MPS not fully supported yet, convert tensors to CPU labelme_dataset NMS
        prediction = prediction.cpu()
    bs = prediction.shape[0]  # batch size
    nc = prediction.shape[2] - nm - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates
    mi = 5 + nc  # mask start index
    output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
    for xi, x in enumerate(prediction):  # image index, image inference
        
        x = x[xc[xi]]  # confidence
        
        # Compute conf
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

        # Box/Mask
        box = xywh2xyxy(x[:, :4])  # center_x, center_y, width, height) to (x1, y1, x2, y2)
        mask = x[:, mi:]  # zero columns if no masks

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
        else:  # best class only
            conf, j = x[:, 5:mi].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

       
        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        elif n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence
        else:
            x = x[x[:, 4].argsort(descending=True)]  # sort by confidence

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes 是否对锚点的宽度和高度进行缩放
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]
            
#在执行NMS之前,如果有必要,合并重叠的边界框。通过计算边界框之间的
#IoU,确定哪些边界框需要合并。使用权重(IoU乘以置信度)来计算合并后
#的边界框坐标。如果设置了 redundant,则移除那些没有与其他边界框重叠
#的冗余边界框。

        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if mps:
            output[xi] = output[xi].to(device)
        if (time.time() - t) > time_limit:
            LOGGER.warning(f'WARNING ️ NMS time limit {time_limit:.3f}s exceeded')
            break  # time limit exceeded

    return output
        
    

	    

处理掩码

for i, det in enumerate(pred):
	if len(det):
       masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True)  # HWC
       det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
       if save_txt:
          segments = reversed(masks2segments(masks))
          segments = [scale_segments(im.shape[2:], x, im0.shape, normalize=True) for x in segments]
def process_mask(protos, masks_in, bboxes, shape, upsample=False):
    c, mh, mw = protos.shape  # CHW
    ih, iw = shape
    #sigmoid(实例掩码*掩码原型).view(-1, mh, mw)  得到调整后的掩码概率值
    masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)  # CHW
    # 创建 bboxes 的副本,用于存储边界框的坐标。
    downsampled_bboxes = bboxes.clone()
    # 这些操作将边界框的坐标缩放到原型掩码的尺寸。这是为了确保边界框与掩码的尺寸匹配。
    downsampled_bboxes[:, 0] *= mw / iw
    downsampled_bboxes[:, 2] *= mw / iw
    downsampled_bboxes[:, 3] *= mh / ih
    downsampled_bboxes[:, 1] *= mh / ih
    # 根据边界框裁剪掩码,只保留边界框内的掩码
    masks = crop_mask(masks, downsampled_bboxes)  # CHW
    if upsample:
    # 对掩码进行双线性插值,还原到原图的分辨率
        masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0]  # CHW
    # 最后,返回二值化的掩码,其中概率大于0.5的像素被视为前景。
    return masks.gt_(0.5)

得到分割区域的边缘坐标

def masks2segments(masks, strategy='largest'):
    # Convert masks(n,160,160) into segments(n,xy)
    segments = []
    # 遍历 masks 张量中的每个掩码
    for x in masks.int().cpu().numpy().astype('uint8'):
        # 找到每个掩码中的轮廓
        c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
        if c:
            if strategy == 'concat':  # concatenate all segments
                c = np.concatenate([x.reshape(-1, 2) for x in c])
            #选择最大的轮廓(即包含最多点的轮廓)
            elif strategy == 'largest':  # select largest segment
                # 返回的是轮廓长度最大(即点数最多)的轮廓,重塑为一个二维数组
                c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
        else:
            c = np.zeros((0, 2))  # no segments found
        segments.append(c.astype('float32'))
    return segments

segments = reversed(masks2segments(masks))

缩放到原图尺度

segments = [scale_segments(im.shape[2:], x, im0.shape, normalize=True) for x in segments]
def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
    # Rescale coords (xyxy) from img1_shape to img0_shape
    if ratio_pad is None:  # calculate from img0_shape
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
    else:
        gain = ratio_pad[0][0]
        pad = ratio_pad[1]

    segments[:, 0] -= pad[0]  # x padding
    segments[:, 1] -= pad[1]  # y padding
    segments /= gain
    clip_segments(segments, img0_shape)
    if normalize:
        segments[:, 0] /= img0_shape[1]  # width
        segments[:, 1] /= img0_shape[0]  # height
    return segments

导出掩码坐标(0-1)

for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
    if save_txt:  # Write to file
        segj = segments[j].reshape(-1)  # (n,2) to (n*2)
        line = (cls, *segj, conf) if save_conf else (cls, *segj)  # label format
        with open(f'{txt_path}.txt', 'a') as f:
            f.write(('%g ' * len(line)).rstrip() % line + '\n')

点赞(0) 打赏

评论列表 共有 0 条评论

暂无评论

微信公众账号

微信扫一扫加关注

发表
评论
返回
顶部