fl4health.clients.nnunet_client module¶
- class NnunetClient(device, dataset_id, fold, data_identifier=None, plans_identifier=None, compile=True, always_preprocess=False, max_grad_norm=12, n_dataload_processes=None, verbose=True, metrics=None, progress_bar=False, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, client_name=None, nnunet_trainer_class=<class 'nnunetv2.training.nnUNetTrainer.nnUNetTrainer.nnUNetTrainer'>, nnunet_trainer_class_kwargs={})[source]¶
Bases:
BasicClient
- __init__(device, dataset_id, fold, data_identifier=None, plans_identifier=None, compile=True, always_preprocess=False, max_grad_norm=12, n_dataload_processes=None, verbose=True, metrics=None, progress_bar=False, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, client_name=None, nnunet_trainer_class=<class 'nnunetv2.training.nnUNetTrainer.nnUNetTrainer.nnUNetTrainer'>, nnunet_trainer_class_kwargs={})[source]¶
A client for training nnunet models. Requires the nnunet environment variables to be set. Also requires the following additional keys in the config sent from the server:
“nnunet_plans”: (serialized dict)
“nnunet_config”: (str)
- Parameters:
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often “cpu” or “cuda” or “mps”
dataset_id (int) – The nnunet dataset id for the local client dataset to use for training and validation.
fold (int | str) – Which fold of the local client dataset to use for validation. nnunet defaults to 5 folds (0 to 4). Can also be set to “all” to use all the data for both training and validation.
data_identifier (str | None, optional) – The nnunet data identifier prefix to use. The final data identifier will be
{data_identifier}_config
where “config” is the nnunet config (e.g. 2d,3d_fullres
, etc.). If preprocessed data already exists can be used to specify which preprocessed data to use. By default, theplans_identifier
is used as thedata_identifier
.plans_identifier (str | None, optional) – Specify what to save the client’s local copy of the plans file as. The client makes a local modified copy of the global source plans file sent by the server. If left as default None, the plans identifier will be set as “FL-plansname-000local” where 000 is the
dataset_id
and plansname is the “plans_name” value of the source plans file. The original plans will be saved under thesource_plans_name
key in the modified plans file.compile (bool, optional) – If set to True, the client will Just-In-Time (JIT) compile the nnUNet model and perform optimizations at the start of training. This process significantly reduces the runtime for nnUNet models, especially for larger models or long-running jobs. However, it introduces some overhead time and computation during the initial step. It is recommended to keep this option enabled. The default value is True.
always_preprocess (bool, optional) – If True, will preprocess the local client dataset even if the preprocessed data already seems to exist. The existence of the preprocessed data is determined by matching the provided
data_identifier
with that of the preprocessed data already on the client. Defaults to False.max_grad_norm (float, optional) – The maximum gradient norm to use for gradient clipping. Defaults to 12 which is the nnunetv2 2.5.1 default.
n_dataload_processes (int | None, optional) – The number of processes to spawn for each nnunet dataloader. If left as None we use the nnunetv2 version 2.5.1 defaults for each config
verbose (bool, optional) – If True the client will log some extra INFO logs. Defaults to False unless the log level is DEBUG or lower.
metrics (Sequence[Metric], optional) – Metrics to be computed based on the labels and predictions of the
None. (client model. Defaults to)
progress_bar (bool, optional) – Whether or not to print a progress bar to stdout for training. Defaults to False
loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to
LossMeterType.AVERAGE
.checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter], optional) – A sequence of FL4Health reporters which the client should send data to.
nnunet_trainer_class (type[nnUNetTrainer]) – A
nnUNetTrainer
constructor. Useful for passing customnnUNetTrainer
. Defaults to the standard nnUNetTrainer class. Must match thennunet_trainer_class
passed to theNnunetServer
.nnunet_trainer_class_kwargs (dict[str, Any]) – Additional kwargs to pass to
nnunet_trainer_class
. Defaults to empty dictionary.
- compute_loss_and_additional_losses(preds, features, target)[source]¶
Checks the pred and target types and computes the loss. If device type is cuda, loss computed in mixed precision.
- Parameters:
preds (TorchPredType) – Dictionary of model output tensors indexed by name
features (dict[str, torch.Tensor]) – Not used by this subclass
target (TorchTargetType) – The targets to evaluate the predictions with. If multiple prediction tensors are given, target must be a dictionary with the same number of tensors
- Returns:
A tuple where the first element is the loss and the second element is an optional additional loss
- Return type:
- create_plans(config)[source]¶
Modifies the provided plans file to work with the local client dataset and then saves it to disk. Requires the local
dataset_fingerprint.json
to exist, the localdataset_name
,plans_name
,data_identifier
anddataset_json
.The following fields are modified:
plans_name
dataset_name
original_median_shape_after_transp
original_median_spacing_after_transp
configurations.{config}.data_identifier
configurations.{config}.batch_size
configurations.{config}.median_image_size_in_voxels
foreground_intensity_properties_per_channel
- empty_cache()[source]¶
Checks torch device and empties cache before training to optimize VRAM usage
- Return type:
- get_client_specific_logs(current_round, current_epoch, logging_mode)[source]¶
This function can be overridden to provide any client specific information to the basic client logging. For example, perhaps a client uses an LR scheduler and wants the LR to be logged each epoch. Called at the beginning and end of each server round or local epoch. Also called at the end of validation/testing.
- Parameters:
current_round (int | None) – The current FL round (i.e., current server round).
current_epoch (int | None) – The current epoch of local training.
logging_mode (LoggingMode) – The logging mode (Training, Validation, or Testing).
- Returns:
A string to append to the header log string that typically announces the current server round and current epoch at the beginning of each round or local epoch.
A list of tuples where the first element is a LogLevel as defined in
fl4health.utils.
typing and the second element is a string message. Each item in the list will be logged at the end of each server round or epoch. Elements will also be logged at the end of validation/testing.
- Return type:
- get_client_specific_reports()[source]¶
This function can be overridden by an inheriting client to report additional client specific information to the
wandb_reporter
- get_criterion(config)[source]¶
User defined method that returns PyTorch loss to train model.
- Parameters:
config (Config) – The config from the server.
- Raises:
NotImplementedError – To be defined in child class.
- Return type:
_Loss
- get_data_loaders(**kwargs: Any) Any ¶
User defined method that returns a PyTorch Train DataLoader and a PyTorch Validation DataLoader
- Parameters:
config (Config) – The config from the server.
- Returns:
Tuple of length 2. The client train and validation loader.
- Return type:
tuple[DataLoader, …]
- Raises:
NotImplementedError – To be defined in child class.
- get_lr_scheduler(optimizer_key, config)[source]¶
Creates an LR Scheduler similar to the nnunet default except we set max steps to the total number of steps and update every step. Initial and final LR are the same as nnunet, difference is nnunet sets max steps to num “epochs”, but they define an “epoch” as exactly 250 steps. Therefore they update every 250 steps. Override this method to set your own LR scheduler.
- Parameters:
config (Config) – The server config. This method will look for the
- Returns:
The default nnunet LR Scheduler for nnunetv2 2.5.1
- Return type:
_LRScheduler
- get_model(config)[source]¶
User defined method that returns PyTorch model.
- Parameters:
config (Config) – The config from the server.
- Returns:
The client model.
- Return type:
nn.Module
- Raises:
NotImplementedError – To be defined in child class.
- get_optimizer(config)[source]¶
Method to be defined by user that returns the PyTorch optimizer used to train models locally Return value can be a single torch optimizer or a dictionary of string and torch optimizer. Returning multiple optimizers is useful in methods like APFL which has a different optimizer for the local and global models.
- Parameters:
config (Config) – The config sent from the server.
- Returns:
An optimizer or dictionary of optimizers to train model.
- Return type:
- Raises:
NotImplementedError – To be defined in child class.
- get_properties(**kwargs: Any) Any ¶
Return properties (train and validation dataset sample counts) of client.
- mask_data(pred, target)[source]¶
Masks the pred and target tensors according to nnunet
ignore_label
. The number of classes in the input tensors should be at least 3 corresponding to 2 classes for binary segmentation and 1 class which is the ignore class specified by ignore label. See nnunet documentation for more info:https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/ignore_label.md
- Parameters:
pred (torch.Tensor) – The one hot encoded predicted segmentation maps with shape
(batch, classes, x, y(, z))
target (torch.Tensor) – The ground truth segmentation map with shape
(batch, classes, x, y(, z))
- Returns:
Tuple of:
torch.Tensor: The masked one hot encoded predicted segmentation maps
torch.Tensor: The masked target segmentation maps
- Return type:
tuple[torch.Tensor, torch.Tensor]
- predict(input)[source]¶
Generate model outputs. Overridden because nnunets outputs lists when deep supervision is on so we have to reformat the output into dicts.
Additionally if device type is cuda, loss computed in mixed precision.
- setup_client(**kwargs: Any) Any ¶
Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. Then set initialized attribute to True.
- Parameters:
config (Config) – The config from the server.
- Return type:
- shutdown()[source]¶
Shuts down the client. Involves shutting down W&B reporter if one exists.
- Return type:
- shutdown_dataloader(dataloader, dl_name=None)[source]¶
The nnunet dataloader/augmenter uses multiprocessing under the hood, so the shutdown method terminates the child processes gracefully
- train_step(input, target)[source]¶
Given a single batch of input and target data, generate predictions, compute loss, update parameters and optionally update metrics if they exist. (i.e. backprop on a single batch of data). Assumes
self.model
is in train mode already.Overrides parent to include mixed precision training (autocasting and corresponding gradient scaling) as per the original
nnUNetTrainer
.- Parameters:
input (TorchInputType) – The input to be fed into the model.
target (TorchTargetType) – The target corresponding to the input.
- Returns:
The losses object from the train step along with a dictionary of any predictions produced by the model.
- Return type:
Tuple[TrainingLosses, TorchPredType]
- transform_gradients(losses)[source]¶
Apply the gradient clipping performed by the default nnunet trainer. This is the default behavior for nnunet 2.5.1
- Parameters:
losses (TrainingLosses) – Not used for this transformation.
- Return type:
- update_before_train(current_server_round)[source]¶
Hook method called before training with the number of current server rounds performed. NOTE: This method is called immediately AFTER the aggregated parameters are received from the server. For example, used by MOON and FENDA to save global modules after aggregation.
- update_metric_manager(preds, target, metric_manager)[source]¶
Update the metrics with preds and target. Overridden because we might need to manipulate inputs due to deep supervision
- Parameters:
preds (TorchPredType) – dictionary of model outputs
target (TorchTargetType) – the targets generated by the dataloader to evaluate the preds with
metric_manager (MetricManager) – the metric manager to update
- Return type: