import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.unicode_minus'] = False
import warnings
warnings.filterwarnings("ignore")
df = pd.read_csv('Dataset.csv')
# 劃分特征和目標(biāo)變量
X = df.drop(['target'], axis=1)
y = df['target']
# 劃分訓(xùn)練集和測(cè)試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42, stratify=df['target'])
df.head()

導(dǎo)入數(shù)據(jù)并預(yù)處理:讀取數(shù)據(jù)集,提取特征和目標(biāo)變量、劃分訓(xùn)練集和測(cè)試集:將數(shù)據(jù)集拆分為訓(xùn)練集和測(cè)試集,便于后續(xù)模型的訓(xùn)練和評(píng)估、該數(shù)據(jù)集中使用了14個(gè)特征變量,這些特征包括了患者的基本信息和多項(xiàng)醫(yī)療檢測(cè)指標(biāo),以下是每個(gè)特征的詳細(xì)說(shuō)明:

age(年齡):患者的年齡,單位為年

sex(性別):患者性別,1代表男性,0代表女性

cp(胸痛類型):患者的胸痛類型,有4種可能取值:1:典型心絞痛、2:非典型心絞痛、3:非心絞痛、4:無(wú)癥狀

trestbps(靜息血壓):患者入院時(shí)的靜息血壓,單位為mm Hg

chol(膽固醇):血清膽固醇值,單位為mg/dl

fbs(空腹血糖):空腹血糖值是否大于120 mg/dl,1為真,0為假

restecg(靜息心電圖結(jié)果):靜息心電圖的結(jié)果,有3個(gè)可能取值:0:正常、1:存在ST-T波異常(T波反轉(zhuǎn)或ST段升高或降低超過(guò)0.05 mV)、2:顯示可能或確定的左心室肥大

thalach(最大心率):運(yùn)動(dòng)測(cè)試中達(dá)到的最大心率

exang(運(yùn)動(dòng)誘發(fā)型心絞痛):是否在運(yùn)動(dòng)時(shí)誘發(fā)心絞痛,1為是,0為否

oldpeak(運(yùn)動(dòng)引發(fā)的ST段抑制值):與靜息時(shí)相比的運(yùn)動(dòng)引發(fā)的ST段抑制

slope(ST段峰值的斜率):運(yùn)動(dòng)高峰時(shí)ST段的斜率,有3種取值:1:上升、2:平坦、3:下降

ca(主要血管的數(shù)量):使用熒光鏡檢查染色的主要血管數(shù)量,取值為0到3

thal(地中海貧血癥):血液中的地中海貧血類型,有3個(gè)可能取值:0:正常、1:固定缺陷、2:可逆缺陷

num(診斷結(jié)果):目標(biāo)變量,表示是否診斷出心臟病,0為小于50%直徑縮小(無(wú)心臟病),1為大于50%直徑縮小(有心臟病)

模型構(gòu)建

import xgboost as xgb
from sklearn.model_selection import GridSearchCV

# XGBoost模型參數(shù)
params_xgb = {
'learning_rate': 0.02, # 學(xué)習(xí)率,控制每一步的步長(zhǎng),用于防止過(guò)擬合。典型值范圍:0.01 - 0.1
'booster': 'gbtree', # 提升方法,這里使用梯度提升樹(shù)(Gradient Boosting Tree)
'objective': 'binary:logistic', # 損失函數(shù),這里使用邏輯回歸,用于二分類任務(wù)
'max_leaves': 127, # 每棵樹(shù)的葉子節(jié)點(diǎn)數(shù)量,控制模型復(fù)雜度。較大值可以提高模型復(fù)雜度但可能導(dǎo)致過(guò)擬合
'verbosity': 1, # 控制 XGBoost 輸出信息的詳細(xì)程度,0表示無(wú)輸出,1表示輸出進(jìn)度信息
'seed': 42, # 隨機(jī)種子,用于重現(xiàn)模型的結(jié)果
'nthread': -1, # 并行運(yùn)算的線程數(shù)量,-1表示使用所有可用的CPU核心
'colsample_bytree': 0.6, # 每棵樹(shù)隨機(jī)選擇的特征比例,用于增加模型的泛化能力
'subsample': 0.7, # 每次迭代時(shí)隨機(jī)選擇的樣本比例,用于增加模型的泛化能力
'eval_metric': 'logloss' # 評(píng)價(jià)指標(biāo),這里使用對(duì)數(shù)損失(logloss)
}

# 初始化XGBoost分類模型
model_xgb = xgb.XGBClassifier(**params_xgb)

# 定義參數(shù)網(wǎng)格,用于網(wǎng)格搜索
param_grid = {
'n_estimators': [100, 200, 300, 400, 500], # 樹(shù)的數(shù)量
'max_depth': [3, 4, 5, 6, 7], # 樹(shù)的深度
'learning_rate': [0.01, 0.02, 0.05, 0.1], # 學(xué)習(xí)率
}

# 使用GridSearchCV進(jìn)行網(wǎng)格搜索和k折交叉驗(yàn)證
grid_search = GridSearchCV(
estimator=model_xgb,
param_grid=param_grid,
scoring='neg_log_loss', # 評(píng)價(jià)指標(biāo)為負(fù)對(duì)數(shù)損失
cv=5, # 5折交叉驗(yàn)證
n_jobs=-1, # 并行計(jì)算
verbose=1 # 輸出詳細(xì)進(jìn)度信息
)

# 訓(xùn)練模型
grid_search.fit(X_train, y_train)

# 輸出最優(yōu)參數(shù)
print("Best parameters found: ", grid_search.best_params_)
print("Best Log Loss score: ", -grid_search.best_score_)

# 使用最優(yōu)參數(shù)訓(xùn)練模型
best_model = grid_search.best_estimator_

使用XGBoost分類器構(gòu)建一個(gè)二分類模型,并通過(guò)網(wǎng)格搜索和5折交叉驗(yàn)證(GridSearchCV)來(lái)優(yōu)化模型參數(shù),首先,定義了XGBoost模型的初始參數(shù)(如學(xué)習(xí)率、提升方法、樹(shù)的葉子節(jié)點(diǎn)數(shù)等)和參數(shù)搜索網(wǎng)格(如樹(shù)的數(shù)量、深度、學(xué)習(xí)率等),然后,通過(guò)網(wǎng)格搜索在不同的參數(shù)組合下訓(xùn)練模型,目標(biāo)是最小化負(fù)對(duì)數(shù)損失(logloss)作為評(píng)估指標(biāo),最后,輸出最優(yōu)的參數(shù)組合,并使用該參數(shù)重新訓(xùn)練模型

SHAP值計(jì)算

import shap
explainer = shap.TreeExplainer(best_model)
# 計(jì)算shap值為numpy.array數(shù)組
shap_values_numpy = explainer.shap_values(X)
# 計(jì)算shap值為Explanation格式
shap_values_Explanation = explainer(X)

使用shap.TreeExplainer創(chuàng)建一個(gè)解釋器對(duì)象explainer,它用于解釋基于樹(shù)模型(如XGBoost、決策樹(shù)等)的預(yù)測(cè),分別計(jì)算SHAP值并返回為numpy數(shù)組格式和更復(fù)雜的Explanation格式,后者適合更高級(jí)的分析和圖形展示

XGBoost模型特征重要性可視化

# 獲取XGBoost模型的特征貢獻(xiàn)度(重要性)
feature_importances = best_model.feature_importances_
# 將特征和其重要性一起排序
sorted_indices = np.argsort(feature_importances)[::-1] # 逆序排列,重要性從高到低
sorted_features = X_train.columns[sorted_indices]
sorted_importances = feature_importances[sorted_indices]
# 繪制按重要性排序的特征貢獻(xiàn)性柱狀圖
plt.figure(figsize=(10, 6), dpi=1200)
plt.barh(sorted_features, sorted_importances, color='steelblue')
plt.xlabel('Importance', fontsize=14)
plt.ylabel('Features', fontsize=14)
plt.title('Sorted Feature Importance', fontsize=16)
plt.gca().invert_yaxis()
plt.savefig("Sorted Feature Importance.pdf", format='pdf',bbox_inches='tight')
# 顯示圖表
plt.show()

通過(guò)提取XGBoost模型的內(nèi)置特征重要性,生成并展示了基于模型自身計(jì)算的特征排名,幫助理解每個(gè)特征對(duì)模型預(yù)測(cè)結(jié)果的貢獻(xiàn);與后續(xù)基于SHAP值的特征排名不同,它僅反映模型的分裂結(jié)構(gòu)

基于SHAP numpy數(shù)組格式的特征重要性總結(jié)圖

# 繪制SHAP值總結(jié)圖(Summary Plot)
plt.figure(figsize=(10, 5), dpi=1200)
shap.summary_plot(shap_values_numpy, X, plot_type="bar", show=False)
plt.title('SHAP_numpy Sorted Feature Importance')
plt.savefig("SHAP_numpy Sorted Feature Importance.pdf", format='pdf',bbox_inches='tight')
plt.tight_layout()
plt.show()

使用shap_values_numpy(即基于SHAP值的numpy數(shù)組格式)繪制特征重要性總結(jié)圖,展示各個(gè)特征在整體模型預(yù)測(cè)中的貢獻(xiàn),細(xì)心的讀者可發(fā)現(xiàn)兩種方法的排名不一致,這是因?yàn)閄GBoost基于樹(shù)分裂計(jì)算特征重要性,而SHAP值根據(jù)每個(gè)特征對(duì)個(gè)體預(yù)測(cè)的影響進(jìn)行評(píng)估,因此兩者方法不同導(dǎo)致排名差異

基于SHAP Explanation格式的特征重要性條形圖

plt.figure(figsize=(10, 5), dpi=1200)
shap.plots.bar(shap_values_Explanation, show=False)
plt.title('SHAP_Explanation Sorted Feature Importance')
plt.savefig("SHAP_Explanation Sorted Feature Importance.pdf", format='pdf',bbox_inches='tight')
plt.tight_layout()
plt.show()

使用shap_values_Explanation(即SHAP值的Explanation格式)繪制特征重要性條形圖,展示每個(gè)特征對(duì)模型預(yù)測(cè)的貢獻(xiàn)大小,并保存為PDF文件,該格式提供了更豐富的特征信息,便于更深入的分析和可視化,

這張圖展示了部分特征的SHAP值重要性,但并未完全展示所有特征,較不重要的特征已被合并為“Sum of 4 other features”,如果希望展示更多特征,可以通過(guò)設(shè)置max_display參數(shù),調(diào)整展示的特征數(shù)量,以便觀察更完整的特征貢獻(xiàn)排名

# 設(shè)置 max_display 值
max_display = 13
plt.figure(figsize=(10, 5), dpi=1200)

# 創(chuàng)建 SHAP 值條形圖,并使用 max_display 參數(shù)限制最多顯示的特征數(shù)量
shap.plots.bar(shap_values_Explanation, max_display=max_display, show=False)
plt.title(f'SHAP Explanation Sorted Feature Importance (Top {max_display})')
plt.savefig(f"SHAP_Explanation_Sorted_Feature_Importance_Top_{max_display}.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()

基于SHAP Explanation的單個(gè)樣本特征重要性與實(shí)際數(shù)據(jù)可視化

plt.figure(figsize=(10, 5), dpi=1200)
# 創(chuàng)建 SHAP 值條形圖,展示數(shù)據(jù)
shap.plots.bar(shap_values_Explanation[1], show_data=True, show=False, max_display=13)
plt.title('SHAP Explanation for Instance (Feature Importance with Data)')
plt.savefig("SHAP_Explanation_Instance_Feature_Importance_with_Data.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()

使用shap_values_Explanation[1](即第一個(gè)樣本的SHAP值)繪制特征重要性條形圖,并通過(guò)show_data=True顯示每個(gè)特征的具體數(shù)值,最后將圖表保存為PDF文件,該圖展示了某個(gè)具體樣本的各個(gè)特征如何正負(fù)向影響模型的預(yù)測(cè)結(jié)果,紅色表示正貢獻(xiàn),藍(lán)色表示負(fù)貢獻(xiàn)

基于SHAP Explanation的單個(gè)樣本瀑布圖可視化

plt.figure(figsize=(10, 5), dpi=1200)
# 繪制第1個(gè)樣本的 SHAP 瀑布圖,并設(shè)置 show=False 以避免直接顯示
shap.plots.waterfall(shap_values_Explanation[1], show=False, max_display=13)
# 保存圖像為 PDF 文件
plt.savefig("SHAP_Waterfall_Plot_Sample_1.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()

使用shap_values_Explanation[1](即第一個(gè)樣本的SHAP值)繪制了該樣本的瀑布圖,通過(guò)waterfall函數(shù)展示各個(gè)特征如何逐步影響最終的模型預(yù)測(cè)值,設(shè)置max_display=13限制最多展示13個(gè)特征,該瀑布圖其中紅色表示正向貢獻(xiàn),藍(lán)色表示負(fù)向影響,最終累積影響得出模型的預(yù)測(cè)值為1.543,其中E[f(X)] = -0.172是模型的基準(zhǔn)值,表示在沒(méi)有特征信息的情況下,模型的平均預(yù)測(cè)輸出,它代表了模型對(duì)所有樣本的整體預(yù)測(cè)傾向

基于SHAP Explanation值的按性別分組特征重要性可視化

#  'sex' 列包含性別信息,0 代表女性,1 代表男性
sex = ["Women" if df.loc[i, "sex"] == 0 else "Men" for i in range(df.shape[0])]

plt.figure(figsize=(10, 5), dpi=1200)
# 使用 SHAP 的 cohorts 方法根據(jù) sex 進(jìn)行分組并繪制條形圖,限制顯示最多13個(gè)特征
shap.plots.bar(shap_values_Explanation.cohorts(sex).abs.mean(0), max_display=13, show=False)
plt.title('SHAP Explanation Sorted Feature Importance by Sex')
plt.savefig("SHAP_Explanation_Sorted_Feature_Importance_by_Sex.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()

使用shap_values_Explanation按sex(性別,原始特征為分類數(shù)據(jù))特征分組,通過(guò)SHAP的cohorts方法計(jì)算并繪制各個(gè)特征對(duì)不同性別群體模型預(yù)測(cè)的平均貢獻(xiàn)度,并限制顯示13個(gè)最重要的特征,結(jié)果展示了按性別(男性和女性)分組后,特征對(duì)模型預(yù)測(cè)的重要性。無(wú)論是男性還是女性,cp(胸痛類型)、thal(地中海貧血類型)和ca(主要血管數(shù)量)都是最重要的特征,且貢獻(xiàn)度相近。然而,某些特征的貢獻(xiàn)在性別之間存在差異,如oldpeak在男性中的貢獻(xiàn)更顯著,這表明模型在預(yù)測(cè)不同性別群體時(shí),部分特征的影響力存在明顯差異,右下角數(shù)值分別代表各類別的樣本量

基于SHAP Explanation值的按自動(dòng)分組年齡的特征重要性可視化

# 例如,將年齡劃分為三類:青年(<30歲),中年(30-60歲),老年(>60歲)
age_groups = ["Young" if df.loc[i, "age"] < 30 else "Middle-aged" if df.loc[i, "age"] <= 60 else "Senior" for i in range(df.shape[0])]
# 使用 SHAP 的 cohorts 方法根據(jù) age_groups 進(jìn)行分組
v = shap_values_Explanation.cohorts(age_groups).abs.mean(0)
plt.figure(figsize=(10, 5), dpi=1200)
# 繪制 SHAP 條形圖
shap.plots.bar(v, show=False, max_display=13)
plt.title('SHAP Explanation Sorted Feature Importance by Automatically Grouped Age')
plt.savefig("SHAP_Explanation_Sorted_Feature_Importance_by_Grouped_Age.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()

使用了shap_values_Explanation,根據(jù)年齡(age)這一連續(xù)性特征將樣本分為青年、中年和老年三類,利用SHAP的cohorts方法計(jì)算各組的平均特征貢獻(xiàn)度,并繪制條形圖。該圖展示了在不同年齡段中,特征對(duì)模型預(yù)測(cè)的平均重要性,顯示了特征影響隨年齡變化的差異,結(jié)果顯示特征thal和cp在各年齡段的貢獻(xiàn)度較為接近,但chol(膽固醇)在老年群體中的影響顯著高于中年和青年,說(shuō)明不同年齡段特征對(duì)模型預(yù)測(cè)的影響存在差異

基于SHAP Explanation值的特征聚類可視化(cutoff=0.5)

# 計(jì)算 clustering 結(jié)果和 shap_values_Explanation
clustering = shap.utils.hclust(X, y)
plt.figure(figsize=(10, 5), dpi=1200)
shap.plots.bar(shap_values_Explanation,
clustering=clustering,
clustering_cutoff=0.5,
show=False, max_display=13)
plt.title('SHAP Explanation with Clustering (cutoff=0.5)')
plt.savefig("SHAP_Explanation_with_Clustering.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()

使用shap_values_Explanation,結(jié)合層次聚類方法(hclust),根據(jù)特征對(duì)模型的貢獻(xiàn)進(jìn)行聚類,設(shè)置了聚類截?cái)嘀担╟lustering_cutoff=0.5),以展示相關(guān)特征的分組情況,通過(guò)條形圖展示了前13個(gè)特征的平均SHAP值貢獻(xiàn),顯示了模型中每個(gè)特征的相對(duì)重要性,聚類可視化的作用是通過(guò)層次聚類分析特征之間的相關(guān)性,幫助識(shí)別哪些特征在對(duì)模型預(yù)測(cè)產(chǎn)生類似影響,通過(guò)設(shè)置聚類截?cái)嘀担╟utoff=0.5),我們可以將影響力相似的特征分組,從而更好地理解模型的決策結(jié)構(gòu)以及特征之間的相互作用,從當(dāng)前的可視化結(jié)果來(lái)看,沒(méi)有明確顯示出類似的特征被分組在一起

該圖為生成模擬數(shù)據(jù)基于SHAP值的特征重要性,結(jié)合層次聚類(clustering_cutoff=0.5)對(duì)相似特征進(jìn)行了分組,圖中的feature_0和feature_3被劃分在同一組(有一個(gè)灰色的聚類括號(hào)),這表明這兩個(gè)特征在模型中對(duì)預(yù)測(cè)的影響非常相似,屬于相關(guān)性較強(qiáng)的特征,相比之前未能顯示明顯相關(guān)特征的圖,這個(gè)圖表通過(guò)聚類更好地揭示了特征之間的相似性,有助于更清晰地理解哪些特征在模型決策中發(fā)揮了類似的作用,當(dāng)特征被聚類在同一組時(shí),說(shuō)明它們對(duì)模型的預(yù)測(cè)有類似的貢獻(xiàn),可能在數(shù)據(jù)上表現(xiàn)出較高的相關(guān)性或冗余性,這意味著這些特征可能描述了相似的現(xiàn)象,或者它們?cè)谀P椭刑峁┑莫?dú)立信息量較低。在模型優(yōu)化中,發(fā)現(xiàn)這些相關(guān)特征可以幫助我們減少冗余特征,簡(jiǎn)化模型,或者進(jìn)一步探索這些特征之間的關(guān)系

基于SHAP Explanation值的特征散點(diǎn)圖可視化

# 指定特征的名稱為 'cp'(胸痛類型)
feature_name = 'cp'
# 找到指定特征的索引
feature_index = shap_values_Explanation.feature_names.index(feature_name)
plt.figure(figsize=(10, 5), dpi=1200)
# 使用 SHAP 的 scatter 方法繪制指定特征的散點(diǎn)圖
shap.plots.scatter(shap_values_Explanation[:, feature_index], show=False)
plt.title(f'SHAP Scatter Plot for Feature: {feature_name}')
plt.savefig(f"SHAP_Scatter_Plot_{feature_name}.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()

使用shap_values_Explanation,繪制特定特征cp(胸痛類型)的SHAP散點(diǎn)圖,通過(guò)shap.plots.scatter展示該特征對(duì)模型預(yù)測(cè)的貢獻(xiàn)值,并保存為PDF文件,圖表橫軸顯示特征cp的不同取值,縱軸顯示該特征的SHAP值(即它對(duì)模型輸出的影響),可以發(fā)現(xiàn)cp的取值越高,其對(duì)應(yīng)的SHAP值越大,表明較高的cp值(如4)對(duì)模型的正向預(yù)測(cè)影響更大,而較低的cp值(如1)則對(duì)模型的負(fù)向預(yù)測(cè)影響更顯著

基于SHAP Explanation值的力圖匯總可視化

# 初始化 JS 庫(kù)
shap.initjs()
# 使用 force_plot 方法可視化所有樣本的解釋
shap.force_plot(explainer.expected_value, shap_values_Explanation.values, X)

首先初始化SHAP的JavaScript庫(kù) (shap.initjs()),然后使用force_plot方法可視化所有樣本的解釋結(jié)果,通過(guò)傳入模型的期望輸出值(explainer.expected_value)以及所有樣本的shap_values_Explanation.values,展示整體模型的預(yù)測(cè)解釋力圖(force plot),并為每個(gè)樣本的特征貢獻(xiàn)進(jìn)行匯總展示,生成的力圖是交互式的,用戶可以通過(guò)點(diǎn)擊和懸停在圖中的元素上,查看每個(gè)特征對(duì)模型預(yù)測(cè)的正向或負(fù)向影響,從而更直觀地理解模型的決策過(guò)程

基于SHAP Explanation值的單個(gè)樣本決策圖可視化

# 獲取模型的期望輸出值(平均預(yù)測(cè)值)
expected_value = explainer.expected_value
# 選擇第1個(gè)樣本的 SHAP 值
shap_values = shap_values_numpy[1]
# 決策圖的特征名展示
features_display = X
# 繪制 SHAP 決策圖
plt.figure(figsize=(10, 5), dpi=1200)
shap.decision_plot(expected_value, shap_values, features_display, show=False)
# 保存圖像為 PDF
plt.savefig("shap_decision_plot_samples.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()

使用shap_values_numpy[1](即第1個(gè)樣本的SHAP值),并結(jié)合模型的期望輸出值(explainer.expected_value)生成了一個(gè)SHAP決策圖,通過(guò)shap.decision_plot展示各個(gè)特征對(duì)模型預(yù)測(cè)結(jié)果的貢獻(xiàn),并保存為PDF文件,決策圖顯示了各個(gè)特征如何一步步影響模型的最終輸出值,紅線向右偏表示正向影響,向左偏表示負(fù)向影響,幫助理解模型如何基于各特征做出決策

基于SHAP Explanation值的錯(cuò)分類樣本決策圖可視化

# 獲取模型的期望輸出值(平均預(yù)測(cè)值)
expected_value = explainer.expected_value
# 計(jì)算預(yù)測(cè)值:根據(jù) SHAP 值求和加上期望值,再將結(jié)果通過(guò)閾值判斷為分類輸出
y_pred = (shap_values_numpy.sum(1) + expected_value) > 0
# 計(jì)算錯(cuò)分類的樣本
misclassified = y_pred != y
# 決策圖展示的特征
features_display = X
# 繪制 SHAP 決策圖,帶有錯(cuò)分類的樣本高亮
plt.figure(figsize=(10, 5), dpi=1200) # 設(shè)置畫(huà)布大小和分辨率
shap.decision_plot(expected_value, shap_values_numpy, features_display,
link='logit', highlight=misclassified, show=False)
# 保存圖像為 PDF 文件
plt.savefig("shap_decision_plot_with_misclassified.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()

使用shap_values_numpy計(jì)算模型預(yù)測(cè)的SHAP值,并結(jié)合模型的期望輸出值(expected_value)生成了一個(gè)決策圖。通過(guò)shap.decision_plot展示所有樣本的特征貢獻(xiàn),并高亮顯示了錯(cuò)分類的樣本,最終生成的圖表展示了分類模型中哪些樣本被錯(cuò)誤預(yù)測(cè),并保存為PDF文件,決策圖展示了各特征對(duì)模型預(yù)測(cè)的影響軌跡,藍(lán)色和紅色線條分別代表對(duì)模型輸出的負(fù)向和正向貢獻(xiàn),錯(cuò)分類的樣本被高亮顯示,幫助分析哪些特征對(duì)這些錯(cuò)誤預(yù)測(cè)影響最大,從而可以進(jìn)一步優(yōu)化模型或調(diào)整特征

基于SHAP Explanation值的僅錯(cuò)分類樣本決策圖可視化

# 獲取錯(cuò)分類樣本的索引
misclassified_indices = misclassified[misclassified].index

# 過(guò)濾出錯(cuò)分類樣本的 SHAP 值
shap_values_misclassified = shap_values_numpy[misclassified_indices]

# 保留原始特征名稱,過(guò)濾出錯(cuò)分類樣本的特征
features_display_misclassified = X.loc[misclassified_indices]

# 繪制 SHAP 決策圖,顯示錯(cuò)分類的樣本
plt.figure(figsize=(10, 5), dpi=1200) # 設(shè)置畫(huà)布大小和分辨率
shap.decision_plot(expected_value,
shap_values_misclassified,
link='logit',
highlight=True, # 高亮顯示這些樣本
feature_names=list(X.columns), # 轉(zhuǎn)換為列表以避免 TypeError
show=False)
plt.savefig("shap_decision_plot_misclassified_only.pdf", format='pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()

使用shap_values_numpy,通過(guò)提取錯(cuò)分類樣本的SHAP值,并結(jié)合模型的期望輸出值(expected_value),生成了錯(cuò)分類樣本的決策圖。通過(guò)shap.decision_plot高亮顯示僅錯(cuò)分類的樣本,并展示了每個(gè)特征對(duì)這些樣本模型預(yù)測(cè)的影響,幫助我們深入分析哪些特征導(dǎo)致了錯(cuò)誤預(yù)測(cè),并為后續(xù)模型調(diào)整或特征工程提供參考

文章轉(zhuǎn)自微信公眾號(hào)@Python機(jī)器學(xué)習(xí)AI

上一篇:

?使用Keras函數(shù)式API進(jìn)行深度學(xué)習(xí)

下一篇:

如何用SHAP解讀集成學(xué)習(xí)Stacking中的基學(xué)習(xí)器和元學(xué)習(xí)器以及整體模型貢獻(xiàn)
#你可能也喜歡這些API文章!

我們有何不同?

API服務(wù)商零注冊(cè)

多API并行試用

數(shù)據(jù)驅(qū)動(dòng)選型,提升決策效率

查看全部API→
??

熱門(mén)場(chǎng)景實(shí)測(cè),選對(duì)API

#AI文本生成大模型API

對(duì)比大模型API的內(nèi)容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉(zhuǎn)化潛力

25個(gè)渠道
一鍵對(duì)比試用API 限時(shí)免費(fèi)

#AI深度推理大模型API

對(duì)比大模型API的邏輯推理準(zhǔn)確性、分析深度、可視化建議合理性

10個(gè)渠道
一鍵對(duì)比試用API 限時(shí)免費(fèi)