fl4health.utils.metrics module¶
- class Accuracy(name='accuracy')[source]¶
Bases:
SimpleMetric
- class BalancedAccuracy(name='balanced_accuracy')[source]¶
Bases:
SimpleMetric
- __init__(name='balanced_accuracy')[source]¶
- Balanced accuracy metric for classification tasks. Used for the evaluation of imbalanced datasets.
For more information: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html
- class BinarySoftDiceCoefficient(name='BinarySoftDiceCoefficient', epsilon=1e-07, spatial_dimensions=(2, 3, 4), logits_threshold=0.5)[source]¶
Bases:
SimpleMetric
- __init__(name='BinarySoftDiceCoefficient', epsilon=1e-07, spatial_dimensions=(2, 3, 4), logits_threshold=0.5)[source]¶
Binary DICE Coefficient Metric with configurable spatial dimensions and logits threshold.
- Parameters:
name (str) – Name of the metric.
epsilon (float) – Small float to add to denominator of DICE calculation to avoid divide by 0.
spatial_dimensions (tuple[int, ...]) – The spatial dimensions of the image within the prediction tensors. The default assumes that the images are 3D and have shape: batch_size, channel, spatial, spatial, spatial.
logits_threshold (
float
|None
) – This is a threshold value where values above are classified as 1 and those below are mapped to 0. If the threshold is None, then no thresholding is performed and a continuous or “soft” DICE coefficient is computed.
- class F1(name='F1 score', average='weighted')[source]¶
Bases:
SimpleMetric
- __init__(name='F1 score', average='weighted')[source]¶
Computes the F1 score using the sklearn f1_score function. As such, the values of average correspond to those of that function.
- Parameters:
name (str, optional) – Name of the metric. Defaults to “F1 score”.
average (str | None, optional) – Whether to perform averaging of the F1 scores and how. The values of this string corresponds to those of the sklearn f1_score function. See: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html Defaults to “weighted”.
- class Metric(name)[source]¶
Bases:
ABC
- __init__(name)[source]¶
Base abstract Metric class to extend for metric accumulation and computation.
- Parameters:
name (str) – Name of the metric.
- abstract clear()[source]¶
Resets metric.
- Raises:
NotImplementedError – To be defined in the classes expending this class.
- Return type:
- abstract compute(name)[source]¶
Compute metric on accumulated input and output over updates.
- Parameters:
name (str | None) – Optional name used in conjunction with class attribute name to define key in metrics dictionary.
- Raises:
NotImplementedError – To be defined in the classes extending this class.
- Returns:
- A dictionary of string and Scalar representing the computed metric
and its associated key.
- Return type:
Metrics
- abstract update(input, target)[source]¶
This method updates the state of the metric by appending the passed input and target pairing to their respective list.
- Parameters:
input (torch.Tensor) – The predictions of the model to be evaluated.
target (torch.Tensor) – The ground truth target to evaluate predictions against.
- Raises:
NotImplementedError – To be defined in the classes extending this class.
- Return type:
- class MetricManager(metrics, metric_manager_name)[source]¶
Bases:
object
- __init__(metrics, metric_manager_name)[source]¶
Class to manage a set of metrics associated to a given prediction type.
- compute()[source]¶
Computes set of metrics for each prediction type.
- Returns:
- dictionary containing computed metrics along with string identifiers
for each prediction type.
- Return type:
Metrics
- update(preds, target)[source]¶
Updates (or creates then updates) a list of metrics for each prediction type.
- Parameters:
preds (TorchPredType) – A dictionary of preds from the model
target (TorchTargetType) – The ground truth labels for the data. If target is a dictionary with more than one item, then each value in the preds dictionary is evaluated with the value that has the same key in the target dictionary. If target has only one item or is a torch.Tensor, then the same target is used for all predictions
- Return type:
- class MetricPrefix(value)[source]¶
Bases:
Enum
An enumeration.
- TEST_PREFIX = 'test -'¶
- VAL_PREFIX = 'val -'¶
- class ROC_AUC(name='ROC_AUC score')[source]¶
Bases:
SimpleMetric
- __init__(name='ROC_AUC score')[source]¶
Area under the Receiver Operator Curve (AUCROC) metric for classification. For more information: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html
- class SimpleMetric(name)[source]¶
-
- __init__(name)[source]¶
Abstract metric class with base functionality to update, compute and clear metrics. User needs to define __call__ method which returns metric given inputs and target.
- Parameters:
name (str) – Name of the metric.
- compute(name=None)[source]¶
Compute metric on accumulated input and output over updates.
- Parameters:
name (str | None) – Optional name used in conjunction with class attribute name to define key in metrics dictionary.
- Raises:
AssertionError – Input and target lists must be non empty.
- Returns:
- A dictionary of string and Scalar representing the computed metric
and its associated key.
- Return type:
Metrics
- update(input, target)[source]¶
This method updates the state of the metric by appending the passed input and target pairing to their respective list.
- Parameters:
input (torch.Tensor) – The predictions of the model to be evaluated.
target (torch.Tensor) – The ground truth target to evaluate predictions against.
- Return type:
- class TorchMetric(name, metric)[source]¶
Bases:
Metric
- __init__(name, metric)[source]¶
Thin wrapper on TorchMetric to make it compatible with our Metric interface.
- Parameters:
name (str) – The name of the metric.
metric (TMetric) – TorchMetric class based metric
- clear()[source]¶
Resets metric.
- Raises:
NotImplementedError – To be defined in the classes expending this class.
- Return type:
- compute(name)[source]¶
Compute value of underlying TorchMetric.
- Parameters:
name (str | None) – Optional name used in conjunction with class attribute name to define key in metrics dictionary.
- Returns:
- A dictionary of string and Scalar representing the computed metric
and its associated key.
- Return type:
Metrics
- class TransformsMetric(metric, pred_transforms=None, target_transforms=None)[source]¶
Bases:
Metric
- __init__(metric, pred_transforms=None, target_transforms=None)[source]¶
A thin wrapper class to allow transforms to be applied to preds and targets prior to calculating metrics. Transforms are applied in the order given
- Parameters:
metric (Metric) – A FL4Health compatible metric
pred_transforms (Sequence[TorchTransformFunction] | None, optional) – A list of transform functions to apply to the model predictions before computing the metrics. Each callable must accept and return a torch. Tensor. Use partial to set other arguments.
target_transforms (Sequence[TorchTransformFunction] | None, optional) – A list of transform functions to apply to the targets before computing the metrics. Each callable must accept and return a torch.Tensor. Use partial to set other arguments.
- clear()[source]¶
Resets metric.
- Raises:
NotImplementedError – To be defined in the classes expending this class.
- Return type:
- compute(name)[source]¶
Compute metric on accumulated input and output over updates.
- Parameters:
name (str | None) – Optional name used in conjunction with class attribute name to define key in metrics dictionary.
- Raises:
NotImplementedError – To be defined in the classes extending this class.
- Returns:
- A dictionary of string and Scalar representing the computed metric
and its associated key.
- Return type:
Metrics
- update(pred, target)[source]¶
This method updates the state of the metric by appending the passed input and target pairing to their respective list.
- Parameters:
input (torch.Tensor) – The predictions of the model to be evaluated.
target (torch.Tensor) – The ground truth target to evaluate predictions against.
- Raises:
NotImplementedError – To be defined in the classes extending this class.
- Return type: