fl4health.metrics.compound_metrics module

class TransformsMetric(metric, pred_transforms=None, target_transforms=None)[source]

Bases: Metric, Generic[T]

__init__(metric, pred_transforms=None, target_transforms=None)[source]

A thin wrapper class to allow transforms to be applied to preds and targets prior to calculating metrics. Transforms are applied in the order given

Parameters:
  • metric (Metric) – A FL4Health compatible metric

  • pred_transforms (Sequence[TorchTransformFunction] | None, optional) – A list of transform functions to apply to the model predictions before computing the metrics. Each callable must accept and return a torch.Tensor. Use partial to set other arguments.

  • target_transforms (Sequence[TorchTransformFunction] | None, optional) – A list of transform functions to apply to the targets before computing the metrics. Each callable must accept and return a torch.Tensor. Use partial to set other arguments.

clear()[source]

Resets metric.

Raises:

NotImplementedError – To be defined in the classes expending this class.

Return type:

None

compute(name=None)[source]

Compute metric on state accumulated over updates.

Parameters:

name (str | None) – Optional name used in conjunction with class attribute name to define key in metrics dictionary.

Raises:

NotImplementedError – To be defined in the classes extending this class.

Returns:

A dictionary of string and Scalar representing the computed metric and its associated key.

Return type:

Metrics

update(pred, target)[source]

This method updates the state of the metric by appending the passed input and target pairing to their respective list.

Parameters:
  • input (torch.Tensor) – The predictions of the model to be evaluated.

  • target (torch.Tensor) – The ground truth target to evaluate predictions against.

Raises:

NotImplementedError – To be defined in the classes extending this class.

Return type:

None