Source code for fl4health.metrics.metrics_utils

import torch


[docs] def compute_dice_on_count_tensors( true_positives: torch.Tensor, false_positives: torch.Tensor, false_negatives: torch.Tensor, zero_division: float | None, ) -> torch.Tensor: """ Given a set of count tensors representing true positives (TP), false positives (FP), and false negatives (FN), compute the Dice score as 2*TP/(2*TP + FP + FN) ELEMENTWISE. The zero division argument determines how to deal with examples with all true negatives, which implies that TP + FP + FN = 0 and an undefined value. Args: true_positives (torch.Tensor): count of true positives in each entry false_positives (torch.Tensor): count of false positives in each entry false_negatives (torch.Tensor): count of false negatives in each entry zero_division (float | None): How to deal with zero division. If None, the values with zero division are simply dropped. If a float is specified, this value is injected into each Dice score that would have been undefined. Returns: torch.Tensor: Dice scores computed for each element in the TP, FP, FN tensors computed ELEMENTWISE with replacement or dropping of undefined entries. The tensor returned is flattened to be 1D. """ # Compute union and intersection numerator = 2 * true_positives # Equivalent to 2 times the intersection denominator = 2 * true_positives + false_positives + false_negatives # Equivalent to the union # Remove or replace dice score that will be null due to zero division if zero_division is None: numerator = numerator[denominator != 0] denominator = denominator[denominator != 0] else: numerator[denominator == 0] = zero_division denominator[denominator == 0] = 1 # Return individual dice coefficients return numerator / denominator
[docs] def threshold_tensor(input: torch.Tensor, threshold: float | int) -> torch.Tensor: """ Converts continuous 'soft' tensors into categorical 'hard' ones. Args: input (torch.Tensor): The tensor to threshold. threshold (float | int): A float for thresholding values or an integer specifying the index of the label dimension. If a float is given, elements below the threshold are mapped to 0 and above are mapped to 1. If an integer is given, elements are thresholded based on the class with the highest prediction. Returns: torch.Tensor: Thresholded tensor """ if isinstance(threshold, float): thresholded_tensor = torch.zeros_like(input) mask_1 = input > threshold thresholded_tensor[mask_1] = 1 return thresholded_tensor elif isinstance(threshold, int): # Use argmax to get predicted class labels (hard_preds) and the one-hot-encode them. if threshold >= input.ndim: raise ValueError( f"Cannot apply argmax to Tensor of shape {input.shape}. " f"Label dimension of {threshold} is out of range of tensor with {input.ndim} dimensions." ) hard_input = input.argmax(threshold, keepdim=True) input = torch.zeros_like(input) input.scatter_(threshold, hard_input, 1) return input else: raise ValueError(f"Was expecting threshold argument to be either a float or an int. Got {type(threshold)}")