輸入門:決定當前時刻??添加多少新信息。

候選記憶單元:生成新的候選信息。

當前記憶單元:綜合遺忘信息和新輸入信息。

輸出門:決定輸出多少信息。

最終輸出

其中, 是sigmoid函數, 是逐元素乘積。

Attention機制

Attention機制在處理長序列時特別有用,因為它可以幫助模型關注序列中的重要部分。它通過計算不同時間步的加權和來突出不同的輸入數據對輸出的重要性。

Attention機制的核心公式:

其中,score函數可以是點積、雙線性或其他相似性測量方法。

加權和

最終,Attention機制的輸出是輸入時間步的加權和,能使模型更有效地關注重要的信息。

模型訓練

為了演示Attention-LSTM在時序預測中的應用,我們使用一個模擬的時序數據集進行預測。

Pytorch代碼實現

下面是使用PyTorch構建Attention-LSTM模型的代碼示例。我們使用一個簡單的正弦波數據集來說明。

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# 創建模擬數據
def create_sin_wave(seq_len, n_samples):
x = np.linspace(0, 50, n_samples)
data = np.sin(x)
return data

# Attention機制實現
class Attention(nn.Module):
def __init__(self, hidden_dim):
super(Attention, self).__init__()
self.hidden_dim = hidden_dim
self.attn = nn.Linear(hidden_dim, hidden_dim)
self.context = nn.Linear(hidden_dim, 1, bias=False)

def forward(self, hidden_states):
attn_weights = torch.tanh(self.attn(hidden_states))
attn_weights = self.context(attn_weights).squeeze(2)
attn_weights = torch.softmax(attn_weights, dim=1)
context_vector = torch.sum(attn_weights.unsqueeze(2) * hidden_states, dim=1)
return context_vector, attn_weights

# LSTM模型實現
class AttentionLSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, n_layers):
super(AttentionLSTM, self).__init__()
self.hidden_dim = hidden_dim
self.n_layers = n_layers
self.lstm = nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True)
self.attention = Attention(hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)

def forward(self, x):
h_0 = torch.zeros(self.n_layers, x.size(0), self.hidden_dim).to(x.device)
c_0 = torch.zeros(self.n_layers, x.size(0), self.hidden_dim).to(x.device)
out, _ = self.lstm(x, (h_0, c_0))
context_vector, attn_weights = self.attention(out)
out = self.fc(context_vector)
return out, attn_weights

# 生成數據集
seq_len = 20
n_samples = 1000
data = create_sin_wave(seq_len, n_samples)
data = torch.tensor(data, dtype=torch.float32).unsqueeze(1)

# 準備訓練集和測試集
def create_inout_sequences(data, seq_len):
inout_seq = []
L = len(data)
for i in range(L-seq_len):
train_seq = data[i:i+seq_len]
train_label = data[i+seq_len:i+seq_len+1]
inout_seq.append((train_seq, train_label))
return inout_seq

train_seq = create_inout_sequences(data, seq_len)

train_X = torch.stack([s[0] for s in train_seq])
train_Y = torch.stack([s[1] for s in train_seq])

# 訓練模型
input_dim = 1
hidden_dim = 64
output_dim = 1
n_layers = 2
n_epochs = 100
learning_rate = 0.001

model = AttentionLSTM(input_dim, hidden_dim, output_dim, n_layers)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model.train()
for epoch in range(n_epochs):
optimizer.zero_grad()
output, attn_weights = model(train_X)
loss = criterion(output, train_Y)
loss.backward()
optimizer.step()
if (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item():.4f}')

# 可視化結果
model.eval()
with torch.no_grad():
pred, attn_weights = model(train_X)

# 繪制實際值與預測值
plt.figure(figsize=(14, 7))
plt.plot(data.numpy(), label='True Data')
plt.plot(range(seq_len, seq_len + len(pred)), pred.numpy(), label='Predicted Data')
plt.xlabel('Time step')
plt.ylabel('Value')
plt.title('Attention-LSTM: True vs Predicted')
plt.legend()
plt.show()
# 繪制注意力權重
attn_weights = attn_weights.numpy()
plt.figure(figsize=(14, 7))
plt.imshow(attn_weights.T, aspect='auto', cmap='viridis')
plt.colorbar()
plt.xlabel('Time step')
plt.ylabel('Attention Weights')
plt.title('Attention Weights Distribution')
plt.show()

代碼說明

  1. 數據生成:我們創建了一個正弦波數據集,用于模擬時序數據。
  2. Attention機制實現:定義了一個Attention類,用于計算注意力權重和上下文向量。
  3. LSTM模型實現:定義了一個AttentionLSTM類,該類包含LSTM層、Attention層和全連接層。
  4. 訓練模型:通過梯度下降訓練模型。
  5. 可視化結果:繪制預測值和實際值的對比圖,以及注意力權重的分布圖。

Attention機制能夠增強LSTM模型的性能,使其在處理長時間依賴關系時更加有效。通過引入Attention機制,模型能夠自動關注時序數據中重要的時間步,從而提高預測的準確性。

這種Attention-LSTM模型在金融預測、氣象分析、醫療診斷等領域都有廣泛的應用潛力。大家在未來的研究可以探索不同類型的注意力機制以及其在不同應用場景中的效果。

本文章轉載微信公眾號@深夜努力寫Python

上一篇:

機器學習中的數據歸一化:提升模型性能與收斂速度的關鍵步驟

下一篇:

突破最強算法模型,回歸算法!!!
#你可能也喜歡這些API文章!

我們有何不同?

API服務商零注冊

多API并行試用

數據驅動選型,提升決策效率

查看全部API→
??

熱門場景實測,選對API

#AI文本生成大模型API

對比大模型API的內容創意新穎性、情感共鳴力、商業轉化潛力

25個渠道
一鍵對比試用API 限時免費

#AI深度推理大模型API

對比大模型API的邏輯推理準確性、分析深度、可視化建議合理性

10個渠道
一鍵對比試用API 限時免費