fl4health.checkpointing.opacus_checkpointer module

class BestLossOpacusCheckpointer(checkpoint_dir, checkpoint_name)[source]

Bases: OpacusCheckpointer

__init__(checkpoint_dir, checkpoint_name)[source]

This checkpointer only uses the loss value provided to the maybe_checkpoint function to determine whether a checkpoint should be save. We are always attempting to minimize the loss. So maximize is always set to false

Parameters:
  • checkpoint_dir (str) – Directory to which the model is saved. This directory should already exist. The checkpointer will not create it if it does not.

  • checkpoint_name (str) – Name of the checkpoint to be saved.

maybe_checkpoint(model, loss, metrics)[source]

Overriding the checkpointing strategy of the FunctionTorchCheckpointer to save model state dictionaries instead of using the torch.save workflow.

Parameters:
  • model (nn.Module) – Model to be potentially saved (should be an Opacus wrapped model)

  • loss (float) – Loss value associated with the model to be used in checkpointing decisions.

  • metrics (dict[str, Scalar]) – Metrics associated with the model to be used in checkpointing decisions.

Return type:

None

class LatestOpacusCheckpointer(checkpoint_dir, checkpoint_name)[source]

Bases: OpacusCheckpointer

__init__(checkpoint_dir, checkpoint_name)[source]

This class implements a checkpointer that always saves the model state when called. It uses a placeholder scoring function and maximize argument.

Parameters:
  • checkpoint_dir (str) – Directory to save checkpoint state to

  • checkpoint_name (str) – Name of the file to which state is to be saved to.

maybe_checkpoint(model, loss, _)[source]

Overriding the checkpointing strategy of the FunctionTorchCheckpointer to save model state dictionaries instead of using the torch.save workflow.

Parameters:
  • model (nn.Module) – Model to be potentially saved (should be an Opacus wrapped model)

  • loss (float) – Loss value associated with the model to be used in checkpointing decisions.

  • metrics (dict[str, Scalar]) – Metrics associated with the model to be used in checkpointing decisions.

Return type:

None

class OpacusCheckpointer(checkpoint_dir, checkpoint_name, checkpoint_score_function, maximize=False)[source]

Bases: FunctionTorchModuleCheckpointer

This is a specific type of checkpointer to be used in saving models trained using Opacus for differential privacy. Certain layers within Opacus wrapped models do not interact well with torch.save functionality. This checkpointer fixes this issue.

load_best_checkpoint_into_model(target_model, target_is_grad_sample_module=False)[source]

State dictionary loading requires a model to be provided (unlike the torch.save mechanism). So we define this function, which requires the user to provide a model into which the state dictionary is to be loaded.

Parameters:
  • target_model (nn.Module) – Target model for loading state into.

  • target_is_grad_sample_module (bool, optional) – Whether the target_model that the state_dict is being loaded into is an Opacus module or just a vanilla Pytorch module. Defaults to False.

Return type:

None

load_checkpoint(path_to_checkpoint=None)[source]

Checkpointer with the option to either specify a checkpoint path or fall back on the internal path of the checkpointer. The flexibility to specify a load path is useful, for example, if you are not overwriting checkpoints when saving and need to load a specific past checkpoint for whatever reason.

Parameters:

path_to_checkpoint (str | None, optional) – If provided, the checkpoint will be loaded from this path. If not specified, the checkpointer will load from self.checkpoint_path. Defaults to None.

Returns:

Returns a torch module loaded from the proper checkpoint path.

Return type:

nn.Module

maybe_checkpoint(model, loss, metrics)[source]

Overriding the checkpointing strategy of the FunctionTorchCheckpointer to save model state dictionaries instead of using the torch.save workflow.

Parameters:
  • model (nn.Module) – Model to be potentially saved (should be an Opacus wrapped model)

  • loss (float) – Loss value associated with the model to be used in checkpointing decisions.

  • metrics (dict[str, Scalar]) – Metrics associated with the model to be used in checkpointing decisions.

Return type:

None