加載Iris數(shù)據(jù)集,對類別標(biāo)簽進(jìn)行編碼,劃分訓(xùn)練集和測試集,并對特征進(jìn)行歸一化處理,為接下來構(gòu)建模型做準(zhǔn)備
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
# 初始化并訓(xùn)練決策樹模型
clf = DecisionTreeClassifier(random_state=42)
clf.fit(train_X, train_y)
# 預(yù)測
predictions = clf.predict(test_X)
# 輸出分類報(bào)告
report = classification_report(test_y, predictions, target_names=unique_classes)
print(report)
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
# 可視化決策樹
plt.figure(figsize=(20, 20), dpi=200)
plot_tree(
clf,
feature_names=df.columns[0:4], # 使用 df 的列名作為特征名稱
class_names=unique_classes,
filled=True
)
plt.savefig('決策樹.png')
plt.show()
這里針對鳶尾花分類任務(wù)訓(xùn)練得到的決策樹,可視化展示了決策樹的“玻璃箱”結(jié)構(gòu),按照慣例,在一個(gè)節(jié)點(diǎn)中滿足條件的子節(jié)點(diǎn)位于左側(cè)分支
解釋:樹的根節(jié)點(diǎn)是通過對”sepal length “特征進(jìn)行條件判定,如果其值小于等于0.233,則進(jìn)入左子樹,否則進(jìn)入右子樹,其它節(jié)點(diǎn)類似,最后的葉節(jié)點(diǎn)表示決策樹的最終分類結(jié)果,每個(gè)葉節(jié)點(diǎn)包含樣本的基尼不純度(gini)和樣本數(shù)量(samples),以及每個(gè)類別的樣本數(shù)量和類別(class)
from sklearn.ensemble import RandomForestClassifier
# 初始化并訓(xùn)練隨機(jī)森林模型
rf_clf = RandomForestClassifier(random_state=42, n_estimators=20)
rf_clf.fit(train_X, train_y)
# 從隨機(jī)森林中提取一棵決策樹
estimator = rf_clf.estimators_[0] # 提取第一棵樹
# 預(yù)測
predictions = rf_clf.predict(test_X)
# 輸出分類報(bào)告
report = classification_report(test_y, predictions, target_names=unique_classes)
print(report)
在隨機(jī)森林中,n_estimators參數(shù)表示隨機(jī)森林中包含的決策樹的數(shù)量,這里為10表示存在10棵決策樹,增加其值通常可以提高模型的預(yù)測性能,因?yàn)楦嗟臉淇梢愿玫夭蹲綌?shù)據(jù)的復(fù)雜模式,性能提升會逐漸減小,但是請?zhí)岱滥P瓦^擬合
plt.figure(figsize=(20, 20), dpi=200)
plot_tree(
estimator,
feature_names=df.columns[0:4], # 使用 df 的列名作為特征名稱
class_names=unique_classes,
filled=True
)
plt.show()
這里的可視化為第一棵決策樹,接下來繪制隨機(jī)森林里面的所有決策樹
# 可視化所有提取的決策樹
fig, axes = plt.subplots(nrows=5, ncols=4, figsize=(40, 50), dpi=200)
for i in range(20):
ax = axes[i // 4, i % 4]
plot_tree(
rf_clf.estimators_[i],
feature_names=df.columns[0:7], # 使用 df 的列名作為特征名稱
class_names=unique_classes,
filled=True,
ax=ax
)
ax.set_title(f'Tree {i+1}')
plt.savefig('隨機(jī)森林.png')
plt.tight_layout()
plt.show()
這就是這個(gè)隨機(jī)森林存在的20棵決策樹的決策過程,每棵樹都會對輸入樣本進(jìn)行預(yù)測,并根據(jù)多數(shù)投票的結(jié)果來確定最終樣本的分類
文章轉(zhuǎn)自微信公眾號@Python機(jī)器學(xué)習(xí)AI