• TorchVision Transforms API 大升级,支持目标检测、实例/语义分割及视频类任务


    内容导读:TorchVision Transforms API 扩展升级,现已支持目标检测、实例及语义分割以及视频类任务。新 API 尚处于测试阶段,开发者可以试用体验。

    本文首发自微信公众号:PyTorch 开发者社区

    在这里插入图片描述

    TorchVision 现已针对 Transforms API 进行了扩展, 具体如下:

    • 除用于图像分类外,现在还可以用其进行目标检测、实例及语义分割以及视频分类等任务;

    • 支持从 TorchVision 直接导入 SoTA 数据增强,如 MixUp、 CutMix、Large Scale Jitter 以及 SimpleCopyPaste。

    • 支持使用全新的 functional transforms 转换视频、Bounding box 以及分割掩码 (Segmentation Mask)。

    Transforms 当前的局限性

    稳定版 TorchVision Transforms API,也也就是我们常说的 Transforms V1,只支持单个图像,因此,只适用于分类任务:

    from torchvision import transforms
    trans = transforms.Compose([
       transforms.ColorJitter(contrast=0.5),
       transforms.RandomRotation(30),
       transforms.CenterCrop(480),
    ])
    imgs = trans(imgs)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    上述方法不支持需要使用 Label 的目标检测、分割或分类 Transforms, 如 MixUp 及 cutMix。这使分类以外的计算机视觉任务都不能用 Transforms API 执行必要的扩展。同时,这也加大了用 TorchVision 原语训练高精度模型的难度。

    为了克服这个局限性,TorchVision 在其 reference script 中提供了自定义实现, 用于演示所有任务中的增强是如何执行的。

    尽管这种做法使得开发者能够训练出高精度的分类、目标检测及分割模型,但做法比较粗糙,TorchVision 二进制文件中还是不能导入 Transforms。

    全新的 Transforms API

    Transforms V2 API 支持视频、bounding box、label 以及分割掩码, 这意味着它为许多计算机视觉任务提供了本地支持。新的解决方案是一种更为直接的替代方案:

    from torchvision.prototype import transforms
    # Exactly the same interface as V1:
    trans = transforms.Compose([
        transforms.ColorJitter(contrast=0.5),
        transforms.RandomRotation(30),
        transforms.CenterCrop(480),
    ])
    imgs, bboxes, labels = trans(imgs, bboxes, labels)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    全新的 Transform Class 无需强制执行特定的顺序或结构,就可以接收任意数量的输入:

    # Already supported:
    trans(imgs)  # Image Classification
    trans(videos)  # Video Tasks
    trans(imgs_or_videos, labels)  # MixUp/CutMix-style Transforms
    trans(imgs, bboxes, labels)  # Object Detection
    trans(imgs, bboxes, masks, labels)  # Instance Segmentation
    trans(imgs, masks)  # Semantic Segmentation
    trans({"image": imgs, "box": bboxes, "tag": labels})  # Arbitrary Structure
    # Future support:
    trans(imgs, bboxes, labels, keypoints)  # Keypoint Detection
    trans(stereo_images, disparities, masks)  # Depth Perception
    trans(image1, image2, optical_flows, masks)  # Optical Flow
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    functional API 已经更新,支持所有输入必要的 signal processing kernel,如 resizing, cropping, affine transforms, padding 等:

    from torchvision.prototype.transforms import functional as F
    # High-level dispatcher, accepts any supported input type, fully BC
    F.resize(inpt, resize=[224, 224])
    # Image tensor kernel
    F.resize_image_tensor(img_tensor, resize=[224, 224], antialias=True)
    # PIL image kernel
    F.resize_image_pil(img_pil, resize=[224, 224], interpolation=BILINEAR)
    # Video kernel
    F.resize_video(video, resize=[224, 224], antialias=True)
    # Mask kernel
    F.resize_mask(mask, resize=[224, 224])
    # Bounding box kernel
    F.resize_bounding_box(bbox, resize=[224, 224], spatial_size=[256, 256])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    API 使用 Tensor subclassing 来包装输入,附加有用的元数据,并 dispatch 到正确的内核。 利用 TorchData Data Pipe 的 Datasets V2 相关工作完成后,就不再需要手动包装输入了。目前,用户可以通过以下方式手动包装输入:

    from torchvision.prototype import features
    imgs = features.Image(images, color_space=ColorSpace.RGB)
    vids = features.Video(videos, color_space=ColorSpace.RGB)
    masks = features.Mask(target["masks"])
    bboxes = features.BoundingBox(target["boxes"], format=BoundingBoxFormat.XYXY, spatial_size=imgs.spatial_size)
    labels = features.Label(target["labels"], categories=["dog", "cat"])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    除新 API 之外,PyTorch 官方还为 SoTA 研究中用到的一些数据增强提供了重要实现,如 MixUp、 CutMix、Large Scale Jitter、 SimpleCopyPaste、AutoAugmentation 方法以及一些新的 Geometric、Colour 和 Type Conversion transforms。

    该 API 继续支持 single image 或 batched input image 的 PIL 和 Tensor 后端,并在 functional API 上保留了 JIT-scriptability。这使得图像映射得以从 uint8 延迟到 float, 带来了性能的进一步提升。

    它目前可以在 TorchVision 的原型区域 (prototype area) 中使用,并且支持从 nightly build 版本中导入。经验证,新 API 与先前实现的准确性一致。

    当前的局限性

    functional API (kernel) 仍然保持 JIT-scriptable 及 fully-BC,Transform Class 提供了相同的接口,却无法使用脚本。

    这是因为 Transform Class 使用的是张量子类 (Tensor Subclassing),且接收任意数量的输入,这是 JIT 所不支持的。该局限将在后续版本中不断优化。

    一个端到端示

    以下是一个新 API 示例,它可以同时使用 PIL 图像和张量。

    测试图片:

    在这里插入图片描述
    代码示例:

    import PIL
    from torchvision import io, utils
    from torchvision.prototype import features, transforms as T
    from torchvision.prototype.transforms import functional as F
    # Defining and wrapping input to appropriate Tensor Subclasses
    path = "COCO_val2014_000000418825.jpg"
    img = features.Image(io.read_image(path), color_space=features.ColorSpace.RGB)
    # img = PIL.Image.open(path)
    bboxes = features.BoundingBox(
        [[2, 0, 206, 253], [396, 92, 479, 241], [328, 253, 417, 332],
         [148, 68, 256, 182], [93, 158, 170, 260], [432, 0, 438, 26],
         [422, 0, 480, 25], [419, 39, 424, 52], [448, 37, 456, 62],
         [435, 43, 437, 50], [461, 36, 469, 63], [461, 75, 469, 94],
         [469, 36, 480, 64], [440, 37, 446, 56], [398, 233, 480, 304],
         [452, 39, 463, 63], [424, 38, 429, 50]],
        format=features.BoundingBoxFormat.XYXY,
        spatial_size=F.get_spatial_size(img),
    )
    labels = features.Label([59, 58, 50, 64, 76, 74, 74, 74, 74, 74, 74, 74, 74, 74, 50, 74, 74])
    # Defining and applying Transforms V2
    trans = T.Compose(
        [
            T.ColorJitter(contrast=0.5),
            T.RandomRotation(30),
            T.CenterCrop(480),
        ]
    )
    img, bboxes, labels = trans(img, bboxes, labels)
    # Visualizing results
    viz = utils.draw_bounding_boxes(F.to_image_tensor(img), boxes=bboxes)
    F.to_pil_image(viz).show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    —— 完 ——

  • 相关阅读:
    深信服AC跨三层取mac,绑定ip/mac
    解决问题:Unable to connect to Redis
    自制Linux功能板-新增功能(基于RTMP流媒体传输协议的视频监控)
    Spring中用了哪些设计模式
    vue调用post方法并且后端代码需要接收ids
    git clone拉取代码错误解决方法
    关于 Rancher 与防火墙 firewalld 的一些注意事项
    R语言—基本统计分析
    VMware 与 SmartX 超融合 I/O 路径对比与性能影响解析
    PAL/NTSC/1080I和interlaced scan(隔行扫描)
  • 原文地址:https://blog.csdn.net/HyperAI/article/details/127766162