fl4health.metrics.compound_metrics module

class EmaMetric(metric, smoothing_factor=0.1, name=None)[source]

Bases: Metric, Generic[T]

__init__(metric, smoothing_factor=0.1, name=None)[source]

Exponential Moving Average (EMA) metric wrapper to apply EMA to the underlying metric.

NOTE: If the underlying metric accumulates batches during update, then updating this metric without clearing in between will result in previously seen inputs and targets being a part of subsequent computations. For example, if we use Accuracy from fl4health.metrics, which accumulates batches, we get the following behavior in the code block below.

from fl4health.metrics import Accuracy

ema = EmaMetric(Accuracy(), 0.1)

preds_1 = torch.Tensor([1, 0, 1]), targets_1 = torch.Tensor([1, 1, 1])

ema.update(preds_1, targets_1)

ema.compute() -> 0.667

preds_2 = torch.Tensor([0, 0, 1]), targets_2 = torch.Tensor([1, 1, 1])

# If no clear before update (new accuracy is computed using both pred_1 and pred_2)

ema.update(preds_2, targets_2) = 0.9(0.667) + 0.1 (0.5)

# If there were a clear before update (new accuracy is computed using pred_2)

ema.clear()

ema.update(preds_2, targets_2 = 0.9(0.667) + 0.1(0.333)
Parameters:
  • metric (T) – An FL4Health compatible metric

  • smoothing_factor (float, optional) – Smoothing factor in range [0, 1] for the EMA. Smaller values increase smoothing by weighting previous scores more heavily. Defaults to 0.1.

  • name (str | None, optional) – Name of the EMAMetric. If left as None will default to ‘EMA_{metric.name}’.

clear()[source]

Resets metric.

Raises:

NotImplementedError – To be defined in the classes expending this class.

Return type:

None

compute(name=None)[source]

Compute metric on state accumulated over updates. This computation considers the exponential moving average with respect to previous scores. For time step \(t\), and metric score \(m_t\), the EMA score is computed

\[\text{smoothing_factor} \cdot m_t + (1-\text{smoothing_factor}) \cdot (m_{t-1}).\]

The very first score is stored as is.

Parameters:

name (str | None, optional) – Optional name used in conjunction with class attribute name to define key in metrics dictionary. Defaults to None.

Returns:

A dictionary of string and Scalar representing the computed metric and its associated key.

Return type:

Metrics

update(input, target)[source]

This method updates the state of the metric by appending the passed input and target pairing to their respective list.

Parameters:
  • input (torch.Tensor) – The predictions of the model to be evaluated.

  • target (torch.Tensor) – The ground truth target to evaluate predictions against.

Raises:

NotImplementedError – To be defined in the classes extending this class.

Return type:

None

class TransformsMetric(metric, pred_transforms=None, target_transforms=None)[source]

Bases: Metric, Generic[T]

__init__(metric, pred_transforms=None, target_transforms=None)[source]

A thin wrapper class to allow transforms to be applied to preds and targets prior to calculating metrics. Transforms are applied in the order given

Parameters:
  • metric (Metric) – A FL4Health compatible metric

  • pred_transforms (Sequence[TorchTransformFunction] | None, optional) – A list of transform functions to apply to the model predictions before computing the metrics. Each callable must accept and return a torch.Tensor. Use partial to set other arguments.

  • target_transforms (Sequence[TorchTransformFunction] | None, optional) – A list of transform functions to apply to the targets before computing the metrics. Each callable must accept and return a torch.Tensor. Use partial to set other arguments.

clear()[source]

Resets metric.

Raises:

NotImplementedError – To be defined in the classes expending this class.

Return type:

None

compute(name=None)[source]

Compute metric on state accumulated over updates.

Parameters:

name (str | None) – Optional name used in conjunction with class attribute name to define key in metrics dictionary.

Raises:

NotImplementedError – To be defined in the classes extending this class.

Returns:

A dictionary of string and Scalar representing the computed metric and its associated key.

Return type:

Metrics

update(pred, target)[source]

This method updates the state of the metric by appending the passed input and target pairing to their respective list.

Parameters:
  • input (torch.Tensor) – The predictions of the model to be evaluated.

  • target (torch.Tensor) – The ground truth target to evaluate predictions against.

Raises:

NotImplementedError – To be defined in the classes extending this class.

Return type:

None