fl4health.checkpointing.opacus_checkpointer module¶
- class BestLossOpacusCheckpointer(checkpoint_dir, checkpoint_name)[source]¶
Bases:
OpacusCheckpointer
- __init__(checkpoint_dir, checkpoint_name)[source]¶
This checkpointer only uses the loss value provided to the maybe_checkpoint function to determine whether a checkpoint should be save. We are always attempting to minimize the loss. So maximize is always set to false
- class LatestOpacusCheckpointer(checkpoint_dir, checkpoint_name)[source]¶
Bases:
OpacusCheckpointer
- __init__(checkpoint_dir, checkpoint_name)[source]¶
This class implements a checkpointer that always saves the model state when called. It uses a placeholder scoring function and maximize argument.
- class OpacusCheckpointer(checkpoint_dir, checkpoint_name, checkpoint_score_function, maximize=False)[source]¶
Bases:
FunctionTorchModuleCheckpointer
This is a specific type of checkpointer to be used in saving models trained using Opacus for differential privacy. Certain layers within Opacus wrapped models do not interact well with torch.save functionality. This checkpointer fixes this issue.
- load_best_checkpoint_into_model(target_model, target_is_grad_sample_module=False)[source]¶
State dictionary loading requires a model to be provided (unlike the torch.save mechanism). So we define this function, which requires the user to provide a model into which the state dictionary is to be loaded.
- load_checkpoint(path_to_checkpoint=None)[source]¶
Checkpointer with the option to either specify a checkpoint path or fall back on the internal path of the checkpointer. The flexibility to specify a load path is useful, for example, if you are not overwriting checkpoints when saving and need to load a specific past checkpoint for whatever reason.
- Parameters:
path_to_checkpoint (str | None, optional) – If provided, the checkpoint will be loaded from this path. If not specified, the checkpointer will load from self.checkpoint_path. Defaults to None.
- Returns:
Returns a torch module loaded from the proper checkpoint path.
- Return type:
nn.Module