import io
import os
import signal
import sys
import warnings
from collections.abc import Callable, Sequence
from enum import Enum
from importlib import reload
from logging import DEBUG, INFO, WARN, Logger
from math import ceil
from typing import Any, no_type_check
import numpy as np
import torch
from flwr.common.logger import log
from torch import nn
from torch.nn.modules.loss import _Loss
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from fl4health.utils.typing import LogLevel
with warnings.catch_warnings():
# silences a bunch of deprecation warnings related to scipy.ndimage
# Raised an issue with nnunet. https://github.com/MIC-DKFZ/nnUNet/issues/2370
warnings.filterwarnings("ignore", category=DeprecationWarning)
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
[docs]
class NnunetConfig(Enum):
"""
The possible nnunet model configs as of nnunetv2 version 2.5.1.
See https://github.com/MIC-DKFZ/nnUNet/tree/v2.5.1
"""
_2D = "2d"
_3D_FULLRES = "3d_fullres"
_3D_CASCADE = "3d_cascade_fullres"
_3D_LOWRES = "3d_lowres"
NNUNET_DEFAULT_NP = { # Nnunet's default number of processes for each config
NnunetConfig._2D: 8,
NnunetConfig._3D_FULLRES: 4,
NnunetConfig._3D_CASCADE: 4,
NnunetConfig._3D_LOWRES: 8,
}
NNUNET_N_SPATIAL_DIMS = { # The number of spatial dims for each config
NnunetConfig._2D: 2,
NnunetConfig._3D_FULLRES: 3,
NnunetConfig._3D_CASCADE: 3,
NnunetConfig._3D_LOWRES: 3,
}
[docs]
def use_default_signal_handlers(fn: Callable) -> Callable:
"""
This is a decorator that resets the SIGINT and SIGTERM signal handlers back to the
python defaults for the execution of the method
flwr 1.9.0 overrides the default signal handlers with handlers that raise an error
on any interruption or termination. Since nnunet spawns child processes which
inherit these handlers, when those subprocesses are terminated (which is expected
behavior), the flwr signal handlers raise an error (which we don't want).
Flwr is expected to fix this in the next release. See the following issue:
https://github.com/adap/flower/issues/3837
"""
def new_fn(*args: Any, **kwargs: Any) -> Any:
# Set SIGINT and SIGTERM back to defaults. Method returns previous handler
sigint_old = signal.signal(signal.SIGINT, signal.default_int_handler)
sigterm_old = signal.signal(signal.SIGTERM, signal.SIG_DFL)
# Execute function
output = fn(*args, **kwargs)
# Reset handlers back to what they were before function call
signal.signal(signal.SIGINT, sigint_old)
signal.signal(signal.SIGTERM, sigterm_old)
return output
return new_fn
[docs]
def reload_modules(packages: Sequence[str]) -> None:
"""
Given the names of one or more packages, subpackages or modules, reloads all the
modules within the scope of each package or the modules themselves if a module was
specified.
Args:
package (Sequence[str]): The absolute names of the packages, subpackages or
modules to reload. The entire import hierarchy must be specified. Eg.
'package.subpackage' to reload all modules in subpackage, not just
'subpackage'. Packages are reloaded in the order they are given
"""
for m_name, module in list(sys.modules.items()):
for package in packages:
if m_name.startswith(package):
try:
reload(module)
except Exception as e:
log(DEBUG, f"Failed to reload module {m_name}: {e}")
[docs]
def set_nnunet_env(verbose: bool = False, **kwargs: str) -> None:
"""
For each keyword argument name and value sets the current environment variable with
the same name to that value and then reloads nnunet. Values must be strings. This
is necessary because nnunet checks some environment variables on import, and
therefore it must be imported or reloaded after they are set.
"""
# Set environment variables
for key, val in kwargs.items():
os.environ[key] = val
if verbose:
log(INFO, f"Resetting env var '{key}' to '{val}'")
# Its necessary to reload nnunetv2.paths first, then other modules with env vars
reload_modules(["nnunetv2.paths"])
reload_modules(["nnunetv2.default_n_proc_DA", "nnunetv2.configuration"])
# Reload whatever depends on nnunetv2 environment variables
# Be careful. If you reload something with an enum in it, things get messed up.
reload_modules(["nnunetv2", "fl4health.clients.nnunet_client"])
# The two convert deepsupervision methods are necessary because fl4health requires
# predictions, targets and inputs to be single torch.Tensors or Dicts of torch.Tensors
[docs]
def convert_deep_supervision_list_to_dict(
tensor_list: list[torch.Tensor] | tuple[torch.Tensor], num_spatial_dims: int
) -> dict[str, torch.Tensor]:
"""
Converts a list of torch.Tensors to a dictionary. Names the keys for
each tensor based on the spatial resolution of the tensor and its
index in the list. Useful for nnUNet models with deep supervision where
model outputs and targets loaded by the dataloader are lists. Assumes the
spatial dimensions of the tensors are last.
Args:
tensor_list (list[torch.Tensor]): A list of tensors, usually either
nnunet model outputs or targets, to be converted into a dictionary
num_spatial_dims (int): The number of spatial dimensions. Assumes the
spatial dimensions are last
Returns:
dict[str, torch.Tensor]: A dictionary containing the tensors as
values where the keys are 'i-XxYxZ' where i was the tensor's index
in the list and X,Y,Z are the spatial dimensions of the tensor
"""
# Convert list of targets into a dictionary
tensors = {}
for i, tensor in enumerate(tensor_list):
# generate a key based on the spatial dimension and index
key = str(i) + "-" + "x".join([str(s) for s in tensor.shape[-num_spatial_dims:]])
tensors[key] = tensor
return tensors
[docs]
def convert_deep_supervision_dict_to_list(tensor_dict: dict[str, torch.Tensor]) -> list[torch.Tensor]:
"""
Converts a dictionary of tensors back into a list so that it can be used
by nnunet deep supervision loss functions
Args:
tensor_dict (dict[str, torch.Tensor]): Dictionary containing
torch.Tensors. The key values must start with 'X-' where X is an
integer representing the index at which the tensor should be placed
in the output list
Returns:
list[torch.Tensor]: A list of torch.Tensors
"""
sorted_list = sorted(tensor_dict.items(), key=lambda x: int(x[0].split("-")[0]))
return [tensor for key, tensor in sorted_list]
[docs]
def get_segs_from_probs(preds: torch.Tensor, has_regions: bool = False, threshold: float = 0.5) -> torch.Tensor:
"""
Converts the nnunet model output probabilities to predicted segmentations
Args:
preds (torch.Tensor): The one hot encoded model output probabilities
with shape (batch, classes, *additional_dims). The background should be a separate class
has_regions (bool, optional): If True, predicted segmentations can be
multiple classes at once. The exception is the background class
which is assumed to be the first class (class 0). If False, each
value in predicted segmentations has only a single class. Defaults to
False.
threshold (float): When has_regions is True, this is the threshold
value used to determine whether or not an output is a part of a
class
Returns:
torch.Tensor: tensor containing the predicted segmentations as a one hot encoded
binary tensor of 64-bit integers.
"""
if has_regions:
pred_segs = preds > threshold
# Mask is the inverse of the background class. Ensures that values
# classified as background are not part of another class
mask = ~pred_segs[:, 0]
return pred_segs * mask
else:
pred_segs = preds.argmax(1)[:, None] # shape (batch, 1, additional_dims)
# one hot encode (OHE) predicted segmentations again
# WARNING: Note the '_' after scatter. scatter_ and scatter are both
# functions with different functionality. It is easy to introduce a bug
# here by using the wrong one
pred_segs_one_hot = torch.zeros(preds.shape, device=preds.device, dtype=torch.float32)
pred_segs_one_hot.scatter_(1, pred_segs, 1) # ohe -> One Hot Encoded
# convert output preds to long since it is binary
return pred_segs_one_hot.long()
[docs]
def collapse_one_hot_tensor(input: torch.Tensor, dim: int = 0) -> torch.Tensor:
"""
Collapses a one hot encoded tensor so that they are no longer one hot encoded.
Args:
input (torch.Tensor): The binary one hot encoded tensor
Returns:
torch.Tensor: Integer tensor with the specified dim collapsed
"""
return torch.argmax(input.long(), dim=dim).to(input.device)
[docs]
def get_dataset_n_voxels(source_plans: dict, n_cases: int) -> float:
"""
Determines the total number of voxels in the dataset. Used by NnunetClient to
determine the maximum batch size.
Args:
source_plans (Dict): The nnunet plans dict that is being modified
n_cases (int): The number of cases in the dataset
Returns:
float: The total number of voxels in the local client dataset
"""
# Need to determine input dimensionality
if NnunetConfig._3D_FULLRES.value in source_plans["configurations"]:
cfg = source_plans["configurations"][NnunetConfig._3D_FULLRES.value]
else:
cfg = source_plans["configurations"][NnunetConfig._2D.value]
# Get total number of voxels in dataset
image_shape = cfg["median_image_size_in_voxels"]
approx_n_voxels = float(np.prod(image_shape, dtype=np.float64) * n_cases)
return approx_n_voxels
[docs]
def prepare_loss_arg(tensor: torch.Tensor | dict[str, torch.Tensor]) -> torch.Tensor | list[torch.Tensor]:
"""
Converts pred and target tensors into the proper data type to be passed to the nnunet loss functions.
Args:
tensor (torch.Tensor | dict[str, torch.Tensor]): The input tensor
Returns:
torch.Tensor | list[torch.Tensor]: The tensor ready to be passed to the loss
function. A single tensor if not using deep supervision and a list of
tensors if deep supervision is on.
"""
# TODO: IDK why we have to make assumptions when we could just have a boolean state
if isinstance(tensor, torch.Tensor):
return tensor # If input is a tensor then no changes required
elif isinstance(tensor, dict):
if len(tensor) > 1: # Assume deep supervision is on and return a list
return convert_deep_supervision_dict_to_list(tensor)
else: # If dict has only one item, assume deep supervision is off
return list(tensor.values())[0] # return the torch.Tensor
[docs]
class nnUNetDataLoaderWrapper(DataLoader):
[docs]
def __init__(
self,
nnunet_augmenter: SingleThreadedAugmenter | NonDetMultiThreadedAugmenter | MultiThreadedAugmenter,
nnunet_config: NnunetConfig | str,
infinite: bool = False,
) -> None:
"""
Wraps nnunet dataloader classes using the pytorch dataloader to make them pytorch compatible. Also handles
some unique stuff specific to nnunet such as deep supervision and infinite dataloaders. The nnunet dataloaders
should only be used for training and validation, not final testing.
Args:
nnunet_dataloader (SingleThreadedAugmenter | NonDetMultiThreadedAugmenter | MultiThreadedAugmenter): The
dataloader used by nnunet
nnunet_config (NnUNetConfig): The nnunet config. Enum type helps ensure that nnunet config is valid
infinite (bool, optional): Whether or not to treat the dataset as infinite. The dataloaders sample data
with replacement either way. The only difference is that if set to False, a StopIteration is
generated after num_samples/batch_size steps. Defaults to False.
"""
# The augmenter is a wrapper on the nnunet dataloader
self.nnunet_augmenter = nnunet_augmenter
if isinstance(self.nnunet_augmenter, SingleThreadedAugmenter):
self.nnunet_dataloader = self.nnunet_augmenter.data_loader
else:
self.nnunet_dataloader = self.nnunet_augmenter.generator
# Figure out if dataloader is 2d or 3d
self.num_spatial_dims = NNUNET_N_SPATIAL_DIMS[NnunetConfig(nnunet_config)]
# nnUNetDataloaders store their datasets under the self.data attribute
self.dataset: nnUNetDataset = self.nnunet_dataloader._data
super().__init__(dataset=self.dataset, batch_size=self.nnunet_dataloader.batch_size)
# nnunet dataloaders are infinite by default so we have to track steps to stop iteration
self.current_step = 0
self.infinite = infinite
def __next__(self) -> tuple[torch.Tensor, torch.Tensor | dict[str, torch.Tensor]]:
if not self.infinite and self.current_step == self.__len__():
self.reset()
raise StopIteration # Raise stop iteration after epoch has completed
else:
self.current_step += 1
batch = next(self.nnunet_augmenter) # This returns a dictionary
# Note: When deep supervision is on, target is a list of ground truth
# segmentations at various spatial scales/resolutions
# nnUNet has a wrapper for loss functions to enable deep supervision
inputs: torch.Tensor = batch["data"]
targets: torch.Tensor | list[torch.Tensor] = batch["target"]
if isinstance(targets, list):
target_dict = convert_deep_supervision_list_to_dict(targets, self.num_spatial_dims)
return inputs, target_dict
elif isinstance(targets, torch.Tensor):
return inputs, targets
else:
raise TypeError(
"Was expecting the target generated by the nnunet dataloader to be a list or a torch.Tensor"
)
def __len__(self) -> int:
"""
nnunetv2 v2.5.1 hardcodes an 'epoch' as 250 steps. We could set the len to
n_samples/batch_size, but this gets complicated as nnunet models operate on
patches of the input images, and therefore can have batch sizes larger than the
dataset. We would then have epochs with only 1 step!
Here we go through the hassle of computing the ratio between the number of
voxels in a sample and the number of voxels in a patch and then using that
factor to scale n_samples. This is particularly important for training 2d
models on 3d data.
"""
sample, _, _ = self.dataset.load_case(self.nnunet_dataloader.indices[0])
n_image_voxels = np.prod(sample.shape)
n_patch_voxels = np.prod(self.nnunet_dataloader.final_patch_size)
# Scale factor is at least one to prevent shrinking the dataset. We can have a
# larger patch size sometimes because nnunet will do padding
scale = max(n_image_voxels / n_patch_voxels, 1)
# Scale n_samples and then divide by batch size to get n_steps per epoch
return round((len(self.dataset) * scale) / self.nnunet_dataloader.batch_size)
[docs]
def reset(self) -> None:
self.current_step = 0
def __iter__(self) -> DataLoader: # type: ignore
# mypy gets angry that the return type is different
return self
[docs]
def shutdown(self) -> None:
"""
The multithreaded augmenters used by nnunet need to be shutdown gracefully to
avoid errors
"""
if isinstance(self.nnunet_augmenter, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)):
self.nnunet_augmenter._finish()
else:
del self.nnunet_augmenter
[docs]
class Module2LossWrapper(_Loss):
"""Converts a nn.Module subclass to a _Loss subclass"""
def __init__(self, loss: nn.Module, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.loss = loss
[docs]
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return self.loss(pred, target)
[docs]
class StreamToLogger(io.StringIO):
[docs]
def __init__(self, logger: Logger, level: LogLevel | int) -> None:
"""
File-like stream object that redirects writes to a logger. Useful for redirecting stdout to a logger.
Args:
logger (Logger): The logger to redirect writes to
level (LogLevel): The log level at which to redirect the writes
"""
self.logger = logger
self.level = level if isinstance(level, int) else level.value
self.linebuf = "" # idk why this is needed. Got this class from stack overflow
[docs]
def write(self, buf: str) -> int:
char_count = 0
for line in buf.rstrip().splitlines():
self.logger.log(self.level, line.rstrip())
char_count += len(line.rstrip())
return char_count
[docs]
def flush(self) -> None:
pass
[docs]
class PolyLRSchedulerWrapper(_LRScheduler):
[docs]
def __init__(
self,
optimizer: torch.optim.Optimizer,
initial_lr: float,
max_steps: int,
exponent: float = 0.9,
steps_per_lr: int = 250,
) -> None:
"""
Learning rate (LR) scheduler with polynomial decay across fixed windows of size steps_per_lr.
Args:
optimizer (Optimizer): The optimizer to apply LR scheduler to.
initial_lr (float): The initial learning rate of the optimizer.
max_steps (int): The maximum total number of steps across all FL rounds.
exponent (float): Controls how quickly LR decreases over time. Higher values
lead to more rapid descent. Defaults to 0.9.
steps_per_lr (int): The number of steps per LR before decaying.
(ie 10 means the LR will be constant for 10 steps prior to being decreased to the subsequent value).
Defaults to 250 as that is the default for nnunet (decay LR once an epoch and epoch is 250 steps).
"""
self.optimizer = optimizer
self.initial_lr = initial_lr
self.max_steps = max_steps
self.exponent = exponent
self.steps_per_lr = steps_per_lr
# Number of windows with constant LR across training
self.num_windows = ceil(max_steps / self.steps_per_lr)
self._step_count: int
super().__init__(optimizer, -1, False)
# mypy incorrectly infers get_lr returns a float
# Documented issue https://github.com/pytorch/pytorch/issues/100804
[docs]
@no_type_check
def get_lr(self) -> Sequence[float]:
"""
Get the current LR of the scheduler.
Returns:
Sequence[float]: A uniform sequence of LR for each of the parameter groups in the optimizer.
"""
if self._step_count - 1 == self.max_steps + 1:
log(
WARN,
f"Current LR step of {self._step_count} reached Max Steps of {self.max_steps}. LR will remain fixed.",
)
# Subtract 1 from step count since it starts at 1 (imposed by PyTorch)
curr_step = min(self._step_count - 1, self.max_steps)
curr_window = int(curr_step / self.steps_per_lr)
new_lr = self.initial_lr * (1 - curr_window / self.num_windows) ** self.exponent
if curr_step % self.steps_per_lr == 0 and curr_step != 0 and curr_step != self.max_steps:
log(INFO, f"Decaying LR of optimizer to {new_lr} at step {curr_step}")
return [new_lr] * len(self.optimizer.param_groups)