fl4health.metrics.metrics_utils module

compute_dice_on_count_tensors(true_positives, false_positives, false_negatives, zero_division)[source]

Given a set of count tensors representing true positives (TP), false positives (FP), and false negatives (FN), compute the Dice score as 2*TP/(2*TP + FP + FN) ELEMENTWISE. The zero division argument determines how to deal with examples with all true negatives, which implies that TP + FP + FN = 0 and an undefined value.

Parameters:
  • true_positives (torch.Tensor) – count of true positives in each entry

  • false_positives (torch.Tensor) – count of false positives in each entry

  • false_negatives (torch.Tensor) – count of false negatives in each entry

  • zero_division (float | None) – How to deal with zero division. If None, the values with zero division are simply dropped. If a float is specified, this value is injected into each Dice score that would have been undefined.

Returns:

Dice scores computed for each element in the TP, FP, FN tensors computed ELEMENTWISE with

replacement or dropping of undefined entries. The tensor returned is flattened to be 1D.

Return type:

torch.Tensor

threshold_tensor(input, threshold)[source]

Converts continuous ‘soft’ tensors into categorical ‘hard’ ones.

Parameters:
  • input (torch.Tensor) – The tensor to threshold.

  • threshold (float | int) – A float for thresholding values or an integer specifying the index of the label dimension. If a float is given, elements below the threshold are mapped to 0 and above are mapped to 1. If an integer is given, elements are thresholded based on the class with the highest prediction.

Returns:

Thresholded tensor

Return type:

torch.Tensor