• 数据增强系列(6)使用Albumentations进行关键点增强


    在本手册中,我们将展示如何将Albumentations应用于关键点增强问题。您可以对具有关键点的图像使用任何像素级增强,因为像素级增强不会影响关键点。

    注意:默认情况下,与关键点一起工作的扩展不会在转换后改变关键点的标签。如果关键点的标签是特异性的,这可能会造成问题。例如,如果您有一个名为left arm的关键点,并应用一个HorizontalFlip增强,您将得到一个具有相同左臂标签的关键点,但它现在看起来像一个右臂关键点。

    如果您使用这种类型的关键点,考虑使用来自albumentations-experimentalSymmetricKeypoints扩展—正是为了处理这种情况而创建的实验性的扩展。pip install -U albumentations_experimental from albumentations_experimental import FlipSymmetricKeypoints

    1.导入相关包

    import random
    import cv2
    from matplotlib import pyplot as plt
    import albumentations as A
    
    KEYPOINT_COLOR = (0, 255, 0)  # Green
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    2.定义一个在图像上可视化关键点的函数

    def vis_keypoints(image, keypoints, color=KEYPOINT_COLOR, diameter=15):
        image = image.copy()
    
        for (x, y) in keypoints:
            cv2.circle(image, (int(x), int(y)), diameter, (0, 255, 0), -1)
    
        plt.figure(figsize=(8, 8))
        plt.axis('off')
        plt.imshow(image)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    3.获得一个图像和它的注释

    我们将对关键点的坐标使用xy格式。每个关键点用两个坐标定义,x是x轴上的位置,y是y轴上的位置。

    image = cv2.imread('keypoints_image.jpg')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    keypoints = [
        (100, 100),
        (720, 410),
        (1100, 400),
        (1700, 30),
        (300, 650),
        (1570, 590),
        (560, 800),
        (1300, 750),
        (900, 1000),
        (910, 780),
        (670, 670),
        (830, 670),
        (1000, 670),
        (1150, 670),
        (820, 900),
        (1000, 900),
    ]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    4.用关键点可视化原始图像

    vis_keypoints(image, keypoints)
    
    • 1

    在这里插入图片描述

    5.定义一个简单的数据增强管道

    transform = A.Compose(
        [A.HorizontalFlip(p=1)],
        keypoint_params=A.KeypointParams(format='xy')
    )
    transformed = transform(image=image, keypoints=keypoints)
    vis_keypoints(transformed['image'], transformed['keypoints'])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    在这里插入图片描述

    6.下面是一些数据增强管道的例子

    transform = A.Compose(
        [A.VerticalFlip(p=1)],
        keypoint_params=A.KeypointParams(format='xy')
    )
    transformed = transform(image=image, keypoints=keypoints)
    vis_keypoints(transformed['image'], transformed['keypoints'])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    在这里插入图片描述

    # 为了可视化的目的,我们固定了随机种子,因此增强将总是产生相同的结果。在真实的计算机视觉管道中,
    # 您不应该在对图像应用变换之前固定随机种子,因为在这种情况下,管道总是输出相同的图像。图像增强的目的是每次使用不同的变换。
    random.seed(7)
    transform = A.Compose(
        [A.RandomCrop(width=768, height=768, p=1)],
        keypoint_params=A.KeypointParams(format='xy')
    )
    transformed = transform(image=image, keypoints=keypoints)
    vis_keypoints(transformed['image'], transformed['keypoints'])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    在这里插入图片描述

    random.seed(7)
    transform = A.Compose(
        [A.Rotate(p=0.5)],
        keypoint_params=A.KeypointParams(format='xy')
    )
    transformed = transform(image=image, keypoints=keypoints)
    vis_keypoints(transformed['image'], transformed['keypoints'])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    在这里插入图片描述

    transform = A.Compose(
        [A.CenterCrop(height=512, width=512, p=1)],
        keypoint_params=A.KeypointParams(format='xy')
    )
    transformed = transform(image=image, keypoints=keypoints)
    vis_keypoints(transformed['image'], transformed['keypoints'])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    在这里插入图片描述

    random.seed(7)
    transform = A.Compose(
        [A.ShiftScaleRotate(p=0.5)],
        keypoint_params=A.KeypointParams(format='xy')
    )
    transformed = transform(image=image, keypoints=keypoints)
    vis_keypoints(transformed['image'], transformed['keypoints'])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    在这里插入图片描述

    7.一个复杂的增强管道的例子

    random.seed(7)
    transform = A.Compose([
        A.RandomSizedCrop(min_max_height=(256, 1025), height=512, width=512, p=0.5),
        A.HorizontalFlip(p=0.5),
        A.OneOf([
            A.HueSaturationValue(p=0.5),
            A.RGBShift(p=0.7)
        ], p=1),
        A.RandomBrightnessContrast(p=0.5)
    ],
        keypoint_params=A.KeypointParams(format='xy'),
    )
    transformed = transform(image=image, keypoints=keypoints)
    vis_keypoints(transformed['image'], transformed['keypoints'])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    在这里插入图片描述

    8.BONUS:Keras的数据增强

    import numpy as np
    import imageio
    import os
    import matplotlib.pyplot as plt
    import pandas as pd
    import albumentations as A
    import cv2
    import json
    from tensorflow.python.keras.utils.data_utils import Sequence
    
    def extract_coordinates(df):
        full_coordinates = df['region_shape_attributes']
        ls_coordinates = []
        for coordinates in full_coordinates:
            coordinates = json.loads(coordinates)
            ls_coordinates.append([coordinates['cx'], coordinates['cy']])
        return np.array(ls_coordinates, dtype=np.float32)
    
    def rescale_image(image):
        return (image / np.max(image) * 255.).astype(np.float32)
    
    class CustomSeq(Sequence):
        def __init__(self, path2imgs, df, batch_size, augmentations=None, mode='train'):
            self.path2imgs = path2imgs
            self.df = df
            self.img_list = self.df['filename']
            self.y = extract_coordinates(self.df)
            self.batch_size = batch_size
            self.augmentations = augmentations
            self.mode = mode.lower()
            
        def __len__(self):
            return int(np.ceil(len(self.df) / float(self.batch_size)))
        
        def on_epoch_end(self):
            self.indexes = range(len(self.img_list))
            if self.mode == 'train':
                self.indexes = random.sample(self.indexes, k=len(self.indexes))
        
        def get_batch_labels(self, idx, shapes):
            y_batch = self.y[idx * self.batch_size: (idx+1) * self.batch_size]
            return y_batch
        
        def get_batch_images(self, idx):
            x_batch = []
            shapes = []
            img_names = self.img_list[idx * self.batch_size: (idx+1) * self.batch_size]
            for img_name in img_names:
                image = imageio.imread(os.path.join(self.path2imgs, img_name))
                image = rescale_image(image)
                x_batch.append(image)
                shapes.append(image.shape)
            return x_batch, np.array(shapes)
        
        def __getitem__(self, idx):
            x_batch, shapes = self.get_batch_images(idx)
            y_batch = self.get_batch_labels(idx, shapes)
            if self.augmentations:
                # walk around images and keypoints
                for i, (x_item, y_item) in enumerate(zip(x_batch, y_batch)):
                    transformed = self.augmentations(image=x_item, keypoints=np.expand_dims(y_item, axis=0))
                    # Rewrite image and keypoints values in not augmented batch
                    x_batch[i], y_batch[i] = transformed['image'], transformed['keypoints'][0]
            return x_batch, y_batch
    
    
    transform = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        #A.InvertImg(p=0.5),
        #A.ShiftScaleRotate(shift_limit=0, scale_limit=0.1, rotate_limit=15, p=0.5),
        A.ToFloat(max_value=255.)
    ], keypoint_params=A.KeypointParams(format='xy'))
    
    path2png_imgs = os.getcwd()
    df = pd.read_csv('vgg_annotate_crop.csv', header=0)
    
    data = CustomSeq(path2png_imgs, df, 1, augmentations=transform)
    
    images, point = data.__getitem__(0)
    
    images[0] = cv2.circle(images[0], list(map(tuple, point.astype(np.int).tolist()))[0], 30, (1), -1)
    plt.figure(figsize=(10, 10))
    plt.imshow(images[0], cmap='gray')
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85

    参考目录

    https://github.com/albumentations-team/albumentations_examples/blob/master/notebooks/example_keypoints.ipynb

  • 相关阅读:
    Apache Kudu 1.15.0部署
    黄金投资新手指南:投资现货黄金怎样开户?
    模糊查询like用法实例(Bee)
    如何在 Windows 10/8.1/8/7 上无密码删除 Deep Freeze
    Qt右键菜单
    去本来该去的地方
    以订单退款流程为例,聊聊如何优化策略模式
    PLGA10K-PEG2K-GA/疏水性嵌段聚丙交酯PLGA10K-乙交酯PEG2K-聚乙二醇GA
    SpringBoot + xxl-job 多数据源异构数据增量同步
    我把一个json格式的数据读到dataframe里面了 怎么解析出自己需要的字段呢?
  • 原文地址:https://blog.csdn.net/weixin_43229348/article/details/121496274