from abc import ABC, abstractmethod
import numpy as np
import torch
from flwr.common.typing import Metrics, Scalar
from sklearn import metrics as sklearn_metrics
from torchmetrics import Metric as TMetric
from fl4health.metrics.base_metrics import Metric
[docs]
class TorchMetric(Metric):
[docs]
def __init__(self, name: str, metric: TMetric) -> None:
"""
Thin wrapper on TorchMetric to make it compatible with our ``Metric`` interface.
Args:
name (str): The name of the metric.
metric (TMetric): ``TorchMetric`` class based metric
"""
super().__init__(name)
self.metric = metric
[docs]
def update(self, input: torch.Tensor, target: torch.Tensor) -> None:
"""
Updates the state of the underlying ``TorchMetric``.
Args:
input (torch.Tensor): The predictions of the model to be evaluated.
target (torch.Tensor): The ground truth target to evaluate predictions against.
"""
self.metric.update(input, target.long())
[docs]
def compute(self, name: str | None = None) -> Metrics:
"""
Compute value of underlying ``TorchMetric``.
Args:
name (str | None): Optional name used in conjunction with class attribute name to define key in metrics
dictionary.
Returns:
Metrics: A dictionary of string and ``Scalar`` representing the computed metric and its associated key.
"""
result_key = f"{name} - {self.name}" if name is not None else self.name
result = self.metric.compute().item()
return {result_key: result}
[docs]
def clear(self) -> None:
self.metric.reset()
[docs]
class SimpleMetric(Metric, ABC):
[docs]
def __init__(self, name: str) -> None:
"""
Abstract metric class with base functionality to update, compute and clear metrics. User needs to define
``__call__`` method which returns metric given inputs and target.
Args:
name (str): Name of the metric.
"""
super().__init__(name)
self.accumulated_inputs: list[torch.Tensor] = []
self.accumulated_targets: list[torch.Tensor] = []
[docs]
def update(self, input: torch.Tensor, target: torch.Tensor) -> None:
"""
This method updates the state of the metric by appending the passed input and target pairing to their
respective list.
Args:
input (torch.Tensor): The predictions of the model to be evaluated.
target (torch.Tensor): The ground truth target to evaluate predictions against.
"""
self.accumulated_inputs.append(input)
self.accumulated_targets.append(target)
[docs]
def compute(self, name: str | None = None) -> Metrics:
"""
Compute metric on accumulated input and output over updates.
Args:
name (str | None): Optional name used in conjunction with class attribute name to define key in metrics
dictionary.
Raises:
AssertionError: Input and target lists must be non empty.
Returns:
Metrics: A dictionary of string and ``Scalar`` representing the computed metric and its associated key.
"""
assert len(self.accumulated_inputs) > 0 and len(self.accumulated_targets) > 0
stacked_inputs = torch.cat(self.accumulated_inputs)
stacked_targets = torch.cat(self.accumulated_targets)
result = self.__call__(stacked_inputs, stacked_targets)
result_key = f"{name} - {self.name}" if name is not None else self.name
return {result_key: result}
[docs]
def clear(self) -> None:
"""
Resets metrics by clearing input and target lists.
"""
self.accumulated_inputs = []
self.accumulated_targets = []
@abstractmethod
def __call__(self, input: torch.Tensor, target: torch.Tensor) -> Scalar:
"""
User defined method that calculates the desired metric given the predictions and target.
Raises:
NotImplementedError: User must define this method.
"""
raise NotImplementedError
[docs]
class BinarySoftDiceCoefficient(SimpleMetric):
[docs]
def __init__(
self,
name: str = "BinarySoftDiceCoefficient",
epsilon: float = 1.0e-7,
spatial_dimensions: tuple[int, ...] = (2, 3, 4),
logits_threshold: float | None = 0.5,
):
"""
Binary DICE Coefficient Metric with configurable spatial dimensions and logits threshold.
Args:
name (str): Name of the metric.
epsilon (float): Small float to add to denominator of DICE calculation to avoid divide by 0.
spatial_dimensions (tuple[int, ...]): The spatial dimensions of the image within the prediction tensors.
The default assumes that the images are 3D and have shape:
(``batch_size``, channel, spatial, spatial, spatial)
logits_threshold: This is a threshold value where values above are classified as 1 and those below are
mapped to 0. If the threshold is None, then no thresholding is performed and a continuous or "soft"
DICE coefficient is computed.
"""
self.epsilon = epsilon
self.spatial_dimensions = spatial_dimensions
self.logits_threshold = logits_threshold
super().__init__(name)
def __call__(self, logits: torch.Tensor, target: torch.Tensor) -> Scalar:
# Assuming the logits are to be mapped to binary. Note that this assumes the logits have already been
# constrained to [0, 1]. The metric still functions if not, but results will be unpredictable.
if self.logits_threshold:
y_pred = (logits > self.logits_threshold).int()
else:
y_pred = logits
intersection = (y_pred * target).sum(dim=self.spatial_dimensions)
union = (0.5 * (y_pred + target)).sum(dim=self.spatial_dimensions)
dice = intersection / (union + self.epsilon)
# If both inputs are empty the dice coefficient should be equal 1
dice[union == 0] = 1
return torch.mean(dice).item()
[docs]
class Accuracy(SimpleMetric):
[docs]
def __init__(self, name: str = "accuracy"):
"""
Accuracy metric for classification tasks.
Args:
name (str): The name of the metric.
"""
super().__init__(name)
def __call__(self, logits: torch.Tensor, target: torch.Tensor) -> Scalar:
# assuming batch first
assert logits.shape[0] == target.shape[0]
# Single value output, assume binary logits
if len(logits.shape) == 1 or logits.shape[1] == 1:
preds = (logits > 0.5).int()
else:
preds = torch.argmax(logits, 1)
target = target.cpu().detach()
preds = preds.cpu().detach()
return sklearn_metrics.accuracy_score(target, preds)
[docs]
class BalancedAccuracy(SimpleMetric):
[docs]
def __init__(self, name: str = "balanced_accuracy"):
"""
Balanced accuracy metric for classification tasks. Used for the evaluation of imbalanced datasets. For more
information:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html
"""
super().__init__(name)
def __call__(self, logits: torch.Tensor, target: torch.Tensor) -> Scalar:
# assuming batch first
assert logits.shape[0] == target.shape[0]
target = target.cpu().detach()
logits = logits.cpu().detach()
y_true = target.reshape(-1)
preds = np.argmax(logits, axis=1)
return sklearn_metrics.balanced_accuracy_score(y_true, preds)
[docs]
class ROC_AUC(SimpleMetric):
[docs]
def __init__(self, name: str = "ROC_AUC score"):
"""
Area under the Receiver Operator Curve (AUCROC) metric for classification. For more information:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html
"""
super().__init__(name)
def __call__(self, logits: torch.Tensor, target: torch.Tensor) -> Scalar:
assert logits.shape[0] == target.shape[0]
prob = torch.nn.functional.softmax(logits, dim=1)
prob = prob.cpu().detach()
target = target.cpu().detach()
y_true = target.reshape(-1)
return sklearn_metrics.roc_auc_score(y_true, prob, average="weighted", multi_class="ovr")
[docs]
class F1(SimpleMetric):
[docs]
def __init__(
self,
name: str = "F1 score",
average: str | None = "weighted",
):
"""
Computes the F1 score using the ``sklearn f1_score`` function. As such, the values of average correspond to
those of that function.
Args:
name (str, optional): Name of the metric. Defaults to "F1 score".
average (str | None, optional): Whether to perform averaging of the F1 scores and how. The values of this
string corresponds to those of the ``sklearn f1_score function``. See:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
Defaults to "weighted".
"""
super().__init__(name)
self.average = average
def __call__(self, logits: torch.Tensor, target: torch.Tensor) -> Scalar:
assert logits.shape[0] == target.shape[0]
target = target.cpu().detach()
logits = logits.cpu().detach()
y_true = target.reshape(-1)
preds = np.argmax(logits, axis=1)
return sklearn_metrics.f1_score(y_true, preds, average=self.average)