混淆矩陣(Confusion Matrix)是機器學習中評估分類模型性能的重要工具。通過混淆矩陣,可以直觀地了解模型在各個類別上的表現,包括正確分類和錯誤分類的樣本數量?;诨煜仃?,我們可以計算準確率、精確率、召回率、F1分數以及真正率和假正率等多個評估指標,用于評估分類模型的性能。

一、混淆矩陣

混淆矩陣(Confusion Matrix)是什么?混淆矩陣是一個表格,用于描述分類模型的預測結果與實際標簽之間的關系。

對于一個二分類問題,混淆矩陣是一個2×2的矩陣。

對于多分類問題,混淆矩陣的大小為類別數乘以類別數。

混淆矩陣的評估指標有哪些?混淆矩陣可用于計算準確率、精確率、召回率、F1分數以及真正率和假正率等多個評估指標,這些指標共同構成了評估分類模型性能的完整體系。

1. 準確率(Accuracy)

準確率是模型正確分類的樣本數占總樣本數的比例。

對于多分類問題,準確率同樣適用,只需將TP、TN、FP、FN替換為對應類別的數量總和。

2. 精確率(Precision)

精確率是針對預測為正類的樣本,模型預測正確的比例。

對于多分類問題,可以計算每個類別的精確率。

3. 召回率(Recall)

召回率是針對實際為正類的樣本,模型預測正確的比例。

同樣,對于多分類問題,可以計算每個類別的召回率。

4. F1分數(F1 Score)

F1分數是精確率和召回率的調和平均數,用于綜合評估模型的性能。

對于多分類問題,可以計算每個類別的F1分數,或者計算宏平均(Macro-average)和微平均(Micro-average)F1分數。

5. 真正率(True Positive Rate, TPR)和假正率(False Positive Rate, FPR)

真正率也稱為靈敏度(Sensitivity)或召回率(Recall)。

假正率也稱為1-特異度(1-Specificity)。

二、二分類問題

二分類問題的混淆矩陣是什么?對于二分類問題,混淆矩陣是一個2×2的表格,用于描述分類模型預測結果與實際標簽之間的關系,包括真正類(TP)、假正類(FP)、假負類(FN)和真負類(TN)四種情況。

Python中,使用sklearn.metrics中的confusion_matrix函數計算了實際標簽y_true與預測標簽y_pred之間的混淆矩陣,并利用seaborn庫的heatmap函數以及matplotlib.pyplot庫的相關函數對混淆矩陣進行了可視化展示。

from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# 假設y_true是實際標簽,y_pred是預測標簽
y_true = [0, 1, 1, 0, 1, 0, 1, 0, 0, 1]
y_pred = [0, 1, 0, 0, 1, 0, 1, 1, 0, 1]

# 計算混淆矩陣
cm = confusion_matrix(y_true, y_pred)

# 使用seaborn繪制混淆矩陣
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

三、多分類問題

多分類問題的混淆矩陣是什么?多分類問題的混淆矩陣是一個表格,其行表示實際類別,列表示預測類別,每個單元格的值表示實際類別與預測類別相匹配的樣本數量。

Python中,使用seaborn和matplotlib庫,基于給定的實際標簽數組y_true和預測標簽數組y_pred,生成并可視化了一個三分類問題的混淆矩陣熱力圖。

import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

# 假設y_true是實際標簽數組, y_pred是預測標簽數組
y_true = [0, 1, 2, 2, 0, 1, 0, 2, 1, 0] # 示例實際標簽
y_pred = [0, 2, 1, 2, 0, 0, 0, 1, 2, 0] # 示例預測標簽

# 生成混淆矩陣
conf_mat = confusion_matrix(y_true, y_pred)

# 使用seaborn繪制熱力圖
sns.heatmap(conf_mat, annot=True, cmap='Blues', xticklabels=['Class 0', 'Class 1', 'Class 2'], yticklabels=['Class 0', 'Class 1', 'Class 2'])
plt.xlabel('Predicted Class')
plt.ylabel('True Class')
plt.title('Confusion Matrix')
plt.show()

本文章轉載微信公眾號@架構師帶你玩轉AI

上一篇:

一文徹底搞懂機器學習 - 混淆矩陣(Confusion Matrix)

下一篇:

一文徹底搞懂機器學習 - 分類(Classification)
#你可能也喜歡這些API文章!

我們有何不同?

API服務商零注冊

多API并行試用

數據驅動選型,提升決策效率

查看全部API→
??

熱門場景實測,選對API

#AI文本生成大模型API

對比大模型API的內容創意新穎性、情感共鳴力、商業轉化潛力

25個渠道
一鍵對比試用API 限時免費

#AI深度推理大模型API

對比大模型API的邏輯推理準確性、分析深度、可視化建議合理性

10個渠道
一鍵對比試用API 限時免費