• pytorch 实现一个最简单的 GAN:用mnist数据集生成新图像



    用 pytorch 实现一个最简单的GAN:用mnist数据集生成新图像

    一、代码

    训练细节见代码注释:

    # @Time    : 2022/9/25
    # @Function: 用pytorch实现一个最简单的GAN,用MNIST数据集生成新图片
    
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision
    import torchvision.datasets as datasets
    from torch.utils.data import DataLoader
    import torchvision.transforms as transforms
    from torch.utils.tensorboard import SummaryWriter
    
    import os
    import shutil
    from tqdm import tqdm
    
    
    # 判别器,判断一张图片来源于真实数据集的概率,输入0-1之间的数,数值越大表示数据来源于真实数据集的概率越高。
    class Discriminator(nn.Module):
        def __init__(self, img_dim):
            super().__init__()
            self.disc = nn.Sequential(
                nn.Linear(in_features=img_dim, out_features=128),
                nn.LeakyReLU(0.01),
                nn.Linear(128, 1),
                nn.Sigmoid(),  # 将输出值映射到0-1之间
            )
    
        def forward(self, x):
            return self.disc(x)
    
    
    # 生成器,用随机噪声生成图片
    class Generator(nn.Module):
        def __init__(self, noise_dim, img_dim):
            super().__init__()
            self.gen = nn.Sequential(
                nn.Linear(noise_dim, 256),
                nn.LeakyReLU(0.01),
                nn.Linear(256, img_dim),
                nn.Tanh(),
                # normalize inputs to [-1, 1] so make outputs [-1, 1]
                # 一般二分类问题中,隐藏层用Tanh函数,输出层用Sigmod函数
            )
    
        def forward(self, x):
            return self.gen(x)
    
    
    if __name__ == '__main__':
        device = "cuda" if torch.cuda.is_available() else "cpu"
        lr = 3e-4
        noise_dim = 50  # noise
        image_dim = 28 * 28 * 1  # 784
        batch_size = 32
        num_epochs = 200
    
        # dataset
        transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5), (0.5))])
        dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        fixed_noise = torch.randn((batch_size, noise_dim)).to(device)
    
        D = Discriminator(image_dim).to(device)
        G = Generator(noise_dim, image_dim).to(device)
        opt_disc = optim.Adam(D.parameters(), lr=lr)
        opt_gen = optim.Adam(G.parameters(), lr=lr)
        criterion = nn.BCELoss()     # 二分类交叉熵损失函数
    
        # 存放log的文件夹
        log_dir = "log"
        if (os.path.exists(log_dir)):
            shutil.rmtree(log_dir)
        writer = SummaryWriter(log_dir)
    
        for epoch in tqdm(range(num_epochs), desc='epochs'):
            # GAN不需要真实label
            for batch_idx, (img, _) in enumerate(loader):
                img = img.view(-1, 784).to(device)
                batch_size = img.shape[0]
    
                # 训练判别器: max log(D(x)) + log(1 - D(G(z)))
                noise = torch.randn(batch_size, noise_dim).to(device)
                fake_img = G(noise)    # 根据随机噪声生成虚假数据
                disc_fake = D(fake_img)    # 判别器判断生成数据为真的概率
                # torch.zeros_like(x) 表示生成与 x 形状相同、元素全为0的张量
                lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))    # 虚假数据与0计算损失
                disc_real = D(img)    # 判别器判断真实数据为真的概率
                lossD_real = criterion(disc_real, torch.ones_like(disc_real))     # 真实数据与1计算损失
                lossD = (lossD_real + lossD_fake) / 2
    
                D.zero_grad()
                lossD.backward(retain_graph=True)
                opt_disc.step()
    
                # 训练生成器: 在此过程中将判别器固定,min log(1 - D(G(z))) <-> max log(D(G(z))
                output = D(fake_img)
                lossG = criterion(output, torch.ones_like(output))
                G.zero_grad()
                lossG.backward()
                opt_gen.step()
    
                if batch_idx == 0:
                    # print( f"Epoch [{epoch+1}/{num_epochs}]  Batch {batch_idx}/{len(loader)}   lossD = {lossD:.4f}, lossG = {lossG:.4f}")
                    with torch.no_grad():
                        # 用固定的噪声数据生成图像,以对比经过不同epoch训练后的生成器的生成能力
                        fake_img = G(fixed_noise).reshape(-1, 1, 28, 28)
                        real_img = img.reshape(-1, 1, 28, 28)
    
                        # make_grid的作用是将若干幅图像拼成一幅图像
                        img_grid_fake = torchvision.utils.make_grid(fake_img, normalize=True)
                        img_grid_real = torchvision.utils.make_grid(real_img, normalize=True)
    
                        writer.add_image("Fake Images", img_grid_fake, global_step=epoch)
                        writer.add_image("Real Images", img_grid_real, global_step=epoch)
                        writer.add_scalar(tag="lossD", scalar_value=lossD, global_step=epoch)
                        writer.add_scalar(tag="lossG", scalar_value=lossG, global_step=epoch)
    

    二、生成结果

    2.1 loss的变化

    使用 tensorboard可视化,生成器和判别器的loss变化如下:
    在这里插入图片描述
    这里训练了200个epoch,每个epoch保存了一次loss。按照之前每个batch保存一次loss的结果来看,在训练100个epoch左右时,生成器和判别器的loss达到平衡,可以视为收敛,之后模型过拟合了。

    2.2 生成的虚假图像的变化

    使用相同的噪声生成图像,以观测经过不同epoch训练后的生成器的生成能力(以假乱真能力):

    epoch=3:

    在这里插入图片描述
    epoch=20:

    在这里插入图片描述
    epoch=53:

    在这里插入图片描述
    epoch=141:
    在这里插入图片描述
    epoch=199:

    在这里插入图片描述

    三、不足之处

    程序还有很多不足之处:

    (1)程序实现的是最早的GAN版本,生成器是一个MLP(多层感知机)而不是神经网络,因此特征提取和生成能力较差。

    (2)图像的生成效果与超参数设置有很大关系,如学习率的设置(包括学习率的演化策略)、训练次数、随机噪声的维度,甚至数据集的归一化参数(transforms.Normalize((0.5), (0.5)))都会对生成效果产生一定影响。

    (3)理论上损失函数只要能够适用于二分类即可,如MSE,但一般使用BCE。有一种观点认为BCE的形式与GAN的理论代价函数是一致的,二者可以互推,可以参考 GAN网络概述及LOSS函数详解

  • 相关阅读:
    【基于FreeRTOS的STM32F103系统】内存管理及任务调度
    浅析 em 和 rem
    时间,空间复杂度讲解——夯实根基
    Java版工程行业管理系统源码-专业的工程管理软件- 工程项目各模块及其功能点清单
    一百八十三、大数据离线数仓完整流程——步骤二、在Hive的ODS层建外部表并加载HDFS中的数据
    go语言Array 与 Slice
    实战指南:使用 xUnit.DependencyInjection 在单元测试中实现依赖注入【完整教程】
    北京ib国际学校大盘点
    基础算法一:大整数模积运算
    【电压质量】提高隔离电源系统的电压质量(Simulink实现)
  • 原文地址:https://blog.csdn.net/qq_43799400/article/details/127043011