fl4health.clients.constrained_fenda_client module

class ConstrainedFendaClient(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, loss_container=None)[source]

Bases: FendaClient

__init__(data_path, metrics, device, loss_meter_type=LossMeterType.AVERAGE, checkpoint_and_state_module=None, reporters=None, progress_bar=False, client_name=None, loss_container=None)[source]

This class extends the functionality of FENDA training to include various kinds of constraints applied during the client-side training of FENDA models.

Parameters:
  • data_path (Path) – path to the data to be used to load the data for client-side training

  • metrics (Sequence[Metric]) – Metrics to be computed based on the labels and predictions of the client model

  • device (torch.device) – Device indicator for where to send the model, batches, labels etc. Often ‘cpu’ or ‘cuda’

  • loss_meter_type (LossMeterType, optional) – Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE.

  • checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional) – A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None.

  • reporters (Sequence[BaseReporter] | None, optional) – A sequence of FL4Health reporters which the client should send data to. Defaults to None.

  • progress_bar (bool, optional) – Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False

  • client_name (str | None, optional) – An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None.

  • loss_container (ConstrainedFendaLossContainer | None, optional) – Configuration that determines which losses will be applied during FENDA training. Defaults to None.

compute_evaluation_loss(preds, features, target)[source]

Computes evaluation loss given predictions of the model and ground truth data. Optionally computes additional loss components such as cosine_similarity_loss, contrastive_loss and perfcl_loss based on client attributes set from server config.

Parameters:
  • preds (dict[str, torch.Tensor]) – Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics.

  • features (dict[str, Tensor]) – (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.

  • target (Tensor | dict[str, Tensor]) – (torch.Tensor): Ground truth data to evaluate predictions against.

Returns:

an instance of EvaluationLosses containing checkpoint loss and additional losses

indexed by name. Additional losses may include cosine_similarity_loss, contrastive_loss and perfcl_loss.

Return type:

EvaluationLosses

compute_loss_and_additional_losses(preds, features, target)[source]

Computes the loss and any additional losses given predictions of the model and ground truth data. For FENDA, the loss is the total loss and the additional losses are the loss, total loss and, based on client attributes set from server config, cosine similarity loss, contrastive loss and perfcl losses.

Parameters:
  • preds (dict[str, torch.Tensor]) – Prediction(s) of the model(s) indexed by name.

  • features (dict[str, torch.Tensor]) – Feature(s) of the model(s) indexed by name.

  • target (torch.Tensor) – Ground truth data to evaluate predictions against.

Returns:

  • The tensor for the total loss

  • A dictionary with loss, total_loss and, based on client attributes set from server config, also

    cos_sim_loss, contrastive_loss, contrastive_loss_minimize and contrastive_loss_minimize keys and their respective calculated values.

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor]]; A tuple with

get_parameter_exchanger(config)[source]

Returns Full Parameter Exchangers. Subclasses that require custom Parameter Exchangers can override this.

Parameters:

config (Config) – The config from server.

Returns:

Used to exchange parameters between server and client.

Return type:

ParameterExchanger

predict(input)[source]

Computes the prediction(s) and features of the model(s) given the input.

Parameters:
  • input (TorchInputType) – Inputs to be fed into the model. TorchInputType is simply an alias

  • dict[str (for the union of torch.Tensor and)

  • torch.Tensor].

Returns:

A tuple in which the first element contains predictions indexed by name and the second element contains intermediate activations index by name. Specifically the features of the model, features of the global model and features of the old model are returned. All predictions included in dictionary will be used to compute metrics.

Return type:

tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]

update_after_train(local_steps, loss_dict, config)[source]

This function is called after client-side training concludes. If a contrastive or PerFCL loss function has been defined, it is used to save the local and global feature extraction weights/modules to be used in the next round of client-side training.

Parameters:
  • local_steps (int) – Number of steps performed during training

  • loss_dict (dict[str, float]) – Losses computed during training.

Return type:

None

update_before_train(current_server_round)[source]

This function is called prior to the start of client-side training, but after the server parameters have be received and injected into the model. If a PerFCL loss function has been defined, it is used to save the aggregated global feature extractor weights/module representing the initial state of this module BEFORE this iteration of client-side training but AFTER server-side aggregation.

Parameters:

current_server_round (int) – Current server round being performed.

Return type:

None