為了更好地揭示特征與目標變量之間的復雜非線性關系,在SHAP散點圖中引入了LOWESS擬合曲線,這條曲線是通過局部加權回歸法生成的,能夠平滑數據點之間的變化,幫助直觀地捕捉數據的趨勢走向
在對SHAP解釋圖的進一步優化中,特別添加了SHAP值為0時擬合曲線的交點標注。這一點非常重要,因為SHAP值為0時意味著該特征在該點附近對模型的預測結果沒有顯著影響。標注這一交點,可以幫助識別出特征值在哪些區間對目標變量的影響是從無到有或從正到負的轉變
代碼實現
數據集加載
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.unicode_minus'] = False
import warnings
warnings.filterwarnings("ignore")
df = pd.read_csv('Dataset.csv')
# 劃分特征和目標變量
X = df.drop(['target'], axis=1)
y = df['target']
# 劃分訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42, stratify=df['target'])
df.head()
首先,需要加載數據集并將其劃分為特征 X 和目標變量 y,然后進行訓練集和測試集的劃分。目標變量是我們要預測的值,X 是輸入的特征,這是一個分類任務,目標是預測患者是否患有心臟病。雖然是分類任務,但無論是分類問題還是回歸問題,SHAP 依賴圖的使用方式和原理是相同的,都可以用來解釋模型中各個特征對預測結果的貢獻
訓練機器學習模型
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV
# GBT模型參數
params_gbt = {
'learning_rate': 0.02, # 學習率,控制每一步的步長,用于防止過擬合。典型值范圍:0.01 - 0.1
'max_depth': 3, # 樹的深度,控制模型復雜度
'random_state': 42, # 隨機種子,用于重現模型的結果
'subsample': 0.7, # 每次迭代時隨機選擇的樣本比例,用于增加模型的泛化能力
}
# 初始化Gradient Boosting分類模型
model_gbt = GradientBoostingClassifier(**params_gbt)
# 定義參數網格,用于網格搜索
param_grid = {
'n_estimators': [100, 200, 300], # 樹的數量
'max_depth': [3, 4, 5], # 樹的深度
'learning_rate': [0.01, 0.1], # 學習率
}
# 使用GridSearchCV進行網格搜索和k折交叉驗證
grid_search = GridSearchCV(
estimator=model_gbt,
param_grid=param_grid,
scoring='neg_log_loss', # 評價指標為負對數損失
cv=5, # 5折交叉驗證
n_jobs=-1, # 并行計算
verbose=1 # 輸出詳細進度信息
)
# 訓練模型
grid_search.fit(X_train, y_train)
# 使用最優參數訓練模型
best_model = grid_search.best_estimator_
代碼通過網格搜索 (GridSearchCV) 對Gradient Boosting分類模型的超參數進行優化,并通過5折交叉驗證選擇出最優的模型參數
shap值計算整理
import shap
explainer = shap.TreeExplainer(best_model)
# 計算shap值為numpy.array數組
shap_values_numpy = explainer.shap_values(X)
shap_values_df = pd.DataFrame(shap_values_numpy, columns=X.columns)
shap_values_df.head()
計算模型的SHAP值,并將其轉換為DataFrame格式,方便后續進行自定義繪圖分析
基礎繪圖
# 繪制散點圖,x軸是'age'特征,y軸是SHAP值
plt.figure(figsize=(6, 4),dpi=1200)
plt.scatter(df['age'], shap_values_df['age'], s=10)
# 添加shap=0的橫線
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
plt.xlabel('Age', fontsize=12)
plt.ylabel('SHAP value for\nAge', fontsize=12)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.savefig("SHAP Dependence Plot_1.pdf", format='pdf',bbox_inches='tight')
plt.show()
繪制了一個基礎的SHAP依賴圖,其中x軸代表特征“age”(年齡),y軸代表該特征的SHAP值,即年齡對模型預測的影響大小,散點圖展示了不同年齡的SHAP值,黑色虛線表示SHAP值為0的基準線,表示在該點年齡對預測沒有顯著正負影響,此圖幫助直觀地理解特征“age”對模型預測結果的影響方向和程度,當然更具體的解釋參考文章——復現SCI文章 SHAP 依賴圖可視化以增強機器學習模型的可解釋性
通過擬合曲線與交點標注繪圖
import seaborn as sns
from scipy.optimize import fsolve
# 繪制散點圖
plt.figure(figsize=(8, 6), dpi=300)
plt.scatter(df['age'], shap_values_df['age'], s=20, label='SHAP values', alpha=0.7)
# 添加LOWESS擬合曲線
sns.regplot(x=df['age'], y=shap_values_df['age'], scatter=False, lowess=True, color='lightcoral', label='LOWESS Curve')
# 使用 LOWESS 數據生成擬合曲線
lowess_data = sns.regplot(x=df['age'], y=shap_values_df['age'], scatter=False, lowess=True, color='lightcoral')
line = lowess_data.get_lines()[0] # 擬合線條對象
x_fit = line.get_xdata() # LOWESS 擬合線的 x 軸數據
y_fit = line.get_ydata() # LOWESS 擬合線的 y 軸數據
# 找出所有與 y=0 相交的 x 值
def find_zero_crossings(x_fit, y_fit):
crossings = []
for i in range(1, len(y_fit)):
if (y_fit[i-1] < 0 and y_fit[i] > 0) or (y_fit[i-1] > 0 and y_fit[i] < 0):
# 使用插值法找到 x_fit 和 y_fit 中 y 值接近 0 的 x 值
crossing = fsolve(lambda x: np.interp(x, x_fit, y_fit), x_fit[i])[0]
crossings.append(crossing)
return crossings
x_intercepts = find_zero_crossings(x_fit, y_fit)
# 在圖中標注所有的 x_intercepts
for x_intercept in x_intercepts:
plt.axvline(x=x_intercept, color='blue', linestyle='--', label=f'Intersection at Age = {x_intercept:.2f}')
plt.text(x_intercept, 0.2, f'Age = {x_intercept:.2f}', color='blue', fontsize=10, verticalalignment='bottom')
# 添加shap=0的橫線
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1, label='SHAP = 0')
# 添加圖例
plt.legend()
# 設置標簽和標題
plt.xlabel('Age', fontsize=12)
plt.ylabel('SHAP value for\nAge', fontsize=12)
# 去除上和右的邊框線
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# 保存并顯示圖像
plt.savefig("SHAP Dependence Plot_with_Multiple_Intersections.pdf", format='pdf', bbox_inches='tight')
plt.show()
在這幅圖中,通過LOWESS擬合曲線和SHAP解釋圖來深入分析年齡(Age)對模型預測結果的影響。下面著重解釋擬合線與交點的含義:LOWESS擬合曲線LOWESS曲線(紅色曲線)是局部加權回歸曲線,它用來平滑數據中的非線性趨勢。在這幅圖中,它表示了年齡對目標變量的平均影響趨勢,從曲線中可以看出,隨著年齡的變化,SHAP值也隨之波動,通過這條擬合曲線,可以識別出不同年齡區間對模型預測的不同貢獻
交點
圖中用藍色虛線標注了兩個交點,分別表示SHAP值曲線與y=0的交點,這兩個交點表示特定年齡時,SHAP值為零,即在這些點上,年齡對模型的影響由負向或正向逐漸轉變
通過擬合曲線和交點,可以更直觀地理解特征“年齡”對模型預測結果的非線性影響,尤其是這些交點,它們揭示了年齡在特定區間中對目標變量的關鍵變化點,有助于理解模型如何處理年齡這個特征,以及如何做出更精準的解釋
多特征SHAP依賴圖:通過擬合曲線與交點分析模型特征貢獻
# 定義繪制 SHAP 依賴圖的函數
def plot_shap_dependence(feature_list, df, shap_values_df, file_name="SHAP_Dependence_Plots.pdf"):
fig, axs = plt.subplots(2, 3, figsize=(12, 8), dpi=1200)
plt.subplots_adjust(hspace=0.4, wspace=0.4)
# 循環繪制每個特征的 SHAP 依賴圖
for i, feature in enumerate(feature_list):
row = i // 3 # 行號
col = i % 3 # 列號
ax = axs[row, col]
# 繪制散點圖,x軸是特征值,y軸是SHAP值
ax.scatter(df[feature], shap_values_df[feature], s=10, alpha=0.7)
# 添加 LOWESS 擬合曲線
sns.regplot(x=df[feature], y=shap_values_df[feature], scatter=False, lowess=True, color='lightcoral', ax=ax)
# 使用 LOWESS 數據生成擬合曲線
lowess_data = sns.regplot(x=df[feature], y=shap_values_df[feature], scatter=False, lowess=True, color='lightcoral', ax=ax)
line = lowess_data.get_lines()[0] # 擬合線條對象
x_fit = line.get_xdata() # LOWESS 擬合線的 x 軸數據
y_fit = line.get_ydata() # LOWESS 擬合線的 y 軸數據
# 找出所有與 y=0 相交的 x 值
def find_zero_crossings(x_fit, y_fit):
crossings = []
for i in range(1, len(y_fit)):
if (y_fit[i-1] < 0 and y_fit[i] > 0) or (y_fit[i-1] > 0 and y_fit[i] < 0):
crossing = fsolve(lambda x: np.interp(x, x_fit, y_fit), x_fit[i])[0]
crossings.append(crossing)
return crossings
x_intercepts = find_zero_crossings(x_fit, y_fit)
# 在圖中標注所有的 x_intercepts
for x_intercept in x_intercepts:
ax.axvline(x=x_intercept, color='blue', linestyle='--') # 標注虛線
ax.text(x_intercept, 0.1, f'{x_intercept:.2f}', color='black', fontsize=10, verticalalignment='bottom') # 將文本標注顏色改為淡紅色
# 添加shap=0的橫線
ax.axhline(y=0, color='black', linestyle='-.', linewidth=1)
# 設置x和y軸標簽
ax.set_xlabel(feature, fontsize=12)
ax.set_ylabel(f'SHAP value for\n{feature}', fontsize=12)
# 隱藏頂部和右側的脊柱
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# 隱藏最后一個空圖表的坐標軸 (若畫布未關閉)
axs[1, 2].axis('off')
# 保存為 PDF 文件
plt.savefig(file_name, format='pdf', bbox_inches='tight')
plt.show()
# 使用函數繪制 age、trestbps、chol、thalach、oldpeak 的 shap 依賴圖
feature_list = ['age', 'trestbps', 'chol', 'thalach', 'oldpeak']
plot_shap_dependence(feature_list, df, shap_values_df)
通過繪制多個特征的SHAP依賴圖,結合LOWESS擬合曲線與交點標注,分析各特征對模型預測的影響,當然,也可以采用其他擬合曲線,而不僅限于LOWESS,這里主要是基于參考文獻中所使用的LOWESS擬合曲線進行分析,這里的解釋同前文Age解釋原理相同
本文章轉載微信公眾號@Python機器學習AI