分类模型混淆矩阵优化:添加每个类别的预测准确性

机器学习大数据数据库

picture.image

✨ 欢迎关注Python机器学习AI ✨

本节介绍: 混淆矩阵可视化优化 。数据采用模拟数据,作者根据个人对机器学习的理解进行代码实现与图表输出,细节并不保证与原文一定相同,仅供参考。 详细数据和代码将在稍后上传至交流群,付费成员可在交流群中获取下载。需要的朋友可关注公众文末提供的购买方式。 购买前请咨询,避免不必要的问题。 文末点赞、推荐、转发参与免费包邮赠书~

✨ 优化结果 ✨

picture.image

可以发现在这种混淆矩阵中,相较于直接输出的混淆矩阵,加入了每个类别的预测百分比,使得模型的表现更加直观。具体而言:

  • 对角线上的百分比 代表了每个类别的预测正确率
  • 非对角线上的百分比 显示了模型对其他类别的误分类情况,即错误预测的类别以及它们在总预测中的占比

这种优化的可视化方式能够帮助更容易地评估模型的表现,尤其是在哪些类别上表现较好,在哪些类别上可能存在较多的错误分类。因此,通过这种方式,可以更准确地判断模型在多分类(如 RandomForest)和二分类(如 XGBoost)任务中的性能

✨ 代码实现 ✨

  
import pandas as pd  
import numpy as np  
import matplotlib.pyplot as plt  
plt.rcParams['font.family'] = 'Arial'  
plt.rcParams['axes.unicode_minus'] = False  
import warnings  
# 忽略所有警告  
warnings.filterwarnings("ignore")  
  
path = r"2025-4-12公众号Python机器学习AI.xlsx"  
df = pd.read_excel(path)  
from sklearn.model_selection import train_test_split  
  
# 划分特征和目标变量  
X = df.drop(['Electrical_cardioversion'], axis=1)    
y = df['Electrical_cardioversion']    
# 划分训练集和测试集  
X_train, X_test, y_train, y_test = train_test_split(  
    X,    
    y,   
    test_size=0.3,   
    random_state=42,   
    stratify=df['Electrical_cardioversion']   
)  
from xgboost import XGBClassifier  
from sklearn.model_selection import GridSearchCV, KFold  
from sklearn.metrics import accuracy_score, roc_auc_score  
  
# 定义 XGBoost 分类模型  
model_xgb = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=8)  
  
# 定义参数网格  
param_grid = {  
    'n_estimators': [50, 100, 200],  
    'max_depth': [3, 5, 7],  
    'learning_rate': [0.01, 0.1, 0.2],  
    'subsample': [0.8, 1.0],  
    'colsample_bytree': [0.8, 1.0]  
}  
  
# 使用 K 折交叉验证  
kfold = KFold(n_splits=5, shuffle=True, random_state=8)  
  
# 使用网格搜索寻找最佳参数  
grid_search = GridSearchCV(estimator=model_xgb, param_grid=param_grid, scoring='accuracy',   
                           cv=kfold, verbose=1, n_jobs=-1)  
  
# 拟合模型  
grid_search.fit(X_train, y_train)  
# 使用最优参数训练模型  
xgboost = grid_search.best_estimator_

使用 XGBoost 分类器,结合网格搜索和 K 折交叉验证,自动调整模型的超参数,选择最佳的参数配置,并训练模型进行预测

  
import seaborn as sns  
from sklearn.metrics import confusion_matrix  
from matplotlib.colors import LinearSegmentedColormap  
  
# 绘制混淆矩阵的函数  
def plot_confusion_matrix(model, X, y, label, pdf_filename="confusion_matrix.pdf"):  
    """  
    绘制混淆矩阵的函数  
  
    Parameters:  
    model : 训练好的模型 (e.g., RandomForest, SVM, etc.)  
    X : 测试数据集的特征 (例如 X_test)  
    y : 测试数据集的标签 (例如 y_test)  
    label : 模型名称 (e.g., 'Random Forest')  
    pdf_filename : 保存的 PDF 文件名(默认为"confusion_matrix.pdf")  
    """  
    # 使用训练好的模型对测试集进行预测  
    y_pred = model.predict(X)  
  
    # 生成混淆矩阵  
    cm = confusion_matrix(y, y_pred)  
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]  # 标准化  
  
    # 创建标签  
    labels = ['NO', 'YES']  
  
    # 创建自定义颜色映射(从 #ff6600 到 #ffffcc)  
    cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#ffffcc", "#ff6600"])  
  
    # 绘制改进后的混淆矩阵  
    fig, ax = plt.subplots(figsize=(10, 8))  
  
    # 绘制热图  
    sns.heatmap(  
        cm_normalized,  
        annot=False,  # 先不添加文字  
        fmt="",  
        cmap=cmap,  
        xticklabels=labels,  
        yticklabels=labels,  
        cbar=True,  
        square=True,  
        linewidths=1.5,  
        linecolor="white",  
        ax=ax,  
    )  
  
    # 添加数值和百分比  
    for i in range(cm.shape[0]):  # 遍历行  
        for j in range(cm.shape[1]):  # 遍历列  
            value = cm[i, j]  
            percentage = cm_normalized[i, j] * 100  # 获取百分比  
  
            # 在每个单元格中显示百分比和数值  
            ax.text(j + 0.5, i + 0.65, f"{percentage:.1f}%", ha="center", va="center", fontsize=18, color="black")  # 百分比显示在上方,字体稍大  
            ax.text(j + 0.5, i + 0.35, f"{value}", ha="center", va="center", fontsize=18, color="black")  # 数值显示在下方  
  
    # 设置标题和轴标签  
    plt.title(f"Confusion Matrix for {label}", fontsize=20)  
    plt.xlabel("Predicted class", fontsize=16)  
    plt.ylabel("True class", fontsize=16)  
  
    # 增大坐标轴字体  
    plt.xticks(fontsize=14)  
    plt.yticks(fontsize=14)  
  
    # 旋转x轴标签并设置水平对齐  
    plt.xticks(rotation=45, ha='right')  
  
    # 保存为PDF文件  
    plt.tight_layout()  
    plt.savefig(pdf_filename, format='pdf', bbox_inches='tight')  
  
    # 显示混淆矩阵  
    plt.show()  
  
# 示例调用方法(模型和数据传入)  
# plot_confusion_matrix(best_models['RF'], X_test, y_test, 'Random Forest', pdf_filename="confusion_matrix.pdf")

自定义一个函数,用于绘制标准化后的混淆矩阵,并在每个单元格中显示真实数值和预测百分比

  
plot\_confusion\_matrix(xgboost, X\_test, y\_test, 'XGBoost', pdf\_filename= "confusion\_matrix.pdf")

调用 plot_confusion_matrix 函数,使用训练好的 XGBoost 模型在测试集上绘制混淆矩阵,并将结果保存为 "confusion_matrix.pdf" 文件,这是一个二分类模型

picture.image

  
plot\_confusion\_matrix(clf, X\_test, y\_test, 'RF', pdf\_filename= "B\_confusion\_matrix.pdf")

调用 plot_confusion_matrix 函数,使用训练好的随机森林(RF)模型在多分类测试集上绘制混淆矩阵,并将结果保存为 "B_confusion_matrix.pdf" 文件

picture.image

✨ 该文章案例 ✨

picture.image

在上传至交流群的文件中,像往期文章一样,将对案例进行逐步分析,确保读者能够达到最佳的学习效果。内容都经过详细解读,帮助读者深入理解模型的实现过程和数据分析步骤,从而最大化学习成果。

同时,结合提供的免费AI聚合网站进行学习,能够让读者在理论与实践之间实现融会贯通,更加全面地掌握核心概念。

✨ 购买介绍 ✨

本节介绍到此结束,有需要学习数据分析和Python机器学习相关的朋友欢迎到淘宝店铺:Python机器学习AI,或添加作者微信deep_ML联系,购买作者的公众号合集。截至目前为止,合集已包含200多篇文章,购买合集的同时,还将提供免费稳定的AI大模型使用,包括但不限于ChatGPT、Deepseek、Claude等。

更新的内容包含数据、代码、注释和参考资料。 作者仅分享案例项目,不提供额外的答疑服务。项目中将提供详细的代码注释和丰富的解读,帮助您理解每个步骤 。 购买前请咨询,避免不必要的问题。

✨ 群友反馈 ✨

picture.image

✨ 淘宝店铺 ✨

picture.image

请大家打开淘宝扫描上方的二维码,进入店铺,获取更多Python机器学习和AI相关的内容,或者添加作者微信deep_ML联系 避免淘宝客服漏掉信息 ,希望能为您的学习之路提供帮助!

✨ 免费赠书 ✨

picture.image

picture.image

支持知识分享,畅享学习乐趣!特别感谢清华大学出版社 对本次赠书活动的鼎力支持!即日起,只需点赞、推荐、转发 此文章,作者将从后台随机抽取一位幸运儿,免费包邮赠送清华出版社提供的《AI Agent开发与应用:基于大模型的智能体构建》这本精彩书籍📚!

💡 赶快参与,一键三连,说不定你就是那位幸运读者哦!

往期推荐

Frontiers in Oncology:利用生存机器学习RSF模型预测患者预后模拟实现

期刊配图:相关系数+统计显著性的饼图可视化 美无需多言

期刊配图:通过SHAP组图解读模型探索不同类型特征和分组对模型的影响

机器学习在临床数据分析中的应用:从数据预处理到Web应用实现的完整流程教学

期刊配图:一区SCI常用数据缺失率展示图可视化

Psychiatry Research基于SHAP可解释性的机器学习模型构建与评估:混淆矩阵、ROC曲线、DCA与校准曲线分析

因果推断:注册行为对后续消费影响的因果推断分析案例

nature communications:基于Light GBM与随机森林结合的多模型特征选择方法

因果推断与机器学习结合:探索酒店预订取消的影响因素

期刊配图:回归模型性能与数据分布(核密度)可视化

picture.image

如果你对类似于这样的文章感兴趣。

欢迎关注、点赞、转发~

个人观点,仅供参考

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
CV 技术在视频创作中的应用
本次演讲将介绍在拍摄、编辑等场景,我们如何利用 AI 技术赋能创作者;以及基于这些场景,字节跳动积累的领先技术能力。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论