CRISP-NAM: Interpretable survival prediction

Competing Risks Interpretable Survival Prediction with Neural Additive Models - A model that enhances interpretability in survival analysis.

AI In Healthcare

In the past decade, AI has made numerous advances in the field of healthcare. Sophisticated algorithms are available to analyze and interpret medical images such as X-rays and CT scans, accelerating disease diagnosis. AI’s ability to synthesize and understand large amounts of patient data (genetics, biomarkers, clinical history) has led to personalized treatment plans and individualized therapies. Drug discovery, disease risk prediction and assessment, and survival analysis are other areas where new AI models have been developed. These AI-driven systems promise greater efficiency and accuracy in the healthcare ecosystem. However, many challenges remain for their widespread deployment and adoption.

Challenges

Current generative AI models achieve a very high accuracy in classification and prediction. However, they suffer from multiple limitations.

Our focus in this blog will be on developing interpretable techniques for survival analysis use cases. It is a robust framework for providing time-to-event predictions, especially in healthcare.

Survival Analysis: Primer

Survival analysis is a collection of statistical procedures for data analysis where the outcome is time-to-event, meaning time until an event of interest occurs. For example, clinicians utilize patient data (covariates) to predict the timing of hospital readmissions for specific medical conditions. There can be multiple outcomes (such as cardiovascular risk, death, etc.) associated with patient data, and multiple mutually exclusive events that can lead to any outcome. These outcomes are referred to as competing risks. Survival modelling learns functions to map a patient’s data (covariates) to the time of occurrence of a single outcome (time-to-event). See Appendix for details.

Competing risks: Multiple causes for the same outcome

Consider a clinician managing patients with advanced liver disease. The goal is to predict outcomes, but patients face more than one possible endpoint. Some will pass away while waiting for treatment; others will receive a liver transplant. These events compete with each other: once a patient receives a transplant, they exit the waiting-list , and vice versa.

This is the problem of competing risks, and it appears throughout the healthcare sector:

Traditional survival models that predict time-to-event were not designed to handle this complexity. While deep learning has introduced powerful approaches to survival analysis, most operate as black boxes, providing predictions without explanations.

CRISP-NAM addresses this gap by handling competing risks while maintaining interpretability, and allowing practitioners to examine exactly how each feature influences each risk. Below is a figure illustrating how competing risks are modeled in survival analysis.

Survival Model Overview
Figure 1: Sketch of modeling competing risks in survival analysis.

Existing Models

Many statistical and machine learning techniques have been used to capture the interactions between covariates and time-to-event variables. The following are a few classes of models that exist in the literature:

Limitations of existing approaches

There are multiple limitations to existing approaches:

Our solution: CRISP-NAM

To address these limitations, we developed a novel approach for modeling competing risks, CRISP-NAM - Competing Risks Interpretable Survival Prediction with Neural Additive Models. It extends the NAM architecture approach to enable flexible, interpretable, and non-linear modeling.

Figure 2: CRISP-NAM combines advantages from multiple modeling approaches.

CRISP-NAM - Background

CRISP-NAM is a derivative of Generalized Additive Model (GAM) family. GAMs have been a statistical mainstay for decades. Rather than assuming linear covariate-outcome relationships, they allow each covariate to have its own contribution via a separate spline function. Spline functions capture non-linear relationships such as thresholds, curves and saturation effects that linear models cannot represent. Since covariate has its own function, and they are added together, GAMs can provide insights into how each feature affects the outcome.

Neural Networks as Shape Functions: NAMsreplace traditional spline-based functions with small neural networks, one per feature. Each neural network takes a single feature and models its contribution to the outcome. Training NAMs is easier and offers multiple advantages such as automatic weight learning, standard regularization techniques, and GPU acceleration. Even though NAMs don't model covariate interactions, their ability to model single covariate's effect is what makes them interpretable.

Architecture

The CRISP-NAM architecture consists of 4 components:

Advantages

Limitations

By design, the model excludes feature interactions. This is intentional, as interactions make interpretation exponentially more difficult. Withpfeatures, there areO(p^2)pairwise interactions, and visualization beyond pairwise terms becomes impractical. We accept some reduction in predictive power in exchange for guaranteed interpretability.

Evaluation

The dataset was evaluated using a mix of both discriminative and calibration metrics.

  1. Time-Dependent Area-Under-the-Curve (TD-AUC): It is an extension of the standard ROC-AUC designed specifically for time-to-event (survival) data, where outcomes happen over time. It measures how well the model ranks patients for an event that happened at time t over those who didn’t experience the events or were censored.

    \widehat{\mathrm{AUC}}(t) = \frac{\sum_{i=1}^n \sum_{j=1}^n I(y_j > t) I(y_i \leq t) \omega_i I(\hat{f}(\mathbf{x}_j) \leq \hat{f}(\mathbf{x}_i))} {(\sum_{i=1}^n I(y_i > t)) (\sum_{i=1}^n I(y_i \leq t) \omega_i)}

    n:Number of patients
    x_i, x_j:Covariates of patientiorj
    y_i, y_j:Time to event response for patientiorj
    \hat{f}(x_i), \hat{f}(x_j):Risk score assigned by the model for patientiorj
    \omega_i, \omega_j:Inverse probability of censoring weight for patientiorj
    I{.}Indicator function that denotes whether the element inside belongs to the set (= 1) or not(= 0).

  2. Time-Dependent Concordance Index (TD-CI): Concordance Index (CI) measures the proportion of pairs for which the model correctly predicts the survival time, among all pairs for which this can be acurately determined. Time-depenent CI calculates this count within a common observation time interval. It is a good metric to assess model performance over time.

    \text{TD-CI} = \frac{\sum{i \neq j} I\{\eta_i \leq \eta_j\} I\{T_i \gt T_j\}{d_j}}{\sum{i \neq j} I\{{T_i \gt T_j}\}{d_j}}i, j:patient idiorj
    \eta_i, \eta_j:Risk score assigned by the model for patientiorj
    T_i, T_jTime to event response for patientiorj
    d_i, d_j:indicates whetherT_ihas been fully observed (= 1) or not (= 0).
    I{.}Indicator function that denotes whether the element belongs to the set (= 1) or not(= 0).

  3. Brier score (BS): This score evaluates the accuracy of a survival function at a given time t. It calculates the average of squared distances between the observed survival status and the predicted survival probability. This is a combined measure of caliberation and discrimination.

    \text{Brier Score} = \frac{1}{N} \sum_{i=1}^{N} (\hat{y}_i - y_i)^2N:Number of subjects
    \hat{y}_i:Predicted survival probability by the model for ith subject
    y_i:Observed outcome in the ith subject (1 for event, 0 for non-event)

  4. Support 2 dataset

    The SUPPORT2 dataset originates from the Study to Understand Prognoses and Preferences for Outcomes and Risks of Treatments (SUPPORT2), a comprehensive investigation conducted across five U.S. medical centers between 1989 and 1994. This dataset contains 9,105 records of critically ill hospitalized adults, each characterized by 42 variables containing the following classes:This dataset was chosen as it contains 2 competing risks: Death due to cancer and Death due to other reasons.

    Results

    Below are the performance metrics comparing CRISP-NAM and the state-of-the-art DeepHIT model for the two competing risks: Death due to cancer (Risk 1) and death due to other causes (Risk 2).

    Table 1: Performance comparison between CRISP-NAM and DeepHIT on SUPPORT2 dataset. Best scores are highlighted in bold green.

    CRISP-NAM shows better performance in Risk 1, and DeepHIT is better in Risk 2. Across more datasets, such as Framingham Heart Study and Primary Biliary Cholangitis (PBC), DeepHit generally achieves a higher performance, and CRISP-NAM demonstrates competitive discrimination with the added benefit of interpretability. The main point to note here is that CRISP-NAM achieves comparable performance to DeepHIT while providing interpretability, which is crucial for clinical adoption.

    Interpretability

    1. Feature importance plots: These plots summarize a covariate’s contribution to the final output of the model. In CRISP-NAM, it is the mean of the covariates risk projection vector for a particular risk. Its value determines the positive and negative contributions to the risk.
    2. Figure 4: Feature Importance for Both Competing Risks.

      NOTE: It is important to note that the feature importance plot provides a global view of how each covariate affects the risk. To determine its accuracy, physicians can validate it against their knowledge. This, along with covariate specific shape plots, provide an overall picture of how an individual covariate affects the risk as determined by the model.

    3. Shape plots: Plotted for every covariate, it shows its individual contribution to the overall probability of the risk. Beneath each curve, rug plots illustrate the empirical distribution of feature values, highlighting regions where data is sparse. For CRISP-NAM, this plot is the risk projection vector of every covariate per risk. This enables ranking features by their impact on each competing risk, providing valuable insights into risk-specific predictor importance.

    Shape Plots for Individual Features

    Figure 5: Shape plots for individual features. Select dataset, risk type, and feature to view the plot.

    Note: The plots may be different as shown in the paper as different re-runs of the model can give different outputs.

    Conclusion

    CRISP-NAM demonstrates that interpretability and competitive predictive performance are not mutually exclusive, even for complex problems like competing risks survival analysis. By extending Neural Additive Models to handle multiple event types, we provide clinicians and researchers with a tool that achieves strong discrimination while fully explaining its predictions.

    In high-stakes healthcare applications, understanding why a model makes its predictions is not a luxury but a requirement. CRISP-NAM makes that understanding achievable.

    The source code for reproducing all experiments and implementations described in this article is available in our CRISP-NAM repository.

    Citation

    @inproceedings{ramachandram2025crispnam,
        title={CRISP-NAM: Competing Risks Interpretable Survival Prediction with Neural Additive Models},
        author={Ramachandram, Dhanesh and Raval, Ananya},
        booktitle={EXPLIMED 2025 - Second Workshop on Explainable AI for the Medical Domain},
        year={2025}
        }

Acknowledgments

We would like to thank the Vector Institute for supporting this research, and the open-source community for providing valuable datasets, tools and frameworks that made this work possible.

Survival analysis: Primer

The Hazard Function

Central to survival analysis is the hazard functionh(t), which represents the instantaneous risk of an event at timet, conditional on survival up to that point. It answers: "Given survival to timet, what is the instantaneous failure rate?"

Currently, the classic Cox Proportional Hazards model relates covariates to this hazard:

h(t \mid X) = h_0(t) \cdot \exp(\beta^\top X)

Here,h_0(t)is the baseline hazard common to all subjects, and\exp(\beta^\top X)is a subject-specific multiplier based on covariates. A coefficient\beta_i = 0.7corresponds to a hazard ratio of\exp(0.7) \approx 2, indicating a doubling of risk.

The proportional hazards assumption: This model assumes that the hazard ratio between any two subjects remains constant over time. Covariates may double your risk relative to another subject, but that multiplicative effect persists regardless of the time horizon.

The Competing Risks Extension

With competing risks, we require separate hazard functions for each event type. For event typek, the cause-specific hazard is defined as:

\lambda_k(t|\mathbf{x}) = \lim_{\Delta t \to 0} \frac{P(t \leq T \leq t + \Delta t, E=k \mid T \geq t, \mathbf{x})}{\Delta t}

This represents the instantaneous rate of experiencing eventkat timet, among subjects still at risk.

The key insight is that these hazards interact through the risk set. A high hazard for one event can effectively reduce the observed rate of another, not through a causal mechanism, but because experiencing one event removes subjects from the risk pool for competing events.

Cumulative Incidence Function

To compute absolute risk predictions, we require the cumulative incidence function (CIF):

F_k(t|\mathbf{x}) = \int_0^t S(u|\mathbf{x}) \cdot \lambda_k(u|\mathbf{x}) \, du

where the overall survival function is:

S(t|\mathbf{x}) = \exp\left(-\sum_{k=1}^{K} \int_0^t \lambda_k(u|\mathbf{x}) \, du\right)

In practice, we discretize time and compute:

\hat{F}_k(t|\mathbf{x}) \approx \sum_{t_m \leq t} \hat{S}(t_{m-1}|\mathbf{x}) \cdot \hat{\lambda}_k(t_m|\mathbf{x})

Breslow Estimator for Baseline Hazard

After training, we estimate the cumulative baseline hazard using:

\hat{\Lambda}_{0k}(t) = \sum_{n: T_n \leq t, E_n = k} \frac{1}{\sum_{j: T_j \geq T_n}\exp(\eta_k(\mathbf{x}_j))}

This non-parametric estimator makes no assumptions about the functional form of the baseline hazard.