• 用来生成二维矩阵的dcgan


    有大量二维矩阵作为样本,为连续数据。数据具有空间连续性,因此用卷积网络,通过dcgan生成二维矩阵。因为是连续变量,因此损失采用nn.MSELoss()。

    1. import torch
    2. import torch.nn as nn
    3. import torch.optim as optim
    4. import numpy as np
    5. from DemDataset import create_netCDF_Dem_trainLoader
    6. import torchvision
    7. from torch.utils.tensorboard import SummaryWriter
    8. batch_size=16
    9. #load data
    10. dataloader = create_netCDF_Dem_trainLoader(batch_size)
    11. # Generator with Conv2D structure
    12. class Generator(nn.Module):
    13. def __init__(self):
    14. super(Generator, self).__init__()
    15. self.model = nn.Sequential(
    16. nn.ConvTranspose2d(100, 512, kernel_size=4, stride=2, padding=1),
    17. nn.BatchNorm2d(512),
    18. nn.ReLU(),
    19. nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),
    20. nn.BatchNorm2d(512),
    21. nn.ReLU(),
    22. nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
    23. nn.BatchNorm2d(256),
    24. nn.ReLU(),
    25. nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
    26. nn.BatchNorm2d(128),
    27. nn.ReLU(),
    28. nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
    29. nn.BatchNorm2d(64),
    30. nn.ReLU(),
    31. nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
    32. nn.BatchNorm2d(32),
    33. nn.ReLU(),
    34. nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
    35. nn.Tanh()
    36. )
    37. def forward(self, z):
    38. img = self.model(z)
    39. return img
    40. # Discriminator with Conv2D structure
    41. class Discriminator(nn.Module):
    42. def __init__(self):
    43. super(Discriminator, self).__init__()
    44. self.model = nn.Sequential(
    45. nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
    46. nn.LeakyReLU(0.2),
    47. nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
    48. nn.LeakyReLU(0.2),
    49. nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
    50. nn.LeakyReLU(0.2),
    51. nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
    52. nn.LeakyReLU(0.2),
    53. nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
    54. nn.LeakyReLU(0.2),
    55. nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
    56. nn.LeakyReLU(0.2),
    57. nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=1),
    58. )
    59. def forward(self, img):
    60. validity = self.model(img)
    61. return validity
    62. # Initialize GAN components
    63. generator = Generator()
    64. discriminator = Discriminator()
    65. # Define loss function and optimizers
    66. criterion = nn.MSELoss()
    67. optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    68. optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    69. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    70. generator.to(device)
    71. discriminator.to(device)
    72. writer_real = SummaryWriter(f"logs/real")
    73. writer_fake = SummaryWriter(f"logs/fake")
    74. step = 0
    75. # Training loop
    76. num_epochs = 200
    77. for epoch in range(num_epochs):
    78. for batch_idx, real_data in enumerate(dataloader):
    79. real_data = real_data.to(device)
    80. # Train Discriminator
    81. optimizer_D.zero_grad()
    82. real_labels = torch.ones(real_data.size(0), 1).to(device)
    83. fake_labels = torch.zeros(real_data.size(0), 1).to(device)
    84. z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
    85. fake_data = generator(z)
    86. real_pred = discriminator(real_data)
    87. fake_pred = discriminator(fake_data.detach())
    88. d_loss_real = criterion(real_pred, real_labels)
    89. d_loss_fake = criterion(fake_pred, fake_labels)
    90. d_loss = d_loss_real + d_loss_fake
    91. d_loss.backward()
    92. optimizer_D.step()
    93. # Train Generator
    94. optimizer_G.zero_grad()
    95. z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
    96. fake_data = generator(z)
    97. fake_pred = discriminator(fake_data)
    98. g_loss = criterion(fake_pred, real_labels)
    99. g_loss.backward()
    100. optimizer_G.step()
    101. # Print progress
    102. if batch_idx % 100 == 0:
    103. print(f"[Epoch {epoch}/{num_epochs}] [Batch {batch_idx}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
    104. with torch.no_grad():
    105. img_grid_real = torchvision.utils.make_grid(
    106. fake_data#, normalize=True,
    107. )
    108. img_grid_fake = torchvision.utils.make_grid(
    109. real_data#, normalize=True
    110. )
    111. writer_fake.add_image("fake_img", img_grid_fake, global_step=step)
    112. writer_real.add_image("real_img", img_grid_real, global_step=step)
    113. step += 1
    114. # After training, you can generate a 2D array by sampling from the generator
    115. z = torch.randn(1, 100, 1, 1).to(device)
    116. generated_array = generator(z)

  • 相关阅读:
    七夕,你来人间一趟,总要谈一场轰轰烈烈的恋爱
    android 多屏幕显示activity,副屏,无线投屏
    多线程快速处理List集合(结合线程池的使用)
    BSCNews报告:Sui网络近期数据激增,生态发展良好
    自己亲手打造的VS Code里写AsciiDoc的快捷模板Snippet文件
    c++---I/o操作
    C++11重写muduo网络库3—Channel库(channel模块待补充完善)
    通过v_COURSE和V_grade查看期末平均成绩在60分以上的课程名称
    2019 java面试题基础
    c语言方阵循环右移
  • 原文地址:https://blog.csdn.net/yanfeng1022/article/details/133929883