fl4health.metrics.efficient_metrics_base module

class BinaryClassificationMetric(name, label_dim=None, batch_dim=None, dtype=torch.float32, pos_label=1, threshold=None, discard=None)[source]

Bases: ClassificationMetric

__init__(name, label_dim=None, batch_dim=None, dtype=torch.float32, pos_label=1, threshold=None, discard=None)[source]

A Base class for BINARY classification metrics that can be computed using the true positives (tp), false positive (fp), false negative (fn) and true negative (tn) counts. These counts are computed for each class independently. How they are composed together for the metric is left to inheriting classes.

On each update, the true_positives, false_positives, false_negatives and true_negatives counts for the provided predictions and targets are accumulated into self.true_positives, self.false_positives, self.false_negatives and self.true_negatives, respectively, for each label type. This reduces the memory footprint required to compute metrics across rounds. The user needs to define the compute_from_counts method which returns a dictionary of Scalar metrics given the true_positives, false_positives, false_negatives, and true_negatives counts. The accumulated counts are reset by the clear method. If your subclass returns multiple metrics you may need to also override the __call__ method.

If the predictions provided are continuous in value, then the associated counts will also be continuous (“soft”). For example, with a target of 1, a prediction of 0.8 contributes 0.8 to the true_positives count and 0.2 to the false_negatives.

NOTE: For this class, the predictions and targets passed to the update function MUST have the same shape.

NOTE: Preds and targets are expected to have elements in the interval [0, 1] or to be thresholded, using that argument to be as such.

NOTE: For this class, only the counts for the positive label are accumulated.

Parameters:
  • name (str) – The name of the metric.

  • label_dim (int | None, optional) – Specifies which dimension in the provided tensors corresponds to the label dimension. During metric computation, this dimension must have size of AT MOST 2. If left as None, this class will assume that there is no label dimension and that each entry in the tensor corresponds to a prediction/target, with the positive class label indicated by pos_label. In both cases, only the counts for the positive class label are accumulated and any counts/predictions for the negative class label are discarded. Defaults to None.

  • batch_dim (int | None, optional) –

    If None, the counts along the specified dimension (i.e. for each sample) are aggregated and the batch dimension is reduced. If specified, counts will be computed along the dimension specified. That is, counts are maintained for each training sample INDIVIDUALLY. For example, if batch_dim = 1 and label_dim = 0, then

    p = torch.tensor([[[0, 0, 0, 1], [1, 1, 1, 1]]])  # Size([1, 2, 4])
    
    t = torch.tensor([[[0, 0, 1, 0], [1, 1, 1, 1]]])  # Size([1, 2, 4])
    
    self.tp = torch.Tensor([[0], [4]])  # Size([2, 1])
    
    self.tn = torch.Tensor([[2], [0]])  # Size([2, 1])
    
    self.fp = torch.Tensor([[1], [0]])  # Size([2, 1])
    
    self.fn = torch.Tensor([[1], [0]])  # Size([2, 1])
    

    NOTE: The resulting counts will always be presented batch dimension first, then label dimension, regardless of input shape. Defaults to None.

  • dtype (torch.dtype) – The dtype to store the counts as. If preds or targets can be continuous, specify a float type. Otherwise specify an integer type to prevent overflow. Defaults to torch.float32.

  • pos_label (int, optional) – The label relative to which to report the counts. Must be either 0 or 1. Defaults to 1.

  • threshold (float | int | None, optional) – A float for thresholding values or an integer specifying the index of the label dimension. If a float is given, predictions below the threshold are mapped to 0 and above are mapped to 1. If an integer is given, predictions are binarized based on the class with the highest prediction where the specified axis is assumed to contain a prediction for each class (where its index along that dimension is the class label). Value of None leaves preds unchanged. Defaults to None.

  • discard (set[ClassificationOutcome] | None, optional) – One or several of ClassificationOutcome values. Specified outcome counts will not be accumulated. Their associated attribute will remain as an empty pytorch tensor. Useful for reducing the memory footprint of metrics that do not use all of the counts in their computation. Defaults to None.

compute_from_counts(true_positives, false_positives, true_negatives, false_negatives)[source]

Provided tensors associated with the various outcomes from predictions compared to targets in the form of true positives, false positives, true negatives, and false negatives, returns a dictionary of Scalar metrics. For example, one might compute recall as true_positives/(true_positives + false_negatives). The shape of these tensors are specific to how this object is configured, see class documentation above.

For this class it is assumed that all counts are presented relative to the class indicated by the pos_label index, counts for the negative label are discarded. Moreover, they are assumed to either have a shape (1,) or have shape (num_samples, 1) if batch_dim was specified. In the former, a single count is presented ACROSS all samples relative to the pos_label specified. In the latter, counts are computed WITHIN each sample, but held separate across samples. A concrete setting where this makes sense is binary image segmentation. You can have such counts summed for all pixels within an image, but separate per image. A metric could then be computed for each image and then averaged.

Parameters:
  • true_positives (torch.Tensor) – Counts associated with positive predictions and positive labels.

  • false_positives (torch.Tensor) – Counts associated with positive predictions and negative labels.

  • true_negatives (torch.Tensor) – Counts associated with negative predictions and negative labels.

  • false_negatives (torch.Tensor) – Counts associated with negative predictions and positive labels.

Raises:

NotImplementedError – Must be implemented by the inheriting class.

Returns:

Metrics computed from the provided outcome counts.

Return type:

Metrics

count_tp_fp_tn_fn(preds, targets)[source]

Given two tensors containing model predictions and targets, returns the number of true positives (tp), false positives (fp), true negatives (tn), and false negatives (fn).

This class overrides the base method to ensure that only the counts with respect to the positive label are returned.

The shape of these counts depends on if the values of self.batch_dim and self.label_dim are specified. If any of the true positive, false positive, true negative, or false negative counts were specified to be discarded during initialization of the class, then that count will not be computed and an empty tensor will be returned in its place.

Parameters:
  • preds (torch.Tensor) – Tensor containing model predictions. Must be the same shape as targets.

  • targets (torch.Tensor) – Tensor containing prediction targets. Must be same shape as preds.

Returns:

Tensors containing the counts along the specified dimensions for each of true positives, false positives, true negatives, and false negatives, respectively. If self.batch_dim is not None then these tensors will have shape (batch_size, 1), Otherwise, it will have shape (1,). The counts will be relative to the index of self.pos_label.

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

class ClassificationMetric(name, label_dim, batch_dim, dtype, threshold, discard)[source]

Bases: Metric, ABC

__init__(name, label_dim, batch_dim, dtype, threshold, discard)[source]

A Base class for efficiently computing classification metrics that can be calculated using the true positives (tp), false positive (fp), false negative (fn) and true negative (tn) counts.

How these values are counted is left to the inheriting class along with how they are composed together for the final metric score. There are two classes inheriting from this class to form the basis of efficient classification metrics: BinaryClassificationMetric and MultiClassificationMetric. These handle implementation of the count_tp_fp_tn_fn method.

On each update, the true_positives, false_positives, false_negatives and true_negatives counts for the provided predictions and targets are accumulated into self.true_positives, self.false_positives, self.false_negatives and self.true_negatives, respectively, and reduced along all unspecified dimensions. This reduces the memory footprint required to compute metrics across rounds. The user needs to define the compute_from_counts method which returns a dictionary of Scalar metrics given the true_positives, false_positives, false_negatives, and true_negatives counts. The accumulated counts are reset by the clear method. If your subclass returns multiple metrics you may need to also override the __call__ method.

If the predictions provided are continuous in value, then the associated counts will also be continuous (“soft”). For example, with a target of 1, a prediction of 0.8 contributes 0.8 to the true_positives count and 0.2 to the false_negatives.

NOTE: Preds and targets are expected to have elements in the interval [0, 1] or to be thresholded, using the argument of this class to be as such.

Parameters:
  • name (str) – The name of the metric.

  • label_dim (int | None, optional) –

    Specifies which dimension in the provided tensors corresponds to the label dimension. If None, the counts along the specified dimension (i.e. for each output label) are aggregated and the label dimension is reduced. If specified, counts will be computed along the specified dimensions. That is, counts are maintained for each output label INDIVIDUALLY.

    NOTE: If both label_dim and batch_dim are specified, then counts will be presented batch dimension first, then label dimension. If neither are specified, each count is a global scalar.

  • batch_dim (int | None, optional) –

    If None, the counts along the specified dimension (i.e. for each sample) are aggregated and the batch dimension is reduced. If specified, counts will be computed along the dimension specified. That is, counts are maintained for each training sample INDIVIDUALLY.

    NOTE: If both `label_dim` and batch_dim are specified, then counts will be presented batch dimension first, then label dimension. If neither are specified, each count is a global scalar.

  • dtype (torch.dtype) – The dtype to store the counts as. If preds or targets can be continuous, specify a float type. Otherwise specify an integer type to prevent overflow.

  • threshold (float | int | None, optional) – A float for thresholding values or an integer specifying the index of the label dimension. If a float is given, predictions below the threshold are mapped to 0 and above are mapped to 1. If an integer is given, predictions are binarized based on the class with the highest prediction where the specified axis is assumed to contain a prediction for each class (where its index along that dimension is the class label). Setting to None leaves preds unchanged.

  • discard (set[ClassificationOutcome] | None, optional) – One or several of ClassificationOutcome values. Specified counts will not be accumulated. Their associated attribute will remain as an empty pytorch tensor. Useful for reducing the memory footprint of metrics that do not use all of the counts in their computation.

clear()[source]

Reset accumulated tp, fp, fn and tn’s. They will be initialized with correct shape on next update.

Return type:

None

compute(name=None)[source]

Computes the metrics from the currently saved counts using the compute_from_counts function defined in inheriting classes.

Parameters:

name (str | None) – Optional name used in conjunction with class attribute name to define key in metrics dictionary.

Returns:

A dictionary of string and Scalar representing the computed metric and its associated key.

Return type:

Metrics

abstract compute_from_counts(true_positives, false_positives, true_negatives, false_negatives)[source]

Provided tensors associated with the various classification outcomes from predictions compared to targets in the form of true positives, false positives, true negatives, and false negatives, returns a dictionary of Scalar metrics. For example, one might compute recall as true_positives/(true_positives + false_negatives).

The count tensors will all have the same shape. This shape depends on whether batch_dim and or label_dim were specified. For example, if the batch and label dimensions have sizes 2 and 3 respectively, then:

  • Both batch_dim and label_dim are specified: Size([2, 3])

  • Only batch_dim is specified: Size([2])

  • Only label_dim is specified: Size([3])

  • Neither specified: Size([])

Inheriting classes may further modify the shapes of the count tensors that are provided as arguments depending on the kind of classification being done (eg. multi-class vs. binary).

Parameters:
  • true_positives (torch.Tensor) – Counts associated with positive predictions of a class and true positives for that class.

  • false_positives (torch.Tensor) – Counts associated with positive predictions of a class and true negatives for that class.

  • true_negatives (torch.Tensor) – Counts associated with negative predictions of a class and true negatives for that class.

  • false_negatives (torch.Tensor) – Counts associated with negative predictions of a class and true positives for that class.

Raises:

NotImplementedError – Must be implemented by the inheriting class.

Returns:

Metrics computed from the provided outcome counts.

Return type:

Metrics

count_tp_fp_tn_fn(preds, targets)[source]

Given two tensors containing model predictions and targets, returns the number of true positives (tp), false positives (fp), true negatives (tn), and false negatives (fn).

The shape of these counts depends on if self.batch_dim and self.label_dim are specified and the implementation of the inheriting class. If the batch dimension appears after the label dimension in the input tensors this method will transpose the count tensors to ensure the batch dimension comes first.

If any of the true positives, false positives, true negative, or false negative counts were specified to be discarded during initialization of the class, then that count will not be computed and an empty tensor will be returned in its place.

NOTE: Inheriting classes may implement additional functionality on top of this class. For example, any preprocessing that needs to be done to preds and targets should be done in the inheriting function. Any post processing should also be done there. See implementations in the BinaryClassificationMetric or MultiClassificationMetric class for examples.

Parameters:
  • preds (torch.Tensor) – Tensor containing model predictions. Must be the same shape as targets

  • targets (torch.Tensor) – Tensor containing prediction targets. Must be same shape as preds.

Returns:

Tensors containing the counts along the specified dimensions for each of true positives, false positives, true negatives, and false negatives respectively. The output shape of these tensors depends on if self.batch_dim and self.label_dim are specified. The batch dimension, if it exists in the output, will always come first. For example, if the batch and label dimensions have sizes 2 and 3 respectively:

  • Both batch_dim and label_dim are specified: Size([2, 3])

  • Only batch_dim is specified: Size([2])

  • Only label_dim is specified: Size([3])

  • Neither specified: Size([])

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

update(preds, targets)[source]

Updates the existing self.true_positive, self.false_positive, self.false_negative and self.true_negative counts with new counts computed from preds and targets.

NOTE: This function assumes that if self.batch_dim is not None, the counts are returned with shapes such that the batch dimension comes FIRST for the counts. If self.count_tp_fp_tn_fn is overridden it must ensure that this remains the case.

Parameters:
  • preds (torch.Tensor) – Predictions tensor.

  • targets (torch.Tensor) – Targets tensor.

Return type:

None

class ClassificationOutcome(value)[source]

Bases: Enum

An enumeration.

FALSE_NEGATIVE = 'false_negative'
FALSE_POSITIVE = 'false_positive'
TRUE_NEGATIVE = 'true_negative'
TRUE_POSITIVE = 'true_positive'
MetricOutcome

alias of ClassificationOutcome

class MultiClassificationMetric(name, label_dim, batch_dim=None, dtype=torch.float32, threshold=None, ignore_background=None, discard=None)[source]

Bases: ClassificationMetric

__init__(name, label_dim, batch_dim=None, dtype=torch.float32, threshold=None, ignore_background=None, discard=None)[source]

A Base class for multi-class, multi-label classification metrics that can be computed using the true positives (tp), false positive (fp), false negative (fn) and true negative (tn) counts. These counts are computed for each class independently. How they are composed together for the metric is left to inheriting classes.

On each update, the true_positives, false_positives, false_negatives and true_negatives counts for the provided predictions and targets are accumulated into self.true_positives, self.false_positives, self.false_negatives and self.true_negatives, respectively, for each label type. This reduces the memory footprint required to compute metrics across rounds. The user needs to define the compute_from_counts method which returns a dictionary of Scalar metrics given the true_positives, false_positives, false_negatives, and true_negatives counts. The accumulated counts are reset by the clear method. If your subclass returns multiple metrics you may need to also override the __call__ method.

If the predictions provided are continuous in value, then the associated counts will also be continuous (“soft”). For example, with a target of 1, a prediction of 0.8 contributes 0.8 to the true_positives count and 0.2 to the false_negatives.

NOTE: Preds and targets are expected to have elements in the interval [0, 1] or to be thresholded, using that argument to be as such.

NOTE: If preds and targets passed to update method have different shapes, or end up with different shapes after thresholding, this class will attempt to align the shapes by one-hot-encoding one of the tensors in the label dimension, if possible.

Parameters:
  • name (str) – The name of the metric.

  • label_dim (int) – Specifies which dimension in the provided tensors corresponds to the label dimension. During metric computation, this dimension must have size of AT LEAST 2. Counts are always computed along the label dimension. That is, counts are maintained for each output label INDIVIDUALLY.

  • batch_dim (int | None, optional) –

    If None, the counts along the specified dimension (i.e. for each sample) are aggregated and the batch dimension is reduced. If specified, counts will be computed along the dimension specified. That is, counts are maintained for each training sample INDIVIDUALLY.

    NOTE: If batch_dim is specified, then counts will be presented batch dimension first, then label dimension. For example, if batch_dim = 1 and label_dim = 0, then

    p = torch.tensor([[[1.0, 1.0, 1.0, 0.0]], [[0.0, 0.0, 0.0, 1.0]]])  # Size([2, 1, 4])
    
    t = torch.tensor([[[1.0, 1.0, 0.0, 0.0]], [[0.0, 0.0, 1.0, 1.0]]])  # Size([2, 1, 4])
    
    self.tp = torch.Tensor([[2, 1]]]) # Size([1, 2])
    
    self.tn = torch.Tensor([[1, 2]])  # Size([1, 2])
    
    self.fp = torch.Tensor([[1, 0]])  # Size([1, 2])
    
    self.fn = torch.Tensor([[0, 1]])  # Size([1, 2])
    

    NOTE: The resulting counts will always be presented batch dimension first, then label dimension, regardless of input shape. Defaults to None.

  • dtype (torch.dtype) – The dtype to store the counts as. If preds or targets can be continuous, specify a float type. Otherwise specify an integer type to prevent overflow. Defaults to torch.float32.

  • threshold (float | int | None, optional) – A float for thresholding values or an integer specifying the index of the label dimension. If a float is given, predictions below the threshold are mapped to 0 and above are mapped to 1. If an integer is given, predictions are binarized based on the class with the highest prediction where the specified axis is assumed to contain a prediction for each class (where its index along that dimension is the class label). Value of None leaves preds unchanged. Defaults to None.

  • ignore_background (int | None) – If specified, the FIRST channel of the specified axis is removed prior to computing the counts. Useful for removing background classes. Defaults to None.

  • discard (set[ClassificationOutcome] | None, optional) – One or several of ClassificationOutcome values. Specified outcome counts will not be accumulated. Their associated attribute will remain as an empty pytorch tensor. Useful for reducing the memory footprint of metrics that do not use all of the counts in their computation.

compute_from_counts(true_positives, false_positives, true_negatives, false_negatives)[source]

Provided tensors associated with the various outcomes from predictions compared to targets in the form of true positives, false positives, true negatives, and false negatives, returns a dictionary of Scalar metrics. For example, one might compute recall as true_positives/(true_positives + false_negatives). The shape of these tensors is The shape of these tensors are specific to how this object is configured, see class documentation above.

For this class, counts are assumed to have shape (num_labels,) or (num_samples, num_labels). In the former, counts have been aggregated ACROSS samples into single count values for each possible label. In the later, counts have been aggregated WITHIN each sample and remain separate across examples. A concrete setting where this makes sense is image segmentation. You can have such counts summed for all pixels within an image, but separate per image. A metric could then be computed for each image and then averaged.

NOTE: A user can implement further reduction along the label dimension (summing TPs across labels for example), if desired. It just needs to be handled in the implementation of this function.

Parameters:
  • true_positives (torch.Tensor) – Counts associated with positive predictions of a class and true positives for that class.

  • false_positives (torch.Tensor) – Counts associated with positive predictions of a class and true negatives for that class.

  • true_negatives (torch.Tensor) – Counts associated with negative predictions of a class and true negatives for that class.

  • false_negatives (torch.Tensor) – Counts associated with negative predictions of a class and true positives for that class.

Raises:

NotImplementedError – Must be implemented by the inheriting class.

Returns:

Metrics computed from the provided outcome counts.

Return type:

Metrics