在介紹之前先簡要回顧一下現有的模型

Transformer:以其注意力機制而聞名,其中序列的任何部分都可以動態地與任何其他部分相互作用,特別是具有因果注意力機制的的Transformer,擅長處理序列中的單個元素。但是它們帶來了顯著的計算和內存成本,與序列長度的平方(L2)成比例。

循環神經網絡(rnn): rnn只考慮當前輸入和最后一個隱藏狀態,按順序更新隱藏狀態。這種方法允許它們潛在地處理無限序列長度和恒定的內存需求。但是rnn的簡單性是一個缺點,限制了它們記住長期依賴關系的能力。此外,rnn中的時間反向傳播(BPTT)是內存密集型的,并且可能遭受梯度消失或爆炸的影響,盡管有LSTM等創新部分結解決了這個問題。

State Space Models(S4):這些模型已經顯示出很好的特性。它們提供了一種平衡,比rnn更有效地捕獲遠程依賴關系,同時比transformer更高效地使用內存。

接下來Manba登場!

Mamba

選擇性狀態空間:Mamba建立在狀態空間模型的概念之上,但引入了一個新的變化。它利用選擇性狀態空間,支持跨長序列更高效和有效地捕獲相關信息。

線性時間復雜度:與Transformer不同,Mamba在序列長度方面以線性時間運行。這個屬性使得它特別適合涉及非常長的序列的任務,而傳統模型在這方面會遇到困難。

Mamba以其選擇性狀態空間的概念引入了傳統狀態空間模型的一個有趣的改進。這種方法稍微放松了標準狀態空間模型的嚴格狀態轉換,使其更具適應性和靈活性(有點類似于lstm)。并且Mamba保留了狀態空間模型的高效計算特性,使其能夠在一次掃描中執行整個序列的前向傳遞-這一特性更讓人想起Transformer

在訓練期間,Mamba的行為類似于Transformer,同時處理整個序列。而lstm必須一步一步地計算前向傳遞,即使所有輸入都是已知的。在推理中,Mamba的行為更符合傳統的循環模型,提供有效的序列處理。

先驗狀態空間模型(ssm)的一個關鍵限制是其剛性的、輸入不變的結構。這些模型為整個序列使用一組固定參數(我們稱它們為a和B)。這種結構甚至比lstm等模型更具限制性,在lstm中,信號的轉換可能依賴于先前的隱藏狀態和輸入。

Mamba則一種范式轉換,即如何計算向下一個隱藏狀態的過渡?在Mamba的體系結構中,轉換依賴于當前輸入,這種方法在傳統ssm的固定計算和循環神經網絡的輸入依賴動態性之間取得了平衡。

主要組成如下:

固定主干:從一個隱藏狀態到下一個隱藏狀態的轉換仍然是一個固定的計算(由a矩陣定義),允許跨序列的預計算。

輸入相關轉換:輸入影響下一個隱藏狀態(由B矩陣定義)的方式取決于當前輸入,而不是之前的隱藏狀態。與傳統ssm相比,這種輸入依賴性提供了更大的靈活性。

為了滿足這種方法的計算需求,Mamba使用了一種硬件感知算法。該算法使用掃描操作而不是卷積來循環執行計算,這樣在gpu上非常高效的。盡管輸入依賴轉換帶來了算法復雜性,但這種效率對于保持高性能至關重要。

Mamba和選擇性狀態空間模型不是同義詞。Mamba是一個使用選擇性狀態空間概念的實現。這種區別是至關重要的,因為它突出了Mamba的獨特貢獻:在保持計算效率的同時,使SSM框架更加靈活和響應輸入。

SRAM和HBM

gpu包含兩種主要類型的內存:HBM (High Bandwidth memory)和SRAM (Static Random-Access memory)。HBM雖然帶寬很高,但與更快但更小的SRAM相比,它的訪問時間相對較慢。Mamba則使用SRAM在矩陣乘法期間進行快速訪問,這是其計算的關鍵。

計算中的主要瓶頸通常不是計算本身,而是數據在內存類型之間的移動。Mamba通過顯著減少傳輸大量數據的需求來解決這個問題。它通過直接在SRAM中執行算法的關鍵部分(如離散化和遞歸計算)來實現,從而減少延遲。

還引入了一個融合選擇掃描層,使其內存需求與使用flash attention的優化Transformer實現相當。這一層對于保持效率至關重要,尤其是在處理模型中依賴于輸入的元素時。

結果

Mamba代表了序列建模的重大進步,特別是在其高效使用GPU內存和計算策略方面。它具有高效率處理長序列的能力,使其成為各種應用的有前途的模型,我們下面來使用Pytorch代碼來對其進復現。

Pytorch復現

導入基本庫

 import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
from einops import rearrange
from tqdm import tqdm

import math
import os
import urllib.request
from zipfile import ZipFile

from transformers import AutoTokenizer

torch.autograd.set_detect_anomaly(True)

設置標志和超參數

 # Configuration flags and hyperparameters
USE_MAMBA = 1
DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

定義超參數和初始化

 d_model = 8
state_size = 128 # Example state size
seq_len = 100 # Example sequence length
batch_size = 256 # Example batch size
last_batch_size = 81 # only for the very last batch of the dataset
current_batch_size = batch_size
different_batch_size = False
h_new = None
temp_buffer = None

這里的超參數,如模型維度(d_model)、狀態大小、序列長度和批大小。

S6模塊是Mamba架構中的一個復雜組件,負責通過一系列線性變換和離散化過程處理輸入序列。它在捕獲序列的時間動態方面起著關鍵作用,這是序列建模任務(如語言建模)的一個關鍵方面。這里包括張量運算和自定義離散化方法來處理序列數據的復雜需求。

class S6(nn.Module):
def __init__(self, seq_len, d_model, state_size, device):
super(S6, self).__init__()

self.fc1 = nn.Linear(d_model, d_model, device=device)
self.fc2 = nn.Linear(d_model, state_size, device=device)
self.fc3 = nn.Linear(d_model, state_size, device=device)

self.seq_len = seq_len
self.d_model = d_model
self.state_size = state_size

self.A = nn.Parameter(F.normalize(torch.ones(d_model, state_size, device=device), p=2, dim=-1))
nn.init.xavier_uniform_(self.A)

self.B = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
self.C = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)

self.delta = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)
self.dA = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
self.dB = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)

# h [batch_size, seq_len, d_model, state_size]
self.h = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
self.y = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)

def discretization(self):

self.dB = torch.einsum("bld,bln->bldn", self.delta, self.B)

self.dA = torch.exp(torch.einsum("bld,dn->bldn", self.delta, self.A))

return self.dA, self.dB

def forward(self, x):
# Algorithm 2 MAMBA paper
self.B = self.fc2(x)
self.C = self.fc3(x)
self.delta = F.softplus(self.fc1(x))

self.discretization()

if DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:

global current_batch_size
current_batch_size = x.shape[0]

if self.h.shape[0] != current_batch_size:
different_batch_size = True

h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h[:current_batch_size, ...]) + rearrange(x, "b l d -> b l d 1") * self.dB

else:
different_batch_size = False
h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h) + rearrange(x, "b l d -> b l d 1") * self.dB

# y [batch_size, seq_len, d_model]
self.y = torch.einsum('bln,bldn->bld', self.C, h_new)

global temp_buffer
temp_buffer = h_new.detach().clone() if not self.h.requires_grad else h_new.clone()

return self.y

else:
# h [batch_size, seq_len, d_model, state_size]
h = torch.zeros(x.size(0), self.seq_len, self.d_model, self.state_size, device=x.device)
y = torch.zeros_like(x)

h = torch.einsum('bldn,bldn->bldn', self.dA, h) + rearrange(x, "b l d -> b l d 1") * self.dB

# y [batch_size, seq_len, d_model]
y = torch.einsum('bln,bldn->bld', self.C, h)

return y

這個S6的模塊,可以處理離散化過程和正向傳播。

MambaBlock類是一個定制的神經網絡模塊,被設計為Mamba模型的關鍵構建塊。它封裝了幾個層和操作來處理輸入數據。

包括線性投影、卷積、激活函數、自定義S6模塊和殘差連接。該塊是Mamba模型的基本組件,負責通過一系列轉換處理輸入序列,以捕獲數據中的相關模式和特征。這些不同層和操作的組合允許MambaBlock有效地處理復雜的序列建模任務。MambaBlock是Mamba核心功能。

 class MambaBlock(nn.Module):
def __init__(self, seq_len, d_model, state_size, device):
super(MambaBlock, self).__init__()

self.inp_proj = nn.Linear(d_model, 2*d_model, device=device)
self.out_proj = nn.Linear(2*d_model, d_model, device=device)

# For residual skip connection
self.D = nn.Linear(d_model, 2*d_model, device=device)

# Set _no_weight_decay attribute on bias
self.out_proj.bias._no_weight_decay = True

# Initialize bias to a small constant value
nn.init.constant_(self.out_proj.bias, 1.0)

self.S6 = S6(seq_len, 2*d_model, state_size, device)

# Add 1D convolution with kernel size 3
self.conv = nn.Conv1d(seq_len, seq_len, kernel_size=3, padding=1, device=device)

# Add linear layer for conv output
self.conv_linear = nn.Linear(2*d_model, 2*d_model, device=device)

# rmsnorm
self.norm = RMSNorm(d_model, device=device)

def forward(self, x):
"""
x_proj.shape = torch.Size([batch_size, seq_len, 2*d_model])
x_conv.shape = torch.Size([batch_size, seq_len, 2*d_model])
x_conv_act.shape = torch.Size([batch_size, seq_len, 2*d_model])
"""
# Refer to Figure 3 in the MAMBA paper

x = self.norm(x)

x_proj = self.inp_proj(x)

# Add 1D convolution with kernel size 3
x_conv = self.conv(x_proj)

x_conv_act = F.silu(x_conv)

# Add linear layer for conv output
x_conv_out = self.conv_linear(x_conv_act)

x_ssm = self.S6(x_conv_out)
x_act = F.silu(x_ssm) # Swish activation can be implemented as x * sigmoid(x)

# residual skip connection with nonlinearity introduced by multiplication
x_residual = F.silu(self.D(x))

x_combined = x_act * x_residual

x_out = self.out_proj(x_combined)

return x_out

Mamba模型

包括一系列MambaBlock模塊。每個塊都順序處理輸入數據,一個塊的輸出作為下一個塊的輸入。這種順序處理允許模型捕獲輸入數據中的復雜模式和關系,使其對涉及順序建模的任務有效。多個塊的堆疊是深度學習架構中的常見設計,因為它使模型能夠學習數據的分層表示。

class Mamba(nn.Module):
def __init__(self, seq_len, d_model, state_size, device):
super(Mamba, self).__init__()
self.mamba_block1 = MambaBlock(seq_len, d_model, state_size, device)
self.mamba_block2 = MambaBlock(seq_len, d_model, state_size, device)
self.mamba_block3 = MambaBlock(seq_len, d_model, state_size, device)

def forward(self, x):
x = self.mamba_block1(x)
x = self.mamba_block2(x)
x = self.mamba_block3(x)
return x

RMSNorm是一個自定義規范化層,這一層用于規范神經網絡的激活,這可以幫助穩定和加快訓練。

 class RMSNorm(nn.Module):
def __init__(self,
d_model: int,
eps: float = 1e-5,
device: str ='cuda'):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model, device=device))

def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

return output

這一層的用法:

 x = torch.rand(batch_size, seq_len, d_model, device=device)
# Create the Mamba model
mamba = Mamba(seq_len, d_model, state_size, device)

# rmsnorm
norm = RMSNorm(d_model)
x = norm(x)

# Forward pass
test_output = mamba(x)
print(f"test_output.shape = {test_output.shape}") # Should be [batch_size, seq_len, d_model]

上面就是模型的全部基本代碼,下面就可以進行數據準備和訓練

我們自定義一個Enwiki8Dataset

 class Enwiki8Dataset(Dataset):
def __init__(self, data):
self.data = data

def __len__(self):
return len(self.data['input_ids'])

def __getitem__(self, idx):
item = {key: val[idx].clone().detach() for key, val in self.data.items()}
return item

pad_sequences_3d用于將一批序列填充到統一的長度,確保批中的每個序列具有相同數量的元素(或時間步長)。這在許多機器學習任務中尤其重要,因為輸入數據必須具有一致的形狀。

 # Define a function for padding
def pad_sequences_3d(sequences, max_len=None, pad_value=0):
# Assuming sequences is a tensor of shape (batch_size, seq_len, feature_size)
batch_size, seq_len, feature_size = sequences.shape

if max_len is None:
max_len = seq_len + 1

# Initialize padded_sequences with the pad_value
padded_sequences = torch.full((batch_size, max_len, feature_size), fill_value=pad_value, dtype=sequences.dtype, device=sequences.device)
# Pad each sequence to the max_len
padded_sequences[:, :seq_len, :] = sequences

return padded_sequences

訓練過程:

def train(model, tokenizer, data_loader, optimizer, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=False):
model.train()
total_loss = 0
for batch in data_loader:
optimizer.zero_grad()

input_data = batch['input_ids'].clone().to(device)
attention_mask = batch['attention_mask'].clone().to(device)

target = input_data[:, 1:]
input_data = input_data[:, :-1]

# Pad all the sequences in the batch:
input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)

if USE_MAMBA:
output = model(input_data)
loss = criterion(output, target)

loss.backward(retain_graph=True)

for name, param in model.named_parameters():
if 'out_proj.bias' not in name:
# clip weights but not bias for out_proj
torch.nn.utils.clip_grad_norm_(param, max_norm=max_grad_norm)

if DEBUGGING_IS_ON:
for name, parameter in model.named_parameters():
if parameter.grad is not None:
print(f"{name} gradient: {parameter.grad.data.norm(2)}")
else:
print(f"{name} has no gradient")

if USE_MAMBA and DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:
model.S6.h[:current_batch_size, ...].copy_(temp_buffer)

optimizer.step()

total_loss += loss.item()
return total_loss / len(data_loader)

評估函數:

def evaluate(model, data_loader, criterion, device):
model.eval()
total_loss = 0
with torch.no_grad():
for batch in data_loader:
input_data = batch['input_ids'].clone().detach().to(device)
attention_mask = batch['attention_mask'].clone().detach().to(device)

target = input_data[:, 1:]
input_data = input_data[:, :-1]

# Pad all the sequences in the batch:
input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)

if USE_MAMBA:
output = model(input_data)
loss = criterion(output, target)
total_loss += loss.item()
return total_loss / len(data_loader)

最后,calculate_perplexity用于評估語言模型(如Mamba)的性能。

def calculate_perplexity(loss):
return math.exp(loss)

load_enwiki8_dataset函數用于下載和提取enwiki8數據集,該數據集通常用于對語言模型進行基準測試。

 def load_enwiki8_dataset():
print(f"Download and extract enwiki8 data")
url = "http://mattmahoney.net/dc/enwik8.zip"
urllib.request.urlretrieve(url, "enwik8.zip")

with ZipFile("enwik8.zip") as f:
data = f.read("enwik8").decode("utf-8")

return data

encode_dataset函數設計用于標記和編碼數據集,為神經網絡模型(如Mamba)處理數據集做準備。

# Tokenize and encode the dataset
def encode_dataset(tokenizer, text_data):
def batch_encode(tokenizer, text_data, batch_size=1000):
# Tokenize in batches
batched_input_ids = []
for i in range(0, len(text_data), batch_size):
batch = text_data[i:i+batch_size]
inputs = tokenizer(batch, add_special_tokens=True, truncation=True,
padding='max_length', max_length=seq_len,
return_tensors='pt')
batched_input_ids.append(inputs['input_ids'])
return torch.cat(batched_input_ids)

# Assuming enwiki8_data is a list of sentences
input_ids = batch_encode(tokenizer, enwiki8_data)

# vocab_size is the number of unique tokens in the tokenizer's vocabulary
global vocab_size
vocab_size = len(tokenizer.vocab) # Note that for some tokenizers, we might access the vocab directly
print(f"vocab_size = {vocab_size}")

# Create an embedding layer
# embedding_dim is the size of the embedding vectors (MAMBA model's D)
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)

# Pass input_ids through the embedding layer # This will change input_ids from shape [B, L] to [B, L, D] def batch_embedding_calls(input_ids, embedding_layer, batch_size=256): # Check if input_ids is already a tensor, if not convert it if not isinstance(input_ids, torch.Tensor): input_ids = torch.tensor(input_ids, dtype=torch.long) # Calculate the number of batches needed num_batches = math.ceil(input_ids.size(0) / batch_size) # List to hold the output embeddings output_embeddings = [] # Process each batch for i in range(num_batches): # Calculate start and end indices for the current batch start_idx = i * batch_size end_idx = start_idx + batch_size # Get the batch input_id_batch = input_ids[start_idx:end_idx] # Call the embedding layer with torch.no_grad(): # No need gradients for this operation batch_embeddings = embedding_layer(input_id_batch) # Append the result to the list output_embeddings.append(batch_embeddings) # Concatenate the embeddings from each batch into a single tensor all_embeddings = torch.cat(output_embeddings, dim=0) return all_embeddings # input_ids is a list or tensor of the input IDs and embedding_layer is model's embedding layer if USE_MAMBA: # Set batch_size to a value that works for memory constraints encoded_inputs = batch_embedding_calls(input_ids, embedding_layer, batch_size=1).float() attention_mask = (input_ids != tokenizer.pad_token_id).type(input_ids.dtype) return encoded_inputs, attention_mask

下面就可以進行訓練了

 # Load a pretrained tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# Assuming encoded_inputs is a preprocessed tensor of shape [num_samples, seq_len, d_model]
encoded_inputs_file = 'encoded_inputs_mamba.pt'

if os.path.exists(encoded_inputs_file):
print("Loading pre-tokenized data...")
encoded_inputs = torch.load(encoded_inputs_file)
else:
print("Tokenizing raw data...")
enwiki8_data = load_enwiki8_dataset()
encoded_inputs, attention_mask = encode_dataset(tokenizer, enwiki8_data)
torch.save(encoded_inputs, encoded_inputs_file)
print(f"finished tokenizing data")

# Combine into a single dictionary
data = {
'input_ids': encoded_inputs,
'attention_mask': attention_mask
}

# Split the data into train and validation sets
total_size = len(data['input_ids'])
train_size = int(total_size * 0.8)

train_data = {key: val[:train_size] for key, val in data.items()}
val_data = {key: val[train_size:] for key, val in data.items()}

train_dataset = Enwiki8Dataset(train_data)
val_dataset = Enwiki8Dataset(val_data)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize the model

model = Mamba(seq_len, d_model, state_size, device).to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-6)

# Training loop
num_epochs = 25 # Number of epochs to train for

for epoch in tqdm(range(num_epochs)): # loop over the dataset multiple times
train_loss = train(model, tokenizer, train_loader, optimizer, criterion, device, max_grad_norm=10.0, DEBUGGING_IS_ON=False)
val_loss = evaluate(model, val_loader, criterion, device)
val_perplexity = calculate_perplexity(val_loss)
print(f'Epoch: {epoch+1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Perplexity: {val_perplexity:.4f}')

以上就是訓練的完整代碼

總結

我們介紹了Mamba的概念和架構,并且從頭開始構建Mamba復現,這樣可以將理論轉化為實踐。通過這種動手的方法,可以看到Mamba序列建模方法和效率。如果你想直接使用,可以看論文提供的代碼。

文章轉自微信公眾號@算法進階

上一篇:

Python數據挖掘算法入門與實踐

下一篇:

圖神經網絡(GNN)和神經網絡的關系
#你可能也喜歡這些API文章!

我們有何不同?

API服務商零注冊

多API并行試用

數據驅動選型,提升決策效率

查看全部API→
??

熱門場景實測,選對API

#AI文本生成大模型API

對比大模型API的內容創意新穎性、情感共鳴力、商業轉化潛力

25個渠道
一鍵對比試用API 限時免費

#AI深度推理大模型API

對比大模型API的邏輯推理準確性、分析深度、可視化建議合理性

10個渠道
一鍵對比試用API 限時免費