Mortality Prediction#

This notebook showcases mortality prediction on the MIMICIV dataset using CyclOps. The task is formulated as a binary classification task, whether the patient will die within the next N days. The prediction can be made after M number of days after admission. For example, if N = 14 and M = 1, we are predicting risk of patient mortality within 14 days of admission after considering 24 hours of data after admission.

Import Libraries#

[1]:
"""Mortality Prediction."""

import copy
import shutil
from datetime import date

import cycquery.ops as qo
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from cycquery import MIMICIVQuerier
from datasets import Dataset
from datasets.features import ClassLabel
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline as ImbPipeline
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder

from cyclops.data.aggregate import RESTRICT_TIMESTAMP, Aggregator
from cyclops.data.clean import normalize_names
from cyclops.data.df.feature import TabularFeatures
from cyclops.data.slicer import SliceSpec
from cyclops.evaluate.fairness import FairnessConfig  # noqa: E402
from cyclops.evaluate.metrics import create_metric
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict
from cyclops.models.catalog import create_model
from cyclops.report import ModelCardReport
from cyclops.report.plot.classification import ClassificationPlotter
from cyclops.report.utils import flatten_results_dict
from cyclops.tasks import BinaryTabularClassificationTask
from cyclops.utils.common import add_years_approximate
/mnt/data/actions_runners/cyclops-actions-runner-1/_work/cyclops/cyclops/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

CyclOps offers a package for documentation of the model through a model report. The ModelCardReport class is used to populate and generate the model report as an HTML file. The model report has the following sections:

  • Overview: Provides a high level overview of how the model is doing (a quick glance of important metrics), and how it is doing over time (performance over several metrics and subgroups over time).

  • Datasets: High level statistics of the training data, including changes in distribution over time.

  • Quantitative Analysis: This section contains additional detailed performance metrics of the model for different sets of the data and subpopulations.

  • Fairness Analysis: This section contains the fairness metrics of the model.

  • Model Details: This section contains descriptive metadata about the model such as the owners, version, license, etc.

  • Model Parameters: This section contains the technical details of the model such as the model architecture, training parameters, etc.

  • Considerations: This section contains descriptions of the considerations involved in developing and using the model such as the intended use, limitations, etc.

We will use this to document the model development process as we go along and generate the model report at the end.

The model report tool is a work in progress and is subject to change.

[2]:
report = ModelCardReport()

Constants#

[3]:
M = 1
N = 14
NAN_THRESHOLD = 0.25
TRAIN_SIZE = 0.8
RANDOM_SEED = 12

Data Querying & Processing#

Compute mortality (labels)#

  1. Get encounters

  2. Filter out encounters less than M days

  3. Set label = 1 for encounters where deathtime is within N days after admission

  4. Get lab events

  5. Aggregate them by computing mean, merge with encounter data

[4]:
querier = MIMICIVQuerier(
    dbms="postgresql",
    port=5432,
    host="localhost",
    database="mimiciv-2.0",
    user="postgres",
    password="pwd",
)


def get_encounters():
    """Get encounters data."""
    patients = querier.patients()
    encounters = querier.mimiciv_hosp.admissions()
    drop_op = qo.Drop(
        ["language", "marital_status", "edregtime", "edouttime"],
    )
    encounters = encounters.ops(drop_op)
    patient_encounters = patients.join(encounters, on="subject_id")
    patient_encounters = patient_encounters.run()
    patient_encounters["age"] = (
        patient_encounters["admittime"].dt.year
        - patient_encounters["anchor_year"]
        + patient_encounters["anchor_age"]
    )
    for col in ["admittime", "dischtime", "deathtime"]:
        patient_encounters[col] = add_years_approximate(
            patient_encounters[col],
            patient_encounters["anchor_year_difference"],
        )

    return patient_encounters[
        [
            "hadm_id",
            "admittime",
            "dischtime",
            "deathtime",
            "anchor_age",
            "age",
            "gender",
            "anchor_year_difference",
            "admission_location",
            "admission_type",
            "insurance",
            "hospital_expire_flag",
        ]
    ]


def compute_mortality_outcome(patient_encounters):
    """Compute mortality outcome."""
    # Drop encounters ending in death which don't have a death timestamp
    invalid = (patient_encounters["hospital_expire_flag"] == 1) & (
        patient_encounters["deathtime"].isna()
    )
    patient_encounters = patient_encounters[~invalid]
    print(f"Encounters with death flag but no death timestamp: {invalid.sum()}")
    # Drop encounters which are shorter than M days
    invalid = (
        patient_encounters["dischtime"] - patient_encounters["admittime"]
    ).dt.days < M
    patient_encounters = patient_encounters[~invalid]
    print(f"Encounters shorter than {M} days: {invalid.sum()}")
    # Death timestamp is within (<=) N days of admission
    valid = (
        patient_encounters["deathtime"] - patient_encounters["admittime"]
    ).dt.days <= N
    print(f"Encounters with death timestamp within {N} days: {valid.sum()}")
    # (Died in hospital) & (Death timestamp is defined)
    print(len(patient_encounters))
    patient_encounters["mortality_outcome"] = pd.Series(
        [0] * len(patient_encounters),
        index=patient_encounters.index,
        dtype="int64[pyarrow]",
    )
    patient_encounters.loc[valid, "mortality_outcome"] = 1
    print(
        f"Encounters with mortality outcome for the model: {patient_encounters['mortality_outcome'].sum()}",
    )

    return patient_encounters


def get_labevents(patient_encounters):
    """Get labevents data."""
    labevents = querier.labevents().run(index_col="hadm_id", batch_mode=True)

    def process_labevents(labevents, patient_encounters):
        """Process labevents before aggregation."""
        # Reverse deidentified dating
        labevents = pd.merge(
            patient_encounters[
                [
                    "hadm_id",
                    "anchor_year_difference",
                ]
            ],
            labevents,
            on="hadm_id",
        )
        labevents["charttime"] = add_years_approximate(
            labevents["charttime"],
            labevents["anchor_year_difference"],
        )
        labevents = labevents.drop("anchor_year_difference", axis=1)
        # Pre-processing
        labevents["label"] = normalize_names(labevents["label"])
        labevents["category"] = normalize_names(labevents["category"])

        return labevents

    start_timestamps = (
        patient_encounters[["hadm_id", "admittime"]]
        .set_index("hadm_id")
        .rename({"admittime": RESTRICT_TIMESTAMP}, axis=1)
    )
    mean_aggregator = Aggregator(
        aggfuncs={
            "valuenum": "mean",
        },
        window_duration=M * 24,
        window_start_time=start_timestamps,
        timestamp_col="charttime",
        time_by="hadm_id",
        agg_by=["hadm_id", "label"],
    )
    means_df = pd.DataFrame()
    for batch_num, labevents_batch in enumerate(labevents):
        labevents_batch = process_labevents(  # noqa: PLW2901
            labevents_batch,
            patient_encounters,
        )
        means = mean_aggregator.fit_transform(
            labevents_batch,
        )
        means = means.reset_index()
        means = means.pivot(index="hadm_id", columns="label", values="valuenum")
        means = means.add_prefix("lab_")
        means = pd.merge(
            patient_encounters[
                [
                    "hadm_id",
                    "mortality_outcome",
                    "age",
                    "gender",
                    "admission_location",
                ]
            ],
            means,
            on="hadm_id",
        )
        means_df = pd.concat([means_df, means])
        if batch_num == 2:
            break
        print("Processing batch {}".format(batch_num + 1))

    return means_df


def run_query():
    """Run query."""
    cohort = get_encounters()
    cohort = compute_mortality_outcome(cohort)

    return get_labevents(cohort)


cohort = run_query()
2024-07-16 17:03:13,199 INFO cycquery.orm    - Database setup, ready to run queries!
2024-07-16 17:03:33,160 INFO cycquery.orm    - Query returned successfully!
2024-07-16 17:03:33,163 INFO cycquery.utils.profile - Finished executing function run_query in 17.703220 s
Encounters with death flag but no death timestamp: 13
Encounters shorter than 1 days: 105518
Encounters with death timestamp within 14 days: 5925
348793
Encounters with mortality outcome for the model: 5925
2024-07-16 17:05:57,315 INFO cycquery.orm    - Query returned successfully!
2024-07-16 17:05:57,317 INFO cycquery.utils.profile - Finished executing function run_query in 141.078313 s
Processing batch 1
Processing batch 2

Data Inspection and Preprocessing#

Drop NaNs based on the NAN_THRESHOLD#

[5]:
null_counts = cohort.isnull().sum()[cohort.isnull().sum() > 0]
fig = go.Figure(data=[go.Bar(x=null_counts.index, y=null_counts.values)])

fig.update_layout(
    title="Number of Null Values per Column",
    xaxis_title="Columns",
    yaxis_title="Number of Null Values",
    height=600,
)

fig.show()

Add the figure to the report

We can use the log_plotly_figure method to add the figure to a section of the report. One can specify whether the figure should be interactive or not by setting the interactive parameter to True or False respectively. The default value is True. This also affects the final size of the report. If the figure is interactive, the size of the report will be larger than if the figure is not interactive.

[6]:
report.log_plotly_figure(
    fig=fig,
    caption="Number of Null Values per Column",
    section_name="datasets",
    interactive=True,
)
[7]:
thresh_nan = int(NAN_THRESHOLD * len(cohort))
cohort = cohort.dropna(axis=1, thresh=thresh_nan)

Outcome distribution#

[8]:
fig = px.pie(cohort, names="mortality_outcome")
fig.update_traces(textinfo="percent+label")
fig.update_layout(title_text="Outcome Distribution")
fig.update_traces(
    hovertemplate="Outcome: %{label}<br>Count: \
    %{value}<br>Percent: %{percent}",
)
fig.show()

Add the figure to the report

[9]:
report.log_plotly_figure(
    fig=fig,
    caption="Outcome Distribution",
    section_name="datasets",
)
[10]:
# The data is heavily unbalanced.
class_counts = cohort["mortality_outcome"].value_counts()
class_ratio = class_counts[0] / class_counts[1]
print(class_ratio, class_counts)
55.36758893280632 mortality_outcome
0    14008
1      253
Name: count, dtype: int64[pyarrow]

Gender distribution#

[11]:
fig = px.pie(cohort, names="gender")
fig.update_layout(
    title="Gender Distribution",
)
fig.show()

Add the figure to the report

[12]:
report.log_plotly_figure(
    fig=fig,
    caption="Gender Distribution",
    section_name="datasets",
)

Age distribution#

[13]:
fig = px.histogram(cohort, x="age")
fig.update_layout(
    title="Age Distribution",
    xaxis_title="Age",
    yaxis_title="Count",
    bargap=0.2,
)
fig.show()

Add the figure to the report

[14]:
report.log_plotly_figure(
    fig=fig,
    caption="Age Distribution",
    section_name="datasets",
)

Identifying feature types#

Cyclops TabularFeatures class helps to identify feature types, an essential step before preprocessing the data. Understanding feature types (numerical/categorical/binary) allows us to apply appropriate preprocessing steps for each type.

[15]:
features_list = set(cohort.columns.tolist()) - {"hadm_id", "mortality_outcome"}
features_list = sorted(features_list)
tab_features = TabularFeatures(
    data=cohort.reset_index(),
    features=features_list,
    by="hadm_id",
    targets="mortality_outcome",
)
print(tab_features.types)
{'lab_i': 'numeric', 'lab_bilirubin, total': 'numeric', 'lab_asparate aminotransferase (ast)': 'numeric', 'lab_red blood cells': 'numeric', 'mortality_outcome': 'binary', 'lab_h': 'numeric', 'gender': 'binary', 'lab_rdw': 'numeric', 'lab_glucose': 'numeric', 'admission_location': 'ordinal', 'lab_platelet count': 'numeric', 'lab_rdw-sd': 'numeric', 'lab_sodium': 'numeric', 'lab_mchc': 'numeric', 'lab_phosphate': 'numeric', 'lab_pt': 'numeric', 'lab_mch': 'numeric', 'lab_urea nitrogen': 'numeric', 'age': 'numeric', 'lab_hematocrit': 'numeric', 'lab_hemoglobin': 'numeric', 'lab_ph': 'numeric', 'lab_alkaline phosphatase': 'numeric', 'lab_anion gap': 'numeric', 'lab_potassium': 'numeric', 'lab_magnesium': 'numeric', 'lab_bicarbonate': 'numeric', 'lab_creatinine': 'numeric', 'lab_inr(pt)': 'numeric', 'lab_chloride': 'numeric', 'lab_calcium, total': 'numeric', 'lab_alanine aminotransferase (alt)': 'numeric', 'lab_ptt': 'numeric', 'lab_mcv': 'numeric', 'lab_l': 'numeric', 'lab_white blood cells': 'numeric'}

Creating data preprocessors#

We create a data preprocessor using sklearn’s ColumnTransformer. This helps in applying different preprocessing steps to different columns in the dataframe. For instance, binary features might be processed differently from numeric features.

[16]:
numeric_transformer = Pipeline(
    steps=[("imputer", SimpleImputer(strategy="mean")), ("scaler", MinMaxScaler())],
)
binary_transformer = Pipeline(
    steps=[("imputer", SimpleImputer(strategy="most_frequent"))],
)
[17]:
numeric_features = sorted((tab_features.features_by_type("numeric")))
numeric_indices = [
    cohort[features_list].columns.get_loc(column) for column in numeric_features
]
print(numeric_features)
['age', 'lab_alanine aminotransferase (alt)', 'lab_alkaline phosphatase', 'lab_anion gap', 'lab_asparate aminotransferase (ast)', 'lab_bicarbonate', 'lab_bilirubin, total', 'lab_calcium, total', 'lab_chloride', 'lab_creatinine', 'lab_glucose', 'lab_h', 'lab_hematocrit', 'lab_hemoglobin', 'lab_i', 'lab_inr(pt)', 'lab_l', 'lab_magnesium', 'lab_mch', 'lab_mchc', 'lab_mcv', 'lab_ph', 'lab_phosphate', 'lab_platelet count', 'lab_potassium', 'lab_pt', 'lab_ptt', 'lab_rdw', 'lab_rdw-sd', 'lab_red blood cells', 'lab_sodium', 'lab_urea nitrogen', 'lab_white blood cells']
[18]:
binary_features = sorted(tab_features.features_by_type("binary"))
ordinal_features = sorted(tab_features.features_by_type("ordinal"))
binary_features.remove("mortality_outcome")
binary_indices = [
    cohort[features_list].columns.get_loc(column) for column in binary_features
]
ordinal_indices = [
    cohort[features_list].columns.get_loc(column) for column in ordinal_features
]
print(binary_features, ordinal_features)
['gender'] ['admission_location']
[19]:
preprocessor = ColumnTransformer(
    transformers=[
        ("num", numeric_transformer, numeric_indices),
        (
            "onehot",
            OneHotEncoder(handle_unknown="ignore"),
            binary_indices + ordinal_indices,
        ),
    ],
    remainder="passthrough",
)
preprocessor_pipeline = [
    ("preprocessor", preprocessor),
    ("oversampling", SMOTE(random_state=RANDOM_SEED)),
]
preprocessor_pipeline = ImbPipeline(preprocessor_pipeline)

Creating Hugging Face Dataset#

We convert our processed Pandas dataframe into a Hugging Face dataset, a powerful and easy-to-use data format which is also compatible with CyclOps models and evaluator modules. The dataset is then split to train and test sets.

[20]:
cohort = cohort.drop(columns=["hadm_id"])
dataset = Dataset.from_pandas(cohort)
dataset.cleanup_cache_files()
[20]:
0
[21]:
dataset = dataset.cast_column("mortality_outcome", ClassLabel(num_classes=2))
dataset = dataset.train_test_split(
    train_size=TRAIN_SIZE,
    stratify_by_column="mortality_outcome",
    seed=RANDOM_SEED,
)
Casting the dataset: 0%| | 0/14261 [00:00&lt;?, ? examples/s]

</pre>

Casting the dataset: 0%| | 0/14261 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Casting the dataset: 0%| | 0/14261 [00:00<?, ? examples/s]

Casting the dataset: 100%|██████████| 14261/14261 [00:00&lt;00:00, 233379.64 examples/s]

</pre>

Casting the dataset: 100%|██████████| 14261/14261 [00:00<00:00, 233379.64 examples/s]

end{sphinxVerbatim}

Casting the dataset: 100%|██████████| 14261/14261 [00:00<00:00, 233379.64 examples/s]


Model Creation#

CyclOps model registry allows for straightforward creation and selection of models. This registry maintains a list of pre-configured models, which can be instantiated with a single line of code. Here we use a XGBoost classifier to fit a logisitic regression model. The model configurations can be passed to create_model based on the parameters for XGBClassifier.

[22]:
model_name = "xgb_classifier"
model = create_model(model_name, random_state=123)

Task Creation#

We use Cyclops tasks to define our model’s task (in this case, BinaryTabularClassificationTask), train the model, make predictions, and evaluate performance. Cyclops task classes encapsulate the entire ML pipeline into a single, cohesive structure, making the process smooth and easy to manage.

[23]:
mortality_task = BinaryTabularClassificationTask(
    {model_name: model},
    task_features=features_list,
    task_target="mortality_outcome",
)
mortality_task.list_models()
[23]:
['xgb_classifier']

Training#

If best_model_params is passed to the train method, the best model will be selected after the hyperparameter search. The parameters in best_model_params indicate the values to create the parameters grid.

Note that the data preprocessor needs to be passed to the tasks methods if the Hugging Face dataset is not already preprocessed.

[24]:
best_model_params = {
    "n_estimators": [100, 250, 500],
    "learning_rate": [0.1, 0.01],
    "max_depth": [2, 5],
    "reg_lambda": [0, 1, 10],
    "colsample_bytree": [0.7, 0.8, 1],
    "gamma": [0, 1, 2, 10],
    "method": "random",
}
mortality_task.train(
    dataset["train"],
    model_name=model_name,
    transforms=preprocessor_pipeline,
    best_model_params=best_model_params,
)
2024-07-16 17:15:47,880 INFO cyclops.models.wrappers.sk_model - Best reg_lambda: 0
2024-07-16 17:15:47,882 INFO cyclops.models.wrappers.sk_model - Best n_estimators: 500
2024-07-16 17:15:47,883 INFO cyclops.models.wrappers.sk_model - Best max_depth: 5
2024-07-16 17:15:47,885 INFO cyclops.models.wrappers.sk_model - Best learning_rate: 0.1
2024-07-16 17:15:47,886 INFO cyclops.models.wrappers.sk_model - Best gamma: 1
2024-07-16 17:15:47,887 INFO cyclops.models.wrappers.sk_model - Best colsample_bytree: 0.7
[24]:
XGBClassifier(base_score=None, booster=None, callbacks=None,
              colsample_bylevel=None, colsample_bynode=None,
              colsample_bytree=0.7, early_stopping_rounds=None,
              enable_categorical=False, eval_metric='logloss',
              feature_types=None, gamma=1, gpu_id=None, grow_policy=None,
              importance_type=None, interaction_constraints=None,
              learning_rate=0.1, max_bin=None, max_cat_threshold=None,
              max_cat_to_onehot=None, max_delta_step=None, max_depth=5,
              max_leaves=None, min_child_weight=3, missing=nan,
              monotone_constraints=None, n_estimators=500, n_jobs=None,
              num_parallel_tree=None, predictor=None, random_state=123, ...)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
[25]:
model_params = mortality_task.list_models_params()[model_name]
print(model_params)
{'objective': 'binary:logistic', 'use_label_encoder': None, 'base_score': None, 'booster': None, 'callbacks': None, 'colsample_bylevel': None, 'colsample_bynode': None, 'colsample_bytree': 0.7, 'early_stopping_rounds': None, 'enable_categorical': False, 'eval_metric': 'logloss', 'feature_types': None, 'gamma': 1, 'gpu_id': None, 'grow_policy': None, 'importance_type': None, 'interaction_constraints': None, 'learning_rate': 0.1, 'max_bin': None, 'max_cat_threshold': None, 'max_cat_to_onehot': None, 'max_delta_step': None, 'max_depth': 5, 'max_leaves': None, 'min_child_weight': 3, 'missing': nan, 'monotone_constraints': None, 'n_estimators': 500, 'n_jobs': None, 'num_parallel_tree': None, 'predictor': None, 'random_state': 123, 'reg_alpha': None, 'reg_lambda': 0, 'sampling_method': None, 'scale_pos_weight': None, 'subsample': None, 'tree_method': None, 'validate_parameters': None, 'verbosity': None, 'seed': 123}

Log the model parameters to the report.

We can add model parameters to the model card using the log_model_parameters method.

[26]:
report.log_model_parameters(params=model_params)

Prediction#

The prediction output can be either the whole Hugging Face dataset with the prediction columns added to it or the single column containing the predicted values.

[27]:
y_pred = mortality_task.predict(
    dataset["test"],
    model_name=model_name,
    transforms=preprocessor,
    proba=False,
    only_predictions=True,
)
print(len(y_pred))
Map: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Map: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Map: 0%| | 0/2853 [00:00<?, ? examples/s]

Map: 18%|█▊ | 500/2853 [00:00&lt;00:02, 967.70 examples/s]

</pre>

Map: 18%|█▊ | 500/2853 [00:00<00:02, 967.70 examples/s]

end{sphinxVerbatim}

Map: 18%|█▊ | 500/2853 [00:00<00:02, 967.70 examples/s]

Map: 35%|███▌ | 1000/2853 [00:01&lt;00:01, 973.73 examples/s]

</pre>

Map: 35%|███▌ | 1000/2853 [00:01<00:01, 973.73 examples/s]

end{sphinxVerbatim}

Map: 35%|███▌ | 1000/2853 [00:01<00:01, 973.73 examples/s]

Map: 53%|█████▎ | 1500/2853 [00:01&lt;00:01, 987.98 examples/s]

</pre>

Map: 53%|█████▎ | 1500/2853 [00:01<00:01, 987.98 examples/s]

end{sphinxVerbatim}

Map: 53%|█████▎ | 1500/2853 [00:01<00:01, 987.98 examples/s]

Map: 70%|███████ | 2000/2853 [00:02&lt;00:00, 982.47 examples/s]

</pre>

Map: 70%|███████ | 2000/2853 [00:02<00:00, 982.47 examples/s]

end{sphinxVerbatim}

Map: 70%|███████ | 2000/2853 [00:02<00:00, 982.47 examples/s]

Map: 88%|████████▊ | 2500/2853 [00:02&lt;00:00, 997.31 examples/s]

</pre>

Map: 88%|████████▊ | 2500/2853 [00:02<00:00, 997.31 examples/s]

end{sphinxVerbatim}

Map: 88%|████████▊ | 2500/2853 [00:02<00:00, 997.31 examples/s]

Map: 100%|██████████| 2853/2853 [00:02&lt;00:00, 1001.57 examples/s]

</pre>

Map: 100%|██████████| 2853/2853 [00:02<00:00, 1001.57 examples/s]

end{sphinxVerbatim}

Map: 100%|██████████| 2853/2853 [00:02<00:00, 1001.57 examples/s]

Map: 100%|██████████| 2853/2853 [00:02&lt;00:00, 988.99 examples/s]

</pre>

Map: 100%|██████████| 2853/2853 [00:02<00:00, 988.99 examples/s]

end{sphinxVerbatim}

Map: 100%|██████████| 2853/2853 [00:02<00:00, 988.99 examples/s]

2853

Evaluation#

Evaluation is done using various evaluation metrics that provide different perspectives on the model’s predictive abilities i.e. standard performance metrics and fairness metrics.

The standard performance metrics can be created using the MetricDict object.

[28]:
metric_names = [
    "binary_accuracy",
    "binary_precision",
    "binary_recall",
    "binary_f1_score",
    "binary_auroc",
    "binary_average_precision",
    "binary_roc_curve",
    "binary_precision_recall_curve",
]
metrics = [
    create_metric(metric_name, experimental=True) for metric_name in metric_names
]
metric_collection = MetricDict(metrics)

In addition to overall metrics, it might be interesting to see how the model performs on certain subpopulations. We can define these subpopulations using SliceSpec objects.

[29]:
spec_list = [
    {
        "age": {
            "min_value": 20,
            "max_value": 50,
            "min_inclusive": True,
            "max_inclusive": False,
        },
    },
    {
        "age": {
            "min_value": 50,
            "max_value": 80,
            "min_inclusive": True,
            "max_inclusive": False,
        },
    },
    {"gender": {"value": "M"}},
    {"gender": {"value": "F"}},
]
slice_spec = SliceSpec(spec_list)

A MetricDict can also be defined for the fairness metrics.

[30]:
specificity = create_metric(metric_name="binary_specificity", experimental=True)
sensitivity = create_metric(metric_name="binary_sensitivity", experimental=True)
fpr = -specificity + 1  # __rsub__ is not implemented for metrics
fnr = -sensitivity + 1
ber = (fpr + fnr) / 2
fairness_metric_collection = MetricDict(
    {
        "Sensitivity": sensitivity,
        "Specificity": specificity,
        "BER": ber,
    },
)

The FairnessConfig helps in setting up and evaluating the fairness of the model predictions.

[31]:
fairness_config = FairnessConfig(
    metrics=fairness_metric_collection,
    dataset=None,  # dataset is passed from the evaluator
    target_columns=None,  # target columns are passed from the evaluator
    groups=["gender", "age"],
    group_bins={"age": [20, 40]},
    group_base_values={"age": 40, "gender": "M"},
    thresholds=[0.5],
)

The evaluate methods outputs the evaluation results and the Hugging Face dataset with the predictions added to it.

[32]:
results, dataset_with_preds = mortality_task.evaluate(
    dataset["test"],
    metric_collection,
    model_names=model_name,
    transforms=preprocessor,
    prediction_column_prefix="preds",
    slice_spec=slice_spec,
    batch_size=-1,
    fairness_config=fairness_config,
    override_fairness_metrics=False,
)
Map: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Map: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Map: 0%| | 0/2853 [00:00<?, ? examples/s]

Map: 18%|█▊ | 500/2853 [00:00&lt;00:02, 998.61 examples/s]

</pre>

Map: 18%|█▊ | 500/2853 [00:00<00:02, 998.61 examples/s]

end{sphinxVerbatim}

Map: 18%|█▊ | 500/2853 [00:00<00:02, 998.61 examples/s]

Map: 35%|███▌ | 1000/2853 [00:01&lt;00:01, 988.41 examples/s]

</pre>

Map: 35%|███▌ | 1000/2853 [00:01<00:01, 988.41 examples/s]

end{sphinxVerbatim}

Map: 35%|███▌ | 1000/2853 [00:01<00:01, 988.41 examples/s]

Map: 53%|█████▎ | 1500/2853 [00:01&lt;00:01, 1035.30 examples/s]

</pre>

Map: 53%|█████▎ | 1500/2853 [00:01<00:01, 1035.30 examples/s]

end{sphinxVerbatim}

Map: 53%|█████▎ | 1500/2853 [00:01<00:01, 1035.30 examples/s]

Map: 70%|███████ | 2000/2853 [00:01&lt;00:00, 1013.88 examples/s]

</pre>

Map: 70%|███████ | 2000/2853 [00:01<00:00, 1013.88 examples/s]

end{sphinxVerbatim}

Map: 70%|███████ | 2000/2853 [00:01<00:00, 1013.88 examples/s]

Map: 88%|████████▊ | 2500/2853 [00:02&lt;00:00, 993.49 examples/s]

</pre>

Map: 88%|████████▊ | 2500/2853 [00:02<00:00, 993.49 examples/s]

end{sphinxVerbatim}

Map: 88%|████████▊ | 2500/2853 [00:02<00:00, 993.49 examples/s]

Map: 100%|██████████| 2853/2853 [00:02&lt;00:00, 978.15 examples/s]

</pre>

Map: 100%|██████████| 2853/2853 [00:02<00:00, 978.15 examples/s]

end{sphinxVerbatim}

Map: 100%|██████████| 2853/2853 [00:02<00:00, 978.15 examples/s]

Map: 100%|██████████| 2853/2853 [00:02&lt;00:00, 989.30 examples/s]

</pre>

Map: 100%|██████████| 2853/2853 [00:02<00:00, 989.30 examples/s]

end{sphinxVerbatim}

Map: 100%|██████████| 2853/2853 [00:02<00:00, 989.30 examples/s]


Flattening the indices: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Flattening the indices: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Flattening the indices: 0%| | 0/2853 [00:00<?, ? examples/s]

Flattening the indices: 35%|███▌ | 1000/2853 [00:00&lt;00:01, 1068.86 examples/s]

</pre>

Flattening the indices: 35%|███▌ | 1000/2853 [00:00<00:01, 1068.86 examples/s]

end{sphinxVerbatim}

Flattening the indices: 35%|███▌ | 1000/2853 [00:00<00:01, 1068.86 examples/s]

Flattening the indices: 70%|███████ | 2000/2853 [00:01&lt;00:00, 1015.44 examples/s]

</pre>

Flattening the indices: 70%|███████ | 2000/2853 [00:01<00:00, 1015.44 examples/s]

end{sphinxVerbatim}

Flattening the indices: 70%|███████ | 2000/2853 [00:01<00:00, 1015.44 examples/s]

Flattening the indices: 100%|██████████| 2853/2853 [00:02&lt;00:00, 1006.60 examples/s]

</pre>

Flattening the indices: 100%|██████████| 2853/2853 [00:02<00:00, 1006.60 examples/s]

end{sphinxVerbatim}

Flattening the indices: 100%|██████████| 2853/2853 [00:02<00:00, 1006.60 examples/s]

Flattening the indices: 100%|██████████| 2853/2853 [00:02&lt;00:00, 1003.29 examples/s]

</pre>

Flattening the indices: 100%|██████████| 2853/2853 [00:02<00:00, 1003.29 examples/s]

end{sphinxVerbatim}

Flattening the indices: 100%|██████████| 2853/2853 [00:02<00:00, 1003.29 examples/s]


Flattening the indices: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Flattening the indices: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Flattening the indices: 0%| | 0/2853 [00:00<?, ? examples/s]

Flattening the indices: 100%|██████████| 2853/2853 [00:00&lt;00:00, 134590.98 examples/s]

</pre>

Flattening the indices: 100%|██████████| 2853/2853 [00:00<00:00, 134590.98 examples/s]

end{sphinxVerbatim}

Flattening the indices: 100%|██████████| 2853/2853 [00:00<00:00, 134590.98 examples/s]


Map: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Map: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Map: 0%| | 0/2853 [00:00<?, ? examples/s]

Map: 100%|██████████| 2853/2853 [00:00&lt;00:00, 35647.87 examples/s]

</pre>

Map: 100%|██████████| 2853/2853 [00:00<00:00, 35647.87 examples/s]

end{sphinxVerbatim}

Map: 100%|██████████| 2853/2853 [00:00<00:00, 35647.87 examples/s]


Filter -&gt; age:[20 - 50): 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Filter -> age:[20 - 50): 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Filter -> age:[20 - 50): 0%| | 0/2853 [00:00<?, ? examples/s]

Filter -&gt; age:[20 - 50): 100%|██████████| 2853/2853 [00:00&lt;00:00, 96470.14 examples/s]

</pre>

Filter -> age:[20 - 50): 100%|██████████| 2853/2853 [00:00<00:00, 96470.14 examples/s]

end{sphinxVerbatim}

Filter -> age:[20 - 50): 100%|██████████| 2853/2853 [00:00<00:00, 96470.14 examples/s]


Filter -&gt; age:[50 - 80): 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Filter -> age:[50 - 80): 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Filter -> age:[50 - 80): 0%| | 0/2853 [00:00<?, ? examples/s]

Filter -&gt; age:[50 - 80): 100%|██████████| 2853/2853 [00:00&lt;00:00, 99361.87 examples/s]

</pre>

Filter -> age:[50 - 80): 100%|██████████| 2853/2853 [00:00<00:00, 99361.87 examples/s]

end{sphinxVerbatim}

Filter -> age:[50 - 80): 100%|██████████| 2853/2853 [00:00<00:00, 99361.87 examples/s]


Filter -&gt; gender:M: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Filter -> gender:M: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Filter -> gender:M: 0%| | 0/2853 [00:00<?, ? examples/s]

Filter -&gt; gender:M: 100%|██████████| 2853/2853 [00:00&lt;00:00, 147397.88 examples/s]

</pre>

Filter -> gender:M: 100%|██████████| 2853/2853 [00:00<00:00, 147397.88 examples/s]

end{sphinxVerbatim}

Filter -> gender:M: 100%|██████████| 2853/2853 [00:00<00:00, 147397.88 examples/s]


Filter -&gt; gender:F: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Filter -> gender:F: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Filter -> gender:F: 0%| | 0/2853 [00:00<?, ? examples/s]

Filter -&gt; gender:F: 100%|██████████| 2853/2853 [00:00&lt;00:00, 102440.22 examples/s]

</pre>

Filter -> gender:F: 100%|██████████| 2853/2853 [00:00<00:00, 102440.22 examples/s]

end{sphinxVerbatim}

Filter -> gender:F: 100%|██████████| 2853/2853 [00:00<00:00, 102440.22 examples/s]


Filter -&gt; overall: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Filter -> overall: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Filter -> overall: 0%| | 0/2853 [00:00<?, ? examples/s]

Filter -&gt; overall: 100%|██████████| 2853/2853 [00:00&lt;00:00, 120662.58 examples/s]

</pre>

Filter -> overall: 100%|██████████| 2853/2853 [00:00<00:00, 120662.58 examples/s]

end{sphinxVerbatim}

Filter -> overall: 100%|██████████| 2853/2853 [00:00<00:00, 120662.58 examples/s]


Filter -&gt; gender:M&amp;age:(-inf - 20.0]: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Filter -> gender:M&age:(-inf - 20.0]: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Filter -> gender:M&age:(-inf - 20.0]: 0%| | 0/2853 [00:00<?, ? examples/s]

Filter -&gt; gender:M&amp;age:(-inf - 20.0]: 100%|██████████| 2853/2853 [00:00&lt;00:00, 77113.20 examples/s]

</pre>

Filter -> gender:M&age:(-inf - 20.0]: 100%|██████████| 2853/2853 [00:00<00:00, 77113.20 examples/s]

end{sphinxVerbatim}

Filter -> gender:M&age:(-inf - 20.0]: 100%|██████████| 2853/2853 [00:00<00:00, 77113.20 examples/s]


Filter -&gt; gender:M&amp;age:(20.0 - 40.0]: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Filter -> gender:M&age:(20.0 - 40.0]: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Filter -> gender:M&age:(20.0 - 40.0]: 0%| | 0/2853 [00:00<?, ? examples/s]

Filter -&gt; gender:M&amp;age:(20.0 - 40.0]: 100%|██████████| 2853/2853 [00:00&lt;00:00, 78904.56 examples/s]

</pre>

Filter -> gender:M&age:(20.0 - 40.0]: 100%|██████████| 2853/2853 [00:00<00:00, 78904.56 examples/s]

end{sphinxVerbatim}

Filter -> gender:M&age:(20.0 - 40.0]: 100%|██████████| 2853/2853 [00:00<00:00, 78904.56 examples/s]


Filter -&gt; gender:M&amp;age:(40.0 - inf]: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Filter -> gender:M&age:(40.0 - inf]: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Filter -> gender:M&age:(40.0 - inf]: 0%| | 0/2853 [00:00<?, ? examples/s]

Filter -&gt; gender:M&amp;age:(40.0 - inf]: 100%|██████████| 2853/2853 [00:00&lt;00:00, 77750.52 examples/s]

</pre>

Filter -> gender:M&age:(40.0 - inf]: 100%|██████████| 2853/2853 [00:00<00:00, 77750.52 examples/s]

end{sphinxVerbatim}

Filter -> gender:M&age:(40.0 - inf]: 100%|██████████| 2853/2853 [00:00<00:00, 77750.52 examples/s]


Filter -&gt; gender:F&amp;age:(-inf - 20.0]: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Filter -> gender:F&age:(-inf - 20.0]: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Filter -> gender:F&age:(-inf - 20.0]: 0%| | 0/2853 [00:00<?, ? examples/s]

Filter -&gt; gender:F&amp;age:(-inf - 20.0]: 100%|██████████| 2853/2853 [00:00&lt;00:00, 80278.20 examples/s]

</pre>

Filter -> gender:F&age:(-inf - 20.0]: 100%|██████████| 2853/2853 [00:00<00:00, 80278.20 examples/s]

end{sphinxVerbatim}

Filter -> gender:F&age:(-inf - 20.0]: 100%|██████████| 2853/2853 [00:00<00:00, 80278.20 examples/s]


Filter -&gt; gender:F&amp;age:(20.0 - 40.0]: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Filter -> gender:F&age:(20.0 - 40.0]: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Filter -> gender:F&age:(20.0 - 40.0]: 0%| | 0/2853 [00:00<?, ? examples/s]

Filter -&gt; gender:F&amp;age:(20.0 - 40.0]: 100%|██████████| 2853/2853 [00:00&lt;00:00, 80818.22 examples/s]

</pre>

Filter -> gender:F&age:(20.0 - 40.0]: 100%|██████████| 2853/2853 [00:00<00:00, 80818.22 examples/s]

end{sphinxVerbatim}

Filter -> gender:F&age:(20.0 - 40.0]: 100%|██████████| 2853/2853 [00:00<00:00, 80818.22 examples/s]


Filter -&gt; gender:F&amp;age:(40.0 - inf]: 0%| | 0/2853 [00:00&lt;?, ? examples/s]

</pre>

Filter -> gender:F&age:(40.0 - inf]: 0%| | 0/2853 [00:00<?, ? examples/s]

end{sphinxVerbatim}

Filter -> gender:F&age:(40.0 - inf]: 0%| | 0/2853 [00:00<?, ? examples/s]

Filter -&gt; gender:F&amp;age:(40.0 - inf]: 100%|██████████| 2853/2853 [00:00&lt;00:00, 78433.41 examples/s]

</pre>

Filter -> gender:F&age:(40.0 - inf]: 100%|██████████| 2853/2853 [00:00<00:00, 78433.41 examples/s]

end{sphinxVerbatim}

Filter -> gender:F&age:(40.0 - inf]: 100%|██████████| 2853/2853 [00:00<00:00, 78433.41 examples/s]


Log the performance metrics to the report.

We can add a performance metric to the model card using the log_performance_metric method, which expects a dictionary where the keys are in the following format: slice_name/metric_name. For instance, overall/accuracy.

We first need to process the evaluation results to get the metrics in the right format.

[33]:
model_name = f"model_for_preds.{model_name}"
results_flat = flatten_results_dict(
    results=results,
    remove_metrics=["BinaryROC", "BinaryPrecisionRecallCurve"],
    model_name=model_name,
)
[34]:
# ruff: noqa: W505
for name, metric in results_flat.items():
    split, name = name.split("/")  # noqa: PLW2901
    descriptions = {
        "BinaryPrecision": "The proportion of predicted positive instances that are correctly predicted.",
        "BinaryRecall": "The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.",
        "BinaryAccuracy": "The proportion of all instances that are correctly predicted.",
        "BinaryAUROC": "The area under the receiver operating characteristic curve (AUROC) is a measure of the performance of a binary classification model.",
        "BinaryAveragePrecision": "The area under the precision-recall curve (AUPRC) is a measure of the performance of a binary classification model.",
        "BinaryF1Score": "The harmonic mean of precision and recall.",
    }
    report.log_quantitative_analysis(
        "performance",
        name=name,
        value=metric.tolist(),
        description=descriptions[name],
        metric_slice=split,
        pass_fail_thresholds=0.7,
        pass_fail_threshold_fns=lambda x, threshold: bool(x >= threshold),
    )

We can also use the ClassificationPlotter to plot the performance metrics and the add the figure to the model card using the log_plotly_figure method.

[35]:
plotter = ClassificationPlotter(task_type="binary", class_names=["0", "1"])
plotter.set_template("plotly_white")
[36]:
# extracting the ROC curves and AUROC results for all the slices
roc_curves = {
    slice_name: slice_results["BinaryROC"]
    for slice_name, slice_results in results[model_name].items()
}
aurocs = {
    slice_name: slice_results["BinaryAUROC"]
    for slice_name, slice_results in results[model_name].items()
}
roc_curves.keys()
[36]:
dict_keys(['age:[20 - 50)', 'age:[50 - 80)', 'gender:M', 'gender:F', 'overall'])
[37]:
# extracting the precision-recall curves and average precision results for all the slices
pr_curves = {
    slice_name: slice_results["BinaryPrecisionRecallCurve"]
    for slice_name, slice_results in results[model_name].items()
}
average_precisions = {
    slice_name: slice_results["BinaryAveragePrecision"]
    for slice_name, slice_results in results[model_name].items()
}
pr_curves.keys()
[37]:
dict_keys(['age:[20 - 50)', 'age:[50 - 80)', 'gender:M', 'gender:F', 'overall'])
[38]:
# plotting the ROC curves for all the slices
roc_plot = plotter.roc_curve_comparison(roc_curves, aurocs=aurocs)
report.log_plotly_figure(
    fig=roc_plot,
    caption="ROC Curve Comparison",
    section_name="quantitative analysis",
)
roc_plot.show()
[39]:
# plotting the precision-recall curves for all the slices
pr_plot = plotter.precision_recall_curve_comparison(
    pr_curves,
    auprcs=average_precisions,
)
report.log_plotly_figure(
    fig=pr_plot,
    caption="Precision-Recall Curve Comparison",
    section_name="quantitative analysis",
)
pr_plot.show()
[40]:
# Extracting the overall classification metric values.
overall_performance = {
    metric_name: metric_value
    for metric_name, metric_value in results[model_name]["overall"].items()
    if metric_name not in ["BinaryROC", "BinaryPrecisionRecallCurve"]
}
[41]:
# Plotting the overall classification metric values.
overall_performance_plot = plotter.metrics_value(
    overall_performance,
    title="Overall Performance",
)
report.log_plotly_figure(
    fig=overall_performance_plot,
    caption="Overall Performance",
    section_name="quantitative analysis",
)
overall_performance_plot.show()
[42]:
# Extracting the metric values for all the slices.
slice_metrics = {
    slice_name: {
        metric_name: metric_value
        for metric_name, metric_value in slice_results.items()
        if metric_name not in ["BinaryROC", "BinaryPrecisionRecallCurve"]
    }
    for slice_name, slice_results in results[model_name].items()
}
[43]:
# Plotting the metric values for all the slices.
slice_metrics_plot = plotter.metrics_comparison_bar(slice_metrics)
report.log_plotly_figure(
    fig=slice_metrics_plot,
    caption="Slice Metric Comparison",
    section_name="quantitative analysis",
)
slice_metrics_plot.show()
[44]:
# Reformatting the fairness metrics
fairness_results = copy.deepcopy(results["fairness"])
fairness_metrics = {}
# remove the group size from the fairness results and add it to the slice name
for slice_name, slice_results in fairness_results.items():
    group_size = slice_results.pop("Group Size")
    fairness_metrics[f"{slice_name} (Size={group_size})"] = slice_results
[45]:
# Plotting the fairness metrics
fairness_plot = plotter.metrics_comparison_scatter(
    fairness_metrics,
    title="Fairness Metrics",
)
report.log_plotly_figure(
    fig=fairness_plot,
    caption="Fairness Metrics",
    section_name="fairness analysis",
)
fairness_plot.show()

Report Generation#

Before generating the model card, let us document some of the details of the model and some considerations involved in developing and using the model.

Let’s start with populating the model details section, which includes the following fields by default: - description: A high-level description of the model and its usage for a general audience. - version: The version of the model. - owners: The individuals or organizations that own the model. - license: The license under which the model is made available. - citation: The citation for the model. - references: Links to resources that are relevant to the model. - path: The path to where the model is stored. - regulatory_requirements: The regulatory requirements that are relevant to the model.

We can add additional fields to the model details section by passing a dictionary to the log_from_dict method and specifying the section name as model_details. You can also use the log_descriptor method to add a new field object with a description attribute to any section of the model card.

[46]:
report.log_from_dict(
    data={
        "name": "Mortality Prediction Model",
        "description": "The model was trained on the MIMICIV dataset \
            to predict risk of in-hospital mortality.",
    },
    section_name="model_details",
)
report.log_version(
    version_str="0.0.1",
    date=str(date.today()),
    description="Initial Release",
)
report.log_owner(
    name="CyclOps Team",
    contact="vectorinstitute.github.io/cyclops/",
    email="cyclops@vectorinstitute.ai",
)
report.log_license(identifier="Apache-2.0")
report.log_reference(
    link="https://xgboost.readthedocs.io/en/stable/python/python_api.html",  # noqa: E501
)

Next, let’s populate the considerations section, which includes the following fields by default: - users: The intended users of the model. - use_cases: The use cases for the model. These could be primary, downstream or out-of-scope use cases. - fairness_assessment: A description of the benefits and harms of the model for different groups as well as the steps taken to mitigate the harms. - ethical_considerations: The risks associated with using the model and the steps taken to mitigate them. This can be populated using the log_risk method.

[47]:
report.log_from_dict(
    data={
        "users": [
            {"description": "Hospitals"},
            {"description": "Clinicians"},
        ],
    },
    section_name="considerations",
)
report.log_user(description="ML Engineers")
report.log_use_case(
    description="Predicting prolonged length of stay",
    kind="primary",
)
report.log_fairness_assessment(
    affected_group="sex, age",
    benefit="Improved health outcomes for patients.",
    harm="Biased predictions for patients in certain groups (e.g. older patients) \
        may lead to worse health outcomes.",
    mitigation_strategy="We will monitor the performance of the model on these groups \
        and retrain the model if the performance drops below a certain threshold.",
)
report.log_risk(
    risk="The model may be used to make decisions that affect the health of patients.",
    mitigation_strategy="The model should be continuously monitored for performance \
        and retrained if the performance drops below a certain threshold.",
)

Once the model card is populated, you can generate the report using the export method. The report is generated in the form of an HTML file. A JSON file containing the model card data will also be generated along with the HTML file. By default, the files will be saved in a folder named cyclops_reports in the current working directory. You can change the path by passing a output_dir argument when instantiating the ModelCardReport class.

[48]:
synthetic_timestamps = [
    "2021-09-01",
    "2021-10-01",
    "2021-11-01",
    "2021-12-01",
    "2022-01-01",
]
report._model_card.overview = None
report_path = report.export(
    output_filename="mortality_report_periodic.html",
    synthetic_timestamp=synthetic_timestamps[0],
)
shutil.copy(f"{report_path}", ".")
for i in range(4):
    report._model_card.overview = None
    for metric in report._model_card.quantitative_analysis.performance_metrics:
        metric.value = np.clip(
            metric.value + np.random.normal(0, 0.1),
            0,
            1,
        )
        metric.tests[0].passed = bool(metric.value >= 0.7)
    report_path = report.export(
        output_filename="mortality_report_periodic.html",
        synthetic_timestamp=synthetic_timestamps[i + 1],
    )
    shutil.copy(f"{report_path}", ".")
shutil.rmtree("./cyclops_report")

You can view the generated HTML report.