• Python图像处理【22】基于卷积神经网络的图像去雾


    0. 前言

    单图像去雾 (dehazing) 是一个具有挑战性的图像恢复问题。为了解决这个问题,大多数算法都采用经典的大气散射模型,该模型是一种基于单一散射和均匀大气介质假设的简化物理模型,但现实环境中的雾霾表述更加复杂。

    1. 渐进特征融合网络

    在本节中,我们将学习如何使用输入自适应端到端深度学习预训练去雾模型,即渐进特征融合网络 (Progressive Feature Fusion Network, PFFNet),并通过使用 Pytorch 来执行模糊图像的去雾操作。渐进特征融合所采用的 U-Net 架构编码器 - 解码器网络,可直接学习从模糊图像到清晰图像的高度非线性转换函数。深度神经网架构如下图所示:

    PFFNet
    从以上体系结构图可以看出:

    • 编码器由五个卷积层组成,每个卷积层之后都有非线性 ReLU 激活函数;第一层用于从原始模糊图像中相对较大的局部感受野上的提取特征,然后,依次执行四次下采样卷积操作,以获取图像金字塔
    • 特征转换模块由基于残差的模块组成,深层网络可以表示非常复杂的特征,也可以学习到许多不同尺度的特征,但同时,在使用反向传播进行训练时,经常会遇到消失的梯度问题,而残差网络就是为了解决这一问题而被提出的,可以用于训练更深的网络
    • 解码器由四个反卷积层和一个卷积层组成,与编码器相反,解码器的反卷积层顺序堆叠以恢复图像结构细节

    2. 图像去雾

    2.1 网络构建

    (1) 首先下载预训练网络模型,并导入所需的库,模块和函数:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from PIL import Image
    from torch.autograd import Variable
    from torchvision.transforms import ToTensor, ToPILImage, Normalize, Resize
    #from torchviz import make_dot
    import matplotlib.pylab as plt 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    (2) 定义与深神经网络中不同层相对应的 ConvLayerUpsampleConvLayer 类,所有网络层都继承自 Pytorchnn.module 类;每个层都需要实现自己的 init() (用于初始化参数/成员变量/层)和 forward() 方法(定义前向传播过程中的计算):

    class ConvLayer(nn.Module):
        def __init__(self, in_channels, out_channels, kernel_size, stride):
            super(ConvLayer, self).__init__()
            reflection_padding = kernel_size // 2
            self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
            self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
        def forward(self, x):
            out = self.reflection_pad(x)
            out = self.conv2d(out)
            return out
    
    class UpsampleConvLayer(torch.nn.Module):
        def __init__(self, in_channels, out_channels, kernel_size, stride):
          super(UpsampleConvLayer, self).__init__()
          reflection_padding = kernel_size // 2
          self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
          self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride)
    
        def forward(self, x):
            out = self.reflection_pad(x)
            out = self.conv2d(out)
            return out
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    (3) 接下来,我们用两个 ConvLayer 类实例定义类 ResidualBlock,在 ConvLayer 类实例之间使用 PReLU 激活函数,该类同样继承自 nn.module,并定义 forward() 方法用于前向传播:

    class ResidualBlock(nn.Module):
        def __init__(self, channels):
            super(ResidualBlock, self).__init__()
            self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
            self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
            self.relu = nn.PReLU()
    
        def forward(self, x):
            residual = x
            out = self.relu(self.conv1(x))
            out = self.conv2(out) * 0.1
            out = torch.add(out, residual)
            return out 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    (4) 定义继承自 nn.conv2d 类的 MeanShift 类,通过将 requires_grad 的参数设置为 False,冻结 MeanShift 层:

    class MeanShift(nn.Conv2d):
        def __init__(self, rgb_range, rgb_mean, sign):
            super(MeanShift, self).__init__(3, 3, kernel_size=1)
            self.weight.data = torch.eye(3).view(3, 3, 1, 1)
            self.bias.data = float(sign) * torch.Tensor(rgb_mean) * rgb_range
    
            # Freeze the MeanShift layer
            for params in self.parameters():
                params.requires_grad = False
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    (5) 最后,根据所定义的神经网络层定义深度神经网络类 Net,该类同样需要定义 init() 方法。网络使用了五个 ConvLayer,然后使用四个 UPSampleconvLayer,最后通过 ConvLayer 层后输出,网络使用 LeakyReLU 作为激活函数。
    同样,需要定义向前传播方法 forward(),并在每个激活函数后使用双线性上采样:

    class Net(nn.Module):
        def __init__(self, res_blocks=18):
            super(Net, self).__init__()
    
            rgb_mean = (0.5204, 0.5167, 0.5129)
            self.sub_mean = MeanShift(1., rgb_mean, -1)
            self.add_mean = MeanShift(1., rgb_mean, 1)
    
            self.conv_input = ConvLayer(3, 16, kernel_size=11, stride=1)
            self.conv2x = ConvLayer(16, 32, kernel_size=3, stride=2)
            self.conv4x = ConvLayer(32, 64, kernel_size=3, stride=2)
            self.conv8x = ConvLayer(64, 128, kernel_size=3, stride=2)
            self.conv16x = ConvLayer(128, 256, kernel_size=3, stride=2)
    
            self.dehaze = nn.Sequential()
            for i in range(1, res_blocks):
                self.dehaze.add_module('res%d' % i, ResidualBlock(256))
    
            self.convd16x = UpsampleConvLayer(256, 128, kernel_size=3, stride=2)
            self.convd8x = UpsampleConvLayer(128, 64, kernel_size=3, stride=2)
            self.convd4x = UpsampleConvLayer(64, 32, kernel_size=3, stride=2)
            self.convd2x = UpsampleConvLayer(32, 16, kernel_size=3, stride=2)
    
            self.conv_output = ConvLayer(16, 3, kernel_size=3, stride=1)
    ()
            self.relu = nn.LeakyReLU(0.2)
    
        def forward(self, x):
            x = self.relu(self.conv_input(x))
            res2x = self.relu(self.conv2x(x))
            res4x = self.relu(self.conv4x(res2x))
    
            res8x = self.relu(self.conv8x(res4x))
            res16x = self.relu(self.conv16x(res8x))
    
            res_dehaze = res16x
            res16x = self.dehaze(res16x)
            res16x = torch.add(res_dehaze, res16x)
    
            res16x = self.relu(self.convd16x(res16x))
            res16x = F.upsample(res16x, res8x.size()[2:], mode='bilinear')
            res8x = torch.add(res16x, res8x)
    
            res8x = self.relu(self.convd8x(res8x))
            res8x = F.upsample(res8x, res4x.size()[2:], mode='bilinear')
            res4x = torch.add(res8x, res4x)
    
            res4x = self.relu(self.convd4x(res4x))
            res4x = F.upsample(res4x, res2x.size()[2:], mode='bilinear')
            res2x = torch.add(res4x, res2x)
    
            res2x = self.relu(self.convd2x(res2x))
            res2x = F.upsample(res2x, x.size()[2:], mode='bilinear')
            x = torch.add(res2x, x)
    
            x = self.conv_output(x)
    
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58

    (6) 定义预训练模型参数位置以及模型使用的残差块数量:

    rb = 13
    checkpoint = "I-HAZE_O-HAZE.pth"
    
    • 1
    • 2

    (7) 实例化 Net() 类并使用 load_state_dict() 方法从检查点加载预训练权重。由于我们不需要训练模型,因此使用测试模式:

    net = Net(rb)
    net.load_state_dict(torch.load(checkpoint)['state_dict'])
    net.eval()
    
    • 1
    • 2
    • 3

    2.2 模型测试

    (1) 接下来,使用 open() 函数读取输入图像:

    im_path = "pic.png"
    im = Image.open(im_path)
    h, w = im.size
    print(h, w)
    
    • 1
    • 2
    • 3
    • 4

    (2) 使用 torchvision.transforms 模块中的 ToTensor() 将图像转换为张量对象以输入网络,然后使用输入图像在模型上运行正向传递过程计算输出,最后将输出转换为图像:

    imt = ToTensor()(im)
    imt = Variable(imt).view(1, -1, w, h)
    #im = im.cuda()
    with torch.no_grad():
        imt = net(imt)
    out = torch.clamp(imt, 0., 1.)
    out = out.cpu()
    out = out.data[0]
    out = ToPILImage()(out)
    
    def plot_image(image, title=None, sz=10):
        plt.imshow(image)
        plt.title(title, size=sz)
        plt.axis('off')
    plt.figure(figsize=(20,10))
    plt.subplot(121), plot_image(im, 'hazed input')
    plt.subplot(122), plot_image(out, 'de-hazed output')
    plt.tight_layout()
    plt.show() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    去雾结果

    小结

    图像去雾已成为计算机视觉的重要研究方向,在雾、霾等恶劣天气下拍摄的的图像通常由于大气散射的作用,图像质量严重下降使颜色偏灰白色,对比度降低,物体特征难以辨认,还会影响图像的分析与处理。因此,需要使用图像去雾技术来增强或修复图像,以改善视觉效果并便于图像的后续处理。在本节中,我们学习了一种基于卷积神经网络的图像去雾模型,通过使用训练后的模型可以显著改善图像视觉效果。

    系列链接

    Python图像处理【1】图像与视频处理基础
    Python图像处理【2】探索Python图像处理库
    Python图像处理【3】Python图像处理库应用
    Python图像处理【4】图像线性变换
    Python图像处理【5】图像扭曲/逆扭曲
    Python图像处理【6】通过哈希查找重复和类似的图像
    Python图像处理【7】采样、卷积与离散傅里叶变换
    Python图像处理【8】使用低通滤波器模糊图像
    Python图像处理【9】使用高通滤波器执行边缘检测
    Python图像处理【10】基于离散余弦变换的图像压缩
    Python图像处理【11】利用反卷积执行图像去模糊
    Python图像处理【12】基于小波变换执行图像去噪
    Python图像处理【13】使用PIL执行图像降噪
    Python图像处理【14】基于非线性滤波器的图像去噪
    Python图像处理【15】基于非锐化掩码锐化图像
    Python图像处理【16】OpenCV直方图均衡化
    Python图像处理【17】指纹增强和细节提取
    Python图像处理【18】边缘检测详解
    Python图像处理【19】基于霍夫变换的目标检测
    Python图像处理【20】图像金字塔
    Python图像处理【21】基于卷积神经网络增强微光图像

  • 相关阅读:
    Java8 新特性之Stream(五)-- Stream的3种创建方法
    【Python安全攻防】【网络安全】一、常见被动信息搜集手段
    8051(c51)单片机从汇编到C语言,从Boot到应用实践教程
    ValidatingWebhookConfiguration 设计说明
    RBTree(红黑树)模拟实现(插入)
    VSCode Linux的C++代码格式化配置
    图解算法,原理逐步揭开「GitHub 热点速览」
    Java定时器
    Linux ARM平台开发系列讲解(调试篇) 1.3.2 RK3399移植Ubuntu文件系统步骤
    git代码管理工具保姆级教学
  • 原文地址:https://blog.csdn.net/qq_30167691/article/details/136613668