import torch.nn as nn

import torch.optim as optim

import torchvision

import torchvision.transforms as transforms

from torch.utils.data import DataLoader

import numpy as np

import matplotlib.pyplot as plt

# 設(shè)置隨機(jī)種子,確保結(jié)果可復(fù)現(xiàn)

torch.manual_seed(42)

第二步:準(zhǔn)備數(shù)據(jù)集

# 數(shù)據(jù)預(yù)處理和加載

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize((0.5,), (0.5,))

])

# 加載MNIST數(shù)據(jù)集

train_dataset = torchvision.datasets.MNIST(

root='./data',

train=True,

transform=transform,

download=True

)

# 創(chuàng)建數(shù)據(jù)加載器

train_loader = DataLoader(

dataset=train_dataset,

batch_size=64,

shuffle=True

)

第三步:構(gòu)建生成器網(wǎng)絡(luò)

class Generator(nn.Module):

def __init__(self):

super(Generator, self).__init__()

# 生成器結(jié)構(gòu):用簡單的全連接層

self.gen = nn.Sequential(

# 輸入是隨機(jī)噪聲(latent_dim)

nn.Linear(100, 256),

nn.LeakyReLU(0.2),

nn.BatchNorm1d(256),

nn.Linear(256, 512),

nn.LeakyReLU(0.2),

nn.BatchNorm1d(512),

nn.Linear(512, 1024),

nn.LeakyReLU(0.2),

nn.BatchNorm1d(1024),

# 輸出層,生成28*28=784維的圖像

nn.Linear(1024, 784),

nn.Tanh()

)

def forward(self, x):

return self.gen(x)

第四步:構(gòu)建判別器網(wǎng)絡(luò)

class Discriminator(nn.Module):

def __init__(self):

super(Discriminator, self).__init__()

# 判別器結(jié)構(gòu)

self.disc = nn.Sequential(

# 輸入是展平的圖像(784維)

nn.Linear(784, 1024),

nn.LeakyReLU(0.2),

nn.Dropout(0.3),

nn.Linear(1024, 512),

nn.LeakyReLU(0.2),

nn.Dropout(0.3),

nn.Linear(512, 256),

nn.LeakyReLU(0.2),

nn.Dropout(0.3),

# 輸出一個概率值

nn.Linear(256, 1),

nn.Sigmoid()

)

def forward(self, x):

return self.disc(x)

第五步:訓(xùn)練模型

# 初始化模型

generator = Generator()

discriminator = Discriminator()

# 損失函數(shù)和優(yōu)化器

criterion = nn.BCELoss()

g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 訓(xùn)練參數(shù)

num_epochs = 100

latent_dim = 100

fixed_noise = torch.randn(16, latent_dim) # 用于可視化

# 訓(xùn)練循環(huán)

for epoch in range(num_epochs):

for i, (real_images, _) in enumerate(train_loader):

batch_size = real_images.size(0)

# 準(zhǔn)備真實(shí)和虛假的標(biāo)簽

real_label = torch.ones(batch_size, 1)

fake_label = torch.zeros(batch_size, 1)

# 展平圖像

real_images = real_images.view(-1, 784)

# 訓(xùn)練判別器

d_optimizer.zero_grad()

output_real = discriminator(real_images)

d_loss_real = criterion(output_real, real_label)

# 生成假圖像

noise = torch.randn(batch_size, latent_dim)

fake_images = generator(noise)

output_fake = discriminator(fake_images.detach())

d_loss_fake = criterion(output_fake, fake_label)

# 計(jì)算判別器總損失

d_loss = d_loss_real + d_loss_fake

d_loss.backward()

d_optimizer.step()

# 訓(xùn)練生成器

g_optimizer.zero_grad()

output_fake = discriminator(fake_images)

g_loss = criterion(output_fake, real_label)

g_loss.backward()

g_optimizer.step()

if i % 100 == 0:

print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(train_loader)}], '

f'd_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

第六步:生成和顯示結(jié)果

def show_images(images):

plt.figure(figsize=(4, 4))

plt.axis(“off”)

plt.imshow(np.transpose(torchvision.utils.make_grid(

images.reshape(-1, 1, 28, 28), nrow=4, padding=2, normalize=True

).cpu(), (1, 2, 0)))

plt.show()

# 生成示例圖像

with torch.no_grad():

fake_images = generator(fixed_noise)

show_images(fake_images)

小貼士:

  1. 增加訓(xùn)練穩(wěn)定性:

* 使用標(biāo)簽平滑化技術(shù)

* 添加梯度裁剪

* 嘗試不同的激活函數(shù)

  1. 提高生成質(zhì)量:

* 增加網(wǎng)絡(luò)層數(shù)和通道數(shù)

* 使用更復(fù)雜的網(wǎng)絡(luò)結(jié)構(gòu)(如卷積層)

* 調(diào)整學(xué)習(xí)率和批次大小

  1. 避免模式崩潰:

* 使用Wasserstein GAN

* 實(shí)現(xiàn)批次歸一化

* 適當(dāng)調(diào)整判別器和生成器的訓(xùn)練比例

這個GAN實(shí)現(xiàn)雖然簡單,但已經(jīng)能產(chǎn)生不錯的效果了。記得調(diào)參是提升效果的關(guān)鍵,慢慢調(diào)整,你一定能生成出更逼真的圖像!

文章轉(zhuǎn)自微信公眾號@寒江映孤月

上一篇:

時空圖神經(jīng)網(wǎng)絡(luò)ST-GNN的概念以及Pytorch實(shí)現(xiàn)

下一篇:

使用Pytorch實(shí)現(xiàn)頻譜歸一化生成對抗網(wǎng)絡(luò)(SN-GAN)
#你可能也喜歡這些API文章!

我們有何不同?

API服務(wù)商零注冊

多API并行試用

數(shù)據(jù)驅(qū)動選型,提升決策效率

查看全部API→
??

熱門場景實(shí)測,選對API

#AI文本生成大模型API

對比大模型API的內(nèi)容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉(zhuǎn)化潛力

25個渠道
一鍵對比試用API 限時免費(fèi)

#AI深度推理大模型API

對比大模型API的邏輯推理準(zhǔn)確性、分析深度、可視化建議合理性

10個渠道
一鍵對比試用API 限時免費(fèi)