Source code for fl4health.metrics.base_metrics

from abc import ABC, abstractmethod
from enum import Enum

import torch
from flwr.common.typing import Metrics


[docs] class MetricPrefix(Enum): TEST_PREFIX = "test -" VAL_PREFIX = "val -"
TEST_NUM_EXAMPLES_KEY = f"{MetricPrefix.TEST_PREFIX.value} num_examples" TEST_LOSS_KEY = f"{MetricPrefix.TEST_PREFIX.value} checkpoint"
[docs] class Metric(ABC):
[docs] def __init__(self, name: str) -> None: """ Base abstract Metric class to extend for metric accumulation and computation. Args: name (str): Name of the metric. """ self.name = name
[docs] @abstractmethod def update(self, input: torch.Tensor, target: torch.Tensor) -> None: """ This method updates the state of the metric by appending the passed input and target pairing to their respective list. Args: input (torch.Tensor): The predictions of the model to be evaluated. target (torch.Tensor): The ground truth target to evaluate predictions against. Raises: NotImplementedError: To be defined in the classes extending this class. """ raise NotImplementedError
[docs] @abstractmethod def compute(self, name: str | None = None) -> Metrics: """ Compute metric on state accumulated over updates. Args: name (str | None): Optional name used in conjunction with class attribute name to define key in metrics dictionary. Raises: NotImplementedError: To be defined in the classes extending this class. Returns: Metrics: A dictionary of string and ``Scalar`` representing the computed metric and its associated key. """ raise NotImplementedError
[docs] @abstractmethod def clear(self) -> None: """ Resets metric. Raises: NotImplementedError: To be defined in the classes expending this class. """ raise NotImplementedError