接下來(lái)作者將利用KAN進(jìn)行對(duì)鳶尾花的分類(lèi)實(shí)現(xiàn),體現(xiàn)它相對(duì)于MLP無(wú)法比擬的可解釋性、交互性特點(diǎn),當(dāng)然KAN也有其缺點(diǎn)就是目前版本訓(xùn)練速度較慢
import pandas as pd
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import numpy as np
iris = load_iris()
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df['target'] = iris.target
df = iris_df[iris_df['target'] != 2] # 只要0和1完成一個(gè)二分類(lèi)問(wèn)題
df.head()
這里將數(shù)據(jù)簡(jiǎn)單梳理為二分類(lèi)問(wèn)題,并且將這個(gè)分類(lèi)問(wèn)題看作回歸問(wèn)題,去探討不同輸出維度下的KAN
from sklearn.model_selection import train_test_split
import torch
train_input, test_input, train_label, test_label = train_test_split(df.iloc[:, 0:4], df['target'],
test_size=0.2, random_state=42, stratify=df['target'])
# 將 DataFrame 和 Series 轉(zhuǎn)換為 np.array
train_input = train_input.to_numpy()
test_input = test_input.to_numpy()
train_label = train_label.to_numpy()
test_label = test_label.to_numpy()
# 轉(zhuǎn)換為pytorch張量
dataset = {}
dataset['train_input'] = torch.from_numpy(train_input)
dataset['test_input'] = torch.from_numpy(test_input)
dataset['train_label'] = torch.from_numpy(train_label[:,None])
dataset['test_label'] = torch.from_numpy(test_label[:,None])
分割數(shù)據(jù)集且將原始的 DataFrame 數(shù)據(jù)轉(zhuǎn)換為適合在 PyTorch 中使用的張量形式
from kan import KAN
model = KAN(width=[4,1], grid=3, k=3)
# 初始化繪制KAN
model(dataset['train_input']);
model.plot(beta=100)
這里創(chuàng)建一個(gè)KAN:4D輸入,1D輸出,沒(méi)有隱藏的神經(jīng)元,三次樣條 (k=3),3個(gè)網(wǎng)格間隔 (grid=3),如果要添加隱藏的神經(jīng)元在width中添加既可,它表示每層中的神經(jīng)元數(shù),例如,[2,5,5,3] 表示 2D 輸入,3D 輸出,具有 2 層 5 個(gè)隱藏神經(jīng)元,創(chuàng)建這樣的模型后對(duì)其可視化,當(dāng)前這個(gè)模型還沒(méi)有進(jìn)行訓(xùn)練,接下來(lái)訓(xùn)練這個(gè)模型
# 定義訓(xùn)練集準(zhǔn)確率計(jì)算函數(shù)
def train_acc():
# 使用模型對(duì)訓(xùn)練輸入進(jìn)行預(yù)測(cè),取預(yù)測(cè)值的第一個(gè)輸出并四舍五入
# 將預(yù)測(cè)值與訓(xùn)練標(biāo)簽進(jìn)行比較,計(jì)算準(zhǔn)確率
return torch.mean((torch.round(model(dataset['train_input'])[:, 0]) == dataset['train_label'][:, 0]).float())
# 定義測(cè)試集準(zhǔn)確率計(jì)算函數(shù)
def test_acc():
# 使用模型對(duì)測(cè)試輸入進(jìn)行預(yù)測(cè),取預(yù)測(cè)值的第一個(gè)輸出并四舍五入
# 將預(yù)測(cè)值與測(cè)試標(biāo)簽進(jìn)行比較,計(jì)算準(zhǔn)確率
return torch.mean((torch.round(model(dataset['test_input'])[:, 0]) == dataset['test_label'][:, 0]).float())
# 訓(xùn)練模型,使用LBFGS優(yōu)化器,訓(xùn)練20步,計(jì)算訓(xùn)練和測(cè)試集的準(zhǔn)確率
results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc))
model.plot()
定義了兩個(gè)函數(shù) train_acc() 和 test_acc() 分別用于計(jì)算訓(xùn)練集和測(cè)試集上模型的準(zhǔn)確率,然后使用 LBFGS 優(yōu)化器對(duì)模型進(jìn)行訓(xùn)練,訓(xùn)練步數(shù)為 20 步,并同時(shí)計(jì)算并輸出訓(xùn)練和測(cè)試集的準(zhǔn)確率,最后對(duì)模型進(jìn)行可視化,對(duì)比模型初始可視化可以發(fā)現(xiàn)激活函數(shù)明顯不一樣了,這就是KAN對(duì)激活函數(shù)學(xué)習(xí)的一個(gè)結(jié)果,接下來(lái)我們把這個(gè)模型進(jìn)行解釋性輸出
lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','tan','abs']
model.auto_symbolic(lib=lib)
formula = model.symbolic_formula()[0][0]
formula
可以發(fā)現(xiàn)KAN模型相對(duì)其其它深度學(xué)習(xí)框架,它可以輸出一個(gè)具體的公式,當(dāng)然這個(gè)KAN是單輸出所以只有一個(gè)公式,通過(guò)這個(gè)公式它不在是一個(gè)黑箱模型,而是可以被我們所解釋的模型,實(shí)際上把相應(yīng)的X值輸入公式并進(jìn)行四舍五入返回的值就是0或1也就是我們的實(shí)際類(lèi)別,接下來(lái)通過(guò)這個(gè)公式來(lái)輸出在訓(xùn)練集、測(cè)試集上的模型精確度
def acc(formula, X, y):
batch = X.shape[0] # 獲取批量大小
correct = 0 # 初始化正確預(yù)測(cè)的數(shù)量
for i in range(batch):
# 構(gòu)建替換字典,將 x_1, x_2, x_3, x_4 替換為當(dāng)前樣本的值
subs_dict = {'x_1': X[i, 0], 'x_2': X[i, 1], 'x_3': X[i, 2], 'x_4': X[i, 3]}
# 使用給定的公式對(duì)當(dāng)前樣本進(jìn)行預(yù)測(cè),并將結(jié)果轉(zhuǎn)換為浮點(diǎn)數(shù)
prediction = float(formula.subs(subs_dict))
# 四舍五入預(yù)測(cè)值,與真實(shí)標(biāo)簽進(jìn)行比較
if np.round(prediction) == y[i, 0]:
correct += 1
# 計(jì)算準(zhǔn)確率
accuracy = correct / batch
return accuracy
# 計(jì)算訓(xùn)練集和測(cè)試集的準(zhǔn)確率
train_accuracy = acc(formula, dataset['train_input'], dataset['train_label'])
test_accuracy = acc(formula, dataset['test_input'], dataset['test_label'])
print('train acc of the formula:', train_accuracy)
print('test acc of the formula:', test_accuracy)
通過(guò)準(zhǔn)確率可知這個(gè)單輸出的二分類(lèi)KAN模型,表現(xiàn)的很好只是在訓(xùn)練集上出現(xiàn)了一點(diǎn)錯(cuò)誤,接下來(lái)我們重新去構(gòu)建一個(gè)二輸出的KAN模型
dataset = {}
dataset['train_input'] = torch.from_numpy(train_input)
dataset['test_input'] = torch.from_numpy(test_input)
dataset['train_label'] = torch.from_numpy(train_label).type(torch.long)
dataset['test_label'] = torch.from_numpy(test_label).type(torch.long)
model = KAN(width=[4,2], grid=3, k=3)
model(dataset['train_input']);
model.plot(beta=100)
這個(gè)模型相對(duì)于第一個(gè)模型只去修改了它的輸出維數(shù)為二,同樣還是把它看作是一個(gè)回歸模型
def train_acc():
return torch.mean((torch.argmax(model(dataset['train_input']), dim=1) == dataset['train_label']).float())
def test_acc():
return torch.mean((torch.argmax(model(dataset['test_input']), dim=1) == dataset['test_label']).float())
results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc), loss_fn=torch.nn.CrossEntropyLoss())
model.plot()
同樣是對(duì)激活函數(shù)進(jìn)行學(xué)習(xí),并可視化
lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
model.auto_symbolic(lib=lib)
formula1, formula2 = model.symbolic_formula()[0]
formula1
formula2
這是一個(gè)輸出維數(shù)為二的KAN模型相應(yīng)的它的輸出都有與它一一對(duì)應(yīng)的數(shù)學(xué)公式來(lái)進(jìn)行解釋
def acc(formula1, formula2, X, y):
batch = X.shape[0]
correct = 0
for i in range(batch):
logit1 = np.array(formula1.subs('x_1', X[i,0]).subs('x_2', X[i,1]).subs('x_3', X[i,2]).subs('x_4', X[i,3])).astype(np.float64)
logit2 = np.array(formula2.subs('x_1', X[i,0]).subs('x_2', X[i,1]).subs('x_3', X[i,2]).subs('x_4', X[i,3])).astype(np.float64)
correct += (logit2 > logit1) == y[i]
return correct/batch
print('train acc of the formula:', acc(formula1, formula2, dataset['train_input'], dataset['train_label']))
print('test acc of the formula:', acc(formula1, formula2, dataset['test_input'], dataset['test_label']))
相應(yīng)的計(jì)算這個(gè)KAN模型的準(zhǔn)確率,可以發(fā)現(xiàn)這個(gè)輸出維數(shù)為二的KAN比輸出維度為一的KAN要好,這個(gè)KAN模型在這個(gè)數(shù)據(jù)集上百分比預(yù)測(cè)正確,這里利用的是預(yù)測(cè)結(jié)果(即 logit2 > logit1 的布爾值)與真實(shí)標(biāo)簽 y[i] 相等,則返回 True(1),否則返回 False(0),來(lái)進(jìn)行準(zhǔn)確率計(jì)算,到這里就完成了這個(gè)分類(lèi)模型的構(gòu)建,讀者也可以嘗試對(duì)所有數(shù)據(jù)集進(jìn)行三分類(lèi)KAN構(gòu)建,下面是作者對(duì)完整鳶尾花數(shù)據(jù)進(jìn)行構(gòu)建的KAN模型可視化
本文章轉(zhuǎn)載微信公眾號(hào)@Python機(jī)器學(xué)習(xí)AI
對(duì)比大模型API的內(nèi)容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉(zhuǎn)化潛力
一鍵對(duì)比試用API 限時(shí)免費(fèi)