Competing Risks Interpretable Survival Prediction with Neural Additive Models - A model that enhances interpretability in survival analysis.
AI In Healthcare
In the past decade, AI has made numerous advances in the field of healthcare. Sophisticated algorithms are available to analyze and interpret medical images such as X-rays and CT scans, accelerating disease diagnosis. AI’s ability to synthesize and understand large amounts of patient data (genetics, biomarkers, clinical history) has led to personalized treatment plans and individualized therapies. Drug discovery, disease risk prediction and assessment, and survival analysis are other areas where new AI models have been developed. These AI-driven systems promise greater efficiency and accuracy in the healthcare ecosystem. However, many challenges remain for their widespread deployment and adoption.
Challenges
Current generative AI models achieve a very high accuracy in classification and prediction. However, they suffer from multiple limitations.
Interpretability: Deep learning models are black boxes, and it is challenging to determine how they make a particular prediction. In a high-stakes field such as healthcare, clinicians must understand the reasoning behind a model’s output. Misinterpretation can lead to incorrect treatment plans or adverse patient outcomes.
Data privacy: Many AI models are trained using vast amounts of patient data - EHR, images, etc. It contains personally identifiable information that needs to be protected. This can require regulatory practises in place before more widespread adoption.
Regulatory frameworks: Many policies have established explainability and interpretability constraints for high-risk AI systems. As an example, the Food and Drug Administration (FDA)guidance requires clear documentation of AI/ML medical device decision-making.
Our focus in this blog will be on developing interpretable techniques for survival analysis use cases. It is a robust framework for providing time-to-event predictions, especially in healthcare.
Survival Analysis: Primer
Survival analysis is a collection of statistical procedures for data analysis where the outcome is time-to-event, meaning time until an event of interest occurs. For example, clinicians utilize patient data (covariates) to predict the timing of hospital readmissions for specific medical conditions. There can be multiple outcomes (such as cardiovascular risk, death, etc.) associated with patient data, and multiple mutually exclusive events that can lead to any outcome. These outcomes are referred to as competing risks. Survival modelling learns functions to map a patient’s data (covariates) to the time of occurrence of a single outcome (time-to-event). See Appendix for details.
Competing risks: Multiple causes for the same outcome
Consider a clinician managing patients with advanced liver disease. The goal is to predict outcomes, but patients face more than one possible endpoint. Some will pass away while waiting for treatment; others will receive a liver transplant. These events compete with each other: once a patient receives a transplant, they exit the waiting-list , and vice versa.
This is the problem of competing risks, and it appears throughout the healthcare sector:
Cancer patients may die from their malignancy or from other causes
Heart disease patients face cardiovascular events or non-cardiac mortality
Hospitalized patients may be discharged or die during their stay
Traditional survival models that predict time-to-event were not designed to handle this complexity. While deep learning has introduced powerful approaches to survival analysis, most operate as black boxes, providing predictions without explanations.
CRISP-NAM addresses this gap by handling competing risks while maintaining interpretability, and allowing practitioners to examine exactly how each feature influences each risk. Below is a figure illustrating how competing risks are modeled in survival analysis.
Figure 1: Sketch of modeling competing risks in survival analysis.
Existing Models
Many statistical and machine learning techniques have been used to capture the interactions between covariates and time-to-event variables. The following are a few classes of models that exist in the literature:
Traditional: Cox Proportional Hazards(CoxPH)is a popular model. It linearly maps covariates to time-to-event prediction. It assumes a multiplicative effect of covariates on the prediction and that their effect is constant over time.
Deep Learning: This class of models uses a feed-forward neural network for mapping covariates to the risk of the outcome. This captures non-linear and high-dimensional relationships between the two variables. DeepSurv, CoxNAM, and SurvNamare a few examples.
Models for competing risks: Such models follow a cause-specific hazard framework to model multiple risks. These models map covariates to each competing risk separately and learn the joint distribution of survival times and the risks directly. DeepHitand Neural Fine Grayare a few examples.
Limitations of existing approaches
There are multiple limitations to existing approaches:
Insufficient parameters: Many traditional models(CoxPH) are linear. They fail to capture the non-linearity between the covariates and time-to-event variables.
Lack of interpretability: Deep learning survival models (DeepSurv) are black-box and suffer from a lack of interpretability. This makes it difficult to understand how individual features contribute to predictions for different competing risks.
Require >1 model: Cause-specific modelling requires learning multiple models - one for each outcome. This can be computationally expensive.
Our solution: CRISP-NAM
To address these limitations, we developed a novel approach for modeling competing risks, CRISP-NAM - Competing Risks Interpretable Survival Prediction with Neural Additive Models. It extends the NAM architecture approach to enable flexible, interpretable, and non-linear modeling.
Figure 2: CRISP-NAM combines advantages from multiple modeling approaches.
CRISP-NAM - Background
CRISP-NAM is a derivative of Generalized Additive Model (GAM) family. GAMs have been a statistical mainstay for decades. Rather than assuming linear covariate-outcome relationships, they allow each covariate to have its own contribution via a separate spline function. Spline functions capture non-linear relationships such as thresholds, curves and saturation effects that linear models cannot represent. Since covariate has its own function, and they are added together, GAMs can provide insights into how each feature affects the outcome.
Neural Networks as Shape Functions: NAMsreplace traditional spline-based functions with small neural networks, one per feature. Each neural network takes a single feature and models its contribution to the outcome. Training NAMs is easier and offers multiple advantages such as automatic weight learning, standard regularization techniques, and GPU acceleration. Even though NAMs don't model covariate interactions, their ability to model single covariate's effect is what makes them interpretable.
Architecture
The CRISP-NAM architecture consists of 4 components:
FeatureNets: Each input feature is processed by its own dedicated neural network called FeatureNet. This sub-network is designed to learn the feature’s non-linear contribution to the overall outcome. Since it isolates the effects of each feature, it is more interpretable than existing approaches.
Risk-Specific Projections: Each covariate and competing risk pair is then linearly transformed by a dedicated projection layer to measure its contribution to the outcome. Normalization constraint is added to ensure that contribution scales are comparable across risks enabling fair comparisonof feature importance between different event types.
Additive Risk Aggregation: Individual risk contribution to the hazard function is computed as the sum of individual feature contributions, preserving the additive nature of the original NAM model.
Cause-specific Log Hazard Calculation: This calculates the absolute probability for each risk by capturing the underlying temporal pattern independent of the covariates. It calculates the final risk probability with patients experiencing it within a specified time-interval and a cause-specific baseline measure.
Advantages
It retains the interpretability and feature-wise transparency of NAMs, allowing for flexible and non-linear modelling in competing risk scenarios. This enables us to visualize each feature’s contribution via shape plots, enhancing interpretability.
The calculation of absolute risks is done independently of patient covariates. This provides predictions for clinically relevant time horizons (1-year, 5-year) and facilitates better patient care.
Limitations
By design, the model excludes feature interactions. This is intentional, as interactions make interpretation exponentially more difficult. Withpfeatures, there areO(p^2)pairwise interactions, and visualization beyond pairwise terms becomes impractical. We accept some reduction in predictive power in exchange for guaranteed interpretability.
Evaluation
The dataset was evaluated using a mix of both discriminative and calibration metrics.
Time-Dependent Area-Under-the-Curve (TD-AUC): It is an extension of the standard ROC-AUC designed specifically for time-to-event (survival) data, where outcomes happen over time. It measures how well the model ranks patients for an event that happened at time t over those who didn’t experience the events or were censored.
n:Number of patients x_i, x_j:Covariates of patientiorj y_i, y_j:Time to event response for patientiorj \hat{f}(x_i), \hat{f}(x_j):Risk score assigned by the model for patientiorj \omega_i, \omega_j:Inverse probability of censoring weight for patientiorj I{.}Indicator function that denotes whether the element inside belongs to the set (= 1) or not(= 0).
Time-Dependent Concordance Index (TD-CI): Concordance Index (CI) measures the proportion of pairs for which the model correctly predicts the survival time, among all pairs for which this can be acurately determined. Time-depenent CI calculates this count within a common observation time interval. It is a good metric to assess model performance over time.
\text{TD-CI} = \frac{\sum{i \neq j} I\{\eta_i \leq \eta_j\} I\{T_i \gt T_j\}{d_j}}{\sum{i \neq j} I\{{T_i \gt T_j}\}{d_j}}i, j:patient idiorj \eta_i, \eta_j:Risk score assigned by the model for patientiorj T_i, T_jTime to event response for patientiorj d_i, d_j:indicates whetherT_ihas been fully observed (= 1) or not (= 0). I{.}Indicator function that denotes whether the element belongs to the set (= 1) or not(= 0).
Brier score (BS): This score evaluates the accuracy of a survival function at a given time t. It calculates the average of squared distances between the observed survival status and the predicted survival probability. This is a combined measure of caliberation and discrimination.
\text{Brier Score} = \frac{1}{N} \sum_{i=1}^{N} (\hat{y}_i - y_i)^2N:Number of subjects \hat{y}_i:Predicted survival probability by the model for ith subject y_i:Observed outcome in the ith subject (1 for event, 0 for non-event)
Support 2 dataset
The SUPPORT2 dataset originates from the Study to Understand Prognoses and Preferences for Outcomes and Risks of Treatments (SUPPORT2), a comprehensive investigation conducted across five U.S. medical centers between 1989 and 1994. This dataset contains 9,105 records of critically ill hospitalized adults, each characterized by 42 variables containing the following classes:
Demographics: age, death, sex, income
Physiological measurements: ph, glucose, urine
Disease severity indicators: diabetes, dementia
This dataset was chosen as it contains 2 competing risks: Death due to cancer and Death due to other reasons.
Results
Below are the performance metrics comparing CRISP-NAM and the state-of-the-art DeepHIT model for the two competing risks: Death due to cancer (Risk 1) and death due to other causes (Risk 2).
Table 1: Performance comparison between CRISP-NAM and DeepHIT on SUPPORT2 dataset. Best scores are highlighted in bold green.
CRISP-NAM shows better performance in Risk 1, and DeepHIT is better in Risk 2. Across more datasets, such as Framingham Heart Study and Primary Biliary Cholangitis (PBC), DeepHit generally achieves a higher performance, and CRISP-NAM demonstrates competitive discrimination with the added benefit of interpretability. The main point to note here is that CRISP-NAM achieves comparable performance to DeepHIT while providing interpretability, which is crucial for clinical adoption.
Interpretability
Feature importance plots: These plots summarize a covariate’s contribution to the final output of the model. In CRISP-NAM, it is the mean of the covariates risk projection vector for a particular risk. Its value determines the positive and negative contributions to the risk.
Figure 4: Feature Importance for Both Competing Risks.
NOTE: It is important to note that the feature importance plot provides a global view of how each covariate affects the risk. To determine its accuracy, physicians can validate it against their knowledge. This, along with covariate specific shape plots, provide an overall picture of how an individual covariate affects the risk as determined by the model.
Shape plots: Plotted for every covariate, it shows its individual contribution to the overall probability of the risk. Beneath each curve, rug plots illustrate the empirical distribution of feature values, highlighting regions where data is sparse. For CRISP-NAM, this plot is the risk projection vector of every covariate per risk. This enables ranking features by their impact on each competing risk, providing valuable insights into risk-specific predictor importance.
Shape Plots for Individual Features
Figure 5: Shape plots for individual features. Select dataset, risk type, and feature to view the plot.
Note: The plots may be different as shown in the paper as different re-runs of the model can give different outputs.
Conclusion
CRISP-NAM demonstrates that interpretability and competitive predictive performance are not mutually exclusive, even for complex problems like competing risks survival analysis. By extending Neural Additive Models to handle multiple event types, we provide clinicians and researchers with a tool that achieves strong discrimination while fully explaining its predictions.
In high-stakes healthcare applications, understanding why a model makes its predictions is not a luxury but a requirement. CRISP-NAM makes that understanding achievable.
The source code for reproducing all experiments and implementations described in this article is available in our CRISP-NAM repository.
Citation
@inproceedings{ramachandram2025crispnam,
title={CRISP-NAM: Competing Risks Interpretable Survival Prediction with Neural Additive Models},
author={Ramachandram, Dhanesh and Raval, Ananya},
booktitle={EXPLIMED 2025 - Second Workshop on Explainable AI for the Medical Domain},
year={2025}
}
Acknowledgments
We would like to thank the Vector Institute for supporting this research, and the open-source community for providing valuable datasets, tools and frameworks that made this work possible.
Survival analysis: Primer
The Hazard Function
Central to survival analysis is the hazard functionh(t), which represents the instantaneous risk of an event at timet, conditional on survival up to that point. It answers: "Given survival to timet, what is the instantaneous failure rate?"
Currently, the classic Cox Proportional Hazards model relates covariates to this hazard:
h(t \mid X) = h_0(t) \cdot \exp(\beta^\top X)
Here,h_0(t)is the baseline hazard common to all subjects, and\exp(\beta^\top X)is a subject-specific multiplier based on covariates. A coefficient\beta_i = 0.7corresponds to a hazard ratio of\exp(0.7) \approx 2, indicating a doubling of risk.
The proportional hazards assumption: This model assumes that the hazard ratio between any two subjects remains constant over time. Covariates may double your risk relative to another subject, but that multiplicative effect persists regardless of the time horizon.
The Competing Risks Extension
With competing risks, we require separate hazard functions for each event type. For event typek, the cause-specific hazard is defined as:
\lambda_k(t|\mathbf{x}) = \lim_{\Delta t \to 0} \frac{P(t \leq T \leq t + \Delta t, E=k \mid T \geq t, \mathbf{x})}{\Delta t}
This represents the instantaneous rate of experiencing eventkat timet, among subjects still at risk.
The key insight is that these hazards interact through the risk set. A high hazard for one event can effectively reduce the observed rate of another, not through a causal mechanism, but because experiencing one event removes subjects from the risk pool for competing events.
Cumulative Incidence Function
To compute absolute risk predictions, we require the cumulative incidence function (CIF):
F_k(t|\mathbf{x}) = \int_0^t S(u|\mathbf{x}) \cdot \lambda_k(u|\mathbf{x}) \, du