!pip install -q diffusers

# 導(dǎo)入所需要的包
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
# 輸出
Using device: cuda

此時(shí)會(huì)輸出運(yùn)行環(huán)境是GPU還是CPU

二、載MNIST數(shù)據(jù)集

? ? ? ?MNIST數(shù)據(jù)集是一個(gè)小數(shù)據(jù)集,存儲(chǔ)的是0-9手寫(xiě)數(shù)字字體,每張圖片都28X28的灰度圖片,每個(gè)像素的取值范圍是[0,1],下面加載該數(shù)據(jù)集,并展示部分?jǐn)?shù)據(jù):

dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
# 輸出
Input shape: torch.Size([8, 1, 28, 28])
Labels: tensor([7, 8, 4, 2, 3, 6, 0, 2])

三、擴(kuò)散模型的退化過(guò)程

? ? ? ?所謂退化過(guò)程,其實(shí)就是對(duì)輸入數(shù)據(jù)加入噪聲的過(guò)程,由于MNIST數(shù)據(jù)集的像素范圍在[0,1],那么我們加入噪聲也需要保持在相同的范圍,這樣我們可以很容易的把輸入數(shù)據(jù)與噪聲進(jìn)行混合,代碼如下:

def corrupt(x, amount):
"""Corrupt the input x by mixing it with noise according to amount""" noise = torch.rand_like(x) amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works return x*(1-amount) + noise*amount

接下來(lái),我們看一下逐步加噪的效果,代碼如下:

# Plotting the input data
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')

# Adding noise
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)

# Plottinf the noised version
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');

從上圖可以看出,從左到右加入的噪聲逐步增多,當(dāng)噪聲量接近1時(shí),數(shù)據(jù)看起來(lái)像純粹的隨機(jī)噪聲。

四、構(gòu)建一個(gè)簡(jiǎn)單的UNet模型

? ? ? ?UNet模型與自編碼器有異曲同工之妙,UNet最初是用于完成醫(yī)學(xué)圖像中分割任務(wù)的,網(wǎng)絡(luò)結(jié)構(gòu)如下所示:

代碼如下:

class BasicUNet(nn.Module):
"""A minimal UNet implementation."""
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.down_layers = torch.nn.ModuleList([
nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
])
self.up_layers = torch.nn.ModuleList([
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 32, kernel_size=5, padding=2),
nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
])
self.act = nn.SiLU() # The activation function
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)

def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = self.act(l(x)) # Through the layer and the activation function
if i < 2: # For all but the third (final) down layer:
h.append(x) # Storing output for skip connection
x = self.downscale(x) # Downscale ready for the next layer

for i, l in enumerate(self.up_layers):
if i > 0: # For all except the first up layer
x = self.upscale(x) # Upscale
x += h.pop() # Fetching stored output (skip connection)
x = self.act(l(x)) # Through the layer and the activation function

return x

我們來(lái)檢驗(yàn)一下模型輸入輸出的shape變化是否符合預(yù)期,代碼如下:

net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape
# 輸出
torch.Size([8, 1, 28, 28])

再來(lái)看一下模型的參數(shù)量,代碼如下:

sum([p.numel() for p in net.parameters()])
# 輸出
309057

      至此,已經(jīng)完成數(shù)據(jù)加載和UNet模型構(gòu)建,當(dāng)然UNet模型的結(jié)構(gòu)可以有不同的設(shè)計(jì)。

五、擴(kuò)散模型訓(xùn)練

? ? ? ?擴(kuò)散模型應(yīng)該學(xué)習(xí)什么?其實(shí)有很多不同的目標(biāo),比如學(xué)習(xí)噪聲,我們先以一個(gè)簡(jiǎn)單的例子開(kāi)始,輸入數(shù)據(jù)為帶噪聲的MNIST數(shù)據(jù),擴(kuò)散模型應(yīng)該輸出對(duì)應(yīng)的最佳數(shù)字預(yù)測(cè),因此學(xué)習(xí)的目標(biāo)是預(yù)測(cè)值與真實(shí)值的MSE,訓(xùn)練代碼如下:

# Dataloader (you can mess with batch size)
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# How many runs through the data should we do?
n_epochs = 3

# Create the network
net = BasicUNet()
net.to(device)

# Our loss finction
loss_fn = nn.MSELoss()

# The optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3)

# Keeping a record of the losses for later viewing
losses = []

# The training loop
for epoch in range(n_epochs):

for x, y in train_dataloader:

# Get some data and prepare the corrupted version
x = x.to(device) # Data on the GPU
noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
noisy_x = corrupt(x, noise_amount) # Create our noisy x

# Get the model prediction
pred = net(noisy_x)

# Calculate the loss
loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?

# Backprop and update the params:
opt.zero_grad()
loss.backward()
opt.step()

# Store the loss for later
losses.append(loss.item())

# Print our the average of the loss values for this epoch:
avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')

# View the loss curve
plt.plot(losses)
plt.ylim(0, 0.1);
# 輸出
Finished epoch 0. Average loss for this epoch: 0.024689
Finished epoch 1. Average loss for this epoch: 0.019226
Finished epoch 2. Average loss for this epoch: 0.017939

訓(xùn)練過(guò)程的loss曲線如下圖所示:

六、擴(kuò)散模型效果評(píng)估

我們選取一部分?jǐn)?shù)據(jù)來(lái)評(píng)估一下模型的預(yù)測(cè)效果,代碼如下:

#@markdown Visualizing model predictions on noisy inputs:

# Fetch some data
x, y = next(iter(train_dataloader))
x = x[:8] # Only using the first 8 for easy plotting

# Corrupt with a range of amounts
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)

# Get the model predictions
with torch.no_grad():
preds = net(noised_x.to(device)).detach().cpu()

# Plot
fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')
axs[1].set_title('Corrupted data')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')
axs[2].set_title('Network Predictions')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');

從上圖可以看出,對(duì)于噪聲量較低的輸入,模型的預(yù)測(cè)效果是很不錯(cuò)的,當(dāng)amount=1時(shí),模型的輸出接近整個(gè)數(shù)據(jù)集的均值,這正是擴(kuò)散模型的工作原理。

Note:我們的訓(xùn)練并不太充分,讀者可以嘗試不同的超參數(shù)來(lái)優(yōu)化模型。

文章轉(zhuǎn)自微信公眾號(hào)@ArronAI

上一篇:

擴(kuò)散模型實(shí)戰(zhàn)(三):擴(kuò)散模型的應(yīng)用

下一篇:

擴(kuò)散模型實(shí)戰(zhàn)(五):采樣過(guò)程
#你可能也喜歡這些API文章!

我們有何不同?

API服務(wù)商零注冊(cè)

多API并行試用

數(shù)據(jù)驅(qū)動(dòng)選型,提升決策效率

查看全部API→
??

熱門(mén)場(chǎng)景實(shí)測(cè),選對(duì)API

#AI文本生成大模型API

對(duì)比大模型API的內(nèi)容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉(zhuǎn)化潛力

25個(gè)渠道
一鍵對(duì)比試用API 限時(shí)免費(fèi)

#AI深度推理大模型API

對(duì)比大模型API的邏輯推理準(zhǔn)確性、分析深度、可視化建議合理性

10個(gè)渠道
一鍵對(duì)比試用API 限時(shí)免費(fèi)