Source code for fl4health.metrics.base_metrics
from abc import ABC, abstractmethod
from enum import Enum
import torch
from flwr.common.typing import Metrics
TEST_NUM_EXAMPLES_KEY = f"{MetricPrefix.TEST_PREFIX.value} num_examples"
TEST_LOSS_KEY = f"{MetricPrefix.TEST_PREFIX.value} checkpoint"
[docs]
class Metric(ABC):
[docs]
def __init__(self, name: str) -> None:
"""
Base abstract Metric class to extend for metric accumulation and computation.
Args:
name (str): Name of the metric.
"""
self.name = name
[docs]
@abstractmethod
def update(self, input: torch.Tensor, target: torch.Tensor) -> None:
"""
This method updates the state of the metric by appending the passed input and target pairing to their
respective list.
Args:
input (torch.Tensor): The predictions of the model to be evaluated.
target (torch.Tensor): The ground truth target to evaluate predictions against.
Raises:
NotImplementedError: To be defined in the classes extending this class.
"""
raise NotImplementedError
[docs]
@abstractmethod
def compute(self, name: str | None = None) -> Metrics:
"""
Compute metric on state accumulated over updates.
Args:
name (str | None): Optional name used in conjunction with class attribute name to define key in metrics
dictionary.
Raises:
NotImplementedError: To be defined in the classes extending this class.
Returns:
Metrics: A dictionary of string and ``Scalar`` representing the computed metric and its associated key.
"""
raise NotImplementedError
[docs]
@abstractmethod
def clear(self) -> None:
"""
Resets metric.
Raises:
NotImplementedError: To be defined in the classes expending this class.
"""
raise NotImplementedError