fl4health.clients.nnunet_client module

class NnunetClient(device, dataset_id, fold, data_identifier=None, plans_identifier=None, compile=True, always_preprocess=False, max_grad_norm=12, n_dataload_processes=None, verbose=True, metrics=None, progress_bar=False, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, client_name=None, nnunet_trainer_class=<class 'nnunetv2.training.nnUNetTrainer.nnUNetTrainer.nnUNetTrainer'>, nnunet_trainer_class_kwargs={})[source]

Bases: BasicClient

__init__(device, dataset_id, fold, data_identifier=None, plans_identifier=None, compile=True, always_preprocess=False, max_grad_norm=12, n_dataload_processes=None, verbose=True, metrics=None, progress_bar=False, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, client_name=None, nnunet_trainer_class=<class 'nnunetv2.training.nnUNetTrainer.nnUNetTrainer.nnUNetTrainer'>, nnunet_trainer_class_kwargs={})[source]

A client for training nnunet models. Requires the nnunet environment variables to be set. Also requires the following additional keys in the config sent from the server:

‘nnunet_plans’: (serialized dict) ‘nnunet_config’: (str)

Parameters:
  • device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’ or ‘mps’

  • dataset_id (int) – The nnunet dataset id for the local client dataset to use for training and validation.

  • fold (int | str) – Which fold of the local client dataset to use for validation. nnunet defaults to 5 folds (0 to 4). Can also be set to ‘all’ to use all the data for both training and validation.

  • data_identifier (str | None, optional) – The nnunet data identifier prefix to use. The final data identifier will be {data_identifier}_config where ‘config’ is the nnunet config (eg. 2d, 3d_fullres, etc.). If preprocessed data already exists can be used to specify which preprocessed data to use. By default, the plans_identifier is used as the data_identifier.

  • plans_identifier (str | None, optional) – Specify what to save the client’s local copy of the plans file as. The client makes a local modified copy of the global source plans file sent by the server. If left as default None, the plans identifier will be set as ‘FL-plansname-000local’ where 000 is the dataset_id and plansname is the ‘plans_name’ value of the source plans file. The original plans will be saved under the source_plans_name key in the modified plans file.

  • compile (bool, optional) – If True, the client will jit compile the pytorch model. This requires some overhead time at the beginning of training to compile the model, but results in faster training times. Defaults to True

  • always_preprocess (bool, optional) – If True, will preprocess the local client dataset even if the preprocessed data already seems to exist. Defaults to False. The existence of the preprocessed data is determined by matching the provided data_identifier with that of the preprocessed data already on the client.

  • max_grad_norm (float, optional) – The maximum gradient norm to use for gradient clipping. Defaults to 12 which is the nnunetv2 2.5.1 default.

  • n_dataload_processes (int | None, optional) – The number of processes to spawn for each nnunet dataloader. If left as None we use the nnunetv2 version 2.5.1 defaults for each config

  • verbose (bool, optional) – If True the client will log some extra INFO logs. Defaults to False unless the log level is DEBUG or lower.

  • metrics (Sequence[Metric], optional) – Metrics to be computed based on the labels and predictions of the

  • None. (client model. Defaults to)

  • progress_bar (bool, optional) – Whether or not to print a progress bar to stdout for training. Defaults to False

  • loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.

  • checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.

  • reporters (Sequence[BaseReporter], optional) – A sequence of FL4Health reporters which the client should send data to.

  • nnunet_trainer_class (type[nnUNetTrainer]) – A nnUNetTrainer constructor. Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. Must match the nnunet_trainer_class passed to the NnunetServer.

  • nnunet_trainer_class_kwargs (dict[str, Any]) – Additional kwargs to pass to nnunet_trainer_class. Defaults to empty dictionary.

compute_loss_and_additional_losses(preds, features, target)[source]
Checks the pred and target types and computes the loss.

If device type is cuda, loss computed in mixed precision.

Parameters:
  • preds (TorchPredType) – Dictionary of model output tensors indexed by name

  • features (dict[str, torch.Tensor]) – Not used by this subclass

  • target (TorchTargetType) – The targets to evaluate the predictions with. If multiple prediction tensors are given, target must be a dictionary with the same number of tensors

Returns:

A tuple

where the first element is the loss and the second element is an optional additional loss

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor] | None]

create_plans(config)[source]

Modifies the provided plans file to work with the local client dataset and then saves it to disk. Requires the local dataset_fingerprint.json to exist, the local dataset_name, plans_name, data_identifier and dataset_json.

The following fields are modified:
  • plans_name

  • dataset_name

  • original_median_shape_after_transp

  • original_median_spacing_after_transp

  • configurations.{config}.data_identifier

  • configurations.{config}.batch_size

  • configurations.{config}.median_image_size_in_voxels

  • foreground_intensity_properties_per_channel

Parameters:

config (Config) – The config provided by the server. Expects the ‘nnunet_plans’ key with a pickled dictionary as the value

Returns:

The modified nnunet plans for the client

Return type:

dict[str, Any]

empty_cache()[source]

Checks torch device and empties cache before training to optimize VRAM usage

Return type:

None

get_client_specific_logs(current_round, current_epoch, logging_mode)[source]

This function can be overridden to provide any client specific information to the basic client logging. For example, perhaps a client uses an LR scheduler and wants the LR to be logged each epoch. Called at the beginning and end of each server round or local epoch. Also called at the end of validation/testing.

Parameters:
  • current_round (int | None) – The current FL round (i.e., current server round).

  • current_epoch (int | None) – The current epoch of local training.

  • logging_mode (LoggingMode) – The logging mode (Training, Validation, or Testing).

Returns:

A string to append to the header log string that

typically announces the current server round and current epoch at the beginning of each round or local epoch.

list[tuple[LogLevel, str]]] | None: A list of tuples where the

first element is a LogLevel as defined in fl4health.utils. typing and the second element is a string message. Each item in the list will be logged at the end of each server round or epoch. Elements will also be logged at the end of validation/testing.

Return type:

str | None

get_client_specific_reports()[source]

This function can be overridden by an inheriting client to report additional client specific information to the wandb_reporter

Returns:

A dictionary of things to report

Return type:

dict[str, Any]

get_criterion(config)[source]

User defined method that returns PyTorch loss to train model.

Parameters:

config (Config) – The config from the server.

Raises:

NotImplementedError – To be defined in child class.

Return type:

_Loss

get_data_loaders(**kwargs: Any) Any

User defined method that returns a PyTorch Train DataLoader and a PyTorch Validation DataLoader

Parameters:

config (Config) – The config from the server.

Returns:

Tuple of length 2. The client train and validation loader.

Return type:

tuple[DataLoader, …]

Raises:

NotImplementedError – To be defined in child class.

get_lr_scheduler(optimizer_key, config)[source]

Creates an LR Scheduler similar to the nnunet default except we set max steps to the total number of steps and update every step. Initial and final LR are the same as nnunet, difference is nnunet sets max steps to num ‘epochs’, but they define an ‘epoch’ as exactly 250 steps. Therefore they update every 250 steps. Override this method to set your own LR scheduler.

Parameters:

config (Config) – The server config. This method will look for the

Returns:

The default nnunet LR Scheduler for nnunetv2 2.5.1

Return type:

_LRScheduler

get_model(config)[source]

User defined method that returns PyTorch model.

Parameters:

config (Config) – The config from the server.

Returns:

The client model.

Return type:

nn.Module

Raises:

NotImplementedError – To be defined in child class.

get_optimizer(config)[source]

Method to be defined by user that returns the PyTorch optimizer used to train models locally Return value can be a single torch optimizer or a dictionary of string and torch optimizer. Returning multiple optimizers is useful in methods like APFL which has a different optimizer for the local and global models.

Parameters:

config (Config) – The config sent from the server.

Returns:

An optimizer or dictionary of optimizers to train model.

Return type:

Optimizer | dict[str, Optimizer]

Raises:

NotImplementedError – To be defined in child class.

get_properties(**kwargs: Any) Any

Return properties (train and validation dataset sample counts) of client.

Parameters:

config (Config) – The config from the server.

Returns:

A dictionary with two entries corresponding to the sample counts in

the train and validation set.

Return type:

dict[str, Scalar]

mask_data(pred, target)[source]

Masks the pred and target tensors according to nnunet ignore_label. The number of classes in the input tensors should be at least 3 corresponding to 2 classes for binary segmentation and 1 class which is the ignore class specified by ignore label. See nnunet documentation for more info: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/ignore_label.md

Parameters:
  • pred (torch.Tensor) – The one hot encoded predicted segmentation maps with shape (batch, classes, x, y(, z))

  • target (torch.Tensor) – The ground truth segmentation map with shape (batch, classes, x, y(, z))

Returns:

The masked one hot encoded predicted segmentation maps torch.Tensor: The masked target segmentation maps

Return type:

torch.Tensor

maybe_extract_fingerprint(**kwargs: Any) Any
Return type:

Any

maybe_preprocess(**kwargs: Any) Any
Return type:

Any

predict(input)[source]
Generate model outputs. Overridden because nnunets output lists when

deep supervision is on so we have to reformat the output into dicts If device type is cuda, loss computed in mixed precision.

Parameters:

input (TorchInputType) – The model inputs

Returns:

A tuple in which the first element model outputs indexed by name. The second element is unused by this subclass and therefore is always an empty dict

Return type:

tuple[TorchPredType, dict[str, torch.Tensor]]

setup_client(**kwargs: Any) Any

Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. Then set initialized attribute to True.

Parameters:

config (Config) – The config from the server.

Return type:

Any

shutdown()[source]

Shuts down the client. Involves shutting down W&B reporter if one exists.

Return type:

None

shutdown_dataloader(dataloader, dl_name=None)[source]

The nnunet dataloader/augmenter uses multiprocessing under the hood, so the shutdown method terminates the child processes gracefully

Parameters:
  • dataloader (DataLoader) – The dataloader to shutdown

  • dl_name (str | None) – A string that identifies the dataloader to shutdown. Used for logging purposes. Defaults to None

Return type:

None

train_step(input, target)[source]
Given a single batch of input and target data, generate predictions, compute loss, update parameters and

optionally update metrics if they exist. (ie backprop on a single batch of data). Assumes self.model is in train mode already.

Overrides parent to include mixed precision training (autocasting and corresponding gradient scaling)

as per the original nnUNetTrainer.

Parameters:
  • input (TorchInputType) – The input to be fed into the model.

  • target (TorchTargetType) – The target corresponding to the input.

Returns:

The losses object from the train step along with

a dictionary of any predictions produced by the model.

Return type:

Tuple[TrainingLosses, TorchPredType]

transform_gradients(losses)[source]

Apply the gradient clipping performed by the default nnunet trainer. This is the default behavior for nnunet 2.5.1

Return type:

None

update_before_train(current_server_round)[source]

Hook method called before training with the number of current server rounds performed. NOTE: This method is called immediately AFTER the aggregated parameters are received from the server. For example, used by MOON and FENDA to save global modules after aggregation.

Parameters:

current_server_round (int) – The number of current server round.

Return type:

None

update_metric_manager(preds, target, metric_manager)[source]

Update the metrics with preds and target. Overridden because we might need to manipulate inputs due to deep supervision

Parameters:
  • preds (TorchPredType) – dictionary of model outputs

  • target (TorchTargetType) – the targets generated by the dataloader to evaluate the preds with

  • metric_manager (MetricManager) – the metric manager to update

Return type:

None