Skip to content

Usage

Example code to show how to use two classes in the package.

Use of CrispNamModel model class

import torch

from crisp_nam.models import CrispNamModel

# Example usage and testing
if __name__ == "__main__":

    # Generate some test data
    torch.manual_seed(42)
    test_data = torch.randn(100, 5)

    # Create model with L2 normalized projections
    model = CrispNamModel(
        num_features=5,
        num_competing_risks=3,
        hidden_sizes=[32, 32],
        normalize_projections=True,
    )

    # Forward pass
    risk_predictions = model(test_data)

    print("First 5 risk predictions:", risk_predictions[:5])

    #Calculate feature importance
    feature_importance = model.calculate_feature_importance(test_data)
    print("First 5 feature importance values:", feature_importance)

    # Analyze projection weights
    print('Analyzing projection weights...')
    projection_weights = model.analyze_projection_weights()

Use of DeepHIT model class

import torch

from crisp_nam.models.deephit_model import DeepHit

# Example usage and testing
if __name__ == "__main__":

    # Generate some test data
    torch.manual_seed(42)
    test_data = torch.randn(100, 5)

    input_dims = {
        'x_dim': test_data.shape[1],
        'num_Event': 2,
        'num_Category': 100
    }

    network_settings = {
        'h_dim_shared': 128,
        'h_dim_CS': 32,
        'num_layers_shared': 1,
        'num_layers_CS': 2,
        'active_fn': 'tanh',
        'keep_prob': 1.0 - 0.3 #1.0 - dropout_rate
    }

    model = DeepHit(input_dims, network_settings)

    out = model.predict(test_data)
    print("Output shape:", out.shape)