fl4health.utils.nnunet_utils module¶
- class Module2LossWrapper(loss, **kwargs)[source]¶
Bases:
_Loss
Converts a nn.Module subclass to a _Loss subclass
- forward(pred, target)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class NnunetConfig(value)[source]¶
Bases:
Enum
The possible nnunet model configs as of nnunetv2 version 2.5.1. See https://github.com/MIC-DKFZ/nnUNet/tree/v2.5.1
- class PolyLRSchedulerWrapper(optimizer, initial_lr, max_steps, exponent=0.9, steps_per_lr=250)[source]¶
Bases:
_LRScheduler
- __init__(optimizer, initial_lr, max_steps, exponent=0.9, steps_per_lr=250)[source]¶
Learning rate (LR) scheduler with polynomial decay across fixed windows of size steps_per_lr.
- Parameters:
optimizer (Optimizer) – The optimizer to apply LR scheduler to.
initial_lr (float) – The initial learning rate of the optimizer.
max_steps (int) – The maximum total number of steps across all FL rounds.
exponent (float) – Controls how quickly LR decreases over time. Higher values lead to more rapid descent. Defaults to 0.9.
steps_per_lr (int) – The number of steps per LR before decaying. (ie 10 means the LR will be constant for 10 steps prior to being decreased to the subsequent value). Defaults to 250 as that is the default for nnunet (decay LR once an epoch and epoch is 250 steps).
- class StreamToLogger(logger, level)[source]¶
Bases:
StringIO
- __init__(logger, level)[source]¶
File-like stream object that redirects writes to a logger. Useful for redirecting stdout to a logger.
- Parameters:
logger (Logger) – The logger to redirect writes to
level (LogLevel) – The log level at which to redirect the writes
- collapse_one_hot_tensor(input, dim=0)[source]¶
Collapses a one hot encoded tensor so that they are no longer one hot encoded.
- Parameters:
input (torch.Tensor) – The binary one hot encoded tensor
- Returns:
Integer tensor with the specified dim collapsed
- Return type:
torch.Tensor
- convert_deep_supervision_dict_to_list(tensor_dict)[source]¶
Converts a dictionary of tensors back into a list so that it can be used by nnunet deep supervision loss functions
- convert_deep_supervision_list_to_dict(tensor_list, num_spatial_dims)[source]¶
Converts a list of torch.Tensors to a dictionary. Names the keys for each tensor based on the spatial resolution of the tensor and its index in the list. Useful for nnUNet models with deep supervision where model outputs and targets loaded by the dataloader are lists. Assumes the spatial dimensions of the tensors are last.
- Parameters:
- Returns:
- A dictionary containing the tensors as
values where the keys are ‘i-XxYxZ’ where i was the tensor’s index in the list and X,Y,Z are the spatial dimensions of the tensor
- Return type:
- get_dataset_n_voxels(source_plans, n_cases)[source]¶
Determines the total number of voxels in the dataset. Used by NnunetClient to determine the maximum batch size.
- get_segs_from_probs(preds, has_regions=False, threshold=0.5)[source]¶
Converts the nnunet model output probabilities to predicted segmentations
- Parameters:
preds (torch.Tensor) – The one hot encoded model output probabilities with shape (batch, classes, *additional_dims). The background should be a separate class
has_regions (bool, optional) – If True, predicted segmentations can be multiple classes at once. The exception is the background class which is assumed to be the first class (class 0). If False, each value in predicted segmentations has only a single class. Defaults to False.
threshold (float) – When has_regions is True, this is the threshold value used to determine whether or not an output is a part of a class
- Returns:
- tensor containing the predicted segmentations as a one hot encoded
binary tensor of 64-bit integers.
- Return type:
torch.Tensor
- class nnUNetDataLoaderWrapper(nnunet_augmenter, nnunet_config, infinite=False)[source]¶
Bases:
DataLoader
- __init__(nnunet_augmenter, nnunet_config, infinite=False)[source]¶
Wraps nnunet dataloader classes using the pytorch dataloader to make them pytorch compatible. Also handles some unique stuff specific to nnunet such as deep supervision and infinite dataloaders. The nnunet dataloaders should only be used for training and validation, not final testing.
- Parameters:
nnunet_dataloader (SingleThreadedAugmenter | NonDetMultiThreadedAugmenter | MultiThreadedAugmenter) – The dataloader used by nnunet
nnunet_config (NnUNetConfig) – The nnunet config. Enum type helps ensure that nnunet config is valid
infinite (bool, optional) – Whether or not to treat the dataset as infinite. The dataloaders sample data with replacement either way. The only difference is that if set to False, a StopIteration is generated after num_samples/batch_size steps. Defaults to False.
- prepare_loss_arg(tensor)[source]¶
Converts pred and target tensors into the proper data type to be passed to the nnunet loss functions.
- reload_modules(packages)[source]¶
Given the names of one or more packages, subpackages or modules, reloads all the modules within the scope of each package or the modules themselves if a module was specified.
- Parameters:
package (Sequence[str]) – The absolute names of the packages, subpackages or modules to reload. The entire import hierarchy must be specified. Eg. ‘package.subpackage’ to reload all modules in subpackage, not just ‘subpackage’. Packages are reloaded in the order they are given
- Return type:
- set_nnunet_env(verbose=False, **kwargs)[source]¶
For each keyword argument name and value sets the current environment variable with the same name to that value and then reloads nnunet. Values must be strings. This is necessary because nnunet checks some environment variables on import, and therefore it must be imported or reloaded after they are set.
- Return type:
- use_default_signal_handlers(fn)[source]¶
This is a decorator that resets the SIGINT and SIGTERM signal handlers back to the python defaults for the execution of the method
flwr 1.9.0 overrides the default signal handlers with handlers that raise an error on any interruption or termination. Since nnunet spawns child processes which inherit these handlers, when those subprocesses are terminated (which is expected behavior), the flwr signal handlers raise an error (which we don’t want).
Flwr is expected to fix this in the next release. See the following issue: https://github.com/adap/flower/issues/3837
- Return type: