• pytroch 颜色增强ColorJitter,墙裂推荐


    目录

    函数参数解释:

    随机亮度测试,非常方便,墙裂推荐:

    单项测试:

    举例:

    yolov5颜色增强示例,效果差不多,opencv的:


    函数参数解释:

    函数名:
    torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
    函数解析:
    随机改变一个图像的亮度、对比度、饱和度和色调。如果图像是 tensor,那么它的 shape 为[…,1或3,H,W],其中…表示 batch。如果图像是PIL图像,那么不支持模式 “1”、“I”、"F "和带有透明度(alpha通道)的模式。

    参数:
    brightness (类型为 float 或 tuple: float (min, max)) - 亮度的偏移程度。 brightness_factor可以是 [max(0, 1 - brightness), 1 + brightness],也可以直接给出最大、最小值的范围 [min, max],然后从中随机采样。brightness_factor 值应该是非负数。

    contrast (类型为 float 或 tuple: float (min, max)) - 对比度的偏移程度。 contrast_factor 可以是 [max(0, 1 - contrast), 1 + contrast],也可以直接给出最大、最小值的范围 [min, max],然后从中随机采样。contrast_factor 值应该是非负数。

    saturation (类型为 float 或 tuple: float (min, max)) - 饱和度的偏移程度。 saturation_factor 可以是 [max(0, 1 - saturation), 1 + saturation],也可以直接给出最大、最小值的范围 [min, max],然后从中随机采样。saturation_factor 值应该是非负数。

    hue (类型为 float 或 tuple: float (min, max)) - 色调的偏移程度。hue_factor 可以是 [-hue, hue],也可以直接给出最大、最小值的范围 [min, max],然后从中随机采样,它的值应当满足 0<= hue <= 0.5 或者 -0.5<= min <= max <= 0.5。为了使色调偏移,输入图像的像素值必须是非负值,以便转换到 HSV 颜色空间。因此,如果将图像归一化到一个有负值的区间,或者在使用这个函数之前使用会产生负值的插值方法,那么它就不会起作用。

    随机亮度测试,非常方便,墙裂推荐:

    参数是测试过的经验值

    1. import cv2
    2. import numpy as np
    3. import torch
    4. from PIL import Image
    5. from torchvision.transforms import ColorJitter
    6. import random
    7. class CustomColorJitter(ColorJitter):
    8. def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
    9. super(CustomColorJitter, self).__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
    10. def get_params(self, brightness, contrast, saturation, hue):
    11. self.last_brightness =None
    12. self.bright_param=None
    13. if brightness is not None:
    14. self.last_brightness = brightness[0] + random.uniform(0, 1) * (brightness[1] - brightness[0])
    15. self.bright_param=(self.last_brightness,self.last_brightness)
    16. self.contrast_param = None
    17. if contrast is not None:
    18. self.last_contrast = contrast[0] + random.uniform(0, 1) * (contrast[1] - contrast[0])
    19. self.contrast_param=(self.last_contrast,self.last_contrast)
    20. self.saturation_param = None
    21. if saturation is not None:
    22. self.last_saturation = saturation[0] + random.uniform(0, 1) * (saturation[1] - saturation[0])
    23. self.saturation_param=(self.last_saturation,self.last_saturation)
    24. self.hue_param=None
    25. if hue is not None:
    26. self.last_hue = hue[0] + random.uniform(0, 1) * (hue[1] - hue[0])
    27. self.hue_param=(self.last_hue,self.last_hue)
    28. return super().get_params(brightness=self.bright_param, contrast=self.contrast_param,
    29. saturation=self.saturation_param, hue=self.hue_param)
    30. img_path = "./aaa.png"
    31. debug=True
    32. if debug:
    33. transform = CustomColorJitter(brightness=[0.6, 1.3], contrast=[0.5, 1.5], saturation=[0.5, 1.5], hue=[-0.02, 0.02])
    34. # transform = CustomColorJitter(hue=[-0.02, 0.02])
    35. # transform = CustomColorJitter(saturation=[0.5, 1.5])
    36. # transform = CustomColorJitter( contrast=[0.5, 1.5])
    37. # transform = CustomColorJitter( brightness=[0.7, 1.3])
    38. else:
    39. transform = ColorJitter(brightness=[0.6, 1.3], contrast=[0.5, 1.5], saturation=[0.5, 1.5], hue=[-0.02, 0.02])
    40. while True:
    41. img = cv2.imread(img_path)
    42. pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    43. image = transform(pil_img)
    44. if debug:
    45. if transform.bright_param is not None:
    46. print("Last brightness value:", transform.bright_param)
    47. if transform.contrast_param is not None:
    48. print("Last contrast value:", transform.contrast_param)
    49. if transform.saturation_param is not None:
    50. print("Last saturation value:", transform.saturation_param)
    51. if transform.hue_param is not None:
    52. print("Last hue value:", transform.hue_param)
    53. img_cv = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
    54. cv2.imshow("img_o", img)
    55. cv2.imshow("img_cv", img_cv)
    56. cv2.waitKey(0)

    单项测试:

    1. import cv2
    2. import numpy as np
    3. import torch
    4. import torchvision.transforms as f
    5. from PIL import Image
    6. from torchvision.transforms import ColorJitter
    7. import random
    8. class CustomColorJitter(ColorJitter):
    9. def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
    10. super(CustomColorJitter, self).__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
    11. def get_params(self, brightness, contrast, saturation, hue):
    12. self.last_brightness = brightness[0] + random.uniform(0, 1) * (brightness[1] - brightness[0])
    13. return super().get_params(brightness=(self.last_brightness, self.last_brightness), contrast=contrast, saturation=saturation, hue=hue)
    14. img_path = "./aaa.png"
    15. transform = CustomColorJitter(brightness=[0.5, 1.5])
    16. while True:
    17. img = cv2.imread(img_path)
    18. pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    19. image = transform(pil_img)
    20. print("Last brightness value:", transform.last_brightness)
    21. img_cv = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
    22. cv2.imshow("img_o", img)
    23. cv2.imshow("img_cv", img_cv)
    24. cv2.waitKey(0)

    举例:

    以下内容转自:https://blog.csdn.net/lxhRichard/article/details/128083192
    1. 以随机亮度为例

    1. import torch
    2. import torchvision.transforms as f
    3. from PIL import Image
    4. img_path = "./1.jpg"
    5. img = Image.open(img_path)
    6. trans = f.ColorJitter(brightness=[0.01,0.05])
    7. image = trans(img)
    8. image.show()



    输出对比:


    2. 以随机对比度为例

    1. import torch
    2. import torchvision.transforms as f
    3. from PIL import Image
    4. img_path = "./1.jpg"
    5. img = Image.open(img_path)
    6. trans = f.ColorJitter(contrast=[0.3,0.6])
    7. image = trans(img)
    8. image.show()



    输出对比:


    3. 以随机饱和度为例

    1. import torch
    2. import torchvision.transforms as f
    3. from PIL import Image
    4. img_path = "./1.jpg"
    5. img = Image.open(img_path)
    6. trans = f.ColorJitter(saturation=[0.2,0.5])
    7. image = trans(img)
    8. image.show()



    输出对比:


    4. 以随机色调为例

    1. import torch
    2. import torchvision.transforms as f
    3. from PIL import Image
    4. img_path = "./1.jpg"
    5. img = Image.open(img_path)
    6. trans = f.ColorJitter(hue=[-0.1,0.2])
    7. image = trans(img)
    8. image.show()



    输出对比:


    5. 综合调整:

    1. import torch
    2. import torchvision.transforms as f
    3. from PIL import Image
    4. img_path = "./1.jpg"
    5. img = Image.open(img_path)
    6. trans = f.ColorJitter(brightness=0.6, contrast=0.7, saturation=0.5, hue=0.1)
    7. image = trans(img)
    8. image.show()


    输出对比:


    官方文档链接:https://pytorch.org/vision/stable/generated/torchvision.transforms.ColorJitter.html?highlight=transforms+colorjitter#torchvision.transforms.ColorJitter
     

    yolov5颜色增强示例,效果差不多,opencv的:

    1. import cv2
    2. import numpy as np
    3. def augment_hsv(img, h_gain=0.015, s_gain=0.7, v_gain=0.4):
    4. r = np.random.uniform(-1, 1, 3) * [h_gain, s_gain, v_gain] + 1 # random gains
    5. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
    6. print(r[0], r[1], r[2])
    7. dtype = img.dtype # uint8
    8. x = np.arange(0, 256, dtype=np.int16)
    9. lut_hue = ((x * r[0]) % 180).astype(dtype)
    10. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
    11. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
    12. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
    13. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
    14. if __name__ == '__main__':
    15. img_path = "./aaa.png"
    16. while True:
    17. img_o = cv2.imread(img_path)
    18. img=img_o.copy()
    19. augment_hsv(img)
    20. cv2.imshow("img_o", img_o)
    21. cv2.imshow('HSV Augmented Image', img)
    22. cv2.waitKey(0)

  • 相关阅读:
    深入解析智慧互联网医院系统源码:医院小程序开发的架构到实现
    软件设计模式系列之二十五——访问者模式
    PyTorch多GPU训练时同步梯度是mean还是sum?
    推动制造业数字化转型是发展数字经济的重要环节
    矩阵与线性变换
    06.webpack性能优化--构建速度
    2021年03月 Python(三级)真题解析#中国电子学会#全国青少年软件编程等级考试
    Maven打Jar包,启动报NoClassDefFoundError错误
    难得五年来第一次暑假没有出海,即使最终没有逃过8月份的CPT外业
    粒子群算法(PSO)优化长短期记忆神经网络的数据回归预测,PSO-LSTM回归预测,多输入单输出模型
  • 原文地址:https://blog.csdn.net/jacke121/article/details/132995313