• GAN生成哆啦A梦,亲测疯狂训练50000epoch的效果,俺的菜菜电脑吃不消



    最近闲着没事学了一下GAN网络,感觉这个东西挺有趣的,所以就打算自己进行动手实践一下这个,满足自己的好奇。最后做出来了,所以就打算跟大家分享一下这个东西,其中的原理就不用跟大家说了,因为现在其他的博客介绍原理都很全面,大家可以去看一下其他博客看一下GAN的原理,顺便推一下公式,我这里就放上我参照其他多位博主的博客进行融合敲出来的代码吧。

    废话少说,先摆上我的训练结果,证明俺的代码是可以运行的叭,防止大家觉得我的代码是不能运行的......

    编译环境:

    python3.9.7 其实我觉得就算是python3.6都能运行

        

     挺可爱的嗷!!!哈哈


    第一步:

    我的训练图片是从百度上面拿的,拿来的图片格式是.webp格式的图片
    然后我进行训练的时候使用的图片分辨率是64X64,怎么说呢,就是尽量使用64X64的分辨率的图片进行训练叭,因为这样可以让你的电脑能够运行的起来,我刚开始进行训练的是医学的数据938X636结果就在电脑报错了:**RuntimeError: CUDA out of memory. Tried to allocate 6.71 GiB (GPU 0; 15.78 GiB total capacity; 13.49 GiB already allocated; 984.75 MiB free; 13.51 GiB reserved in total by PyTorch)**因为电脑的配置跟不上呀
     所以最后我对收集到的数据进行分辨率的修改,编程64X64就能轻轻松松地进行训练了
     

    先给大伙看看我整体的文件格式,因为路径那些都是配套,按照我这样才能进行运行,这个对小白比较友好,不然就自己进行修改代码了

                                                                         

    解释一下:
    # gan_A是总的文件夹 
    # images是放置网页进行获取的图片,我拿的图片格式是.webp格式的图片,这个很重要,因为后面进行图片分辨率的转化的时候,我的代码就是针对这个格式的图片的,如果不是这种格式的话,在我的代码image_tool.py那里需要进行修改一些东西。
     # img # 是每一万轮就对训练G产生的图片进行保存 
    # result是对webp格式的图片进行转化后保存的一个文件夹,也就是gan_1进行数据的训练的数据
     # saved_models是每隔一万轮对G的模型进行保存 

    image_tool.py代码如下:

    1. # -*- coding:utf-8 -*-
    2. # @Time : 2022-04-18 15:07
    3. # @Author : DaFuChen
    4. # @File : image_tool.py
    5. # @software: PyCharm
    6. # 进行分辨率的重新修改 进行值的变化修改
    7. # 导入需要的模块
    8. from glob import glob
    9. from PIL import Image
    10. import os
    11. # 图片路径
    12. # 使用 glob模块 获得文件夹内所有jpg图像
    13. img_path = glob("./images/*.webp")
    14. # img_path = glob("./images/*.jpg")
    15. # img_path = glob("./images/*.jpeg")
    16. # 存储(输出)路径
    17. path_save = "./result"
    18. for i, file in enumerate(img_path):
    19. name = os.path.join(path_save, "%d.jpg" % i)
    20. im = Image.open(file)
    21. # im.thumbnail((720,1280))
    22. reim = im.resize((64, 64))
    23. print(im.format, reim.size, reim.mode)
    24. reim.save(name, im.format)

    gan_1.py代码如下:

    1. # -*- coding:utf-8 -*-
    2. # @Time : 2022-04-18 15:05
    3. # @Author : DaFuChen
    4. # @File : gan_1.py
    5. # @software: PyCharm
    6. import argparse
    7. import os
    8. import numpy as np
    9. import math
    10. import torchvision.transforms as transforms
    11. from torchvision.utils import save_image
    12. from torch.utils.data import DataLoader, Dataset
    13. from torchvision import datasets
    14. from torch.autograd import Variable
    15. from PIL import Image
    16. import torch.nn as nn
    17. import torch.nn.functional as F
    18. import torch
    19. # 输出图片保存路径 没有就会自动进行创建
    20. os.makedirs("img", exist_ok=True)
    21. # 参数设置
    22. parser = argparse.ArgumentParser()
    23. parser.add_argument("--n_epochs", type=int, default=1000001, help="number of epochs of training")
    24. parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
    25. parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    26. parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    27. parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    28. parser.add_argument("--gpu", type=int, default=0, help="number of cpu threads to use during batch generation")
    29. # 输入噪声向量维度,默认100
    30. parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    31. # [enforce fail at .\c10 core\CPUAllocator.cpp:79]DefaultCPUAllocator:内存不足:您试图分配7203840000字节。 改改网络结构吧 tm
    32. # 输入图片维度,默认64*64*3 但是进行修改之后两个值 938 625 但是这个设计的网络结构不稳定 很难搞 这个是撑不住了
    33. parser.add_argument("--img_size1", type=int, default=64, help="size of each image dimension")
    34. parser.add_argument("--img_size2", type=int, default=64, help="size of each image dimension")
    35. parser.add_argument("--channels", type=int, default=3, help="number of image channels")
    36. # 每隔一个sample_interval的批次进行一次图片的保存
    37. parser.add_argument("--sample_interval", type=int, default=10000, help="interval betwen image samples")
    38. # 其实是创建了一个对象 之后可以调用它其中的参数值 使用了它的这一个类里面刻画的一些属性
    39. opt = parser.parse_args()
    40. print(opt)
    41. # 图像的分辨率值
    42. img_shape = (opt.channels, opt.img_size1, opt.img_size2)
    43. cuda = True if torch.cuda.is_available() else False
    44. class Generator(nn.Module):
    45. def __init__(self):
    46. super(Generator, self).__init__()
    47. def block(in_feat, out_feat, normalize=True):
    48. layers = [nn.Linear(in_feat, out_feat)]
    49. if normalize:
    50. layers.append(nn.BatchNorm1d(out_feat, 0.8))
    51. layers.append(nn.LeakyReLU(0.2, inplace=True))
    52. return layers
    53. self.model = nn.Sequential(
    54. *block(opt.latent_dim, 128, normalize=False),
    55. *block(128, 256),
    56. *block(256, 512),
    57. *block(512, 1024),
    58. nn.Linear(1024, int(np.prod(img_shape))),
    59. nn.Tanh()
    60. )
    61. def forward(self, z):
    62. img = self.model(z)
    63. img = img.view(img.size(0), *img_shape)
    64. return img
    65. class Discriminator(nn.Module):
    66. def __init__(self):
    67. super(Discriminator, self).__init__()
    68. self.model = nn.Sequential(
    69. # 基本的一个操作 进行线性化的拟合 然后再relu一下取个好的效果
    70. nn.Linear(int(np.prod(img_shape)), 512),
    71. nn.LeakyReLU(0.2, inplace=True),
    72. nn.Linear(512, 256),
    73. nn.LeakyReLU(0.2, inplace=True),
    74. nn.Linear(256, 1),
    75. nn.Sigmoid(),
    76. )
    77. def forward(self, img):
    78. img_flat = img.view(img.size(0), -1)
    79. validity = self.model(img_flat)
    80. return validity
    81. # Loss function
    82. adversarial_loss = torch.nn.BCELoss()
    83. # Initialize generator and discriminator
    84. generator = Generator()
    85. discriminator = Discriminator()
    86. # 将属性放进GPU进行训练
    87. if cuda:
    88. generator.cuda()
    89. discriminator.cuda()
    90. adversarial_loss.cuda()
    91. # Configure data loader
    92. img_transform = transforms.Compose([
    93. # transforms.ToPILImage(),
    94. transforms.ToTensor(),
    95. transforms.Normalize((0.5,), (0.5,)) # (x-mean) / std
    96. ])
    97. class MyData(Dataset): # 继承Dataset
    98. def __init__(self, root_dir, transform=None): # __init__是初始化该类的一些基础参数
    99. self.root_dir = root_dir # 文件目录
    100. self.transform = transform # 变换
    101. self.images = os.listdir(self.root_dir) # 目录里的所有文件
    102. def __len__(self): # 返回整个数据集的大小
    103. return len(self.images)
    104. def __getitem__(self, index): # 根据索引index返回dataset[index]
    105. image_index = self.images[index] # 根据索引index获取该图片
    106. img_path = os.path.join(self.root_dir, image_index) # 获取索引为index的图片的路径名
    107. img = Image.open(img_path) # 读取该图片
    108. if self.transform:
    109. img = self.transform(img)
    110. return img # 返回该样本
    111. # 输入图片所在文件夹
    112. mydataset = MyData(
    113. root_dir='./result/', transform=img_transform
    114. )
    115. # data loader 数据载入
    116. dataloader = DataLoader(
    117. dataset=mydataset, batch_size=opt.batch_size, shuffle=True
    118. )
    119. # 下面这一块是可以省略
    120. # os.makedirs("./data/MNIST", exist_ok=True)
    121. # dataloader = torch.utils.data.DataLoader(
    122. # datasets.MNIST(
    123. # "./data/MNIST",
    124. # train=True,
    125. # download=True,
    126. # transform=transforms.Compose(
    127. # [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
    128. # ),
    129. # ),
    130. # batch_size=opt.batch_size,
    131. # shuffle=True,
    132. # )
    133. # Optimizers
    134. optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    135. optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    136. Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    137. # ----------
    138. # Training
    139. # ----------
    140. for epoch in range(opt.n_epochs):
    141. for i, img in enumerate(dataloader):
    142. imgs = img
    143. # Adversarial ground truths
    144. valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
    145. fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
    146. # Configure input
    147. real_imgs = Variable(imgs.type(Tensor))
    148. # -----------------
    149. # Train Generator
    150. # -----------------
    151. optimizer_G.zero_grad()
    152. # Sample noise as generator input
    153. z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
    154. # Generate a batch of images
    155. gen_imgs = generator(z)
    156. # Loss measures generator's ability to fool the discriminator
    157. g_loss = adversarial_loss(discriminator(gen_imgs), valid)
    158. g_loss.backward()
    159. optimizer_G.step()
    160. # ---------------------
    161. # Train Discriminator
    162. # ---------------------
    163. optimizer_D.zero_grad()
    164. # Measure discriminator's ability to classify real from generated samples
    165. real_loss = adversarial_loss(discriminator(real_imgs), valid)
    166. fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
    167. d_loss = (real_loss + fake_loss) / 2
    168. d_loss.backward()
    169. optimizer_D.step()
    170. print(
    171. "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
    172. % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
    173. )
    174. batches_done = epoch * len(dataloader) + i
    175. if batches_done % opt.sample_interval == 0:
    176. save_image(gen_imgs.data[:25], "./img/%d.png" % batches_done, nrow=5, normalize=True)
    177. torch.save(generator, "saved_models/generator_%d.pth" % epoch)

  • 相关阅读:
    【SQL Server数据库】视图的使用
    【SHELL】推箱子游戏
    使用SSH通过FinalShell远程连接Ubuntu服务器
    应用计量经济学问题~
    【js&vue】联合gtp仿写一个简单的vue框架,以此深度学习JavaScript
    1. 获取数据-requests.get()
    cos 腾讯云直传
    uniapp获取手机号一键登录和退出登录功能
    【云原生】Spring Cloud是什么?Spring Cloud版本介绍
    【附源码】计算机毕业设计SSM铜仁学院毕业就业管理系统
  • 原文地址:https://blog.csdn.net/blockshowtouse/article/details/126549697