深度生存学习:Cox PH与Deep Cox PH模型构建及多时间点SHAP深度可解释性分析

机器学习大模型算法

picture.image

✨ 欢迎关注Python机器学习AI ✨

本节介绍: Cox PH与Deep Cox PH模型构建及多时间点SHAP深度可解释性分析 ,数据采用模拟数据无任何现实意义 ,作者根据个人对机器学习的理解进行代码实现与图表输出,仅供参考。 完整 数据和代码将在稍后上传至交流群,成员可在交流群中获取下载。需要的朋友可关注公众文末提供的获取方式。文末提供高效的AI工具~!点赞、推荐参与文末包邮赠书~!

✨ 论文信息 ✨

picture.image

Cox比例风险模型(Cox PH)是经典的生存分析模型,通过假设各个特征对风险的影响是线性的比例关系,广泛应用于医学和社会科学领域。Deep Cox PH模型则是在传统Cox模型基础上引入深度神经网络,用非线性函数替代线性风险函数,使模型能够捕捉更复杂的特征交互和非线性关系,从而提升生存时间预测的精度和表现。两者的关系可以理解为:Deep Cox PH是对经典Cox PH模型的扩展与升级,兼具可解释性和更强的拟合能力,适用于复杂高维数据的生存分析任务,接下来,将分别实现Cox PH模型和Deep Cox PH模型,并利用多时间点的风险预测进行SHAP值计算,实现对模型在不同时间点的可解释性分析

✨ 基础代码 ✨

  
import pandas as pd  
import numpy as np  
import matplotlib.pyplot as plt  
plt.rcParams['font.family'] = 'Times New Roman'  
plt.rcParams['axes.unicode_minus'] = False  
import warnings  
# 忽略所有警告  
warnings.filterwarnings("ignore")  
  
features = pd.read_excel("2025-12-23公众号Python机器学习AI-features.xlsx")  
outcomes = pd.read_excel("2025-12-23公众号Python机器学习AI-outcomes.xlsx")  
  
from sklearn.model_selection import train_test_split  
  
# 识别分类特征(cat_feats)和连续特征(num_feats)  
cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']  # 分类特征列表  
num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp',      # 连续特征列表  
             'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph',   
             'glucose', 'bun', 'urine', 'adlp', 'adls']  
  
  
# 将数据集拆分为训练集、验证集和测试集  
# 第一次拆分:80%训练+验证,20%测试  
X_train_val, X_test, y_train_val, y_test = train_test_split(features, outcomes, test_size=0.2, random_state=1)  
  
# 第二次拆分:从训练+验证集中划分出25%作为验证集,剩余75%作为训练集  
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.25, random_state=1)   
  
from auton_survival.preprocessing import Preprocessor  
  
# 初始化预处理器  
# 对分类特征采用忽略策略(不填充),对数值特征采用均值填充  
preprocessor = Preprocessor(cat_feat_strat='ignore', num_feat_strat='mean')   
  
# 拟合预处理器(包括填充和标准化)到训练数据,并进行one-hot编码  
transformer = preprocessor.fit(X_train, cat_feats=cat_feats, num_feats=num_feats,   
                              one_hot=True, fill_value=-1)  
  
# 对训练集、验证集和测试集进行相同的预处理转换  
X_train = transformer.transform(X_train)  
X_val = transformer.transform(X_val)  
X_test = transformer.transform(X_test)

从包含特征变量的 features 和包含生存状态及时间(event, time)的 outcomes 两个Excel文件中读取数据,划分训练、验证和测试集,并针对分类和数值特征分别进行预处理(分类特征忽略缺失,数值特征用均值填充、标准化及one-hot编码),为后续生存分析模型训练做好数据准备

  
from auton_survival.estimators import SurvivalModel  
from auton_survival.metrics import survival_regression_metric  
from sklearn.model_selection import ParameterGrid  
  
# 定义模型调参的参数范围  
param_grid = {'l2': [1e-3, 1e-4]}  # L2正则化系数的取值列表  
params = ParameterGrid(param_grid)  # 生成所有参数组合的网格  
  
# 定义模型评估时的时间点,取训练集中有事件发生的时间的10%到100%分位点,共10个时间点  
times = np.quantile(y_train['time'][y_train['event'] == 1], np.linspace(0.1, 1, 10)).tolist()  
  
# 超参数调优:遍历所有参数组合训练模型并评估验证集表现  
models = []  
for param in params:  
    # 初始化Cox比例风险模型,设置随机种子和L2正则化强度  
    model = SurvivalModel('cph', random_seed=2, l2=param['l2'])  
  
    # 训练模型  
    model.fit(X_train, y_train)  
  
    # 预测验证集的生存概率  
    predictions_val = model.predict_survival(X_val, times)  
  
    # 计算验证集上的积分布雷尔分数(Integrated Brier Score,IBS),评价模型预测性能  
    metric_val = survival_regression_metric('ibs', y_val, predictions_val, times, y_train)  
  
    # 记录当前模型及其验证集指标  
    models.append([metric_val, model])  
  
# 从所有训练好的模型中选取验证集IBS最小的模型作为最终模型  
metric_vals = [i[0] for i in models]  
first_min_idx = metric_vals.index(min(metric_vals))  
model = models[first_min_idx][1]  
# 使用最终模型预测测试集的生存概率  
predictions_te = model.predict_survival(X_test, times)  
  
# 计算测试集上的Brier分数和时间依赖的C指数,用于评估模型性能  
results = dict()  
results['Brier Score'] = survival_regression_metric('brs', outcomes=y_test, predictions=predictions_te,   
                                                    times=times, outcomes_train=y_train)  
results['Concordance Index'] = survival_regression_metric('ctd', outcomes=y_test, predictions=predictions_te,   
                                                          times=times, outcomes_train=y_train)

picture.image

通过网格搜索调优Cox比例风险模型的L2正则化参数,选取验证集上积分布雷尔分数(IBS)最优的模型,最终用该模型预测测试集生存概率,并计算测试集的Brier分数和时间依赖的C指数来评估模型性能,可视化展示了Cox比例风险模型在不同时间点上的预测性能,左图为各时间点的Brier分数(预测误差——越小越好),右图为各时间点的C指数(预测一致性——越解决1越好)

  
# 定义模型调参的参数网格  
param_grid = {  
    'bs': [100, 200],                      # batch size 批量大小  
    'learning_rate': [1e-4, 1e-3],        # 学习率  
    'layers': [[100], [100, 100]]          # 神经网络层结构,隐藏层节点数列表  
}  
  
# 生成所有参数组合  
params = ParameterGrid(param_grid)  
  
# 定义用于调参和评估模型的时间点  
# 取训练集中事件发生时间的分位数作为时间点,分成10个等分  
times = np.quantile(y_train['time'][y_train['event'] == 1], np.linspace(0.1, 1, 10)).tolist()  
  
# 进行超参数调优  
models = []  
for param in params:  
    # 根据当前参数实例化生存模型,这里是Deep Cox PH模型  
    model = SurvivalModel(  
        'dcph',  
        random_seed=0,  
        bs=param['bs'],  
        learning_rate=param['learning_rate'],  
        layers=param['layers']  
    )  
  
    # 训练模型,使用训练集特征和标签  
    model.fit(X_train, y_train)  
  
    # 在验证集上预测生存概率,使用预定义的时间点  
    predictions_val = model.predict_survival(X_val, times)  
  
    # 计算验证集上的积分Brier分数(Integrated Brier Score),用于评估模型性能  
    metric_val = survival_regression_metric('ibs', y_val, predictions_val, times, y_train)  
  
    # 保存当前模型及对应的性能指标  
    models.append([metric_val, model])  
  
# 根据验证集的指标,选择性能最优的模型(指标值最小)  
metric_vals = [i[0] for i in models]  
first_min_idx = metric_vals.index(min(metric_vals))  
model = models[first_min_idx][1]

picture.image

通过遍历多个超参数组合训练Deep Cox PH模型,使用验证集上的积分Brier分数(IBS)评估性能,最终选出IBS最优的模型,可以发现在模拟数据集上 Deep Cox PH模型略优于 Cox比例风险模型,但是这里并没有进行k折交叉验证,接下来引入k折帮助提高模型性能稳定性

  
from auton_survival.experiments import SurvivalRegressionCV  
  
# 取训练集中事件发生时间的分位数作为时间点,分成10个等分  
times = np.quantile(y_train['time'][y_train['event'] == 1], np.linspace(0.1, 1, 10)).tolist()  
  
# 定义超参数搜索空间  
param_grid = {  
    'bs': [100, 200],                      # batch size  
    'learning_rate': [1e-4, 1e-3],        # 学习率  
    'layers': [[100], [100, 100]]          # 神经网络结构  
}  
  
# 创建交叉验证实验对象,设置 num_folds=5 表示5折交叉验证  
experiment = SurvivalRegressionCV(  
    model='dcph',                         # Deep Cox PH 模型  
    num_folds=5,                         # 5折交叉验证  
    hyperparam_grid=param_grid,          # 超参数空间  
    random_seed=0  
)  
  
# 训练并调参,metric 选择积分 Brier 分数 'ibs'  
best_model = experiment.fit(X_train, y_train, times, metric='ibs')  
  
# best_model 是在所有训练数据上用最优超参数训练好的模型

picture.image

使用5折交叉验证和积分Brier分数(IBS)作为指标,在给定超参数空间内自动调优Deep Cox PH模型返回在全部训练数据上用最优超参数训练得到的最佳模型,可以发现和上面未加入k折的 Deep Cox PH模型性能一致,说明二者使用的同一套超参数,值得说明的一点是

数据预处理通常是在数据集划分之后进行的:先划分训练集、验证集和测试集,然后仅基于训练集拟合预处理器,再将该预处理器应用于验证集和测试集,以避免数据泄露。然而,当后续采用k折交叉验证时,数据会被多次划分,每次训练时的预处理操作依然是基于剩余训练折进行的

这就带来了一个潜在的数据泄露问题——在每个fold中,理想且更规范的做法是对训练折单独拟合预处理器,并使用该预处理器转换对应验证折的数据,确保验证折数据在预处理阶段不被训练折信息污染。但是现实中很多实践和代码实现往往忽略了这一点,直接使用在整个训练集上拟合的预处理器,造成一定程度的信息泄露

虽然这种泄露在多数情况下对模型性能影响有限,常被忽略,但从严格的机器学习规范和结果可信度角度来看,建议在交叉验证中对每个fold独立进行数据预处理,以最大限度地避免潜在的信息泄露,提升模型评估的公正性和可信度

  
# 使用训练好的模型预测风险函数(risk):  
# 输入预处理后的特征 x 和指定时间点 times,  
# 输出每个样本在这些时间点对应的风险值(通常表示事件发生的概率或风险程度)  
out_risk = best_model.predict_risk(X_test, times)  
out_risk 
  
array([[0.03825413, 0.06818813, 0.10561338, ..., 0.35528396, 0.42902793,          0.59149506],  
       [0.0704159 , 0.12384284, 0.18856331, ..., 0.56032173, 0.64974893,          0.81286471],  
       [0.0362972 , 0.06475226, 0.10039574, ..., 0.34036598, 0.41210654,          0.57198463],  
       ...,  
       [0.04123487, 0.07341064, 0.11352281, ..., 0.37741653, 0.45393327,          0.61958686],  
       [0.05165692, 0.09156759, 0.14081987, ..., 0.44947155, 0.53329167,          0.70398948],  
       [0.11055979, 0.1911492 , 0.28485833, ..., 0.73246374, 0.81425301,          0.93206122]])

使用训练好的模型对测试集样本在指定时间点预测风险值,反映每个样本在这些时间点发生事件的风险程度

  
# 输出每个样本在这些时间点对应的生存概率(即事件未发生的概率)  
out_survival = best_model.predict_survival(X_test, times)  
out_survival
  
array([[0.96174587, 0.93181187, 0.89438662, ..., 0.64471604, 0.57097207,          0.40850494],  
       [0.9295841 , 0.87615716, 0.81143669, ..., 0.43967827, 0.35025107,          0.18713529],  
       [0.9637028 , 0.93524774, 0.89960426, ..., 0.65963402, 0.58789346,          0.42801537],  
       ...,  
       [0.95876513, 0.92658936, 0.88647719, ..., 0.62258347, 0.54606673,          0.38041314],  
       [0.94834308, 0.90843241, 0.85918013, ..., 0.55052845, 0.46670833,          0.29601052],  
       [0.88944021, 0.8088508 , 0.71514167, ..., 0.26753626, 0.18574699,          0.06793878]])

使用训练好的模型预测测试集中每个样本在指定时间点的生存概率,即事件尚未发生的可能性,当然也可以返回给个时间节点下的AUC值

  
For time 6.000:  
  Concordance Index: 0.7633  
  Brier Score: 0.0615  
  ROC AUC: 0.7678  
For time 10.000:  
  Concordance Index: 0.7480  
  Brier Score: 0.1003  
  ROC AUC: 0.7582  
For time 18.000:  
  Concordance Index: 0.7172  
  Brier Score: 0.1395  
  ROC AUC: 0.7349  
For time 31.000:  
  Concordance Index: 0.7015  
  Brier Score: 0.1715  
  ROC AUC: 0.7205  
For time 58.000:  
  Concordance Index: 0.6856  
  Brier Score: 0.1955  
  ROC AUC: 0.7082  
For time 106.000:  
  Concordance Index: 0.6842  
  Brier Score: 0.2069  
  ROC AUC: 0.7151  
For time 194.000:  
  Concordance Index: 0.6797  
  Brier Score: 0.2133  
  ROC AUC: 0.7179  
For time 333.600:  
  Concordance Index: 0.6763  
  Brier Score: 0.2114  
  ROC AUC: 0.7232  
For time 622.000:  
  Concordance Index: 0.6734  
  Brier Score: 0.2006  
  ROC AUC: 0.7318  
For time 1910.000:  
  Concordance Index: 0.6693  
  Brier Score: 0.1406  
  ROC AUC: 0.8096

输出展示了基于网格搜索和K折交叉验证训练的Deep Cox PH模型,在不同时间点上测试集的时间依赖性预测性能指标,包括C指数、Brier分数和ROC AUC,接下来就可以对这个模型进行解释

picture.image

picture.image

picture.image

picture.image

picture.image

这里基于模型的predict_risk函数对不同时间点的风险值进行SHAP解释(展示排名前5特征),解释对象是样本在各时间点的事件风险(风险程度),并通过从训练和测试集中各选取距离时间点最近的代表性样本共50个作为背景和解释数据,提升计算效率;当然,也可以改用predict_survival解释生存概率

可以发现各个时间节点的排名第一特征均为“age”(年龄),说明年龄是模型在所有时间点中对风险预测影响最大的变量,且其SHAP值表明年龄越大,个体的风险越高,对生存时间的预测作用显著(当然这个可以通过SHAP依赖图更明显的看出 已经计算出了SHAP值对于其它SHAP可视化参考公众号往期文章绘制即可 并不复杂 所以这里不一一给出)

在各个时间节点中,除了年龄(age)始终作为对模型风险预测影响最大的特征外,排名第二的重要特征主要包括“dzclass_Coma”(昏迷状态)和“ca_no”(无癌症状态)。这表明患者的昏迷状况和癌症状态在不同时间点均对生存风险具有显著影响,是继年龄之后的关键预测因素

✨ 往期生存学习 ✨

期刊复现:基于递归特征筛选的XGBoost、RSF、COX、GBSA与SSVM生存分析模型性能提升
期刊复现:机器学习预后ExtraSurvivalTrees+SHAP构建可解释生存模型

JAMA子刊复现:COX回归结合SHAP方法解析特征对生存预测的影响

当然,公众号中还有更多机器学习期刊实战技巧,您可以通过历史文章进行检索和阅读,关注公众号,点击“发信息”>“历史文章”即可搜索公众号所有文章信息

图片

✨ 该文章案例 ✨

在上传至交流群的文件中,像往期文章一样,将对案例进行逐步分析,确保读者能够达到最佳的学习效果。内容都经过详细解读,帮助读者深入理解模型的实现过程和数据分析步骤,从而最大化学习成果。

同时,结合提供的免费AI聚合网站进行学习,能够让读者在理论与实践之间实现融会贯通,更加全面地掌握核心概念。

✨ 介绍 ✨

本节介绍到此结束,有需要学习数据分析和Python机器学习相关的朋友欢迎到淘宝店铺:Python机器学习AI,下方提供淘宝店铺二维码获取作者的公众号合集。截至目前为止,合集已包含近300多篇文章,购买合集的同时,还将提供免费稳定的AI大模型使用。

更新的内容包含数据、代码、注释和参考资料。 作者仅分享案例项目,不提供额外的答疑服务。项目中将提供详细的代码注释和丰富的解读,帮助您理解每个步骤 。 获取 前请咨询,避免不必要的问题。

✨ 群友反馈 ✨

picture.image

✨ 淘宝店铺 ✨

picture.image

请大家打开淘宝扫描上方的二维码,进入店铺,获取更多Python机器学习和AI相关的内容 ,希望能为您的学习之路提供帮助!

✨ AI工具推荐 ✨

picture.image

✨ 赠书活动 ✨

picture.image 支持知识分享,畅享学习乐趣!特别感谢清华大学出版社 对本次赠书活动的鼎力支持!即日起,只需点赞、推荐、转发 此文章,作者将从后台随机抽取一位幸运儿,免费包邮赠送清华出版社提供的《Al Agent智能体智能体开发实践》这本精彩书籍📚!

💡 赶快参与,一键三连,说不定你就是那位幸运读者哦!

往期推荐

期刊配图:模型SHAP解释特征类别柱状图、饼图与蜂巢图的组合展示

期刊复现:基于自动机器学习的预测模型构建及其残差和部分依赖分析

期刊复现:SVM、RF、BDT、DT、Logit五大模型堆叠31种组合情况最优模型的SHAP解释

期刊复现:单变量特征降维与共线性分析结合RFE集成排名进行特征筛选下的组合拳流程

期刊复现:SVM、RF、BDT、DT、Logit五大模型堆叠31种组合情况优化与最优模型选择可视化

期刊复现:基于相关系数与AUC值优化特征选择剔除冗余特征精简模型(附代码)

期刊复现:如何正确使用LASSO进行二分类特征选择?避开常见误区,掌握实用技巧

期刊复现:融合聚类与多分类转二分类的亚型可解释SHAP机器学习模型构建

期刊复现:基于LightGBM、XGBoost与RF的Lasso回归堆叠模型在连续和分类特征上的模型解释

期刊复现:基于LightGBM、XGBoost与RF的堆叠模型贝叶斯优化调参与Lasso回归元模型,结合10倍交叉验证

picture.image

如果你对类似于这样的文章感兴趣。

欢迎关注、点赞、转发~

个人观点,仅供参考

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
火山引擎大规模机器学习平台架构设计与应用实践
围绕数据加速、模型分布式训练框架建设、大规模异构集群调度、模型开发过程标准化等AI工程化实践,全面分享如何以开发者的极致体验为核心,进行机器学习平台的设计与实现。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论