fl4health.utils.client module

check_if_batch_is_empty_and_verify_input(input)[source]

This function checks whether the provided batch (input) is empty. If the input is a dictionary of inputs, it first verifies that the length of all inputs is the same, then checks if they are non-empty. NOTE: This function assumes the input is BATCH FIRST

Parameters:
  • input (TorchInputType) – Input batch. input can be of type torch.Tensor or dict[str, torch.Tensor], and in the

  • case (latter)

  • zero. (the batch is considered to be empty if all tensors in the dictionary have length)

Raises:
  • TypeError – Raised if input is not of type torch.Tensor or dict[str, torch.Tensor].

  • ValueError – Raised if input has type dict[str, torch.Tensor] and not all tensors within the dictionary have the same size.

Returns:

True if input is an empty batch.

Return type:

bool

clone_and_freeze_model(model)[source]

Creates a clone of the model with frozen weights to be used in loss calculations so the original model is preserved in its current state.

Parameters:

model (nn.Module) – Model to clone and freeze

Returns:

Cloned and frozen model

Return type:

nn.Module

fold_loss_dict_into_metrics(metrics, loss_dict, logging_mode)[source]
Return type:

None

maybe_progress_bar(iterable, display_progress_bar)[source]

Used to print progress bars during client training and validation. If self.progress_bar is false, just returns the original input iterable without modifying it. :type iterable: Iterable :param iterable: The iterable to wrap :type iterable: Iterable

Returns:

an iterator which acts exactly like the original

iterable, but prints a dynamically updating progress bar every time a value is requested. Or the original iterable if self.progress_bar is False

Return type:

Iterable

move_data_to_device(data, device)[source]

_summary_

Parameters:
  • data (T) – The data to move to self.device. Can be a TorchInputType or a TorchTargetType

  • device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’

Raises:

TypeError – Raised if data is not one of the types specified by TorchInputType or TorchTargetType

Returns:

The data argument except now it’s been moved to self.device

Return type:

T

set_pack_losses_with_val_metrics(config)[source]
Return type:

bool