fl4health.metrics.compound_metrics module¶
- class TransformsMetric(metric, pred_transforms=None, target_transforms=None)[source]¶
-
- __init__(metric, pred_transforms=None, target_transforms=None)[source]¶
A thin wrapper class to allow transforms to be applied to preds and targets prior to calculating metrics. Transforms are applied in the order given
- Parameters:
metric (Metric) – A FL4Health compatible metric
pred_transforms (Sequence[TorchTransformFunction] | None, optional) – A list of transform functions to apply to the model predictions before computing the metrics. Each callable must accept and return a
torch.Tensor
. Use partial to set other arguments.target_transforms (Sequence[TorchTransformFunction] | None, optional) – A list of transform functions to apply to the targets before computing the metrics. Each callable must accept and return a
torch.Tensor
. Use partial to set other arguments.
- clear()[source]¶
Resets metric.
- Raises:
NotImplementedError – To be defined in the classes expending this class.
- Return type:
- compute(name=None)[source]¶
Compute metric on state accumulated over updates.
- Parameters:
name (str | None) – Optional name used in conjunction with class attribute name to define key in metrics dictionary.
- Raises:
NotImplementedError – To be defined in the classes extending this class.
- Returns:
A dictionary of string and
Scalar
representing the computed metric and its associated key.- Return type:
Metrics
- update(pred, target)[source]¶
This method updates the state of the metric by appending the passed input and target pairing to their respective list.
- Parameters:
input (torch.Tensor) – The predictions of the model to be evaluated.
target (torch.Tensor) – The ground truth target to evaluate predictions against.
- Raises:
NotImplementedError – To be defined in the classes extending this class.
- Return type: