什么是混淆矩阵热力图
在机器学习分类任务中,评估模型性能是至关重要的一步,其中一种直观有效的评估方法就是使用混淆矩阵,而为了更直观地展示混淆矩阵中的数据,我们可以将其可视化为热力图,那么什么是混淆矩阵热力图呢?
混淆矩阵
混淆矩阵是一个方阵,用于描述分类模型的性能,它通过展示实际分类和预测分类之间的对比来显示模型的正确和错误分类情况,混淆矩阵的每一行代表一个实际类,每一列代表一个预测类,典型的混淆矩阵如下:
- 真正类(TP): 模型正确预测的实例数量
- 假阴性(FN): 实际为正类但被错误预测为负类的实例数量
- 假阳性(FP): 实际为负类但被错误预测为正类的实例数量
- 真正负类(TN): 模型正确预测为负类的实例数量(通常不在标准的混淆矩阵中展示)
通过混淆矩阵,我们可以计算出许多评价指标,如准确率、精确率、召回率和F1分数,这些指标可以帮助我们全面了解模型的分类性能
混淆矩阵热力图
虽然混淆矩阵提供了详尽的分类结果,但其表格形式在数据量较大时不够直观,混淆矩阵热力图通过颜色深浅来表示数值大小,使得数据的分布和趋势一目了然
代码实现
数据处理
import pandas as pd
from sklearn.datasets import load_iris
# 加载鸢尾花数据集
iris = load_iris()
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df['species'] = iris.target
from sklearn.model_selection import train_test_split
X = iris_df.drop(['species'], axis=1)
y = iris_df['species']
X_train, X_test, y_train, y_test = train_test_split(
X, # 特征矩阵
y, # 目标变量
test_size=0.3, # 测试集所占比例,0.3 表示 30% 的数据用于测试集
stratify=iris_df['species'], # 按 'species' 列进行分层采样,确保训练集和测试集中各类标签的比例与原数据集一致
random_state=42 # 随机种子,确保结果可重复
)
加载鸢尾花数据集,将其分割为训练集和测试集,并确保测试集中各类标签的比例与原数据集一致,以便进行后续的机器学习模型训练和评估
模型构建
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
# 初始化KNN分类器,设定k值为3
knn = KNeighborsClassifier(n_neighbors=3)
# 训练KNN分类器
knn.fit(X_train, y_train)
# 预测测试集
y_pred = knn.predict(X_test)
# 输出模型报告, 查看评价指标
print(classification_report(y_test, y_pred))
初始化并训练一个KNN分类器,使用测试集进行预测,并输出分类性能报告
混淆矩阵热力图
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# 输出混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
# 自定义标签
species_names = iris.target_names
# 绘制热力图
plt.figure(figsize=(10, 7))
sns.heatmap(conf_matrix, annot=True, annot_kws={'size':15}, fmt='d', cmap='YlGnBu',
xticklabels=species_names, yticklabels=species_names, cbar_kws={'shrink': 0.75})
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.title('Confusion Matrix of KNN Classifier on Iris Dataset', fontsize=15)
plt.show()
热力图展示了KNN分类器在鸢尾花数据集上的分类效果,其中setosa和versicolor类别被100%正确分类,而virginica类别有2个样本被错误预测为versicolor
往期推荐
小白轻松上手:一键生成SHAP解释图的GUI应用,支持多种梯度提升模型选择
特征工程进阶:暴力特征字典的构建与应用 实现模型精度质的飞跃
快速选择最佳模型:轻松上手LightGBM、XGBoost、CatBoost和NGBoost!
无网络限制!同步官网所有功能!让编程小白也能轻松上手进行代码编写!!!
微信号|deep_ML
欢迎添加作者微信进入Python、ChatGPT群
进群请备注Python或AI进入相关群
无需科学上网、同步官网所有功能、使用无限制
如果你对类似于这样的文章感兴趣。
欢迎关注、点赞、转发~
个人观点,仅供参考