鍵.png)
如何高效爬取全球新聞網(wǎng)站 – 整合Scrapy、Selenium與Mediastack API實(shí)現(xiàn)自動化新聞采集
# 數(shù)據(jù)預(yù)處理和加載
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加載MNIST數(shù)據(jù)集
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
transform=transform,
download=True
)
# 創(chuàng)建數(shù)據(jù)加載器
train_loader = DataLoader(
dataset=train_dataset,
batch_size=64,
shuffle=True
)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# 生成器結(jié)構(gòu):用簡單的全連接層
self.gen = nn.Sequential(
# 輸入是隨機(jī)噪聲(latent_dim)
nn.Linear(100, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
# 輸出層,生成28*28=784維的圖像
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, x):
return self.gen(x)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# 判別器結(jié)構(gòu)
self.disc = nn.Sequential(
# 輸入是展平的圖像(784維)
nn.Linear(784, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
# 輸出一個概率值
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.disc(x)
# 初始化模型
generator = Generator()
discriminator = Discriminator()
# 損失函數(shù)和優(yōu)化器
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 訓(xùn)練參數(shù)
num_epochs = 100
latent_dim = 100
fixed_noise = torch.randn(16, latent_dim) # 用于可視化
# 訓(xùn)練循環(huán)
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(train_loader):
batch_size = real_images.size(0)
# 準(zhǔn)備真實(shí)和虛假的標(biāo)簽
real_label = torch.ones(batch_size, 1)
fake_label = torch.zeros(batch_size, 1)
# 展平圖像
real_images = real_images.view(-1, 784)
# 訓(xùn)練判別器
d_optimizer.zero_grad()
output_real = discriminator(real_images)
d_loss_real = criterion(output_real, real_label)
# 生成假圖像
noise = torch.randn(batch_size, latent_dim)
fake_images = generator(noise)
output_fake = discriminator(fake_images.detach())
d_loss_fake = criterion(output_fake, fake_label)
# 計(jì)算判別器總損失
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
# 訓(xùn)練生成器
g_optimizer.zero_grad()
output_fake = discriminator(fake_images)
g_loss = criterion(output_fake, real_label)
g_loss.backward()
g_optimizer.step()
if i % 100 == 0:
print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(train_loader)}], '
f'd_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
def show_images(images):
plt.figure(figsize=(4, 4))
plt.axis(“off”)
plt.imshow(np.transpose(torchvision.utils.make_grid(
images.reshape(-1, 1, 28, 28), nrow=4, padding=2, normalize=True
).cpu(), (1, 2, 0)))
plt.show()
# 生成示例圖像
with torch.no_grad():
fake_images = generator(fixed_noise)
show_images(fake_images)
小貼士:
* 使用標(biāo)簽平滑化技術(shù)
* 添加梯度裁剪
* 嘗試不同的激活函數(shù)
* 增加網(wǎng)絡(luò)層數(shù)和通道數(shù)
* 使用更復(fù)雜的網(wǎng)絡(luò)結(jié)構(gòu)(如卷積層)
* 調(diào)整學(xué)習(xí)率和批次大小
* 使用Wasserstein GAN
* 實(shí)現(xiàn)批次歸一化
* 適當(dāng)調(diào)整判別器和生成器的訓(xùn)練比例
這個GAN實(shí)現(xiàn)雖然簡單,但已經(jīng)能產(chǎn)生不錯的效果了。記得調(diào)參是提升效果的關(guān)鍵,慢慢調(diào)整,你一定能生成出更逼真的圖像!
文章轉(zhuǎn)自微信公眾號@寒江映孤月