import pandas as pd
# Load the iris dataset
iris = load_iris()
# Create a DataFrame from the iris dataset
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df['target'] = iris.target
from sklearn.model_selection import train_test_split
X = iris_df.drop(['target'],axis=1)
y = iris_df['target']
X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
# 然后將訓練集進一步劃分為訓練集和驗證集
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.125,stratify=y_temp, random_state=42) # 0.125 x 0.8 = 0.1
加載鳶尾花數據集,將鳶尾花數據集分割為訓練集、驗證集和測試集,具體過程是:從整個數據集中抽取20%作為測試集;剩余的80%數據中抽取12.5%作為驗證集,最終驗證集占整個數據集的10%,訓練集占整個數據集的70%,這一步驟為后續使用XGBoost進行多分類模型的訓練和評估奠定基礎
模型建立
import xgboost as xgb
# 更新后的多分類模型參數
params_xgb = {
'learning_rate': 0.02, # 學習率
'booster': 'gbtree', # 提升方法
'objective': 'multi:softprob', # 損失函數,多分類使用softmax
'num_class': 3, # 類別數,鳶尾花數據集有三類
'max_leaves': 127, # 每棵樹的葉子節點數量
'verbosity': 1, # 輸出信息的詳細程度
'seed': 42, # 隨機種子
'nthread': -1, # 并行運算的線程數量
'colsample_bytree': 0.6, # 每棵樹隨機選擇的特征比例
'subsample': 0.7, # 每次迭代時隨機選擇的樣本比例
'early_stopping_rounds': 100, # 早停輪數
'eval_metric': 'mlogloss' # 評估指標,多分類使用mlogloss
}
# 創建并訓練多分類模型
model_xgb = xgb.XGBClassifier(**params_xgb)
model_xgb.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False)
配置并訓練一個XGBoost多分類模型來預測鳶尾花數據集的類別,使用特定的參數設置和早停機制,比如這里是三分類,如果需要更改為其他多分類(如四分類),只需修改參數num_class,并相應調整其他參數以達到最優模型效果
評價報告
from sklearn.metrics import classification_report
# 預測測試集
y_pred = model_xgb.predict(X_test)
# 輸出模型報告, 查看評價指標
print(classification_report(y_test, y_pred))
混淆矩陣熱力圖
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# 輸出混淆矩陣
conf_matrix = confusion_matrix(y_test, y_pred)
# 繪制熱力圖
plt.figure(figsize=(10, 7))
sns.heatmap(conf_matrix, annot=True, annot_kws={'size':15}, fmt='d', cmap='YlGnBu')
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.title('Confusion matrix heat map', fontsize=15)
plt.show()
Shap實現
創建Shap解釋器
import shap
# 創建SHAP解釋器
explainer = shap.Explainer(model_xgb)
# 計算SHAP值
shap_values = explainer(X_test)
print("shap值維度;",shap_values.shape)
shap_values
可以看見針對測試集的shap值的維度為(30,4,3),也就是計算的每個類別的SHAP值,對于每個樣本,SHAP值將是一個矩陣,其中每個元素表示一個特征對某個類別的貢獻
繪制Shap解釋圖
# 特征標簽
labels = X_train.columns
# 設置 matplotlib 的全局字體配置
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'
plt.rcParams['font.size'] = 13
# 提取每個類別的 SHAP 值
shap_values_class_1 = shap_values.values[:, :, 0]
shap_values_class_2 = shap_values.values[:, :, 1]
shap_values_class_3 = shap_values.values[:, :, 2]
shap_values_class_1
# 繪制 SHAP 總結圖,使用viridis配色方案
plt.figure()
plt.title('class_1')
shap.summary_plot(shap_values_class_1, X_val, feature_names=labels, plot_type="dot", cmap="viridis")
plt.show()
這里針對鳶尾花的測試集第一個類別0進行shap解釋圖繪制
繪制Shap依賴圖
shap.dependence_plot('sepal length (cm)', shap_values_class_1, X_val, interaction_index='sepal width (cm)')
plt.show()
針對鳶尾花的測試集第一個類別0的特征sepal length (cm)、sepal width (cm)進行shap依賴圖繪制
繪制Shap力圖
# 選擇一個樣本索引進行解釋
sample_index = 1
expected_value = explainer.expected_value[0] # 需要指定個類別的基準值,這里是第一個類別
# 獲取單個樣本的 SHAP 值
sample_shap_values = shap_values_class_1[sample_index]
# 繪制 SHAP 解釋力圖 (Force Plot)
shap.force_plot(expected_value, sample_shap_values, X_val.iloc[sample_index], matplotlib=True)
# 顯示繪圖
plt.show()
shap力圖解釋同樣會選擇數據集以及類別,但是還會多一個應選擇的基準值,比如這里選擇的第一個類比那基準值也要選擇第一個類比的基準值
生成Shap交互作用圖
shap_interaction_values = explainer.shap_interaction_values(X_val)
# 提取每個類別的值
shap_interaction_values_class_1 = shap_interaction_values[:, :, :, 0] # 類別1
shap_interaction_values_class_2 = shap_interaction_values[:, :, :, 1] # 類別2
shap_interaction_values_class_3 = shap_interaction_values[:, :, :, 2] # 類別3
# 繪制 SHAP 交互值的總結圖
plt.figure()
shap.summary_plot(shap_interaction_values_class_1, X_val, feature_names=labels)
plt.show()
計算的交互值比Shap值維度多一,同理得提取每一個類比的交互值,具體怎么提取參考這個三分類代碼
生成Shap熱圖
expected_value = explainer.expected_value[0] # 需要指定個類別的基準值,這里是第一個類別
# 創建 shap.Explanation 對象
shap_explanation = shap.Explanation(values=shap_values_class_1[0:10, :],
base_values= expected_value,
data=X_val.iloc[0:10, :],
feature_names=X_val.columns)
# 繪制熱圖
plt.figure()
shap.plots.heatmap(shap_explanation)
plt.show()
代碼繪制的是第一個類比測試集前10個樣本的Shap熱圖,和力圖一樣得確定一個基準值,對哪一個類比做就采用哪一個類比的基準值
創建一個使用Tkinter庫構建的GUI應用程序,旨在通過按鈕、標簽、組合框和文本框等組件實現數據上傳、選擇目標特征、設置分類任務的類別數、選擇數據集、選擇顏色方案、選擇特征、輸入樣本索引、輸入樣本范圍等功能,從而對XGBoost分類模型進行訓練并生成相關的解釋圖,并確保將這些圖保存為高DPI的PDF文件,以保證可視化效果不受損失
本文章轉載微信公眾號@Python機器學習AI