Source code for fl4health.metrics.efficient_metrics

from logging import WARNING

import torch
from flwr.common.logger import log
from flwr.common.typing import Metrics, Scalar

from fl4health.metrics.efficient_metrics_base import (
    BinaryClassificationMetric,
    MetricOutcome,
    MultiClassificationMetric,
)
from fl4health.metrics.metrics_utils import compute_dice_on_count_tensors


[docs] class MultiClassDice(MultiClassificationMetric):
[docs] def __init__( self, batch_dim: int | None, label_dim: int, name: str = "MultiClassDice", dtype: torch.dtype = torch.float32, threshold: float | int | None = None, ignore_background: int | None = None, zero_division: float | None = None, ) -> None: """ Computes the Mean Dice Coefficient between class predictions and targets with multiple classes. NOTE: The default behavior for Dice Scores is to compute the mean over each SAMPLE of the dataset being measured. In the image domain, for example, this means that the Dice score is computed for each image separately and then averaged across images (then classes) to produce a single score. This is accomplished by specifying the batch_dim here. If, however, you would like to compute the Dice score over ALL TP, FP, FNs across all samples (then classes) as a single count, batch_dim = None is appropriate. NOTE: Preds and targets are expected to have elements in the interval [0, 1] or to be thresholded, using that argument to be as such. NOTE: If preds and targets passed to the update method have different shapes, this class will attempt to align the shapes by one-hot-encoding one (but not both) of the tensors if possible. NOTE: In the case of BINARY predictions/targets with 2 labels, the result will be the AVERAGE of the Dice score for the two labels. If you want a single score associated with one of the binary labels, use BinaryDice. Args: batch_dim (int | None, optional): If None, then counts are aggregated across the batch dimension. If specified, counts will be computed along the dimension specified. That is, counts are maintained for each training sample INDIVIDUALLY. For example, if batch_dim = 1 and label_dim = 0, then .. code-block:: python p = torch.tensor([[[1., 1., 1., 0.], [0., 0., 0., 0.]], [[0., 0., 0., 1.], [1., 1., 1., 1.]]]) t = torch.tensor([[[1., 1., 0., 0.], [0., 0., 0., 0.]], [[0., 0., 1., 1.], [1., 1., 1., 1.]]]) self.tp = torch.Tensor([[2, 1], [0, 4]]) self.tn = torch.Tensor([[1, 2], [4, 0]]) self.fp = torch.Tensor([[1, 0], [0, 0]]) self.fn = torch.Tensor([[0, 1], [0, 0]]) In computing the Dice score, we get scores for each sample, label pair as [[2*2/(2*2+1+0), 2*1/(2*1+0+1)], [0*2/(0*2+0+0), 2*4/(2*4+0+0)]]. Assuming zero_division = None, the undefined calculation at (1, 0) is dropped and the remainder of the individual scores are averaged to be (1/3)*(4/5 + 2/3 + 8/8) = 0.8222 label_dim (int): Specifies which dimension in the provided tensors corresponds to the label dimension. During metric computation, this dimension must have size of AT LEAST 2 and is required for this class. name (str): Name of the metric. Defaults to 'MultiClassDice' dtype (torch.dtype): The dtype to store the counts as. If preds or targets can be continuous, specify a float type. Otherwise specify an integer type to prevent overflow. Defaults to torch.float32 threshold (float | int | None, optional): A float for thresholding values or an integer specifying the index of the label dimension. If a float is given, predictions below the threshold are mapped to 0 and above are mapped to 1. If an integer is given, predictions are binarized based on the class with the highest prediction where the specified axis is assumed to contain a prediction for each class (where its index along that dimension is the class label). Default of None leaves preds unchanged. ignore_background (int | None): If specified, the FIRST channel of the specified axis is removed prior to computing the counts. Useful for removing background classes. Defaults to None. zero_division (float | None, optional): Set what the individual Dice coefficients should be when there is a zero division (only true negatives present). How this argument affects the final Dice score will vary depending on the Dice scores for other labels. If left as None, the resultant Dice coefficients will be excluded from the average/final Dice score. """ super().__init__( name=name, batch_dim=batch_dim, label_dim=label_dim, dtype=dtype, threshold=threshold, ignore_background=ignore_background, discard={MetricOutcome.TRUE_NEGATIVE}, ) self.zero_division = zero_division
[docs] def compute_from_counts( self, true_positives: torch.Tensor, false_positives: torch.Tensor, true_negatives: torch.Tensor, false_negatives: torch.Tensor, ) -> Metrics: """ Computes a multi-class Dice score, defined to be the mean Dice score across all labels in the multi-class problem. This score is computed relative the outcome counts provided in the form of true positives (TP), false positives (FP), and false negatives (FN). Because Dice scores don't factor in true negatives, this argument is unused. For a set of counts, the Dice score for a particular label is 2*TP/(2*TP + FP + FN). For this class, counts are assumed to have shape (num_labels,) or (num_samples, num_labels). In the former, a single Dice score is computed relative to the counts for each label and then AVERAGED. In the latter, an AVERAGE dice score over both the samples AND labels computed. The second setting is useful, for example, if you are computing the Dice score per image and then averaging. The first setting is useful, for example, if you want to treat all examples as a SINGLE image. Args: true_positives (torch.Tensor): Counts associated with positive predictions and positive labels false_positives (torch.Tensor): Counts associated with positive predictions and negative labels true_negatives (torch.Tensor): Counts associated with negative predictions and negative labels false_negatives (torch.Tensor): Counts associated with negative predictions and positive labels Returns: Metrics: A mean dice score associated with the counts """ # compute dice coefficients and return mean dice = compute_dice_on_count_tensors(true_positives, false_positives, false_negatives, self.zero_division) if dice.numel() == 0: log(WARNING, "Currently, Dice score is undefined due to only true negatives present") return {self.name: torch.mean(dice).item()}
def __call__(self, input: torch.Tensor, target: torch.Tensor) -> Scalar: """ Computes the Dice score relative to the single input and target tensors provided Args: input (torch.Tensor): predictions tensor target (torch.Tensor): target tensor Returns: Scalar: Mean dice score for the provided tensors """ true_positives, false_positives, true_negatives, false_negatives = self.count_tp_fp_tn_fn(input, target) dice_metric = self.compute_from_counts(true_positives, false_positives, true_negatives, false_negatives) # Extract the scalar from the dictionary. return dice_metric[self.name]
[docs] class BinaryDice(BinaryClassificationMetric):
[docs] def __init__( self, batch_dim: int | None, name: str = "BinaryDice", label_dim: int | None = None, dtype: torch.dtype = torch.float32, pos_label: int = 1, threshold: float | int | None = None, zero_division: float | None = None, ) -> None: """ Computes the Dice Coefficient between binary predictions and targets. These can be vector encoded or just single elements values with an implicit positive class. That is, predictions might be vectorized where a single predictions is a 2D vector [0.2, 0.8] or a float 0.8 (with the complement implied). Regardless of how the input is structured, the provided score will be provided with respect to the value of the ``pos_label'' variable, which defaults to 1 (and can only have values {0, 1}). That is, the reported score will correspond to the score from the perspective of the specified label. For additional documentation see that of the parent class ``BinaryClassificationMetric`` and the function ``_post_process_count_tensor`` therein NOTE: For this class, the predictions and targets passed to the update function MUST have the same shape NOTE: The default behavior for Dice Scores is to compute the mean over each SAMPLE of the dataset being measured. In the image domain, for example, this means that the Dice score is computed for each image separately and then averaged across images (then classes) to produce a single score. This is accomplished by specifying the batch_dim here. If, however, you would like to compute the Dice score over ALL TP, FP, FNs across all samples (then classes) as a single count, batch_dim = None is appropriate. NOTE: Preds and targets are expected to have elements in the interval [0, 1] or to be thresholded, using the argument of this class to be as such. Args: batch_dim (int | None, optional): If None, then counts are aggregated across the batch dimension. If specified, counts will be computed along the dimension specified. That is, counts are maintained for each training sample INDIVIDUALLY. For example, if batch_dim = 1 and label_dim = 0, then .. code-block:: python predictions = torch.tensor([[[0, 0, 0, 1], [1, 1, 1, 1]]]) # shape (1, 2, 4) targets = torch.tensor([[[0, 0, 1, 0], [1, 1, 1, 1]]]) # shape (1, 2, 4) self.true_positives = torch.Tensor([[0], [4]]) self.true_negatives = torch.Tensor([[2], [0]]) self.false_positives = torch.Tensor([[1], [0]]) self.false_negatives = torch.Tensor([[1], [0]]) In computing the Dice score, we get scores for each sample [[2*0/(2*0 +1+1)], [2*4/(2*4+0+0)]]. These are then averaged to get 0.5. name (str): Name of the metric. Defaults to 'BinaryDice' label_dim (int | None, optional): Specifies which dimension in the provided tensors corresponds to the label dimension. During metric computation, this dimension must have size of AT MOST 2. If left as None, this class will assume that each entry in the tensor corresponds to a prediction/target, with the positive class indicated by predictions of 1. Defaults to None. dtype (torch.dtype): The dtype to store the counts as. If preds or targets can be continuous, specify a float type. Otherwise specify an integer type to prevent overflow. Defaults to torch.float32 pos_label (int, optional): The label relative to which to report the Dice. Must be either 0 or 1. Defaults to 1. threshold (float | int | None, optional): A float for thresholding values or an integer specifying the index of the label dimension. If a float is given, predictions below the threshold are mapped to 0 and above are mapped to 1. If an integer is given, predictions are binarized based on the class with the highest prediction where the specified axis is assumed to contain a prediction for each class (where its index along that dimension is the class label). Value of None leaves preds unchanged. Defaults to None. zero_division (float | None, optional): Set what the individual dice coefficients should be when there is a zero division (only true negatives present). If None, these examples will be dropped. If all components are only TNs, then NaN will be returned. """ # The right set of counts that can be ignored for Dice computation depends on which label relative to which # we're reporting the score. If reporting relative to the positive label, then we need not track # True Negatives, as they don't factor into the standard Dice score. On the other hand, if reporting relative # to the negative class, we need not keep True Positives around, for the same reason. discard = {MetricOutcome.TRUE_NEGATIVE} if pos_label == 1 else {MetricOutcome.TRUE_POSITIVE} super().__init__( name=name, batch_dim=batch_dim, label_dim=label_dim, dtype=dtype, threshold=threshold, pos_label=pos_label, discard=discard, ) self.zero_division = zero_division
[docs] def compute_from_counts( self, true_positives: torch.Tensor, false_positives: torch.Tensor, true_negatives: torch.Tensor, false_negatives: torch.Tensor, ) -> Metrics: """ Computes a binary Dice score associated with the outcome counts provided in the form of true positives (TP), false positives (FP), and false negatives (FN). Because Dice scores don't factor in true negatives, this argument is unused. For a set of counts, the binary Dice score is 2*TP/(2*TP + FP + FN). For this class it is assumed that all counts are presented relative to the class indicated by the `pos_label` index. Moreover, they are assumed to either have a single entry or have shape (num_samples, 1). In the former, a single Dice score is computed relative to the counts. In the latter, a MEAN dice score over the samples is computed. The second setting is useful, for example, if you are computing the Dice score per image and then averaging. The first setting is useful, for example, if you want to treat all examples as a SINGLE image. Args: true_positives (torch.Tensor): Counts associated with positive predictions and positive labels false_positives (torch.Tensor): Counts associated with positive predictions and negative labels true_negatives (torch.Tensor): Counts associated with negative predictions and negative labels false_negatives (torch.Tensor): Counts associated with negative predictions and positive labels Returns: Metrics: A mean dice score associated with the counts """ # compute dice coefficients and return mean dice = compute_dice_on_count_tensors(true_positives, false_positives, false_negatives, self.zero_division) if dice.numel() == 0: log(WARNING, "Currently, Dice score is undefined due to only true negatives present") return {self.name: torch.mean(dice).item()}
def __call__(self, input: torch.Tensor, target: torch.Tensor) -> Scalar: """ Computes the Dice score relative to the single input and target tensors provided Args: input (torch.Tensor): predictions tensor target (torch.Tensor): target tensor Returns: Scalar: Mean dice score for the provided tensors """ true_positives, false_positives, true_negatives, false_negatives = self.count_tp_fp_tn_fn(input, target) dice_metric = self.compute_from_counts(true_positives, false_positives, true_negatives, false_negatives) # Extract the scalar from the dictionary. return dice_metric[self.name]