有大量二维矩阵作为样本,为连续数据。数据具有空间连续性,因此用卷积网络,通过dcgan生成二维矩阵。因为是连续变量,因此损失采用nn.MSELoss()。
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import numpy as np
- from DemDataset import create_netCDF_Dem_trainLoader
- import torchvision
- from torch.utils.tensorboard import SummaryWriter
-
- batch_size=16
- #load data
- dataloader = create_netCDF_Dem_trainLoader(batch_size)
-
- # Generator with Conv2D structure
- class Generator(nn.Module):
- def __init__(self):
- super(Generator, self).__init__()
- self.model = nn.Sequential(
- nn.ConvTranspose2d(100, 512, kernel_size=4, stride=2, padding=1),
- nn.BatchNorm2d(512),
- nn.ReLU(),
- nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),
- nn.BatchNorm2d(512),
- nn.ReLU(),
- nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
- nn.BatchNorm2d(256),
- nn.ReLU(),
- nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
- nn.BatchNorm2d(128),
- nn.ReLU(),
- nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
- nn.BatchNorm2d(64),
- nn.ReLU(),
- nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
- nn.BatchNorm2d(32),
- nn.ReLU(),
- nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
- nn.Tanh()
- )
-
- def forward(self, z):
- img = self.model(z)
- return img
-
- # Discriminator with Conv2D structure
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator, self).__init__()
- self.model = nn.Sequential(
- nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
- nn.LeakyReLU(0.2),
- nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
- nn.LeakyReLU(0.2),
- nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
- nn.LeakyReLU(0.2),
- nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
- nn.LeakyReLU(0.2),
- nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
- nn.LeakyReLU(0.2),
- nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
- nn.LeakyReLU(0.2),
- nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=1),
- )
-
- def forward(self, img):
- validity = self.model(img)
- return validity
-
- # Initialize GAN components
- generator = Generator()
- discriminator = Discriminator()
-
-
- # Define loss function and optimizers
- criterion = nn.MSELoss()
- optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
- optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- generator.to(device)
- discriminator.to(device)
-
- writer_real = SummaryWriter(f"logs/real")
- writer_fake = SummaryWriter(f"logs/fake")
- step = 0
-
- # Training loop
- num_epochs = 200
- for epoch in range(num_epochs):
- for batch_idx, real_data in enumerate(dataloader):
- real_data = real_data.to(device)
-
- # Train Discriminator
- optimizer_D.zero_grad()
- real_labels = torch.ones(real_data.size(0), 1).to(device)
- fake_labels = torch.zeros(real_data.size(0), 1).to(device)
- z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
- fake_data = generator(z)
- real_pred = discriminator(real_data)
- fake_pred = discriminator(fake_data.detach())
- d_loss_real = criterion(real_pred, real_labels)
- d_loss_fake = criterion(fake_pred, fake_labels)
- d_loss = d_loss_real + d_loss_fake
- d_loss.backward()
- optimizer_D.step()
-
- # Train Generator
- optimizer_G.zero_grad()
- z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
- fake_data = generator(z)
- fake_pred = discriminator(fake_data)
- g_loss = criterion(fake_pred, real_labels)
- g_loss.backward()
- optimizer_G.step()
-
- # Print progress
- if batch_idx % 100 == 0:
- print(f"[Epoch {epoch}/{num_epochs}] [Batch {batch_idx}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
- with torch.no_grad():
- img_grid_real = torchvision.utils.make_grid(
- fake_data#, normalize=True,
-
- )
-
- img_grid_fake = torchvision.utils.make_grid(
- real_data#, normalize=True
- )
-
- writer_fake.add_image("fake_img", img_grid_fake, global_step=step)
- writer_real.add_image("real_img", img_grid_real, global_step=step)
-
- step += 1
-
- # After training, you can generate a 2D array by sampling from the generator
- z = torch.randn(1, 100, 1, 1).to(device)
- generated_array = generator(z)