fl4health.utils.metrics module

class Accuracy(name='accuracy')[source]

Bases: SimpleMetric

__init__(name='accuracy')[source]

Accuracy metric for classification tasks.

Parameters:

name (str) – The name of the metric.

class BalancedAccuracy(name='balanced_accuracy')[source]

Bases: SimpleMetric

__init__(name='balanced_accuracy')[source]
Balanced accuracy metric for classification tasks. Used for the evaluation of imbalanced datasets.

For more information: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html

class BinarySoftDiceCoefficient(name='BinarySoftDiceCoefficient', epsilon=1e-07, spatial_dimensions=(2, 3, 4), logits_threshold=0.5)[source]

Bases: SimpleMetric

__init__(name='BinarySoftDiceCoefficient', epsilon=1e-07, spatial_dimensions=(2, 3, 4), logits_threshold=0.5)[source]

Binary DICE Coefficient Metric with configurable spatial dimensions and logits threshold.

Parameters:
  • name (str) – Name of the metric.

  • epsilon (float) – Small float to add to denominator of DICE calculation to avoid divide by 0.

  • spatial_dimensions (tuple[int, ...]) – The spatial dimensions of the image within the prediction tensors. The default assumes that the images are 3D and have shape: batch_size, channel, spatial, spatial, spatial.

  • logits_threshold (float | None) – This is a threshold value where values above are classified as 1 and those below are mapped to 0. If the threshold is None, then no thresholding is performed and a continuous or “soft” DICE coefficient is computed.

class F1(name='F1 score', average='weighted')[source]

Bases: SimpleMetric

__init__(name='F1 score', average='weighted')[source]

Computes the F1 score using the sklearn f1_score function. As such, the values of average correspond to those of that function.

Parameters:
class Metric(name)[source]

Bases: ABC

__init__(name)[source]

Base abstract Metric class to extend for metric accumulation and computation.

Parameters:

name (str) – Name of the metric.

abstract clear()[source]

Resets metric.

Raises:

NotImplementedError – To be defined in the classes expending this class.

Return type:

None

abstract compute(name)[source]

Compute metric on accumulated input and output over updates.

Parameters:

name (str | None) – Optional name used in conjunction with class attribute name to define key in metrics dictionary.

Raises:

NotImplementedError – To be defined in the classes extending this class.

Returns:

A dictionary of string and Scalar representing the computed metric

and its associated key.

Return type:

Metrics

abstract update(input, target)[source]

This method updates the state of the metric by appending the passed input and target pairing to their respective list.

Parameters:
  • input (torch.Tensor) – The predictions of the model to be evaluated.

  • target (torch.Tensor) – The ground truth target to evaluate predictions against.

Raises:

NotImplementedError – To be defined in the classes extending this class.

Return type:

None

class MetricManager(metrics, metric_manager_name)[source]

Bases: object

__init__(metrics, metric_manager_name)[source]

Class to manage a set of metrics associated to a given prediction type.

Parameters:
  • metrics (Sequence[Metric]) – List of metric to evaluate predictions on.

  • metric_manager_name (str) – Name of the metric manager (ie train, val, test)

check_target_prediction_keys_equal(preds, target)[source]
Return type:

None

clear()[source]

Clears metrics for each of the prediction type.

Return type:

None

compute()[source]

Computes set of metrics for each prediction type.

Returns:

dictionary containing computed metrics along with string identifiers

for each prediction type.

Return type:

Metrics

update(preds, target)[source]

Updates (or creates then updates) a list of metrics for each prediction type.

Parameters:
  • preds (TorchPredType) – A dictionary of preds from the model

  • target (TorchTargetType) – The ground truth labels for the data. If target is a dictionary with more than one item, then each value in the preds dictionary is evaluated with the value that has the same key in the target dictionary. If target has only one item or is a torch.Tensor, then the same target is used for all predictions

Return type:

None

class MetricPrefix(value)[source]

Bases: Enum

An enumeration.

TEST_PREFIX = 'test -'
VAL_PREFIX = 'val -'
class ROC_AUC(name='ROC_AUC score')[source]

Bases: SimpleMetric

__init__(name='ROC_AUC score')[source]

Area under the Receiver Operator Curve (AUCROC) metric for classification. For more information: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html

class SimpleMetric(name)[source]

Bases: Metric, ABC

__init__(name)[source]

Abstract metric class with base functionality to update, compute and clear metrics. User needs to define __call__ method which returns metric given inputs and target.

Parameters:

name (str) – Name of the metric.

clear()[source]

Resets metrics by clearing input and target lists.

Return type:

None

compute(name=None)[source]

Compute metric on accumulated input and output over updates.

Parameters:

name (str | None) – Optional name used in conjunction with class attribute name to define key in metrics dictionary.

Raises:

AssertionError – Input and target lists must be non empty.

Returns:

A dictionary of string and Scalar representing the computed metric

and its associated key.

Return type:

Metrics

update(input, target)[source]

This method updates the state of the metric by appending the passed input and target pairing to their respective list.

Parameters:
  • input (torch.Tensor) – The predictions of the model to be evaluated.

  • target (torch.Tensor) – The ground truth target to evaluate predictions against.

Return type:

None

class TorchMetric(name, metric)[source]

Bases: Metric

__init__(name, metric)[source]

Thin wrapper on TorchMetric to make it compatible with our Metric interface.

Parameters:
  • name (str) – The name of the metric.

  • metric (TMetric) – TorchMetric class based metric

clear()[source]

Resets metric.

Raises:

NotImplementedError – To be defined in the classes expending this class.

Return type:

None

compute(name)[source]

Compute value of underlying TorchMetric.

Parameters:

name (str | None) – Optional name used in conjunction with class attribute name to define key in metrics dictionary.

Returns:

A dictionary of string and Scalar representing the computed metric

and its associated key.

Return type:

Metrics

update(input, target)[source]

Updates the state of the underlying TorchMetric.

Parameters:
  • input (torch.Tensor) – The predictions of the model to be evaluated.

  • target (torch.Tensor) – The ground truth target to evaluate predictions against.

Return type:

None

class TransformsMetric(metric, pred_transforms=None, target_transforms=None)[source]

Bases: Metric

__init__(metric, pred_transforms=None, target_transforms=None)[source]

A thin wrapper class to allow transforms to be applied to preds and targets prior to calculating metrics. Transforms are applied in the order given

Parameters:
  • metric (Metric) – A FL4Health compatible metric

  • pred_transforms (Sequence[TorchTransformFunction] | None, optional) – A list of transform functions to apply to the model predictions before computing the metrics. Each callable must accept and return a torch. Tensor. Use partial to set other arguments.

  • target_transforms (Sequence[TorchTransformFunction] | None, optional) – A list of transform functions to apply to the targets before computing the metrics. Each callable must accept and return a torch.Tensor. Use partial to set other arguments.

clear()[source]

Resets metric.

Raises:

NotImplementedError – To be defined in the classes expending this class.

Return type:

None

compute(name)[source]

Compute metric on accumulated input and output over updates.

Parameters:

name (str | None) – Optional name used in conjunction with class attribute name to define key in metrics dictionary.

Raises:

NotImplementedError – To be defined in the classes extending this class.

Returns:

A dictionary of string and Scalar representing the computed metric

and its associated key.

Return type:

Metrics

update(pred, target)[source]

This method updates the state of the metric by appending the passed input and target pairing to their respective list.

Parameters:
  • input (torch.Tensor) – The predictions of the model to be evaluated.

  • target (torch.Tensor) – The ground truth target to evaluate predictions against.

Raises:

NotImplementedError – To be defined in the classes extending this class.

Return type:

None