fl4health.metrics.metrics_utils module¶
- compute_dice_on_count_tensors(true_positives, false_positives, false_negatives, zero_division)[source]¶
Given a set of count tensors representing true positives (TP), false positives (FP), and false negatives (FN), compute the Dice score as 2*TP/(2*TP + FP + FN) ELEMENTWISE. The zero division argument determines how to deal with examples with all true negatives, which implies that TP + FP + FN = 0 and an undefined value.
- Parameters:
true_positives (torch.Tensor) – count of true positives in each entry
false_positives (torch.Tensor) – count of false positives in each entry
false_negatives (torch.Tensor) – count of false negatives in each entry
zero_division (float | None) – How to deal with zero division. If None, the values with zero division are simply dropped. If a float is specified, this value is injected into each Dice score that would have been undefined.
- Returns:
- Dice scores computed for each element in the TP, FP, FN tensors computed ELEMENTWISE with
replacement or dropping of undefined entries. The tensor returned is flattened to be 1D.
- Return type:
torch.Tensor
- threshold_tensor(input, threshold)[source]¶
Converts continuous ‘soft’ tensors into categorical ‘hard’ ones.
- Parameters:
input (torch.Tensor) – The tensor to threshold.
threshold (float | int) – A float for thresholding values or an integer specifying the index of the label dimension. If a float is given, elements below the threshold are mapped to 0 and above are mapped to 1. If an integer is given, elements are thresholded based on the class with the highest prediction.
- Returns:
Thresholded tensor
- Return type:
torch.Tensor