Skip to content

Python API Reference

This section documents the Python API for CRISP-NAM.

Model

crisp_nam.models.crisp_nam_model

CrispNamModel for competing-risks survival analysis.

PyTorch implementation of CrispNamModel for competing risks survival analysis with L2 normalized projection weights.

CrispNamModel

Bases: Module

Competing risks CoxNAM with L2 normalized projection weights.

Each feature contributes to each risk through a separate shape function. All projection weights are constrained to unit L2 norm.

Source code in crisp_nam/models/crisp_nam_model.py
class CrispNamModel(nn.Module):
    """Competing risks CoxNAM with L2 normalized projection weights.

    Each feature contributes to each risk through a separate shape function.
    All projection weights are constrained to unit L2 norm.
    """

    def __init__(
        self,
        num_features: int,
        num_competing_risks: int,
        hidden_sizes: Sequence[int] = (64, 64),
        dropout_rate: float = 0.1,
        feature_dropout: float = 0.1,
        batch_norm: bool = False,
        normalize_projections: bool = True,
        eps: float = 1e-8,
    ):
        """Initialize the CrispNamModel."""
        super(CrispNamModel, self).__init__()
        self.num_features = num_features
        self.num_competing_risks = num_competing_risks
        self.batch_norm = batch_norm
        self.feature_dropout = feature_dropout
        self.normalize_projections = normalize_projections
        self.eps = eps

        # Create a FeatureNet for each input feature
        self.feature_nets = nn.ModuleList(
            [
                _FeatureNet(hidden_sizes, dropout_rate, feature_dropout, batch_norm)
                for _ in range(num_features)
            ]
        )

        # For each feature and risk type, create a projection layer
        if normalize_projections:
            self.risk_projections: nn.ModuleList = nn.ModuleList(
                [
                    nn.ModuleList(
                        [
                            _L2NormalizedLinear(hidden_sizes[-1], 1, bias=False, eps=eps)
                            for _ in range(num_competing_risks)
                        ]
                    )
                    for _ in range(num_features)
                ]
            )
        else:
            # Fallback to standard linear layers
            self.risk_projections = nn.ModuleList(
                [
                    nn.ModuleList(
                        [
                            nn.Linear(hidden_sizes[-1], 1, bias=False)
                            for _ in range(num_competing_risks)
                        ]
                    )
                    for _ in range(num_features)
                ]
            )

    def forward(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """Forward pass to compute risk scores for all competing risks.

        Parameters
        -----------
            x: Tensor of shape (batch_size, num_features)

        Returns
        -------
            risk_scores: List of (batch_size, 1) Tensors
            feature_outputs: List of (batch_size, hidden) Tensors
        """
        # ensure float32
        x = x.to(dtype=torch.float32)
        batch_size, _ = x.shape
        device = x.device

        # one-shot feature dropout
        if self.training and self.feature_dropout > 0:
            # bernoulli_ is in-place, fast
            mask = torch.empty_like(x).bernoulli_(1.0 - self.feature_dropout)
            x = x * mask

        # pre-allocate combined scores [batch, num_risks]
        combined = torch.zeros(batch_size, self.num_competing_risks, device=device)
        feature_outputs = []

        # loop features
        for feat_idx, fnet in enumerate(self.feature_nets):
            # take one column and get repr
            col = x[:, feat_idx].unsqueeze(1)  # [batch,1]
            repr = fnet(col)  # [batch, hidden]
            feature_outputs.append(repr)

            proj: Optional[nn.ModuleList | None] = None
            # project into each risk channel with L2 normalized weights
            for risk_idx, proj in enumerate(self.risk_projections[feat_idx]):
                # proj automatically applies L2 normalization
                # if normalize_projections=True
                combined[:, risk_idx] += proj(repr).view(-1)

        # split back into list of [batch,1]
        risk_scores = [
            combined[:, r].unsqueeze(1) for r in range(self.num_competing_risks)
        ]

        return risk_scores, feature_outputs

    def get_projection_norms(self) -> dict:
        """Get the L2 norms of all projection weights (should be ~1.0 if normalized).

        Returns
        -------
            Dictionary of weight norms by feature and risk
        """
        norms = {}

        for feat_idx in range(self.num_features):
            for risk_idx in range(self.num_competing_risks):
                proj = self.risk_projections[feat_idx][risk_idx]

                if hasattr(proj, "weight"):
                    weight_norm = proj.weight.norm(p=2, dim=1).item()
                    norms[f"feature_{feat_idx}_risk_{risk_idx}"] = weight_norm

        return norms

    def get_normalized_projection_weights(self) -> dict:
        """Get the actual L2 normalized weights used in computation.

        Returns
        -------
            Dictionary of normalized weights
        """
        normalized_weights = {}

        for feat_idx in range(self.num_features):
            for risk_idx in range(self.num_competing_risks):
                proj = self.risk_projections[feat_idx][risk_idx]

                if hasattr(proj, "get_normalized_weights"):
                    # L2NormalizedLinear layer
                    weights = proj.get_normalized_weights().detach().cpu().numpy()
                elif hasattr(proj, "weight"):
                    # Standard linear layer - normalize manually
                    weights = (
                        F.normalize(proj.weight, p=2, dim=1).detach().cpu().numpy()
                    )
                else:
                    weights = None

                normalized_weights[f"feature_{feat_idx}_risk_{risk_idx}"] = weights

        return normalized_weights

    def calculate_feature_importance(
        self,
        x_data: Optional[torch.Tensor | np.ndarray],
        feature_idx: Optional[int | None] = None,
    ) -> dict:
        """Calculate feature importance based on the magnitude of
        risk-specific projection outputs.

        With L2 normalized weights, this gives a fair
        comparison across features.

        Parameters
        -----------
            x_data: Input data tensor or numpy array
            feature_idx: Optional; if provided, only calculate
            importance for this feature

        Returns
        -------
            Dictionary of feature importances by risk type
        """

        self.eval()
        device = next(self.parameters()).device

        # Convert to tensor if needed
        if not isinstance(x_data, torch.Tensor):
            x_data = torch.FloatTensor(x_data)
        x_data = x_data.to(device)

        feature_indices = (
            [feature_idx] if feature_idx is not None else range(self.num_features)
        )
        importance: dict = {
            f"risk_{j + 1}": {} for j in range(self.num_competing_risks)
        }

        for i in feature_indices:
            # Get feature values
            feature_values = x_data[:, i].view(-1, 1)

            with torch.no_grad():
                # Get the feature representation
                feature_repr = self.feature_nets[i](feature_values)

                # Calculate importance for each risk (mean absolute value)
                # With L2 normalized weights, this is comparable across features
                for j in range(self.num_competing_risks):
                    risk_specific_output = self.risk_projections[i][j](feature_repr)
                    abs_values = torch.abs(risk_specific_output).cpu().numpy()
                    importance[f"risk_{j + 1}"][f"feature_{i}"] = float(
                        np.mean(abs_values)
                    )

        return importance

    # Utility functions for model analysis
    def analyze_projection_weights(self) -> dict:
        """Analyze the L2 norms and statistics of projection weights.

        Parameters
        -----------
            None

        Returns
        -------
            None
        """
        print("Projection Weight Analysis:")
        print("=" * 50)

        # Get weight norms
        norms = self.get_projection_norms()
        norm_values = list(norms.values())

        print("Weight L2 Norms (should be ~1.0):")
        print(f"  Mean: {np.mean(norm_values):.6f}")
        print(f"  Std:  {np.std(norm_values):.6f}")
        print(f"  Min:  {np.min(norm_values):.6f}")
        print(f"  Max:  {np.max(norm_values):.6f}")

        # Show some individual norms
        print("\nSample individual norms:")
        for _i, (name, norm) in enumerate(list(norms.items())[:6]):
            print(f"  {name}: {norm:.6f}")

        return norms

__init__

__init__(num_features, num_competing_risks, hidden_sizes=(64, 64), dropout_rate=0.1, feature_dropout=0.1, batch_norm=False, normalize_projections=True, eps=1e-08)

Initialize the CrispNamModel.

Source code in crisp_nam/models/crisp_nam_model.py
def __init__(
    self,
    num_features: int,
    num_competing_risks: int,
    hidden_sizes: Sequence[int] = (64, 64),
    dropout_rate: float = 0.1,
    feature_dropout: float = 0.1,
    batch_norm: bool = False,
    normalize_projections: bool = True,
    eps: float = 1e-8,
):
    """Initialize the CrispNamModel."""
    super(CrispNamModel, self).__init__()
    self.num_features = num_features
    self.num_competing_risks = num_competing_risks
    self.batch_norm = batch_norm
    self.feature_dropout = feature_dropout
    self.normalize_projections = normalize_projections
    self.eps = eps

    # Create a FeatureNet for each input feature
    self.feature_nets = nn.ModuleList(
        [
            _FeatureNet(hidden_sizes, dropout_rate, feature_dropout, batch_norm)
            for _ in range(num_features)
        ]
    )

    # For each feature and risk type, create a projection layer
    if normalize_projections:
        self.risk_projections: nn.ModuleList = nn.ModuleList(
            [
                nn.ModuleList(
                    [
                        _L2NormalizedLinear(hidden_sizes[-1], 1, bias=False, eps=eps)
                        for _ in range(num_competing_risks)
                    ]
                )
                for _ in range(num_features)
            ]
        )
    else:
        # Fallback to standard linear layers
        self.risk_projections = nn.ModuleList(
            [
                nn.ModuleList(
                    [
                        nn.Linear(hidden_sizes[-1], 1, bias=False)
                        for _ in range(num_competing_risks)
                    ]
                )
                for _ in range(num_features)
            ]
        )

forward

forward(x)

Forward pass to compute risk scores for all competing risks.

Returns:

Type Description
risk_scores: List of (batch_size, 1) Tensors
Source code in crisp_nam/models/crisp_nam_model.py
def forward(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    """Forward pass to compute risk scores for all competing risks.

    Parameters
    -----------
        x: Tensor of shape (batch_size, num_features)

    Returns
    -------
        risk_scores: List of (batch_size, 1) Tensors
        feature_outputs: List of (batch_size, hidden) Tensors
    """
    # ensure float32
    x = x.to(dtype=torch.float32)
    batch_size, _ = x.shape
    device = x.device

    # one-shot feature dropout
    if self.training and self.feature_dropout > 0:
        # bernoulli_ is in-place, fast
        mask = torch.empty_like(x).bernoulli_(1.0 - self.feature_dropout)
        x = x * mask

    # pre-allocate combined scores [batch, num_risks]
    combined = torch.zeros(batch_size, self.num_competing_risks, device=device)
    feature_outputs = []

    # loop features
    for feat_idx, fnet in enumerate(self.feature_nets):
        # take one column and get repr
        col = x[:, feat_idx].unsqueeze(1)  # [batch,1]
        repr = fnet(col)  # [batch, hidden]
        feature_outputs.append(repr)

        proj: Optional[nn.ModuleList | None] = None
        # project into each risk channel with L2 normalized weights
        for risk_idx, proj in enumerate(self.risk_projections[feat_idx]):
            # proj automatically applies L2 normalization
            # if normalize_projections=True
            combined[:, risk_idx] += proj(repr).view(-1)

    # split back into list of [batch,1]
    risk_scores = [
        combined[:, r].unsqueeze(1) for r in range(self.num_competing_risks)
    ]

    return risk_scores, feature_outputs

get_projection_norms

get_projection_norms()

Get the L2 norms of all projection weights (should be ~1.0 if normalized).

Returns:

Type Description
Dictionary of weight norms by feature and risk
Source code in crisp_nam/models/crisp_nam_model.py
def get_projection_norms(self) -> dict:
    """Get the L2 norms of all projection weights (should be ~1.0 if normalized).

    Returns
    -------
        Dictionary of weight norms by feature and risk
    """
    norms = {}

    for feat_idx in range(self.num_features):
        for risk_idx in range(self.num_competing_risks):
            proj = self.risk_projections[feat_idx][risk_idx]

            if hasattr(proj, "weight"):
                weight_norm = proj.weight.norm(p=2, dim=1).item()
                norms[f"feature_{feat_idx}_risk_{risk_idx}"] = weight_norm

    return norms

get_normalized_projection_weights

get_normalized_projection_weights()

Get the actual L2 normalized weights used in computation.

Returns:

Type Description
Dictionary of normalized weights
Source code in crisp_nam/models/crisp_nam_model.py
def get_normalized_projection_weights(self) -> dict:
    """Get the actual L2 normalized weights used in computation.

    Returns
    -------
        Dictionary of normalized weights
    """
    normalized_weights = {}

    for feat_idx in range(self.num_features):
        for risk_idx in range(self.num_competing_risks):
            proj = self.risk_projections[feat_idx][risk_idx]

            if hasattr(proj, "get_normalized_weights"):
                # L2NormalizedLinear layer
                weights = proj.get_normalized_weights().detach().cpu().numpy()
            elif hasattr(proj, "weight"):
                # Standard linear layer - normalize manually
                weights = (
                    F.normalize(proj.weight, p=2, dim=1).detach().cpu().numpy()
                )
            else:
                weights = None

            normalized_weights[f"feature_{feat_idx}_risk_{risk_idx}"] = weights

    return normalized_weights

calculate_feature_importance

calculate_feature_importance(x_data, feature_idx=None)

Calculate feature importance based on the magnitude of risk-specific projection outputs.

With L2 normalized weights, this gives a fair comparison across features.

Returns:

Type Description
Dictionary of feature importances by risk type
Source code in crisp_nam/models/crisp_nam_model.py
def calculate_feature_importance(
    self,
    x_data: Optional[torch.Tensor | np.ndarray],
    feature_idx: Optional[int | None] = None,
) -> dict:
    """Calculate feature importance based on the magnitude of
    risk-specific projection outputs.

    With L2 normalized weights, this gives a fair
    comparison across features.

    Parameters
    -----------
        x_data: Input data tensor or numpy array
        feature_idx: Optional; if provided, only calculate
        importance for this feature

    Returns
    -------
        Dictionary of feature importances by risk type
    """

    self.eval()
    device = next(self.parameters()).device

    # Convert to tensor if needed
    if not isinstance(x_data, torch.Tensor):
        x_data = torch.FloatTensor(x_data)
    x_data = x_data.to(device)

    feature_indices = (
        [feature_idx] if feature_idx is not None else range(self.num_features)
    )
    importance: dict = {
        f"risk_{j + 1}": {} for j in range(self.num_competing_risks)
    }

    for i in feature_indices:
        # Get feature values
        feature_values = x_data[:, i].view(-1, 1)

        with torch.no_grad():
            # Get the feature representation
            feature_repr = self.feature_nets[i](feature_values)

            # Calculate importance for each risk (mean absolute value)
            # With L2 normalized weights, this is comparable across features
            for j in range(self.num_competing_risks):
                risk_specific_output = self.risk_projections[i][j](feature_repr)
                abs_values = torch.abs(risk_specific_output).cpu().numpy()
                importance[f"risk_{j + 1}"][f"feature_{i}"] = float(
                    np.mean(abs_values)
                )

    return importance

analyze_projection_weights

analyze_projection_weights()

Analyze the L2 norms and statistics of projection weights.

Returns:

Type Description
None
Source code in crisp_nam/models/crisp_nam_model.py
def analyze_projection_weights(self) -> dict:
    """Analyze the L2 norms and statistics of projection weights.

    Parameters
    -----------
        None

    Returns
    -------
        None
    """
    print("Projection Weight Analysis:")
    print("=" * 50)

    # Get weight norms
    norms = self.get_projection_norms()
    norm_values = list(norms.values())

    print("Weight L2 Norms (should be ~1.0):")
    print(f"  Mean: {np.mean(norm_values):.6f}")
    print(f"  Std:  {np.std(norm_values):.6f}")
    print(f"  Min:  {np.min(norm_values):.6f}")
    print(f"  Max:  {np.max(norm_values):.6f}")

    # Show some individual norms
    print("\nSample individual norms:")
    for _i, (name, norm) in enumerate(list(norms.items())[:6]):
        print(f"  {name}: {norm:.6f}")

    return norms

crisp_nam.models.deephit_model

PyTorch implementation of DeepHit for competing risks survival analysis.

DeepHit

Bases: Module

PyTorch implementation of DeepHit for competing risks survival analysis.

Source code in crisp_nam/models/deephit_model.py
class DeepHit(nn.Module):
    """PyTorch implementation of DeepHit for competing risks survival analysis."""

    def __init__(self, input_dims: dict, network_settings: dict):
        """Initialize the DeepHit model."""
        super(DeepHit, self).__init__()

        # Input dimensions
        self.x_dim = input_dims["x_dim"]
        self.num_Event = input_dims["num_Event"]
        self.num_Category = input_dims["num_Category"]

        # Network settings
        self.h_dim_shared = network_settings["h_dim_shared"]
        self.h_dim_CS = network_settings["h_dim_CS"]
        self.num_layers_shared = network_settings["num_layers_shared"]
        self.num_layers_CS = network_settings["num_layers_CS"]

        # Activation function
        if network_settings["active_fn"] == "relu":
            self.active_fn = nn.ReLU()
        elif network_settings["active_fn"] == "elu":
            self.active_fn = nn.ELU()
        elif network_settings["active_fn"] == "tanh":
            self.active_fn = nn.Tanh()
        else:
            self.active_fn = nn.ReLU()

        # Regularization
        self.keep_prob = network_settings.get("keep_prob", 0.5)
        self.dropout_rate = 1.0 - self.keep_prob

        # Initialize networks
        self._build_network()

    def _build_network(self) -> None:
        """Build the shared and cause-specific networks.

        Parameters
        ----------
            None

        Returns
        -------
            None
        """
        # Shared network
        self.shared_net = FCNet(
            in_dim=self.x_dim,
            num_layers=self.num_layers_shared,
            h_dim=self.h_dim_shared,
            activation=self.active_fn,
            dropout_rate=self.dropout_rate,
        )

        # Cause-specific networks
        self.cs_nets = nn.ModuleList(
            [
                FCNet(
                    in_dim=self.x_dim
                    + self.h_dim_shared,  # Concatenate input and shared output
                    num_layers=self.num_layers_CS,
                    h_dim=self.h_dim_CS,
                    activation=self.active_fn,
                    dropout_rate=self.dropout_rate,
                )
                for _ in range(self.num_Event)
            ]
        )

        # Output layer
        self.output_layer = nn.Linear(
            self.num_Event * self.h_dim_CS, self.num_Event * self.num_Category
        )

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
        """Forward pass through the network.

        Parameters
        ----------
            x: Tensor of shape (batch_size, num_Event, num_Category)

        Returns
        -------
            risk_scores: List of (batch_size, 1) Tensors
            feature_outputs: None
        """
        # Shared network
        shared_out = self.shared_net(x)

        # Concatenate input with shared output
        h = torch.cat([x, shared_out], dim=1)

        # Cause-specific networks
        cs_outputs = []
        for cs_net in self.cs_nets:
            cs_out = cs_net(h)
            cs_outputs.append(cs_out)

        # Stack outputs
        stacked_out = torch.stack(
            cs_outputs, dim=1
        )  # [batch_size, num_Event, h_dim_CS]
        reshaped_out = stacked_out.view(
            -1, self.num_Event * self.h_dim_CS
        )  # [batch_size, num_Event * h_dim_CS]

        # Final output layer
        logits = self.output_layer(
            F.dropout(reshaped_out, self.dropout_rate, self.training)
        )
        out = F.softmax(logits.view(-1, self.num_Event * self.num_Category), dim=1)

        # Reshape to [batch_size, num_Event, num_Category]
        out = out.view(-1, self.num_Event, self.num_Category)

        # For compatibility with the training script, return both
        # raw risks and shape functions
        # In this model, we don't have separate shape functions, so just return None
        return out, None

    def _log_likelihood_loss(
        self,
        out: torch.Tensor,
        t: Optional[torch.Tensor | np.ndarray],
        k: Optional[torch.Tensor | np.ndarray],
        mask1: torch.Tensor,
        mask2: torch.Tensor,
    ) -> torch.Tensor:
        """Log-likelihood loss (including log-likelihood of censored subjects).

        Parameters
        ----------
            out: Torch.tensor
            t: Torch.tensor or numpy array
            k: Torch.tensor or numpy array
            mask1: Torch.tensor
            mask2: Torch.tensor

        Returns
        -------
            loss: Torch.tensor
        """
        # Convert to PyTorch tensors if necessary
        if not isinstance(k, torch.Tensor):
            k = torch.tensor(k, device=out.device)
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t, device=out.device)

        # Indicator for uncensored subjects
        i_1 = (k > 0).float().view(-1, 1)

        # For uncensored: log P(T=t, K=k|x)
        tmp1 = torch.sum(torch.sum(mask1 * out, dim=2), dim=1, keepdim=True)
        tmp1 = i_1 * torch.log(tmp1 + 1e-8)

        # For censored: log ∑ P(T>t|x)
        tmp2 = torch.sum(
            torch.sum(mask2.unsqueeze(1) * out, dim=2), dim=1, keepdim=True
        )
        tmp2 = (1.0 - i_1) * torch.log(tmp2 + 1e-8)

        return -torch.mean(tmp1 + tmp2)

    def _ranking_loss(
        self,
        out: torch.Tensor,
        t: Optional[torch.Tensor | np.ndarray],
        k: Optional[torch.Tensor | np.ndarray],
        mask2: torch.Tensor,
    ) -> torch.Tensor:
        """Ranking loss (calculated only for acceptable pairs).

        Parameters
        ----------
            out: Torch.tensor
            t: Torch.tensor or numpy array
            k: Torch.tensor or numpy array
            mask2: Torch.tensor

        Returns
        -------
            loss: Torch.tensor
        """
        sigma1 = 0.1
        eta = []

        # Convert to PyTorch tensors if necessary
        if not isinstance(k, torch.Tensor):
            k = torch.tensor(k, device=out.device)
        if not isinstance(t, torch.Tensor):
            t = torch.tensor(t, device=out.device)

        one_vector = torch.ones_like(t)

        for e in range(self.num_Event):
            i_2 = (k == e + 1).float()
            i_2_diag = torch.diag(i_2.squeeze())

            # Extract event-specific probabilities
            tmp_e = out[:, e, :]  # [batch_size, num_Category]

            # Calculate risk scores
            r = torch.matmul(tmp_e, mask2.transpose(0, 1))  # [batch_size, batch_size]
            diag_r = torch.diag(r).unsqueeze(1)  # [batch_size, 1]
            r = (
                torch.matmul(one_vector, diag_r.transpose(0, 1)) - r
            )  # [batch_size, batch_size]
            r = r.transpose(0, 1)  # Now R_ij = r_i(T_i) - r_j(T_i)

            # Time comparison matrix
            time = F.relu(
                torch.sign(
                    torch.matmul(one_vector, t.transpose(0, 1))
                    - torch.matmul(t, one_vector.transpose(0, 1))
                )
            )

            # Filter by event occurrence
            time = torch.matmul(i_2_diag, time)

            # Calculate ranking loss for current event
            tmp_eta = torch.mean(time * torch.exp(-r / sigma1), dim=1, keepdim=True)
            eta.append(tmp_eta)

        eta = torch.stack(eta, dim=1)  # [batch_size, num_Event]
        eta_mean = torch.mean(eta.reshape(-1, self.num_Event), dim=1, keepdim=True)

        return torch.sum(eta_mean)

    def _calibration_loss(
        self,
        out: torch.Tensor,
        t: Optional[torch.Tensor | np.ndarray],
        k: Optional[torch.Tensor | np.ndarray],
        mask2: torch.Tensor,
    ) -> torch.Tensor:
        """Calibration loss.

        Parameters
        ----------
            out: Torch.tensor
            t: Torch.tensor or numpy array
            k: Torch.tensor or numpy array
            mask2: Torch.tensor

        Returns
        -------
            loss: Torch.tensor
        """
        eta = []

        # Convert to PyTorch tensors if necessary
        if not isinstance(k, torch.Tensor):
            k = torch.tensor(k, device=out.device)

        for e in range(self.num_Event):
            # Indicator for current event type
            i_2 = (k == e + 1).float()

            # Extract event-specific probabilities
            tmp_e = out[:, e, :]  # [batch_size, num_Category]

            # Calculate calibration loss
            r = torch.sum(tmp_e * mask2, dim=1)
            tmp_eta = torch.mean((r - i_2) ** 2, dim=0, keepdim=True)
            eta.append(tmp_eta)

        eta = torch.stack(eta, dim=1)  # [1, num_Event]
        eta_mean = torch.mean(eta.reshape(-1, self.num_Event), dim=1, keepdim=True)

        return torch.sum(eta_mean)

    def compute_loss(
        self,
        out: torch.Tensor,
        t: Optional[torch.Tensor | np.ndarray],
        k: Optional[torch.Tensor | np.ndarray],
        mask1: Optional[torch.Tensor | np.ndarray],
        mask2: torch.Tensor,
        alpha: float = 1.0,
        beta: float = 1.0,
        gamma: float = 1.0,
    ) -> torch.Tensor:
        """Compute total loss.

        Parameters
        ----------
            out: Torch.tensor
            t: Torch.tensor or numpy array
            k: Torch.tensor or numpy array
            mask1: Torch.tensor
            mask2: Torch.tensor
            alpha: float, weight for log-likelihood loss
            beta: float, weight for ranking loss
            gamma: float, weight for calibration loss

        Returns
        -------
            total_loss: Torch.tensor
        """
        loss1 = self._log_likelihood_loss(out, t, k, mask1, mask2)
        loss2 = self._ranking_loss(out, t, k, mask2)
        loss3 = self._calibration_loss(out, t, k, mask2)

        # L2 regularization is handled by optimizer (weight_decay)
        return alpha * loss1 + beta * loss2 + gamma * loss3

    def predict(self, x: torch.Tensor) -> torch.Tensor:
        """Predict risk scores for input x.

        Parameters
        ----------
            x: Tensor of shape (batch_size, num_Event, num_Category)

        Returns
        -------
            out: Tensor of shape (batch_size, num_Event, num_Category)
        """
        self.eval()
        with torch.no_grad():
            out, _ = self.forward(x)
        return out

forward

forward(x)

Forward pass through the network.

Returns:

Type Description
risk_scores: List of (batch_size, 1) Tensors
Source code in crisp_nam/models/deephit_model.py
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
    """Forward pass through the network.

    Parameters
    ----------
        x: Tensor of shape (batch_size, num_Event, num_Category)

    Returns
    -------
        risk_scores: List of (batch_size, 1) Tensors
        feature_outputs: None
    """
    # Shared network
    shared_out = self.shared_net(x)

    # Concatenate input with shared output
    h = torch.cat([x, shared_out], dim=1)

    # Cause-specific networks
    cs_outputs = []
    for cs_net in self.cs_nets:
        cs_out = cs_net(h)
        cs_outputs.append(cs_out)

    # Stack outputs
    stacked_out = torch.stack(
        cs_outputs, dim=1
    )  # [batch_size, num_Event, h_dim_CS]
    reshaped_out = stacked_out.view(
        -1, self.num_Event * self.h_dim_CS
    )  # [batch_size, num_Event * h_dim_CS]

    # Final output layer
    logits = self.output_layer(
        F.dropout(reshaped_out, self.dropout_rate, self.training)
    )
    out = F.softmax(logits.view(-1, self.num_Event * self.num_Category), dim=1)

    # Reshape to [batch_size, num_Event, num_Category]
    out = out.view(-1, self.num_Event, self.num_Category)

    # For compatibility with the training script, return both
    # raw risks and shape functions
    # In this model, we don't have separate shape functions, so just return None
    return out, None

compute_loss

compute_loss(out, t, k, mask1, mask2, alpha=1.0, beta=1.0, gamma=1.0)

Compute total loss.

Returns:

Type Description
total_loss: Torch.tensor
Source code in crisp_nam/models/deephit_model.py
def compute_loss(
    self,
    out: torch.Tensor,
    t: Optional[torch.Tensor | np.ndarray],
    k: Optional[torch.Tensor | np.ndarray],
    mask1: Optional[torch.Tensor | np.ndarray],
    mask2: torch.Tensor,
    alpha: float = 1.0,
    beta: float = 1.0,
    gamma: float = 1.0,
) -> torch.Tensor:
    """Compute total loss.

    Parameters
    ----------
        out: Torch.tensor
        t: Torch.tensor or numpy array
        k: Torch.tensor or numpy array
        mask1: Torch.tensor
        mask2: Torch.tensor
        alpha: float, weight for log-likelihood loss
        beta: float, weight for ranking loss
        gamma: float, weight for calibration loss

    Returns
    -------
        total_loss: Torch.tensor
    """
    loss1 = self._log_likelihood_loss(out, t, k, mask1, mask2)
    loss2 = self._ranking_loss(out, t, k, mask2)
    loss3 = self._calibration_loss(out, t, k, mask2)

    # L2 regularization is handled by optimizer (weight_decay)
    return alpha * loss1 + beta * loss2 + gamma * loss3

predict

predict(x)

Predict risk scores for input x.

Returns:

Type Description
out: Tensor of shape (batch_size, num_Event, num_Category)
Source code in crisp_nam/models/deephit_model.py
def predict(self, x: torch.Tensor) -> torch.Tensor:
    """Predict risk scores for input x.

    Parameters
    ----------
        x: Tensor of shape (batch_size, num_Event, num_Category)

    Returns
    -------
        out: Tensor of shape (batch_size, num_Event, num_Category)
    """
    self.eval()
    with torch.no_grad():
        out, _ = self.forward(x)
    return out

## Metrics

crisp_nam.metrics.calibration

Calibration metrics for time-to-event models with competing risks.

This module contains functions to compute the Brier score and integrated Brier score for competing risks.

epsilon module-attribute

epsilon = 0.0001

estimate_ipcw

estimate_ipcw(km)

Estimate the inverse probability of censoring weights (IPCW) using a Kaplan-Meier estimator.

Parameters:

Name Type Description Default
km tuple or KaplanMeierFitter
required

Returns:

Name Type Description
kmf KaplanMeierFitter

A KaplanMeierFitter instance fitted to the provided data or the input instance if already fitted.

Source code in crisp_nam/metrics/ipcw.py
def estimate_ipcw(km: tuple | KaplanMeierFitter) -> KaplanMeierFitter:
    """Estimate the inverse probability of censoring weights (IPCW)
    using a Kaplan-Meier estimator.

    Parameters
    ----------
    km : tuple or KaplanMeierFitter

    Returns
    -------
    kmf : KaplanMeierFitter
        A KaplanMeierFitter instance fitted to the provided data or
        the input instance if already fitted.
    """
    if isinstance(km, tuple):
        kmf = KaplanMeierFitter()
        e_train, t_train = km
        # For IPCW, we need to reverse the event indicator
        # For censoring distribution, events are when subject is censored (e_train == 0)
        c_train = (e_train == 0).astype(int)  # Convert boolean to int
        kmf.fit(t_train, event_observed=c_train)
    else:
        kmf = km
    return kmf

brier_score

brier_score(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk=1)

Compute the corrected Brier score for a given competing risk.

This implementation is based on the work of Schoop et al. on quantifying the predictive accuracy of time-to-event models in the presence of competing risks.

Returns:

Type Description
brier (float): The corrected Brier score evaluated at time t.

km (object): Updated Kaplan–Meier estimator (if applicable).

Source code in crisp_nam/metrics/calibration.py
def brier_score(
    e_test: np.ndarray,
    t_test: np.ndarray,
    risk_predicted_test: np.ndarray,
    times: np.ndarray,
    t: float,
    km: object | None = None,
    primary_risk: int = 1,
) -> tuple[float, object]:
    """Compute the corrected Brier score for a given competing risk.

    This implementation is based on the work of Schoop et al. on quantifying the
    predictive accuracy of time-to-event models in the presence of competing risks.

    Parameters
    ----------
        e_test (ndarray): Array of event indicators (0 = censored;
        positive integers for different events).
        t_test (ndarray): Array of event/censoring times.
        risk_predicted_test (ndarray): Predicted risk matrix with
        shape (n_samples, n_times).
        times (ndarray): Array of time points corresponding to columns
        in risk_predicted_test.
        t (float): Time at which to evaluate the Brier score.
        km (object, optional): Kaplan–Meier estimator or data to estimate
        the censoring distribution.
        primary_risk (int, optional): The event label for which to compute the score.

    Returns
    -------
        brier (float): The corrected Brier score evaluated at time t.
        km (object): Updated Kaplan–Meier estimator (if applicable).
    """
    # Binary truth: True if event of interest (primary_risk) occurs by time t.
    truth = (e_test == primary_risk) & (t_test <= t)
    # Find index of the time horizon closest to t.
    index = np.argmin(np.abs(times - t))
    km = estimate_ipcw(km)

    if truth.sum() == 0:
        return np.nan, km

    # If no KM is provided, compute unweighted Brier score.
    if km is None:
        return ((truth - risk_predicted_test[:, index]) ** 2).mean(), km

    # Initialize weights for IPCW correction.
    weights = np.zeros_like(e_test, dtype=float)
    # For subjects with events (or censoring) before t
    # (excluding those censored exactly at 0 event label), use KM weights.
    mask = (t_test <= t) & (e_test != 0)
    weights[mask] = 1.0 / np.clip(
        km.survival_function_at_times(t_test[mask]), epsilon, None
    )
    # For subjects still at risk at time t, assign constant weight based
    # on KM at time t.
    weights[t_test > t] = 1.0 / np.clip(km.survival_function_at_times(t), epsilon, None)

    brier = (weights * (truth - risk_predicted_test[:, index]) ** 2).mean()
    return brier, km

integrated_brier_score

integrated_brier_score(e_test, t_test, risk_predicted_test, times, t_eval=None, km=None, primary_risk=1)

Compute the integrated Brier score for competing risks over a range of time points.

The integrated Brier score is computed by numerically integrating the Brier score over the evaluation times.

Returns:

Type Description
ibs (float): Integrated Brier score.

km (object): Updated Kaplan–Meier estimator.

Source code in crisp_nam/metrics/calibration.py
def integrated_brier_score(
    e_test: np.ndarray,
    t_test: np.ndarray,
    risk_predicted_test: np.ndarray,
    times: np.ndarray,
    t_eval: np.ndarray | None = None,
    km: object | None = None,
    primary_risk: int = 1,
) -> tuple[float, object]:
    """
    Compute the integrated Brier score for competing risks over a range of time points.

    The integrated Brier score is computed by numerically integrating the
    Brier score over the evaluation times.

    Parameters
    ----------
        e_test (ndarray): Event indicators.
        t_test (ndarray): Event/censoring times.
        risk_predicted_test (ndarray): Predicted risk matrix
        with shape (n_samples, n_times).
        times (ndarray): Array of time points corresponding to the predictions.
        t_eval (ndarray, optional): Specific time points at which to
        compute the score. Defaults to using 'times'.
        km (object, optional): Kaplan–Meier estimator for IPCW.
        primary_risk (int, optional): The event label for which to compute the score.

    Returns
    -------
        ibs (float): Integrated Brier score.
        km (object): Updated Kaplan–Meier estimator.
    """
    km = estimate_ipcw(km)
    t_eval = times if t_eval is None else t_eval
    # Compute Brier scores at each time point.
    brier_scores = [
        brier_score(
            e_test, t_test, risk_predicted_test, times, t_val, km, primary_risk
        )[0]
        for t_val in t_eval
    ]
    # Remove NaN values if any.
    t_eval = t_eval[~np.isnan(brier_scores)]
    brier_scores = np.array(brier_scores)[~np.isnan(brier_scores)]

    if t_eval.shape[0] < 2:
        raise ValueError("At least two time points must be provided for integration.")

    ibs = np.trapz(brier_scores, t_eval) / (t_eval[-1] - t_eval[0])
    return ibs, km

crisp_nam.metrics.discrimination

Discrimination metrics for time-to-event models with competing risks.

This module contains functions to compute the cumulative and single time-dependent AUC and time-dependent C-index for evaluating competing risks.

epsilon module-attribute

epsilon = 1e-10

estimate_ipcw

estimate_ipcw(km)

Estimate the inverse probability of censoring weights (IPCW) using a Kaplan-Meier estimator.

Parameters:

Name Type Description Default
km tuple or KaplanMeierFitter
required

Returns:

Name Type Description
kmf KaplanMeierFitter

A KaplanMeierFitter instance fitted to the provided data or the input instance if already fitted.

Source code in crisp_nam/metrics/ipcw.py
def estimate_ipcw(km: tuple | KaplanMeierFitter) -> KaplanMeierFitter:
    """Estimate the inverse probability of censoring weights (IPCW)
    using a Kaplan-Meier estimator.

    Parameters
    ----------
    km : tuple or KaplanMeierFitter

    Returns
    -------
    kmf : KaplanMeierFitter
        A KaplanMeierFitter instance fitted to the provided data or
        the input instance if already fitted.
    """
    if isinstance(km, tuple):
        kmf = KaplanMeierFitter()
        e_train, t_train = km
        # For IPCW, we need to reverse the event indicator
        # For censoring distribution, events are when subject is censored (e_train == 0)
        c_train = (e_train == 0).astype(int)  # Convert boolean to int
        kmf.fit(t_train, event_observed=c_train)
    else:
        kmf = km
    return kmf

auc_td

auc_td(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk=1)

Compute the time-dependent AUC for a given competing risk using predicted CIFs.

Returns:

Type Description
auc_value : float
AUC estimate at time t (always between 0 and 1)

km : Updated Kaplan-Meier estimator

Source code in crisp_nam/metrics/discrimination.py
def auc_td(
    e_test: np.ndarray,
    t_test: np.ndarray,
    risk_predicted_test: np.ndarray,
    times: np.ndarray,
    t: float,
    km: object | None = None,
    primary_risk: int = 1,
) -> tuple[float, object]:
    """
    Compute the time-dependent AUC for a given competing risk using predicted CIFs.

    Parameters
    ----------
        e_test : ndarray of shape (n_samples,)
            Event indicator (0=censored, 1=event of interest, 2=competing event, etc.)
        t_test : ndarray of shape (n_samples,)
            Observed time to event or censoring.
        risk_predicted_test : ndarray of shape (n_samples, n_times)
            Predicted cumulative incidence for the event of interest across time.
        times : ndarray of shape (n_times,)
            Evaluation time grid (same axis as second dim of risk_predicted_test).
        t : float
            Specific evaluation time point.
        km : KaplanMeierFitter or tuple of (e_train, t_train), optional
            For IPCW adjustment; can be None to skip weighting.
        primary_risk : int
            The event label to treat as the event of interest.

    Returns
    -------
        auc_value : float
            AUC estimate at time t (always between 0 and 1)
        km : Updated Kaplan-Meier estimator
    """
    index = np.argmin(np.abs(times - t))
    preds = risk_predicted_test[:, index]

    # Define event group and at-risk group
    event_mask = (e_test == primary_risk) & (t_test <= t)
    control_mask = t_test > t  # those still at risk

    if event_mask.sum() == 0 or control_mask.sum() == 0:
        return np.nan, km

    event_scores = preds[event_mask]
    control_scores = preds[control_mask]

    # Compute IPCW weights
    if km is None:
        weights_event = np.ones_like(event_scores)
        weights_control = np.ones_like(control_scores)
    else:
        km = estimate_ipcw(km)
        weights_event = 1.0 / np.clip(km.predict(t_test[event_mask]), epsilon, 1.0)
        weights_control = 1.0 / np.clip(km.predict(t_test[control_mask]), epsilon, 1.0)

    # Compute pairwise AUC: compare each (event, control) pair
    auc_numerator = 0.0
    auc_denominator = 0.0
    for _i, (score_i, w_i) in enumerate(zip(event_scores, weights_event)):
        for _j, (score_j, w_j) in enumerate(zip(control_scores, weights_control)):
            weight = w_i * w_j
            auc_denominator += weight
            if score_i > score_j:
                auc_numerator += weight
            elif np.isclose(score_i, score_j):
                auc_numerator += 0.5 * weight
            # else: no increment

    auc_value = auc_numerator / auc_denominator if auc_denominator > 0 else np.nan
    return auc_value, km

cumulative_dynamic_auc

cumulative_dynamic_auc(e_test, t_test, risk_predicted_test, times, t_eval=None, km=None, primary_risk=1)

Compute the cumulative dynamic AUC by numerically integrating the time-dependent AUC over a range of time points.

Returns:

Type Description
auc_integral: float
The cumulative dynamic AUC.

km: object Updated Kaplan-Meier estimator.

Source code in crisp_nam/metrics/discrimination.py
def cumulative_dynamic_auc(
    e_test: np.ndarray,
    t_test: np.ndarray,
    risk_predicted_test: np.ndarray,
    times: np.ndarray,
    t_eval: np.ndarray | None = None,
    km: object | None = None,
    primary_risk: int = 1,
) -> tuple[float, object]:
    """Compute the cumulative dynamic AUC by numerically
    integrating the time-dependent AUC over a range of time points.

    Parameters
    ----------
        e_test, t_test, risk_predicted_test, times, km, primary_risk:
            Same as in auc_td.
        t_eval: ndarray, optional
            Specific time points to evaluate. If None, uses times.

    Returns
    -------
        auc_integral: float
            The cumulative dynamic AUC.
        km: object
            Updated Kaplan-Meier estimator.
    """
    km = estimate_ipcw(km)
    t_eval = times if t_eval is None else t_eval
    aucs = [
        auc_td(e_test, t_test, risk_predicted_test, times, t, km, primary_risk)[0]
        for t in t_eval
    ]
    t_eval, aucs = t_eval[~np.isnan(aucs)], np.array(aucs)[~np.isnan(aucs)]
    if t_eval.shape[0] < 2:
        raise ValueError("At least two time points must be given")
    auc_integral = np.trapz(aucs, t_eval) / (t_eval[-1] - t_eval[0])
    return auc_integral, km

truncated_concordance_td

truncated_concordance_td(e_test, t_test, risk_predicted_test, times, t, km=None, primary_risk=1, tied_tol=1e-08)

Compute the truncated time-dependent concordance index (C-index).

Returns:

Type Description
c_index : float

km : Updated km object

Source code in crisp_nam/metrics/discrimination.py
def truncated_concordance_td(
    e_test: np.ndarray,
    t_test: np.ndarray,
    risk_predicted_test: np.ndarray,
    times: np.ndarray,
    t: float,
    km: object | None = None,
    primary_risk: int = 1,
    tied_tol: float = 1e-8,
) -> tuple[float, object]:
    """
    Compute the truncated time-dependent concordance index (C-index).

    Parameters
    ----------
        e_test : ndarray
            Event indicator (0=censored, 1=event of interest, etc.)
        t_test : ndarray
            Time-to-event or censoring
        risk_predicted_test : ndarray
            Predicted cumulative incidence (n_samples, n_timepoints)
        times : ndarray
            Time grid
        t : float
            Specific evaluation time
        km : KaplanMeierFitter or (e_train, t_train), optional
            For IPCW weighting
        primary_risk : int
            Risk of interest
        tied_tol : float
            Tolerance to assign 0.5 score for ties

    Returns
    -------
        c_index : float
        km : Updated km object
    """
    epsilon = 1e-10
    index = np.argmin(np.abs(times - t))

    # IPCW
    if km is not None:
        km = estimate_ipcw(km)
        weights_event = np.clip(km.predict(t_test), epsilon, None)
    else:
        weights_event = np.ones_like(t_test)

    # Event of interest occurred before t
    event_mask = (e_test == primary_risk) & (t_test <= t)
    if event_mask.sum() == 0:
        return np.nan, km

    nominator = 0.0
    denominator = 0.0

    for i in np.where(event_mask)[0]:
        t_i = t_test[i]
        r_i = risk_predicted_test[i, index]
        w_i = weights_event[i]

        # Define other subjects at risk
        after_mask = t_test > t_i
        before_mask = (t_test <= t_i) & (e_test != primary_risk) & (e_test != 0)

        weights_after = weights_event[after_mask] / (w_i**2)
        weights_before = weights_event[before_mask] / (w_i * weights_event[before_mask])

        risks_after = risk_predicted_test[after_mask, index]
        risks_before = risk_predicted_test[before_mask, index]

        concordant_after = (risks_after < r_i).astype(float)
        concordant_before = (risks_before < r_i).astype(float)

        concordant_after[np.abs(risks_after - r_i) <= tied_tol] = 0.5
        concordant_before[np.abs(risks_before - r_i) <= tied_tol] = 0.5

        nominator += (concordant_after * weights_after).sum()
        nominator += (concordant_before * weights_before).sum()

        denominator += weights_after.sum()
        denominator += weights_before.sum()

    if denominator == 0:
        return np.nan, km

    c_index = nominator / denominator
    return c_index, km

crisp_nam.metrics.ipcw

IPCW estimation for time-to-event models with competing risks.

This module provides a function to estimate the inverse probability of censoring weights (IPCW) using a Kaplan-Meier estimator.

estimate_ipcw

estimate_ipcw(km)

Estimate the inverse probability of censoring weights (IPCW) using a Kaplan-Meier estimator.

Parameters:

Name Type Description Default
km tuple or KaplanMeierFitter
required

Returns:

Name Type Description
kmf KaplanMeierFitter

A KaplanMeierFitter instance fitted to the provided data or the input instance if already fitted.

Source code in crisp_nam/metrics/ipcw.py
def estimate_ipcw(km: tuple | KaplanMeierFitter) -> KaplanMeierFitter:
    """Estimate the inverse probability of censoring weights (IPCW)
    using a Kaplan-Meier estimator.

    Parameters
    ----------
    km : tuple or KaplanMeierFitter

    Returns
    -------
    kmf : KaplanMeierFitter
        A KaplanMeierFitter instance fitted to the provided data or
        the input instance if already fitted.
    """
    if isinstance(km, tuple):
        kmf = KaplanMeierFitter()
        e_train, t_train = km
        # For IPCW, we need to reverse the event indicator
        # For censoring distribution, events are when subject is censored (e_train == 0)
        c_train = (e_train == 0).astype(int)  # Convert boolean to int
        kmf.fit(t_train, event_observed=c_train)
    else:
        kmf = km
    return kmf

## Utilities

crisp_nam.utils.plotting

Utility functions for plotting.

This module provides functions to visualize feature importance and shape functions for both crisp-nam and deephit models.

plot_feature_importance

plot_feature_importance(model, x_data, feature_names=None, n_top=5, n_bottom=5, risk_idx=1, figsize=(8, 6), output_file='', color_positive='#2196F3', color_negative='#F44336')

Plot feature importance with both top positive and negative influences, handling both CPU and CUDA devices automatically.

Returns:

Type Description
- fig: Matplotlib figure object
- ax: Matplotlib axes object
- top_pos: List of top positive feature names
- top_neg: List of top negative feature names
Source code in crisp_nam/utils/plotting.py
def plot_feature_importance(
    model: torch.nn.Module,
    x_data: Union[np.ndarray, torch.Tensor],
    feature_names=None,
    n_top: int = 5,
    n_bottom: int = 5,
    risk_idx: int = 1,
    figsize: tuple = (8, 6),
    output_file: str = "",
    color_positive: str = "#2196F3",
    color_negative: str = "#F44336",
) -> tuple:
    """Plot feature importance with both top positive and negative influences,
    handling both CPU and CUDA devices automatically.

    Parameters
    ----------
    - model: A trained CoxNAM model (torch.nn.Module)
    - x_data: Input data (numpy array or torch tensor) to compute contributions
    - feature_names: Optional list of feature names (default: generic names)
    - n_top: Number of top positive features to display
    - n_bottom: Number of top negative features to display
    - risk_idx: Index of the competing risk to analyze
    - figsize: Size of the plot (width, height)
    - output_file: Optional path to save the plot image (e.g., "feature_importance.png")
    - color_positive: Color for positive contributions (default: blue)
    - color_negative: Color for negative contributions (default: red)

    Returns
    -------
    - fig: Matplotlib figure object
    - ax: Matplotlib axes object
    - top_pos: List of top positive feature names
    - top_neg: List of top negative feature names
    """

    # determine model device
    device = next(model.parameters()).device
    model.eval()

    # prepare feature names
    num_features: torch.Tensor = model.num_features
    if feature_names is None:
        feature_names = [f"Feature {i + 1}" for i in range(num_features)]

    # convert x_data to tensor on the model device
    if not isinstance(x_data, torch.Tensor):
        x = torch.tensor(x_data, dtype=torch.float32, device=device)
    else:
        x = x_data.to(device)

    feature_contribs = {}
    risk_idx0 = risk_idx - 1

    with torch.no_grad():
        for i in range(num_features):
            vals = x[:, i].unsqueeze(1)  # shape (N,1)
            if torch.var(vals) <= 1e-8:
                feature_contribs[feature_names[i]] = 0.0
                continue

            # forward through the feature net and projection
            rep : torch.nn.ModuleList = model.feature_nets[i](vals)
            proj : torch.nn.ModuleList = model.risk_projections[i][risk_idx0](rep)
            # mean contribution as a Python float
            contrib = proj.mean().item()
            feature_contribs[feature_names[i]] = contrib

    # build a DataFrame for sorting
    df = pd.DataFrame(
        {
            "feature": list(feature_contribs.keys()),
            "contribution": list(feature_contribs.values()),
        }
    )
    df["abs_contrib"] = df["contribution"].abs()
    df = df.sort_values("abs_contrib", ascending=False)

    pos = df[df["contribution"] > 0].head(n_top).sort_values("contribution")
    neg = (
        df[df["contribution"] < 0]
        .head(n_bottom)
        .sort_values("contribution", ascending=False)
    )

    top_pos = pos["feature"].tolist()
    top_neg = neg["feature"].tolist()

    # plotting
    fig, ax = plt.subplots(figsize=figsize)
    ax.barh(pos["feature"], pos["contribution"], color=color_positive, alpha=0.8)
    ax.barh(neg["feature"], neg["contribution"], color=color_negative, alpha=0.8)
    ax.axvline(0, color="black", linestyle="-", alpha=0.3)

    ax.set_xlabel("Contribution to Risk Score")
    ax.set_title(
        f"Top {n_top} Positive & {n_bottom} Negative Features for risk_{risk_idx}"
    )
    ax.grid(axis="x", linestyle="--", alpha=0.5)
    plt.tight_layout()

    if output_file:
        plt.savefig(output_file, bbox_inches="tight", dpi=300)

    return fig, ax, top_pos, top_neg

plot_coxnam_shape_functions

plot_coxnam_shape_functions(model, X, risk_to_plot=1, feature_names=None, top_features=None, ncols=3, figsize=(12, 8), output_file='')

Plot shape functions for each feature in a CoxNAM model, automatically handling CPU vs CUDA inputs.

Returns:

Type Description
- fig: Matplotlib figure object
- axes: List of Matplotlib axes objects for each plotted feature
Source code in crisp_nam/utils/plotting.py
def plot_coxnam_shape_functions(
    model: torch.nn.Module,
    X: Union[np.ndarray, torch.Tensor],
    risk_to_plot: int = 1,
    feature_names: np.ndarray | None = None,
    top_features: List[str] | None = None,
    ncols: int = 3,
    figsize: tuple = (12, 8),
    output_file: str = "",
) -> tuple:
    """Plot shape functions for each feature in a CoxNAM model,
    automatically handling CPU vs CUDA inputs.

    Parameters
    ----------
    - model: A trained CoxNAM model (torch.nn.Module)
    - X: Input data (numpy array or torch tensor) to compute shape functions
    - risk_to_plot: Index of the competing risk to visualize
    - feature_names: Optional list of feature names (default: generic names)
    - top_features: Optional list of feature names to plot features)
    - ncols: Number of columns in the subplot grid
    - figsize: Size of the entire figure (width, height)
    - output_file: Optional path to save the plot image (e.g., "shape_functions.png")

    Returns
    -------
    - fig: Matplotlib figure object
    - axes: List of Matplotlib axes objects for each plotted feature
    """
    device = next(model.parameters()).device
    model.eval()
    risk_idx = risk_to_plot - 1

    # ensure X is a numpy array
    X_np = X.cpu().numpy() if isinstance(X, torch.Tensor) else np.array(X, dtype=float)

    # derive feature list
    num_features = model.num_features
    print(f'{plot_coxnam_shape_functions.__name__}: top_features={top_features}')
    if feature_names is None:
        feature_names = [f"Feature {i + 1}" for i in range(num_features)]
    if top_features is not None :
        # map names back to indices
        idx_map = {name: i for i, name in enumerate(feature_names)}
        selected = [(idx_map.get(name), name) for name in top_features]
        selected = [(i, name) for i, name in selected if i is not None]
    else:
        selected = list(zip(range(num_features), feature_names))

    print(f'{plot_coxnam_shape_functions.__name__}: num_features={num_features}, feature_names={feature_names}, top_features={top_features}')
    print(f'{plot_coxnam_shape_functions.__name__}: selected={selected}')
    n_selected = len(selected)
    nrows = int(np.ceil(n_selected / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(figsize))
    axes = np.array(axes).reshape(-1)

    with torch.no_grad():
        for ax, (f_idx, fname) in zip(axes, selected):
            vals = X_np[:, f_idx]
            if vals.size == 0:
                ax.text(0.5, 0.5, "no data", ha="center", va="center")
                continue

            # choose evaluation points
            if np.issubdtype(vals.dtype, np.integer) or len(np.unique(vals)) <= 10:
                pts = np.unique(vals)
            else:
                pts = np.linspace(vals.min(), vals.max(), 100)

            # convert to tensor on correct device
            t_pts = torch.tensor(pts, dtype=torch.float32, device=device).unsqueeze(1)

            # compute shape values
            rep : torch.nn.ModuleList = model.feature_nets[f_idx](t_pts)
            proj : torch.nn.ModuleList = model.risk_projections[f_idx][risk_idx](rep)
            shp = proj.squeeze(-1).cpu().numpy()

            # plot
            ax.plot(pts, shp, linewidth=2)
            ax.axhline(0, linestyle="--", alpha=0.5)
            ax.set_title(fname)
            ax.set_xlabel("Value")
            ax.set_ylabel("Contribution")
            # rug plot
            ax.plot(vals, np.zeros_like(vals) - 0.1, "|", alpha=0.3)

    # turn off any extra axes
    for ax in axes[n_selected:]:
        ax.axis("off")

    fig.suptitle(f"Shape Functions for Risk {risk_to_plot}", fontsize=14)
    plt.tight_layout(rect=(0, 0, 1, 0.96))
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches="tight")

    return fig, axes[:n_selected]

crisp_nam.utils.risk_cif

Risk functions for evaluation.

This module provides functions to compute cumulative incidence functions (CIFs) and risk scores for competing risk models.

compute_baseline_cif

compute_baseline_cif(times, events, eval_times, event_type)

Compute baseline cumulative incidence function for a specific event type.

Args: times: Numpy array of event times events: Numpy array of event indicators (0=censored, 1...K=event types) eval_times: Times at which to evaluate the CIF event_type: Event type to compute CIF for (1...K)

Returns:

Type Description
Numpy array of baseline CIF values at eval_times
Source code in crisp_nam/utils/risk_cif.py
def compute_baseline_cif(
    times: np.ndarray, events: np.ndarray, eval_times: List[Any], event_type: np.ndarray
) -> np.ndarray:
    """
    Compute baseline cumulative incidence function for a specific event type.

    Args:
        times: Numpy array of event times
        events: Numpy array of event indicators (0=censored, 1...K=event types)
        eval_times: Times at which to evaluate the CIF
        event_type: Event type to compute CIF for (1...K)

    Returns
    -------
        Numpy array of baseline CIF values at eval_times
    """
    # Sort times and corresponding events
    sort_idx = np.argsort(times)
    sorted_times = times[sort_idx]
    sorted_events = events[sort_idx]

    # Initialize cumulative hazard
    n_samples = len(times)
    baseline_cif = np.zeros(len(eval_times))

    # For each evaluation time
    for i, t in enumerate(eval_times):
        cif_t = 0.0
        # Count number of events of the specified type before time t
        event_count = np.sum((sorted_events == event_type) & (sorted_times <= t))
        if event_count > 0:
            # Simple Aalen-Johansen estimator
            cif_t = event_count / n_samples
        baseline_cif[i] = cif_t

    return baseline_cif

predict_cif

predict_cif(model, x, baseline_cif, times, event_of_interest)

Predict cumulative incidence function for a specific competing risk.

Returns:

Type Description
cif_pred: Array of shape (n_samples, len(times)) — predicted CIF per sample.
Source code in crisp_nam/utils/risk_cif.py
def predict_cif(
    model: torch.nn.Module,
    x: np.ndarray,
    baseline_cif: np.ndarray,
    times: np.ndarray,
    event_of_interest: int,
) -> np.ndarray:
    """
    Predict cumulative incidence function for a specific competing risk.

    Parameters
    ----------
        model: Trained  model.
        x: Input tensor of shape (n_samples, n_features).
        baseline_cif: Array of shape (len(times),) —
        estimated CIF for baseline (e.g. from compute_baseline_cif).
        times: Time points at which CIF is evaluated.
        event_type: Integer, 0-based index of event of interest.

    Returns
    -------
        cif_pred: Array of shape (n_samples, len(times)) — predicted CIF per sample.
    """
    model.eval()
    with torch.no_grad():
        logits, _ = model(x)  # list of length num_risks
        f_j_x = logits[event_of_interest].squeeze(1).cpu().numpy()  # (n_samples,)

    baseline_cif = np.asarray(baseline_cif).reshape(1, -1)  # (1, T)
    risk_scores = np.exp(f_j_x).reshape(-1, 1)  # (N, 1)

    # Return Fine-Gray style CIF prediction under PH assumption
    return 1.0 - np.power(1.0 - baseline_cif, risk_scores)  # shape (N, T)

predict_risk

predict_risk(model, x_input, device='cpu')

Predicts relative risk scores for each competing risk.

Args: model : Trained model. x_input (np.ndarray or torch.Tensor): Input features of shape (n_samples, n_features). device (str): Device to run the computation on.

Returns:

Type Description
np.ndarray: Array of shape (n_samples, num_risks) with relative risk scores.
Source code in crisp_nam/utils/risk_cif.py
def predict_risk(
    model: torch.nn.Module, x_input: np.ndarray, device: str = "cpu"
) -> np.ndarray:
    """
    Predicts relative risk scores for each competing risk.

    Args:
        model : Trained model.
        x_input (np.ndarray or torch.Tensor): Input features of
        shape (n_samples, n_features).
        device (str): Device to run the computation on.

    Returns
    -------
        np.ndarray: Array of shape (n_samples, num_risks) with relative risk scores.
    """
    model.eval()

    if isinstance(x_input, np.ndarray):
        x_tensor = torch.from_numpy(x_input).float().to(device)
    else:
        x_tensor = x_input.to(device).float()

    with torch.no_grad():
        risk_outputs, _ = model(x_tensor)  # List of [batch_size, 1] tensors
        risks = torch.cat(risk_outputs, dim=1)  # Shape: [batch_size, num_risks]

    return risks.cpu().numpy()

predict_absolute_risk

predict_absolute_risk(model, x_input, baseline_cifs, times, device='cpu')

Predict absolute risk (CIF) for each competing event by given time points.

Returns:

Type Description
np.ndarray: Shape (n_samples, num_events, n_times) with predicted CIFs.
Source code in crisp_nam/utils/risk_cif.py
def predict_absolute_risk(
    model: torch.nn.Module,
    x_input: np.ndarray,
    baseline_cifs: List[Any],
    times: List[Any],
    device: str = "cpu",
) -> np.ndarray:
    """
    Predict absolute risk (CIF) for each competing event by given time points.

    Parameters
    ----------
        model: Trained  model.
        x_input (np.ndarray or Tensor): Input features, shape (n_samples, n_features).
        baseline_cifs (dict): Mapping of event index to baseline CIF
        array of shape (n_times,).
        times (np.ndarray): Time grid used for baseline_cifs.
        device: CPU or CUDA.

    Returns
    -------
        np.ndarray: Shape (n_samples, num_events, n_times) with predicted CIFs.
    """
    rel_risks = predict_risk(model, x_input, device)  # shape (n_samples, num_events)
    n_samples, num_events = rel_risks.shape
    n_times = len(times)

    abs_risks = np.zeros((n_samples, num_events, n_times))

    for k in range(num_events):
        base_cif = np.clip(baseline_cifs[k], 1e-10, 0.9999)  # avoid edge cases
        for i in range(n_samples):
            abs_risks[i, k, :] = 1 - np.power(1 - base_cif, np.exp(rel_risks[i, k]))

    return abs_risks

crisp_nam.utils.loss

Loss functions for competing risks.

This module implements weighted and un-weighted negative log-likelihood loss, L2 penalty loss functions.

weighted_negative_log_likelihood_loss

weighted_negative_log_likelihood_loss(risk_scores, times, events, num_competing_risks, event_weights=None, sample_weights=None, eps=1e-08)

Compute the weighted negative log-likelihood loss for competing risks Cox model.

Returns:

Type Description
Weighted negative log partial likelihood loss
Source code in crisp_nam/utils/loss.py
def weighted_negative_log_likelihood_loss(
    risk_scores,
    times,
    events,
    num_competing_risks,
    event_weights=None,
    sample_weights=None,
    eps=1e-8,
) -> torch.Tensor:
    """
    Compute the weighted negative log-likelihood loss for competing risks Cox model.

    Parameters
    ----------
        risk_scores: List of tensors with shape (batch_size, 1) for each competing risk
        times: Event/censoring times (batch_size,)
        events: Event indicators (0=censored, 1...K=event types) (batch_size,)
        num_competing_risks: Number of competing risks
        event_weights: Tensor of weights for each competing risk type
        (size: num_competing_risks)
        sample_weights: Tensor of weights for each sample (size: batch_size)
        eps: Small constant for numerical stability

    Returns
    -------
        Weighted negative log partial likelihood loss
    """
    device = times.device
    batch_size = times.shape[0]

    # Initialize loss
    loss = torch.tensor(0.0, device=device)

    # Set default weights if not provided
    if event_weights is None:
        event_weights = torch.ones(num_competing_risks, device=device)
    if sample_weights is None:
        sample_weights = torch.ones(batch_size, device=device)

    # Count number of events
    n_events = (events > 0).sum().item()
    if n_events == 0:
        return loss

    # Process each competing risk separately
    for k in range(1, num_competing_risks + 1):
        # Find samples with this event type
        event_mask = events == k
        n_events_k = event_mask.sum().item()

        if n_events_k == 0:
            continue

        # Get risk scores for this competing risk
        risk_k = risk_scores[k - 1].squeeze()

        # Get weight for this event type
        event_weight = event_weights[k - 1]

        # For each event of type k
        for i in range(batch_size):
            if event_mask[i]:
                # Find samples in risk set (samples with time >= event time)
                risk_set = times >= times[i]

                # Calculate log sum of exp of risk scores in risk set
                risk_set_scores = risk_k[risk_set]
                log_risk_sum = torch.logsumexp(risk_set_scores, dim=0)

                # Subtract individual risk score from log sum and apply weights
                individual_loss = log_risk_sum - risk_k[i]
                weighted_individual_loss = (
                    individual_loss * event_weight * sample_weights[i]
                )
                loss += weighted_individual_loss

    # Return average loss
    return loss / max(n_events, 1)

negative_log_likelihood_loss

negative_log_likelihood_loss(risk_scores, times, events, num_competing_risks, eps=1e-08)

Compute the negative log-likelihood loss for competing risks Cox model.

Returns:

Type Description
Negative log partial likelihood loss
Source code in crisp_nam/utils/loss.py
def negative_log_likelihood_loss(
    risk_scores: float,
    times: torch.Tensor,
    events: torch.Tensor,
    num_competing_risks: int,
    eps: float = 1e-8,
) -> torch.Tensor:
    """
    Compute the negative log-likelihood loss for competing risks Cox model.

    Parameters
    ----------
        risk_scores: List of tensors with shape (batch_size, 1) for each competing risk
        times: Event/censoring times (batch_size,)
        events: Event indicators (0=censored, 1...K=event types) (batch_size,)
        num_competing_risks: Number of competing risks
        eps: Small constant for numerical stability

    Returns
    -------
        Negative log partial likelihood loss
    """
    device = times.device
    batch_size = times.shape[0]

    # Initialize loss
    loss = torch.tensor(0.0, device=device)

    # Count number of events
    n_events = (events > 0).sum().item()
    if n_events == 0:
        return loss

    # Process each competing risk separately
    for k in range(1, num_competing_risks + 1):
        # Find samples with this event type
        event_mask = events == k
        n_events_k = event_mask.sum().item()

        if n_events_k == 0:
            continue

        # Get risk scores for this competing risk
        risk_k = risk_scores[k - 1].squeeze()

        # For each event of type k
        for i in range(batch_size):
            if event_mask[i]:
                # Find samples in risk set (samples with time >= event time)
                risk_set = times >= times[i]

                # Calculate log sum of exp of risk scores in risk set
                risk_set_scores = risk_k[risk_set]
                log_risk_sum = torch.logsumexp(risk_set_scores, dim=0)

                # Subtract individual risk score from log sum
                loss += log_risk_sum - risk_k[i]

    # Return average loss
    return loss / max(n_events, 1)

compute_l2_penalty

compute_l2_penalty(model, include_bias=False)

Compute L2 regularization penalty on model parameters.

Returns:

Type Description
L2 penalty term
Source code in crisp_nam/utils/loss.py
def compute_l2_penalty(
    model: torch.nn.Module,
    include_bias: bool = False
    ) -> torch.Tensor:
    """
    Compute L2 regularization penalty on model parameters.

    Parameters
    ----------
        model: Neural network model
        include_bias: Whether to include bias terms in regularization

    Returns
    -------
        L2 penalty term
    """
    l2_reg = 0.0
    for name, param in model.named_parameters():
        if param.requires_grad:
            # Skip bias parameters if specified
            if not include_bias and "bias" in name:
                continue
            l2_reg += torch.sum(param**2)
    return l2_reg