fl4health.utils.nnunet_utils module¶
- class Module2LossWrapper(loss, **kwargs)[source]¶
Bases:
_Loss
Converts a `` nn.Module`` subclass to a
_Loss
subclass- forward(pred, target)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class NnunetConfig(value)[source]¶
Bases:
Enum
The possible nnunet model configs as of nnunetv2 version 2.5.1.
- class PolyLRSchedulerWrapper(optimizer, initial_lr, max_steps, exponent=0.9, steps_per_lr=250)[source]¶
Bases:
_LRScheduler
- __init__(optimizer, initial_lr, max_steps, exponent=0.9, steps_per_lr=250)[source]¶
Learning rate (LR) scheduler with polynomial decay across fixed windows of size
steps_per_lr
.- Parameters:
optimizer (Optimizer) – The optimizer to apply LR scheduler to.
initial_lr (float) – The initial learning rate of the optimizer.
max_steps (int) – The maximum total number of steps across all FL rounds.
exponent (float) – Controls how quickly LR decreases over time. Higher values lead to more rapid descent. Defaults to 0.9.
steps_per_lr (int) – The number of steps per LR before decaying. (ie 10 means the LR will be constant for 10 steps prior to being decreased to the subsequent value). Defaults to 250 as that is the default for nnunet (decay LR once an epoch and epoch is 250 steps).
- class StreamToLogger(logger, level)[source]¶
Bases:
StringIO
- __init__(logger, level)[source]¶
File-like stream object that redirects writes to a logger. Useful for redirecting stdout to a logger.
- Parameters:
logger (Logger) – The logger to redirect writes to
level (LogLevel) – The log level at which to redirect the writes
- collapse_one_hot_tensor(input, dim=0)[source]¶
Collapses a one hot encoded tensor so that they are no longer one hot encoded.
- Parameters:
input (torch.Tensor) – The binary one hot encoded tensor
- Returns:
Integer tensor with the specified dim collapsed
- Return type:
torch.Tensor
- convert_deep_supervision_dict_to_list(tensor_dict)[source]¶
Converts a dictionary of tensors back into a list so that it can be used by nnunet deep supervision loss functions
- convert_deep_supervision_list_to_dict(tensor_list, num_spatial_dims)[source]¶
Converts a list of
torch.Tensors
to a dictionary. Names the keys for each tensor based on the spatial resolution of the tensor and its index in the list. Useful for nnUNet models with deep supervision where model outputs and targets loaded by the dataloader are lists. Assumes the spatial dimensions of the tensors are last.- Parameters:
- Returns:
A dictionary containing the tensors as values where the keys are ‘i-XxYxZ’ where i was the tensor’s index in the list and X,Y,Z are the spatial dimensions of the tensor
- Return type:
- get_dataset_n_voxels(source_plans, n_cases)[source]¶
Determines the total number of voxels in the dataset. Used by
NnunetClient
to determine the maximum batch size.
- get_segs_from_probs(preds, has_regions=False, threshold=0.5)[source]¶
Converts the nnunet model output probabilities to predicted segmentations
- Parameters:
preds (torch.Tensor) – The one hot encoded model output probabilities with shape (batch, classes, *additional_dims). The background should be a separate class
has_regions (bool, optional) – If True, predicted segmentations can be multiple classes at once. The exception is the background class which is assumed to be the first class (class 0). If False, each value in predicted segmentations has only a single class. Defaults to False.
threshold (float) – When
has_regions
is True, this is the threshold value used to determine whether or not an output is a part of a class
- Returns:
tensor containing the predicted segmentations as a one hot encoded binary tensor of 64-bit integers.
- Return type:
torch.Tensor
- class nnUNetDataLoaderWrapper(nnunet_augmenter, nnunet_config, infinite=False, set_len=None, ref_image_shape=None)[source]¶
Bases:
DataLoader
- __init__(nnunet_augmenter, nnunet_config, infinite=False, set_len=None, ref_image_shape=None)[source]¶
Wraps nnunet dataloader classes using the pytorch dataloader to make them pytorch compatible. Also handles some unique stuff specific to nnunet such as deep supervision and infinite dataloaders. The nnunet dataloaders should only be used for training and validation, not final testing.
- Parameters:
nnunet_dataloader (SingleThreadedAugmenter | NonDetMultiThreadedAugmenter | MultiThreadedAugmenter) – The dataloader used by nnunet
nnunet_config (NnUNetConfig) – The nnunet config. Enum type helps ensure that nnunet config is valid
infinite (bool, optional) – Whether or not to treat the dataset as infinite. The dataloaders sample data with replacement either way. The only difference is that if set to False, a
StopIteration
is generated afternum_samples
/batch_size
steps. Defaults to False.set_len (int | None) – If specified overrides the dataloaders estimate of its own length with the provided value. A
StopIteration
will be raised afterset_len
steps. If not specified the length is determined by scaling the number of samples by the ratio of image size to the networks input patch size.ref_image_shape (Sequence | None) – The image shape to use when computing the scaling factor used in determining the length of the dataloader. Should be representative of the median or average image size in the data set. If not specified a random image is loaded and its shape is used in the calculation of the scaling factor.
- prepare_loss_arg(tensor)[source]¶
Converts pred and target tensors into the proper data type to be passed to the nnunet loss functions.
- reload_modules(packages)[source]¶
Given the names of one or more packages, subpackages or modules, reloads all the modules within the scope of each package or the modules themselves if a module was specified.
- Parameters:
package (Sequence[str]) – The absolute names of the packages, subpackages or modules to reload. The entire import hierarchy must be specified. Eg.
package.subpackage
to reload all modules in subpackage, not justsubpackage
. Packages are reloaded in the order they are given.- Return type:
- set_nnunet_env(verbose=False, **kwargs)[source]¶
For each keyword argument name and value sets the current environment variable with the same name to that value and then reloads nnunet. Values must be strings. This is necessary because nnunet checks some environment variables on import, and therefore it must be imported or reloaded after they are set.
- Return type:
- use_default_signal_handlers(fn)[source]¶
This is a decorator that resets the
SIGINT
andSIGTERM
signal handlers back to the python defaults for the execution of the methodflwr 1.9.0 overrides the default signal handlers with handlers that raise an error on any interruption or termination. Since nnunet spawns child processes which inherit these handlers, when those subprocesses are terminated (which is expected behavior), the flwr signal handlers raise an error (which we don’t want).
Flwr is expected to fix this in the next release. See the following issue:
https://github.com/adap/flower/issues/3837
- Return type: