• coco.py文件详解


    # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
    """
    COCO dataset which returns image_id for evaluation.
    
    Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
    """
    from pathlib import Path#处理文件和目录路径的模块
    
    import torch
    import torch.utils.data#用于创建和操作张量(tensors),以及构建数据加载器用于训练和测试
    import torchvision
    from pycocotools import mask as coco_mask#处理 COCO 数据集的 Python 工具包
    import datasets.transforms as T#自定义的数据预处理和增强操作
    
    
    class CocoDetection(torchvision.datasets.CocoDetection):
        def __init__(self, img_folder, ann_file, transforms, return_masks):
            super(CocoDetection, self).__init__(img_folder, ann_file)#img_folder 是图像文件夹的路径,ann_file 是 COCO 标注文件的路径
            self._transforms = transforms#transforms 是一系列用于数据预处理和增强的转换操作
            self.prepare = ConvertCocoPolysToMask(return_masks)#return_masks 是一个布尔值,表示是否返回目标的遮罩信息
    
        '''
        __getitem__(self, idx) 方法重写了父类的同名方法。它用于获取指定索引 idx 对应的图像和目标。
        首先,通过调用父类的 __getitem__ 方法,获取原始的图像和目标数据。
        然后,从 self.ids 中获取对应索引的图像 ID,并将图像 ID 和目标数据组织成一个字典 target: {'image_id': image_id, 'annotations': target}。
        接下来,通过调用 self.prepare() 方法对图像和目标数据进行进一步处理,如将多边形转换为遮罩。
        最后,如果存在数据转换操作 self._transforms,则将图像和目标数据传递给它们进行处理。
        返回处理后的图像和目标数据。
        '''
        def __getitem__(self, idx):
            img, target = super(CocoDetection, self).__getitem__(idx)
            image_id = self.ids[idx]
            target = {'image_id': image_id, 'annotations': target}
            img, target = self.prepare(img, target)
            if self._transforms is not None:
                img, target = self._transforms(img, target)
            return img, target
    
    '''segmentations 是包含多边形分割信息的列表。
    height 和 width 分别是目标图像的高度和宽度。
    masks 是一个空列表,用于存储遮罩掩膜。
    对于每个多边形 polygons,使用 coco_mask.frPyObjects() 函数将其转换为 COCO 格式的 RLE 编码,并且 height 和 width 参数告诉该函数需要生成怎样尺寸的遮罩掩膜。
    使用 coco_mask.decode() 函数将 RLE 编码转换为实际的遮罩掩膜。
    如果遮罩掩膜的维数小于三,则在最后一维上添加一个新的维度。
    将遮罩掩膜转换为 PyTorch 张量类型,并且只保留第二维和第三维的像素信息。在第二维和第三维上,利用 any() 函数对所有像素点进行逻辑或运算,最终将多边形分割信息转换成二值化的遮罩掩膜。
    将转换后的遮罩掩膜添加到 masks 列表中。
    测试 masks 是否为空列表,如果是,则创建一个全 0 的遮罩掩膜,大小为 (0, height, width)。
    最后,通过 torch.stack() 函数将列表中所有遮罩掩膜沿着新的第 0 维进行叠加,得到一个形状为 (N, height, width) 的张量,其中 N 是分割的数量。'''
    def convert_coco_poly_to_mask(segmentations, height, width):
        masks = []
        for polygons in segmentations:
            rles = coco_mask.frPyObjects(polygons, height, width)
            mask = coco_mask.decode(rles)
            if len(mask.shape) < 3:
                mask = mask[..., None]
            mask = torch.as_tensor(mask, dtype=torch.uint8)
            mask = mask.any(dim=2)
            masks.append(mask)
        if masks:
            masks = torch.stack(masks, dim=0)
        else:
            masks = torch.zeros((0, height, width), dtype=torch.uint8)
        return masks
    
    '''__init__ 方法用于初始化 ConvertCocoPolysToMask 类的实例。它接受一个参数 return_masks,默认为 False。该参数用于指定是否返回遮罩掩膜。
    
    __call__ 方法是类的可调用方法,在实例被调用时会执行。它接受两个参数 image 和 target,分别表示输入的图像和目标。
    
    首先,获取图像的宽度和高度,并将其保存为变量 w 和 h。
    获取目标的图像ID,并将其转换为PyTorch张量类型。
    从目标中提取注释信息 anno。
    过滤掉包含 iscrowd 属性的注释对象或 iscrowd 值为0的注释对象。
    提取注释对象的边界框,并将其转换为PyTorch张量类型。然后对边界框进行归一化处理(从绝对坐标转换为相对坐标)。同时,将边界框的坐标限制在图像边界内。
    提取注释对象的类别标签,并将其转换为PyTorch张量类型。
    如果设置了 return_masks 为 True,则提取注释对象的多边形分割信息,并调用 convert_coco_poly_to_mask 函数将分割信息转换为遮罩掩膜。
    检查是否存在关键点信息,并将其提取为PyTorch张量类型。
    过滤掉无效的边界框,即宽度和高度小于等于0的边界框。
    将过滤后的边界框、类别标签、遮罩掩膜(如果设置了 return_masks)、关键点信息存储到 target 字典中。
    提取注释对象的区域面积和 iscrowd 属性,并将其存储到 target 字典中。
    存储原始图像的尺寸信息和当前图像的尺寸信息到 target 字典中。
    返回图像和 target 字典作为输出。'''
    class ConvertCocoPolysToMask(object):
        def __init__(self, return_masks=False):
            self.return_masks = return_masks
    
        def __call__(self, image, target):
            w, h = image.size
    
            image_id = target["image_id"]
            image_id = torch.tensor([image_id])
    
            anno = target["annotations"]
    
            anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
    
            boxes = [obj["bbox"] for obj in anno]
            # guard against no boxes via resizing
            boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
            boxes[:, 2:] += boxes[:, :2]
            boxes[:, 0::2].clamp_(min=0, max=w)
            boxes[:, 1::2].clamp_(min=0, max=h)
    
            classes = [obj["category_id"] for obj in anno]
            classes = torch.tensor(classes, dtype=torch.int64)
    
            if self.return_masks:
                segmentations = [obj["segmentation"] for obj in anno]
                masks = convert_coco_poly_to_mask(segmentations, h, w)
    
            keypoints = None
            if anno and "keypoints" in anno[0]:
                keypoints = [obj["keypoints"] for obj in anno]
                keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
                num_keypoints = keypoints.shape[0]
                if num_keypoints:
                    keypoints = keypoints.view(num_keypoints, -1, 3)
    
            keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
            boxes = boxes[keep]
            classes = classes[keep]
            if self.return_masks:
                masks = masks[keep]
            if keypoints is not None:
                keypoints = keypoints[keep]
    
            target = {}
            target["boxes"] = boxes
            target["labels"] = classes
            if self.return_masks:
                target["masks"] = masks
            target["image_id"] = image_id
            if keypoints is not None:
                target["keypoints"] = keypoints
    
            # for conversion to coco api
            area = torch.tensor([obj["area"] for obj in anno])
            iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
            target["area"] = area[keep]
            target["iscrowd"] = iscrowd[keep]
    
            target["orig_size"] = torch.as_tensor([int(h), int(w)])
            target["size"] = torch.as_tensor([int(h), int(w)])
    
            return image, target
    
    '''首先创建了一个 normalize 的转换操作,它将图像转换为张量并进行标准化。具体来说,它使用了均值 [0.485, 0.456, 0.406] 和标准差 [0.229, 0.224, 0.225] 进行标准化。
    定义了一系列尺度 scales,用于对图像进行随机调整大小。
    如果 image_set 的值为 'train',则返回一系列的转换操作:
    随机水平翻转图像。
    随机选择以下两种转换操作之一:
    将图像随机调整为 scales 中的某个尺度,并保证最大边长不超过1333像素。
    先随机调整图像的短边长度为 [400, 500, 600] 中的某个值,然后随机裁剪出大小为 [384, 600] 的图像,并将其随机调整为 scales 中的某个尺度,并保证最大边长不超过1333像素。
    对图像进行归一化操作。
    如果 image_set 的值为 'val',则返回一系列的转换操作:
    将图像随机调整为 [800] 中的某个尺度,并保证最大边长不超过1333像素。
    对图像进行归一化操作。
    如果 image_set 的值不是 'train' 或 'val',则抛出一个异常,表示未知的 image_set 值。'''
    def make_coco_transforms(image_set):
    
        normalize = T.Compose([
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
        scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
    
        if image_set == 'train':
            return T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomSelect(
                    T.RandomResize(scales, max_size=1333),
                    T.Compose([
                        T.RandomResize([400, 500, 600]),
                        T.RandomSizeCrop(384, 600),
                        T.RandomResize(scales, max_size=1333),
                    ])
                ),
                normalize,
            ])
    
        if image_set == 'val':
            return T.Compose([
                T.RandomResize([800], max_size=1333),
                normalize,
            ])
    
        raise ValueError(f'unknown {image_set}')
    
    '''首先,根据传入的 args.coco_path 构建 COCO 数据集的根路径 root。
    然后,对 root 进行存在性检查,如果该路径不存在,则抛出异常。
    定义了一个变量 mode,其值为 'instances',表示数据集的模式。
    定义了一个字典 PATHS,其中包含了不同 image_set 对应的图像文件夹路径和注释文件路径。
    根据传入的 image_set 从 PATHS 字典中获取对应的图像文件夹路径和注释文件路径,并分别赋值给 img_folder 和 ann_file 变量。
    调用 CocoDetection 类创建一个 COCO 数据集对象 dataset。CocoDetection 是一个用于处理 COCO 数据集的类,它接收图像文件夹路径、注释文件路径、转换操作和返回掩码选项作为参数。
    最后,返回创建的 COCO 数据集对象 dataset。'''
    def build(image_set, args):
        root = Path(args.coco_path)
        assert root.exists(), f'provided COCO path {root} does not exist'
        mode = 'instances'
        PATHS = {
            "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
            "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
        }
    
        img_folder, ann_file = PATHS[image_set]
        dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks)
        return dataset
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
  • 相关阅读:
    kotlin完成 Code War 题目 解析分子公式
    JAVA JDBC 概述
    Flink 实时数仓(二)【ODS 层开发】
    辉芒微IO单片机FT60F010A-URT
    基于深度学习的小学语文“输出驱动”教学研究课题方案
    开咖啡店需要注意什么?知名咖啡店总结五点
    <图像处理> Harris角点检测
    通过阅读源码解决项目难题:GToken替换JWT实现SSO单点登录
    [Java安全]—Mybatis注入
    vite3 + vue3 异步加载路由后挂载 APP 实例,生产环境下页面空白问题解决
  • 原文地址:https://blog.csdn.net/weixin_43722052/article/details/132948276