
如何快速實(shí)現(xiàn)REST API集成以優(yōu)化業(yè)務(wù)流程
!pip install -q diffusers
# 導(dǎo)入所需要的包
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
# 輸出
Using device: cuda
此時(shí)會(huì)輸出運(yùn)行環(huán)境是GPU還是CPU
? ? ? ?MNIST數(shù)據(jù)集是一個(gè)小數(shù)據(jù)集,存儲(chǔ)的是0-9手寫(xiě)數(shù)字字體,每張圖片都28X28的灰度圖片,每個(gè)像素的取值范圍是[0,1],下面加載該數(shù)據(jù)集,并展示部分?jǐn)?shù)據(jù):
dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
# 輸出
Input shape: torch.Size([8, 1, 28, 28])
Labels: tensor([7, 8, 4, 2, 3, 6, 0, 2])
? ? ? ?所謂退化過(guò)程,其實(shí)就是對(duì)輸入數(shù)據(jù)加入噪聲的過(guò)程,由于MNIST數(shù)據(jù)集的像素范圍在[0,1],那么我們加入噪聲也需要保持在相同的范圍,這樣我們可以很容易的把輸入數(shù)據(jù)與噪聲進(jìn)行混合,代碼如下:
def corrupt(x, amount):
"""Corrupt the input x
by mixing it with noise according to amount
"""
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
return x*(1-amount) + noise*amount
接下來(lái),我們看一下逐步加噪的效果,代碼如下:
# Plotting the input data
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
# Adding noise
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)
# Plottinf the noised version
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');
從上圖可以看出,從左到右加入的噪聲逐步增多,當(dāng)噪聲量接近1時(shí),數(shù)據(jù)看起來(lái)像純粹的隨機(jī)噪聲。
? ? ? ?UNet模型與自編碼器有異曲同工之妙,UNet最初是用于完成醫(yī)學(xué)圖像中分割任務(wù)的,網(wǎng)絡(luò)結(jié)構(gòu)如下所示:
代碼如下:
class BasicUNet(nn.Module):
"""A minimal UNet implementation."""
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.down_layers = torch.nn.ModuleList([
nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
])
self.up_layers = torch.nn.ModuleList([
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 32, kernel_size=5, padding=2),
nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
])
self.act = nn.SiLU() # The activation function
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = self.act(l(x)) # Through the layer and the activation function
if i < 2: # For all but the third (final) down layer:
h.append(x) # Storing output for skip connection
x = self.downscale(x) # Downscale ready for the next layer
for i, l in enumerate(self.up_layers):
if i > 0: # For all except the first up layer
x = self.upscale(x) # Upscale
x += h.pop() # Fetching stored output (skip connection)
x = self.act(l(x)) # Through the layer and the activation function
return x
我們來(lái)檢驗(yàn)一下模型輸入輸出的shape變化是否符合預(yù)期,代碼如下:
net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape
# 輸出
torch.Size([8, 1, 28, 28])
再來(lái)看一下模型的參數(shù)量,代碼如下:
sum([p.numel() for p in net.parameters()])
# 輸出
309057
至此,已經(jīng)完成數(shù)據(jù)加載和UNet模型構(gòu)建,當(dāng)然UNet模型的結(jié)構(gòu)可以有不同的設(shè)計(jì)。
? ? ? ?擴(kuò)散模型應(yīng)該學(xué)習(xí)什么?其實(shí)有很多不同的目標(biāo),比如學(xué)習(xí)噪聲,我們先以一個(gè)簡(jiǎn)單的例子開(kāi)始,輸入數(shù)據(jù)為帶噪聲的MNIST數(shù)據(jù),擴(kuò)散模型應(yīng)該輸出對(duì)應(yīng)的最佳數(shù)字預(yù)測(cè),因此學(xué)習(xí)的目標(biāo)是預(yù)測(cè)值與真實(shí)值的MSE,訓(xùn)練代碼如下:
# Dataloader (you can mess with batch size)
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# How many runs through the data should we do?
n_epochs = 3
# Create the network
net = BasicUNet()
net.to(device)
# Our loss finction
loss_fn = nn.MSELoss()
# The optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
# Keeping a record of the losses for later viewing
losses = []
# The training loop
for epoch in range(n_epochs):
for x, y in train_dataloader:
# Get some data and prepare the corrupted version
x = x.to(device) # Data on the GPU
noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
noisy_x = corrupt(x, noise_amount) # Create our noisy x
# Get the model prediction
pred = net(noisy_x)
# Calculate the loss
loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?
# Backprop and update the params:
opt.zero_grad()
loss.backward()
opt.step()
# Store the loss for later
losses.append(loss.item())
# Print our the average of the loss values for this epoch:
avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')
# View the loss curve
plt.plot(losses)
plt.ylim(0, 0.1);
# 輸出
Finished epoch 0. Average loss for this epoch: 0.024689
Finished epoch 1. Average loss for this epoch: 0.019226
Finished epoch 2. Average loss for this epoch: 0.017939
訓(xùn)練過(guò)程的loss曲線如下圖所示:
我們選取一部分?jǐn)?shù)據(jù)來(lái)評(píng)估一下模型的預(yù)測(cè)效果,代碼如下:
#@markdown Visualizing model predictions on noisy inputs:
# Fetch some data
x, y = next(iter(train_dataloader))
x = x[:8] # Only using the first 8 for easy plotting
# Corrupt with a range of amounts
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)
# Get the model predictions
with torch.no_grad():
preds = net(noised_x.to(device)).detach().cpu()
# Plot
fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')
axs[1].set_title('Corrupted data')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')
axs[2].set_title('Network Predictions')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');
從上圖可以看出,對(duì)于噪聲量較低的輸入,模型的預(yù)測(cè)效果是很不錯(cuò)的,當(dāng)amount=1時(shí),模型的輸出接近整個(gè)數(shù)據(jù)集的均值,這正是擴(kuò)散模型的工作原理。
Note:我們的訓(xùn)練并不太充分,讀者可以嘗試不同的超參數(shù)來(lái)優(yōu)化模型。
文章轉(zhuǎn)自微信公眾號(hào)@ArronAI
對(duì)比大模型API的內(nèi)容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉(zhuǎn)化潛力
一鍵對(duì)比試用API 限時(shí)免費(fèi)