fl4health.clients.nnunet_client module¶
- class NnunetClient(device, dataset_id, fold, data_identifier=None, plans_identifier=None, compile=True, always_preprocess=False, max_grad_norm=12, n_dataload_processes=None, verbose=True, metrics=None, progress_bar=False, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, client_name=None, nnunet_trainer_class=<class 'nnunetv2.training.nnUNetTrainer.nnUNetTrainer.nnUNetTrainer'>, nnunet_trainer_class_kwargs={})[source]¶
Bases:
BasicClient
- __init__(device, dataset_id, fold, data_identifier=None, plans_identifier=None, compile=True, always_preprocess=False, max_grad_norm=12, n_dataload_processes=None, verbose=True, metrics=None, progress_bar=False, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, client_name=None, nnunet_trainer_class=<class 'nnunetv2.training.nnUNetTrainer.nnUNetTrainer.nnUNetTrainer'>, nnunet_trainer_class_kwargs={})[source]¶
A client for training nnunet models. Requires the nnunet environment variables to be set. Also requires the following additional keys in the config sent from the server:
‘nnunet_plans’: (serialized dict) ‘nnunet_config’: (str)
- Parameters:
device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’ or ‘mps’
dataset_id (int) – The nnunet dataset id for the local client dataset to use for training and validation.
fold (int | str) – Which fold of the local client dataset to use for validation. nnunet defaults to 5 folds (0 to 4). Can also be set to ‘all’ to use all the data for both training and validation.
data_identifier (str | None, optional) – The nnunet data identifier prefix to use. The final data identifier will be {data_identifier}_config where ‘config’ is the nnunet config (eg. 2d, 3d_fullres, etc.). If preprocessed data already exists can be used to specify which preprocessed data to use. By default, the plans_identifier is used as the data_identifier.
plans_identifier (str | None, optional) – Specify what to save the client’s local copy of the plans file as. The client makes a local modified copy of the global source plans file sent by the server. If left as default None, the plans identifier will be set as ‘FL-plansname-000local’ where 000 is the dataset_id and plansname is the ‘plans_name’ value of the source plans file. The original plans will be saved under the source_plans_name key in the modified plans file.
compile (bool, optional) – If True, the client will jit compile the pytorch model. This requires some overhead time at the beginning of training to compile the model, but results in faster training times. Defaults to True
always_preprocess (bool, optional) – If True, will preprocess the local client dataset even if the preprocessed data already seems to exist. Defaults to False. The existence of the preprocessed data is determined by matching the provided data_identifier with that of the preprocessed data already on the client.
max_grad_norm (float, optional) – The maximum gradient norm to use for gradient clipping. Defaults to 12 which is the nnunetv2 2.5.1 default.
n_dataload_processes (int | None, optional) – The number of processes to spawn for each nnunet dataloader. If left as None we use the nnunetv2 version 2.5.1 defaults for each config
verbose (bool, optional) – If True the client will log some extra INFO logs. Defaults to False unless the log level is DEBUG or lower.
metrics (Sequence[Metric], optional) – Metrics to be computed based on the labels and predictions of the
None. (client model. Defaults to)
progress_bar (bool, optional) – Whether or not to print a progress bar to stdout for training. Defaults to False
loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.
checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.
reporters (Sequence[BaseReporter], optional) – A sequence of FL4Health reporters which the client should send data to.
nnunet_trainer_class (type[nnUNetTrainer]) – A nnUNetTrainer constructor. Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. Must match the nnunet_trainer_class passed to the NnunetServer.
nnunet_trainer_class_kwargs (dict[str, Any]) – Additional kwargs to pass to nnunet_trainer_class. Defaults to empty dictionary.
- compute_loss_and_additional_losses(preds, features, target)[source]¶
- Checks the pred and target types and computes the loss.
If device type is cuda, loss computed in mixed precision.
- Parameters:
preds (TorchPredType) – Dictionary of model output tensors indexed by name
features (dict[str, torch.Tensor]) – Not used by this subclass
target (TorchTargetType) – The targets to evaluate the predictions with. If multiple prediction tensors are given, target must be a dictionary with the same number of tensors
- Returns:
- A tuple
where the first element is the loss and the second element is an optional additional loss
- Return type:
- create_plans(config)[source]¶
Modifies the provided plans file to work with the local client dataset and then saves it to disk. Requires the local dataset_fingerprint.json to exist, the local dataset_name, plans_name, data_identifier and dataset_json.
- The following fields are modified:
plans_name
dataset_name
original_median_shape_after_transp
original_median_spacing_after_transp
configurations.{config}.data_identifier
configurations.{config}.batch_size
configurations.{config}.median_image_size_in_voxels
foreground_intensity_properties_per_channel
- empty_cache()[source]¶
Checks torch device and empties cache before training to optimize VRAM usage
- Return type:
- get_client_specific_logs(current_round, current_epoch, logging_mode)[source]¶
This function can be overridden to provide any client specific information to the basic client logging. For example, perhaps a client uses an LR scheduler and wants the LR to be logged each epoch. Called at the beginning and end of each server round or local epoch. Also called at the end of validation/testing.
- Parameters:
current_round (int | None) – The current FL round (i.e., current server round).
current_epoch (int | None) – The current epoch of local training.
logging_mode (LoggingMode) – The logging mode (Training, Validation, or Testing).
- Returns:
- A string to append to the header log string that
typically announces the current server round and current epoch at the beginning of each round or local epoch.
- list[tuple[LogLevel, str]]] | None: A list of tuples where the
first element is a LogLevel as defined in fl4health.utils. typing and the second element is a string message. Each item in the list will be logged at the end of each server round or epoch. Elements will also be logged at the end of validation/testing.
- Return type:
str | None
- get_client_specific_reports()[source]¶
This function can be overridden by an inheriting client to report additional client specific information to the wandb_reporter
- get_criterion(config)[source]¶
User defined method that returns PyTorch loss to train model.
- Parameters:
config (Config) – The config from the server.
- Raises:
NotImplementedError – To be defined in child class.
- Return type:
_Loss
- get_data_loaders(**kwargs: Any) Any ¶
User defined method that returns a PyTorch Train DataLoader and a PyTorch Validation DataLoader
- Parameters:
config (Config) – The config from the server.
- Returns:
Tuple of length 2. The client train and validation loader.
- Return type:
tuple[DataLoader, …]
- Raises:
NotImplementedError – To be defined in child class.
- get_lr_scheduler(optimizer_key, config)[source]¶
Creates an LR Scheduler similar to the nnunet default except we set max steps to the total number of steps and update every step. Initial and final LR are the same as nnunet, difference is nnunet sets max steps to num ‘epochs’, but they define an ‘epoch’ as exactly 250 steps. Therefore they update every 250 steps. Override this method to set your own LR scheduler.
- Parameters:
config (Config) – The server config. This method will look for the
- Returns:
The default nnunet LR Scheduler for nnunetv2 2.5.1
- Return type:
_LRScheduler
- get_model(config)[source]¶
User defined method that returns PyTorch model.
- Parameters:
config (Config) – The config from the server.
- Returns:
The client model.
- Return type:
nn.Module
- Raises:
NotImplementedError – To be defined in child class.
- get_optimizer(config)[source]¶
Method to be defined by user that returns the PyTorch optimizer used to train models locally Return value can be a single torch optimizer or a dictionary of string and torch optimizer. Returning multiple optimizers is useful in methods like APFL which has a different optimizer for the local and global models.
- Parameters:
config (Config) – The config sent from the server.
- Returns:
An optimizer or dictionary of optimizers to train model.
- Return type:
- Raises:
NotImplementedError – To be defined in child class.
- get_properties(**kwargs: Any) Any ¶
Return properties (train and validation dataset sample counts) of client.
- mask_data(pred, target)[source]¶
Masks the pred and target tensors according to nnunet ignore_label. The number of classes in the input tensors should be at least 3 corresponding to 2 classes for binary segmentation and 1 class which is the ignore class specified by ignore label. See nnunet documentation for more info: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/ignore_label.md
- Parameters:
pred (torch.Tensor) – The one hot encoded predicted segmentation maps with shape (batch, classes, x, y(, z))
target (torch.Tensor) – The ground truth segmentation map with shape (batch, classes, x, y(, z))
- Returns:
The masked one hot encoded predicted segmentation maps torch.Tensor: The masked target segmentation maps
- Return type:
torch.Tensor
- predict(input)[source]¶
- Generate model outputs. Overridden because nnunets output lists when
deep supervision is on so we have to reformat the output into dicts If device type is cuda, loss computed in mixed precision.
- setup_client(**kwargs: Any) Any ¶
Set dataloaders, optimizers, parameter exchangers and other attributes derived from these. Then set initialized attribute to True.
- Parameters:
config (Config) – The config from the server.
- Return type:
- shutdown()[source]¶
Shuts down the client. Involves shutting down W&B reporter if one exists.
- Return type:
- shutdown_dataloader(dataloader, dl_name=None)[source]¶
The nnunet dataloader/augmenter uses multiprocessing under the hood, so the shutdown method terminates the child processes gracefully
- train_step(input, target)[source]¶
- Given a single batch of input and target data, generate predictions, compute loss, update parameters and
optionally update metrics if they exist. (ie backprop on a single batch of data). Assumes self.model is in train mode already.
- Overrides parent to include mixed precision training (autocasting and corresponding gradient scaling)
as per the original nnUNetTrainer.
- Parameters:
input (TorchInputType) – The input to be fed into the model.
target (TorchTargetType) – The target corresponding to the input.
- Returns:
- The losses object from the train step along with
a dictionary of any predictions produced by the model.
- Return type:
Tuple[TrainingLosses, TorchPredType]
- transform_gradients(losses)[source]¶
Apply the gradient clipping performed by the default nnunet trainer. This is the default behavior for nnunet 2.5.1
- Return type:
- update_before_train(current_server_round)[source]¶
Hook method called before training with the number of current server rounds performed. NOTE: This method is called immediately AFTER the aggregated parameters are received from the server. For example, used by MOON and FENDA to save global modules after aggregation.
- update_metric_manager(preds, target, metric_manager)[source]¶
Update the metrics with preds and target. Overridden because we might need to manipulate inputs due to deep supervision
- Parameters:
preds (TorchPredType) – dictionary of model outputs
target (TorchTargetType) – the targets generated by the dataloader to evaluate the preds with
metric_manager (MetricManager) – the metric manager to update
- Return type: