有大量二维矩阵作为样本,为连续数据。数据具有空间连续性,因此用卷积网络,通过dcgan生成二维矩阵。因为是连续变量,因此损失采用nn.MSELoss()。文章来源地址https://www.toymoban.com/news/detail-722877.html
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from DemDataset import create_netCDF_Dem_trainLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter
batch_size=16
#load data
dataloader = create_netCDF_Dem_trainLoader(batch_size)
# Generator with Conv2D structure
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.ConvTranspose2d(100, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
return img
# Discriminator with Conv2D structure
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=1),
)
def forward(self, img):
validity = self.model(img)
return validity
# Initialize GAN components
generator = Generator()
discriminator = Discriminator()
# Define loss function and optimizers
criterion = nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0
# Training loop
num_epochs = 200
for epoch in range(num_epochs):
for batch_idx, real_data in enumerate(dataloader):
real_data = real_data.to(device)
# Train Discriminator
optimizer_D.zero_grad()
real_labels = torch.ones(real_data.size(0), 1).to(device)
fake_labels = torch.zeros(real_data.size(0), 1).to(device)
z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
fake_data = generator(z)
real_pred = discriminator(real_data)
fake_pred = discriminator(fake_data.detach())
d_loss_real = criterion(real_pred, real_labels)
d_loss_fake = criterion(fake_pred, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
# Train Generator
optimizer_G.zero_grad()
z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
fake_data = generator(z)
fake_pred = discriminator(fake_data)
g_loss = criterion(fake_pred, real_labels)
g_loss.backward()
optimizer_G.step()
# Print progress
if batch_idx % 100 == 0:
print(f"[Epoch {epoch}/{num_epochs}] [Batch {batch_idx}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
with torch.no_grad():
img_grid_real = torchvision.utils.make_grid(
fake_data#, normalize=True,
)
img_grid_fake = torchvision.utils.make_grid(
real_data#, normalize=True
)
writer_fake.add_image("fake_img", img_grid_fake, global_step=step)
writer_real.add_image("real_img", img_grid_real, global_step=step)
step += 1
# After training, you can generate a 2D array by sampling from the generator
z = torch.randn(1, 100, 1, 1).to(device)
generated_array = generator(z)
文章来源:https://www.toymoban.com/news/detail-722877.html
到了这里,关于用来生成二维矩阵的dcgan的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!