fl4health.metrics.compound_metrics module¶
- class EmaMetric(metric, smoothing_factor=0.1, name=None)[source]¶
-
- __init__(metric, smoothing_factor=0.1, name=None)[source]¶
Exponential Moving Average (EMA) metric wrapper to apply EMA to the underlying metric.
NOTE: If the underlying metric accumulates batches during update, then updating this metric without clearing in between will result in previously seen inputs and targets being a part of subsequent computations. For example, if we use Accuracy from fl4health.metrics, which accumulates batches, we get the following behavior in the code block below.
from fl4health.metrics import Accuracy ema = EmaMetric(Accuracy(), 0.1) preds_1 = torch.Tensor([1, 0, 1]), targets_1 = torch.Tensor([1, 1, 1]) ema.update(preds_1, targets_1) ema.compute() -> 0.667 preds_2 = torch.Tensor([0, 0, 1]), targets_2 = torch.Tensor([1, 1, 1]) # If no clear before update (new accuracy is computed using both pred_1 and pred_2) ema.update(preds_2, targets_2) = 0.9(0.667) + 0.1 (0.5) # If there were a clear before update (new accuracy is computed using pred_2) ema.clear() ema.update(preds_2, targets_2 = 0.9(0.667) + 0.1(0.333)
- Parameters:
metric (T) – An FL4Health compatible metric
smoothing_factor (float, optional) – Smoothing factor in range [0, 1] for the EMA. Smaller values increase smoothing by weighting previous scores more heavily. Defaults to 0.1.
name (str | None, optional) – Name of the EMAMetric. If left as None will default to ‘EMA_{metric.name}’.
- clear()[source]¶
Resets metric.
- Raises:
NotImplementedError – To be defined in the classes expending this class.
- Return type:
- compute(name=None)[source]¶
Compute metric on state accumulated over updates. This computation considers the exponential moving average with respect to previous scores. For time step \(t\), and metric score \(m_t\), the EMA score is computed
\[\text{smoothing_factor} \cdot m_t + (1-\text{smoothing_factor}) \cdot (m_{t-1}).\]The very first score is stored as is.
- Parameters:
name (str | None, optional) – Optional name used in conjunction with class attribute name to define key in metrics dictionary. Defaults to None.
- Returns:
A dictionary of string and
Scalar
representing the computed metric and its associated key.- Return type:
Metrics
- update(input, target)[source]¶
This method updates the state of the metric by appending the passed input and target pairing to their respective list.
- Parameters:
input (torch.Tensor) – The predictions of the model to be evaluated.
target (torch.Tensor) – The ground truth target to evaluate predictions against.
- Raises:
NotImplementedError – To be defined in the classes extending this class.
- Return type:
- class TransformsMetric(metric, pred_transforms=None, target_transforms=None)[source]¶
-
- __init__(metric, pred_transforms=None, target_transforms=None)[source]¶
A thin wrapper class to allow transforms to be applied to preds and targets prior to calculating metrics. Transforms are applied in the order given
- Parameters:
metric (Metric) – A FL4Health compatible metric
pred_transforms (Sequence[TorchTransformFunction] | None, optional) – A list of transform functions to apply to the model predictions before computing the metrics. Each callable must accept and return a
torch.Tensor
. Use partial to set other arguments.target_transforms (Sequence[TorchTransformFunction] | None, optional) – A list of transform functions to apply to the targets before computing the metrics. Each callable must accept and return a
torch.Tensor
. Use partial to set other arguments.
- clear()[source]¶
Resets metric.
- Raises:
NotImplementedError – To be defined in the classes expending this class.
- Return type:
- compute(name=None)[source]¶
Compute metric on state accumulated over updates.
- Parameters:
name (str | None) – Optional name used in conjunction with class attribute name to define key in metrics dictionary.
- Raises:
NotImplementedError – To be defined in the classes extending this class.
- Returns:
A dictionary of string and
Scalar
representing the computed metric and its associated key.- Return type:
Metrics
- update(pred, target)[source]¶
This method updates the state of the metric by appending the passed input and target pairing to their respective list.
- Parameters:
input (torch.Tensor) – The predictions of the model to be evaluated.
target (torch.Tensor) – The ground truth target to evaluate predictions against.
- Raises:
NotImplementedError – To be defined in the classes extending this class.
- Return type: