• Pytorch Advanced(三) Neural Style Transfer


    神经风格迁移在之前的博客中已经用keras实现过了,比较复杂,keras版本

    这里用pytorch重新实现一次,原理图如下:


    1. from __future__ import division
    2. from torchvision import models
    3. from torchvision import transforms
    4. from PIL import Image
    5. import argparse
    6. import torch
    7. import torchvision
    8. import torch.nn as nn
    9. import numpy as np
    10. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    加载图像

    1. def load_image(image_path, transform=None, max_size=None, shape=None):
    2. """Load an image and convert it to a torch tensor."""
    3. image = Image.open(image_path)
    4. if max_size:
    5. scale = max_size / max(image.size)
    6. size = np.array(image.size) * scale
    7. image = image.resize(size.astype(int), Image.ANTIALIAS)
    8. if shape:
    9. image = image.resize(shape, Image.LANCZOS)
    10. if transform:
    11. image = transform(image).unsqueeze(0)
    12. return image.to(device)

    这里用的模型是 VGG-19,所要用的是网络中的5个卷积层

    1. class VGGNet(nn.Module):
    2. def __init__(self):
    3. """Select conv1_1 ~ conv5_1 activation maps."""
    4. super(VGGNet, self).__init__()
    5. self.select = ['0', '5', '10', '19', '28']
    6. self.vgg = models.vgg19(pretrained=True).features
    7. def forward(self, x):
    8. """Extract multiple convolutional feature maps."""
    9. features = []
    10. for name, layer in self.vgg._modules.items():
    11. x = layer(x)
    12. if name in self.select:
    13. features.append(x)
    14. return features

     模型结构如下,可以看到使用序列模型来写的VGG-NET,所以标号即层号,我们要保存的是['0', '5', '10', '19', '28'] 的输出结果。

    1. VGG(
    2. (features): Sequential(
    3. (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    4. (1): ReLU(inplace)
    5. (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    6. (3): ReLU(inplace)
    7. (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    8. (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    9. (6): ReLU(inplace)
    10. (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    11. (8): ReLU(inplace)
    12. (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    13. (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    14. (11): ReLU(inplace)
    15. (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    16. (13): ReLU(inplace)
    17. (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    18. (15): ReLU(inplace)
    19. (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    20. (17): ReLU(inplace)
    21. (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    22. (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    23. (20): ReLU(inplace)
    24. (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    25. (22): ReLU(inplace)
    26. (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    27. (24): ReLU(inplace)
    28. (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    29. (26): ReLU(inplace)
    30. (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    31. (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    32. (29): ReLU(inplace)
    33. (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    34. (31): ReLU(inplace)
    35. (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    36. (33): ReLU(inplace)
    37. (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    38. (35): ReLU(inplace)
    39. (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    40. )
    41. (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
    42. (classifier): Sequential(
    43. (0): Linear(in_features=25088, out_features=4096, bias=True)
    44. (1): ReLU(inplace)
    45. (2): Dropout(p=0.5)
    46. (3): Linear(in_features=4096, out_features=4096, bias=True)
    47. (4): ReLU(inplace)
    48. (5): Dropout(p=0.5)
    49. (6): Linear(in_features=4096, out_features=1000, bias=True)
    50. )
    51. )

     训练:

    接下来对训练过程进行解释:

    1、加载风格图像和内容图像,我们在之前的博客中使用的一幅加噪图进行训练,这里是用的内容图像的拷贝。

    2、我们需要优化的就是作为目标的内容图像拷贝,可以看到target需要求导。

    3、VGGnet参数是不需要优化的,所以设置为验证状态。

    4、将3幅图像输入网络,得到总共15个输出(每个图像有5层的输出)

    5、内容损失:这里是遍历5个层的输出来计算损失,而在keras版本中只用了第4层的输出计算损失

    6、风格损失:同样计算格拉姆风格矩阵,将每一层的风格损失叠加,得到总的风格损失,计算公式同样和keras版本有所不一样

    7、反向传播

    1. def main(config):
    2. # Image preprocessing
    3. # VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
    4. # We use the same normalization statistics here.
    5. transform = transforms.Compose([
    6. transforms.ToTensor(),
    7. transforms.Normalize(mean=(0.485, 0.456, 0.406),
    8. std=(0.229, 0.224, 0.225))])
    9. # Load content and style images
    10. # Make the style image same size as the content image
    11. content = load_image(config.content, transform, max_size=config.max_size)
    12. style = load_image(config.style, transform, shape=[content.size(2), content.size(3)])
    13. # Initialize a target image with the content image
    14. target = content.clone().requires_grad_(True)
    15. optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])
    16. vgg = VGGNet().to(device).eval()
    17. for step in range(config.total_step):
    18. # Extract multiple(5) conv feature vectors
    19. target_features = vgg(target)
    20. content_features = vgg(content)
    21. style_features = vgg(style)
    22. style_loss = 0
    23. content_loss = 0
    24. for f1, f2, f3 in zip(target_features, content_features, style_features):
    25. # Compute content loss with target and content images
    26. content_loss += torch.mean((f1 - f2)**2)
    27. # Reshape convolutional feature maps
    28. _, c, h, w = f1.size()
    29. f1 = f1.view(c, h * w)
    30. f3 = f3.view(c, h * w)
    31. # Compute gram matrix
    32. f1 = torch.mm(f1, f1.t())
    33. f3 = torch.mm(f3, f3.t())
    34. # Compute style loss with target and style images
    35. style_loss += torch.mean((f1 - f3)**2) / (c * h * w)
    36. # Compute total loss, backprop and optimize
    37. loss = content_loss + config.style_weight * style_loss
    38. optimizer.zero_grad()
    39. loss.backward()
    40. optimizer.step()
    41. if (step+1) % config.log_step == 0:
    42. print ('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}'
    43. .format(step+1, config.total_step, content_loss.item(), style_loss.item()))
    44. if (step+1) % config.sample_step == 0:
    45. # Save the generated image
    46. denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
    47. img = target.clone().squeeze()
    48. img = denorm(img).clamp_(0, 1)
    49. torchvision.utils.save_image(img, 'output-{}.png'.format(step+1))

    写在if __name__=="__main__"后面的语句只会在本脚本中才能被执行,被调用时是不会被执行的。 

    python的命令行工具:argparse,很优雅的添加参数

    但是由于jupyter不支持添加外部参数,所以使用了外部博客的方法来支持(记住更改读取图片的位置)

    1. import sys
    2. if __name__ == "__main__":
    3. #解决方案来自于博客
    4. if '-f' in sys.argv:
    5. sys.argv.remove('-f')
    6. parser = argparse.ArgumentParser()
    7. parser.add_argument('--content', type=str, default='png/content.png')
    8. parser.add_argument('--style', type=str, default='png/style.png')
    9. parser.add_argument('--max_size', type=int, default=400)
    10. parser.add_argument('--total_step', type=int, default=2000)
    11. parser.add_argument('--log_step', type=int, default=10)
    12. parser.add_argument('--sample_step', type=int, default=500)
    13. parser.add_argument('--style_weight', type=float, default=100)
    14. parser.add_argument('--lr', type=float, default=0.003)
    15. #config = parser.parse_args()
    16. config = parser.parse_known_args()[0] #参考博客 https://blog.csdn.net/ken_for_learning/article/details/89675904
    17. print(config)
    18. main(config)

  • 相关阅读:
    微信每日早安推送 Windows版
    django中session值的数据类型是dict,需要手动save(),更新才会传递到其他页面。
    华大单片机KEIL添加ST-LINK解决方法
    chatgpt赋能python:Python中的数字转换
    算法竞赛进阶指南 关押罪犯
    uni-app 5小时快速入门 9 基础组件
    Redis高级及实战
    重磅!!!监控分布式NVIDIA-GPU状态
    在 VirtualBox 虚拟机上搭建 openEuler 并安装桌面环境 详细教程
    【图像边缘检测】基于matlab自适应阈值的八方向和四方向sobel图像边缘检测【含Matlab源码 2058期】
  • 原文地址:https://blog.csdn.net/qq_41828351/article/details/90899153