fl4health.utils.nnunet_utils module

class Module2LossWrapper(loss, **kwargs)[source]

Bases: _Loss

Converts a nn.Module subclass to a _Loss subclass

forward(pred, target)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class NnunetConfig(value)[source]

Bases: Enum

The possible nnunet model configs as of nnunetv2 version 2.5.1. See https://github.com/MIC-DKFZ/nnUNet/tree/v2.5.1

class PolyLRSchedulerWrapper(optimizer, initial_lr, max_steps, exponent=0.9, steps_per_lr=250)[source]

Bases: _LRScheduler

__init__(optimizer, initial_lr, max_steps, exponent=0.9, steps_per_lr=250)[source]

Learning rate (LR) scheduler with polynomial decay across fixed windows of size steps_per_lr.

Parameters:
  • optimizer (Optimizer) – The optimizer to apply LR scheduler to.

  • initial_lr (float) – The initial learning rate of the optimizer.

  • max_steps (int) – The maximum total number of steps across all FL rounds.

  • exponent (float) – Controls how quickly LR decreases over time. Higher values lead to more rapid descent. Defaults to 0.9.

  • steps_per_lr (int) – The number of steps per LR before decaying. (ie 10 means the LR will be constant for 10 steps prior to being decreased to the subsequent value). Defaults to 250 as that is the default for nnunet (decay LR once an epoch and epoch is 250 steps).

get_lr()[source]

Get the current LR of the scheduler.

Returns:

A uniform sequence of LR for each of the parameter groups in the optimizer.

Return type:

Sequence[float]

class StreamToLogger(logger, level)[source]

Bases: StringIO

__init__(logger, level)[source]

File-like stream object that redirects writes to a logger. Useful for redirecting stdout to a logger.

Parameters:
  • logger (Logger) – The logger to redirect writes to

  • level (LogLevel) – The log level at which to redirect the writes

flush()[source]

Flush write buffers, if applicable.

This is not implemented for read-only and non-blocking streams.

Return type:

None

write(buf)[source]

Write string to file.

Returns the number of characters written, which is always equal to the length of the string.

Return type:

int

collapse_one_hot_tensor(input, dim=0)[source]

Collapses a one hot encoded tensor so that they are no longer one hot encoded.

Parameters:

input (torch.Tensor) – The binary one hot encoded tensor

Returns:

Integer tensor with the specified dim collapsed

Return type:

torch.Tensor

convert_deep_supervision_dict_to_list(tensor_dict)[source]

Converts a dictionary of tensors back into a list so that it can be used by nnunet deep supervision loss functions

Parameters:

tensor_dict (dict[str, torch.Tensor]) – Dictionary containing torch.Tensors. The key values must start with ‘X-’ where X is an integer representing the index at which the tensor should be placed in the output list

Returns:

A list of torch.Tensors

Return type:

list[torch.Tensor]

convert_deep_supervision_list_to_dict(tensor_list, num_spatial_dims)[source]

Converts a list of torch.Tensors to a dictionary. Names the keys for each tensor based on the spatial resolution of the tensor and its index in the list. Useful for nnUNet models with deep supervision where model outputs and targets loaded by the dataloader are lists. Assumes the spatial dimensions of the tensors are last.

Parameters:
  • tensor_list (list[torch.Tensor]) – A list of tensors, usually either nnunet model outputs or targets, to be converted into a dictionary

  • num_spatial_dims (int) – The number of spatial dimensions. Assumes the spatial dimensions are last

Returns:

A dictionary containing the tensors as

values where the keys are ‘i-XxYxZ’ where i was the tensor’s index in the list and X,Y,Z are the spatial dimensions of the tensor

Return type:

dict[str, torch.Tensor]

get_dataset_n_voxels(source_plans, n_cases)[source]

Determines the total number of voxels in the dataset. Used by NnunetClient to determine the maximum batch size.

Parameters:
  • source_plans (Dict) – The nnunet plans dict that is being modified

  • n_cases (int) – The number of cases in the dataset

Returns:

The total number of voxels in the local client dataset

Return type:

float

get_segs_from_probs(preds, has_regions=False, threshold=0.5)[source]

Converts the nnunet model output probabilities to predicted segmentations

Parameters:
  • preds (torch.Tensor) – The one hot encoded model output probabilities with shape (batch, classes, *additional_dims). The background should be a separate class

  • has_regions (bool, optional) – If True, predicted segmentations can be multiple classes at once. The exception is the background class which is assumed to be the first class (class 0). If False, each value in predicted segmentations has only a single class. Defaults to False.

  • threshold (float) – When has_regions is True, this is the threshold value used to determine whether or not an output is a part of a class

Returns:

tensor containing the predicted segmentations as a one hot encoded

binary tensor of 64-bit integers.

Return type:

torch.Tensor

class nnUNetDataLoaderWrapper(nnunet_augmenter, nnunet_config, infinite=False)[source]

Bases: DataLoader

__init__(nnunet_augmenter, nnunet_config, infinite=False)[source]

Wraps nnunet dataloader classes using the pytorch dataloader to make them pytorch compatible. Also handles some unique stuff specific to nnunet such as deep supervision and infinite dataloaders. The nnunet dataloaders should only be used for training and validation, not final testing.

Parameters:
  • nnunet_dataloader (SingleThreadedAugmenter | NonDetMultiThreadedAugmenter | MultiThreadedAugmenter) – The dataloader used by nnunet

  • nnunet_config (NnUNetConfig) – The nnunet config. Enum type helps ensure that nnunet config is valid

  • infinite (bool, optional) – Whether or not to treat the dataset as infinite. The dataloaders sample data with replacement either way. The only difference is that if set to False, a StopIteration is generated after num_samples/batch_size steps. Defaults to False.

reset()[source]
Return type:

None

shutdown()[source]

The multithreaded augmenters used by nnunet need to be shutdown gracefully to avoid errors

Return type:

None

prepare_loss_arg(tensor)[source]

Converts pred and target tensors into the proper data type to be passed to the nnunet loss functions.

Parameters:

tensor (torch.Tensor | dict[str, torch.Tensor]) – The input tensor

Returns:

The tensor ready to be passed to the loss

function. A single tensor if not using deep supervision and a list of tensors if deep supervision is on.

Return type:

torch.Tensor | list[torch.Tensor]

reload_modules(packages)[source]

Given the names of one or more packages, subpackages or modules, reloads all the modules within the scope of each package or the modules themselves if a module was specified.

Parameters:

package (Sequence[str]) – The absolute names of the packages, subpackages or modules to reload. The entire import hierarchy must be specified. Eg. ‘package.subpackage’ to reload all modules in subpackage, not just ‘subpackage’. Packages are reloaded in the order they are given

Return type:

None

set_nnunet_env(verbose=False, **kwargs)[source]

For each keyword argument name and value sets the current environment variable with the same name to that value and then reloads nnunet. Values must be strings. This is necessary because nnunet checks some environment variables on import, and therefore it must be imported or reloaded after they are set.

Return type:

None

use_default_signal_handlers(fn)[source]

This is a decorator that resets the SIGINT and SIGTERM signal handlers back to the python defaults for the execution of the method

flwr 1.9.0 overrides the default signal handlers with handlers that raise an error on any interruption or termination. Since nnunet spawns child processes which inherit these handlers, when those subprocesses are terminated (which is expected behavior), the flwr signal handlers raise an error (which we don’t want).

Flwr is expected to fix this in the next release. See the following issue: https://github.com/adap/flower/issues/3837

Return type:

Callable