本文將通過醫學數據,使用 Python 演示如何復現 SHAP 依賴圖,并詳細解釋連續性特征對模型預測結果的影響
SHAP 依賴圖用于可視化單個特征對機器學習模型預測結果的影響,具體來說,x 軸是特征值,y 軸是 SHAP 值(度量特征對預測結果的重要性),這些圖可以直觀地顯示出某個特征是對模型預測起正向還是負向作用
代碼實現
數據集加載
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.unicode_minus'] = False
import warnings
warnings.filterwarnings("ignore")
df = pd.read_csv('Dataset.csv')
# 劃分特征和目標變量
X = df.drop(['target'], axis=1)
y = df['target']
# 劃分訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42, stratify=df['target'])
df.head()
首先,需要加載數據集并將其劃分為特征 X 和目標變量 y,然后進行訓練集和測試集的劃分。目標變量是我們要預測的值,X 是輸入的特征,這是一個分類任務,目標是預測患者是否患有心臟病。雖然是分類任務,但無論是分類問題還是回歸問題,SHAP 依賴圖的使用方式和原理是相同的,都可以用來解釋模型中各個特征對預測結果的貢獻
訓練機器學習模型
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV
# GBT模型參數
params_gbt = {
'learning_rate': 0.02, # 學習率,控制每一步的步長,用于防止過擬合。典型值范圍:0.01 - 0.1
'max_depth': 3, # 樹的深度,控制模型復雜度
'random_state': 42, # 隨機種子,用于重現模型的結果
'subsample': 0.7, # 每次迭代時隨機選擇的樣本比例,用于增加模型的泛化能力
}
# 初始化Gradient Boosting分類模型
model_gbt = GradientBoostingClassifier(**params_gbt)
# 定義參數網格,用于網格搜索
param_grid = {
'n_estimators': [100, 200, 300], # 樹的數量
'max_depth': [3, 4, 5], # 樹的深度
'learning_rate': [0.01, 0.1], # 學習率
}
# 使用GridSearchCV進行網格搜索和k折交叉驗證
grid_search = GridSearchCV(
estimator=model_gbt,
param_grid=param_grid,
scoring='neg_log_loss', # 評價指標為負對數損失
cv=5, # 5折交叉驗證
n_jobs=-1, # 并行計算
verbose=1 # 輸出詳細進度信息
)
# 訓練模型
grid_search.fit(X_train, y_train)
# 使用最優參數訓練模型
best_model = grid_search.best_estimator_
這里使用了梯度提升樹(GBT),這是一個強大且常用的機器學習算法,通過網格搜索進行參數優化
計算 SHAP 值
import shap
explainer = shap.TreeExplainer(best_model)
# 計算shap值為numpy.array數組
shap_values_numpy = explainer.shap_values(X)
# 計算shap值為Explanation格式
shap_values_Explanation = explainer(X)
模型訓練完畢后,可以使用 shap 包來計算 SHAP 值,SHAP 值用于衡量特定特征對模型輸出的影響,這里分別通過 explainer.shap_values(X) 計算 SHAP 值為數組格式以便自定義繪制,和通過 explainer(X) 計算為 Explanation 格式,直接使用 SHAP 自帶的繪圖函數進行可視化
默認參數下繪制
# 繪制 'age' 特征的SHAP依賴圖
shap.dependence_plot('age', shap_values_Explanation.values, X, show=False)
plt.savefig("SHAP Dependence Plot_1.pdf", format='pdf',bbox_inches='tight',dpi=1200)
圖展示了 age(年齡) 特征對模型預測結果的 SHAP 值的依賴關系,說明不同年齡段如何影響模型的預測
從圖中可以看到:
展示了年齡對模型預測的非線性影響,同時揭示了另一個特征(thal)如何與年齡共同作用,影響預測結果,然而,與文獻中的圖表樣式相比,仍存在一些細微的差別繪制無顏色條的年齡 SHAP 依賴圖
# 繪制 'age' 特征的 SHAP 依賴圖,不顯示顏色條
shap.dependence_plot('age', shap_values_Explanation.values, X, interaction_index=None, show=False)
# 添加 SHAP=0 的橫線
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
plt.savefig("SHAP Dependence Plot_2.pdf", format='pdf',bbox_inches='tight',dpi=1200)
plt.show()
在這里,通過設置 interaction_index=None 可以關閉顏色條,不顯示交互特征的影響。不過,該函數目前沒有內置參數可以直接在 SHAP 值為 0 的位置添加一條橫線。為了實現這一功能,可以利用 matplotlib 的 plt.axhline() 方法,在繪制依賴圖后手動添加橫線
接下來,還可以通過 explainer.shap_values(X) 格式繪制這個shap依賴圖,以便實現自定義繪圖
將 SHAP 值轉換為 DataFrame 格式以便于自定義繪圖
shap_values_df = pd.DataFrame(shap_values_numpy, columns=X.columns)
shap_values_df.head()
單個shap依賴圖繪制
# 繪制散點圖,x軸是'age'特征,y軸是SHAP值
plt.figure(figsize=(6, 4),dpi=1200)
plt.scatter(df['age'], shap_values_df['age'], s=10)
# 添加shap=0的橫線
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
plt.xlabel('Age', fontsize=12)
plt.ylabel('SHAP value for\nAge', fontsize=12)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.savefig("SHAP Dependence Plot_3.pdf", format='pdf',bbox_inches='tight')
plt.show()
代碼生成一個 SHAP 值依賴圖,其中展示了特征 age 對模型輸出的貢獻,同時對圖表進行了一些格式上的優化,比如隱藏不必要的邊框線條、在 SHAP=0 處添加一條基準線,并最終將圖像保存為高分辨率的 PDF 文件。相比于直接使用 shap.dependence_plot() 的默認作圖方式,這種方法提供了更高的靈活性,特別是在定制化繪圖方面,可以根據不同場景、需求對圖表進行高度定制,從而提高可視化的效果和表達的準確性
多個shap依賴圖繪制
# 定義繪制 SHAP 依賴圖的函數
def plot_shap_dependence(feature_list, df, shap_values_df, file_name="SHAP_Dependence_Plots.pdf"):
fig, axs = plt.subplots(2, 3, figsize=(12, 8), dpi=1200)
plt.subplots_adjust(hspace=0.4, wspace=0.4)
# 循環繪制每個特征的 SHAP 依賴圖
for i, feature in enumerate(feature_list):
row = i // 3 # 行號
col = i % 3 # 列號
ax = axs[row, col]
# 繪制散點圖,x軸是特征值,y軸是SHAP值
ax.scatter(df[feature], shap_values_df[feature], s=10)
# 添加shap=0的橫線
ax.axhline(y=0, color='black', linestyle='-.', linewidth=1)
# 設置x和y軸標簽
ax.set_xlabel(feature, fontsize=12)
ax.set_ylabel(f'SHAP value for\n{feature}', fontsize=12)
# 隱藏頂部和右側的脊柱
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# 隱藏最后一個空圖表的坐標軸 (若畫布未關閉)
axs[1, 2].axis('off')
plt.savefig(file_name, format='pdf', bbox_inches='tight')
plt.show()
# 使用函數繪制age、trestbps、chol、thalach、oldpeak的shap依賴圖
feature_list = ['age', 'trestbps', 'chol', 'thalach', 'oldpeak']
plot_shap_dependence(feature_list, df, shap_values_df)
這段代碼定義一個函數 plot_shap_dependence,用于繪制給定特征列表的 SHAP 依賴圖,生成 2 行 3 列的圖表布局,并在 SHAP=0 處添加基準線,最后保存為高分辨率 PDF,該圖的樣式基本上與文獻中的 SHAP 依賴圖形式一致,包括散點圖、SHAP 值為 0 的基準線、去掉頂部和右側脊柱的簡潔圖形設計等
本文章轉載微信公眾號@Python機器學習AI