Source code for fl4health.metrics.metric_managers

import copy
from collections.abc import Sequence

import torch
from flwr.common.typing import Metrics

from fl4health.metrics.base_metrics import Metric
from fl4health.utils.typing import TorchPredType, TorchTargetType


[docs] class MetricManager:
[docs] def __init__(self, metrics: Sequence[Metric], metric_manager_name: str) -> None: """ Class to manage a set of metrics associated to a given prediction type. Args: metrics (Sequence[Metric]): List of metric to evaluate predictions on. metric_manager_name (str): Name of the metric manager (ie train, val, test) """ self.original_metrics = metrics self.metric_manager_name = metric_manager_name self.metrics_per_prediction_type: dict[str, Sequence[Metric]] = {}
[docs] def update(self, preds: TorchPredType, target: TorchTargetType) -> None: """ Updates (or creates then updates) a list of metrics for each prediction type. Args: preds (TorchPredType): A dictionary of preds from the model target (TorchTargetType): The ground truth labels for the data. If target is a dictionary with more than one item, then each value in the preds dictionary is evaluated with the value that has the same key in the target dictionary. If target has only one item or is a ``torch.Tensor``, then the same target is used for all predictions. """ if not self.metrics_per_prediction_type: self.metrics_per_prediction_type = {key: copy.deepcopy(self.original_metrics) for key in preds.keys()} # Check if there are multiple targets if isinstance(target, dict): if len(target.keys()) > 1: self.check_target_prediction_keys_equal(preds, target) else: # There is only one target, get tensor from dict target = list(target.values())[0] for prediction_key, pred in preds.items(): metrics_for_prediction_type = self.metrics_per_prediction_type[prediction_key] assert len(preds) == len(self.metrics_per_prediction_type) for metric_for_prediction_type in metrics_for_prediction_type: if isinstance(target, torch.Tensor): metric_for_prediction_type.update(pred, target) else: metric_for_prediction_type.update(pred, target[prediction_key])
[docs] def compute(self) -> Metrics: """ Computes set of metrics for each prediction type. Returns: Metrics: dictionary containing computed metrics along with string identifiers for each prediction type. """ all_results = {} for metrics_key, metrics in self.metrics_per_prediction_type.items(): for metric in metrics: result = metric.compute(f"{self.metric_manager_name} - {metrics_key}") all_results.update(result) return all_results
[docs] def clear(self) -> None: """ Clears data accumulated in each metric for each of the prediction type. """ for metrics_for_prediction_type in self.metrics_per_prediction_type.values(): for metric in metrics_for_prediction_type: metric.clear()
[docs] def reset(self) -> None: """ Resets the metrics to their initial state. """ # On next update, metrics will be recopied from self.original_metrics which are still in their initial state self.metrics_per_prediction_type = {}
[docs] def check_target_prediction_keys_equal( self, preds: dict[str, torch.Tensor], target: dict[str, torch.Tensor] ) -> None: assert target.keys() == preds.keys(), ( "Received a dict with multiple targets, but the keys of the " "targets do not match the keys of the predictions. Please pass a " "single target or ensure the keys between preds and target are the same" )