fl4health.checkpointing.client_module module

class CheckpointMode(value)[source]

Bases: Enum

An enumeration.

POST_AGGREGATION = 'post_aggregation'
PRE_AGGREGATION = 'pre_aggregation'
class ClientCheckpointAndStateModule(pre_aggregation=None, post_aggregation=None, state_checkpointer=None)[source]

Bases: object

__init__(pre_aggregation=None, post_aggregation=None, state_checkpointer=None)[source]

This module is meant to hold up three to major components that determine how clients handle model and state checkpointing, where state checkpointing is meant to allow clients to restart if FL training is interrupted. For model checkpointing, there are two distinct types.

  • The first type, if defined, is used to checkpoint local models BEFORE server-side aggregation, but after local training. NOTE: This is akin to “further fine-tuning” approaches for global models.

  • The second type, if defined, is used to checkpoint local models AFTER server-side aggregation, but before local training NOTE: This is the “traditional” mechanism for global models.

As a final note, for some methods, such as Ditto or MR-MTL, these checkpoints will actually be identical. That’s because the target model for these methods is never globally aggregated. That is, they remain local

Parameters:
  • pre_aggregation (ModelCheckpointers, optional) – If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their validation metrics/losses BEFORE server-side aggregation. Defaults to None.

  • post_aggregation (ModelCheckpointers, optional) – If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their validation metrics/losses AFTER server-side aggregation. Defaults to None.

  • state_checkpointer (ClientStateCheckpointer | None, optional) – If defined, this checkpointer is used to preserve client state (not just models), in the event one wants to restart federated training. Defaults to None.

maybe_checkpoint(model, loss, metrics, mode)[source]

Performs model checkpointing for a particular mode (either pre- or post-aggregation) if any checkpointers are provided for that particular mode in this module. If present, the various checkpointers will decide whether or not to checkpoint based on their internal criterion and the loss/metrics provided.

Parameters:
  • model (nn.Module) – The model that might be checkpointed by the checkpointers.

  • loss (float) – The metric value obtained by the provided model. Used by the checkpointer(s) to decide whether to checkpoint the model.

  • metrics (dict[str, Scalar]) – The metrics obtained by the provided model. Potentially used by checkpointer to decide whether to checkpoint the model.

  • mode (CheckpointMode) – Determines which of the types of checkpointers to use. Currently, the only modes available are pre- and post-aggregation.

Raises:

ValueError – Thrown if the model checkpointing mode is not recognized.

Return type:

None

maybe_load_state(client)[source]

This function facilitates loading of any pre-existing state (with the name checkpoint_name) in the directory of the checkpoint_dir. If the state already exists at the proper path, the state is loaded and will be automatically saved into client’s attributes. If it doesn’t exist, we return False.

Parameters:

client (BasicClient) – client object into which state will be loaded if a checkpoint exists

Raises:

ValueError – Throws an error if this function is called, but no state checkpointer has been provided

Returns:

If the state checkpoint properly exists and is loaded correctly, client’s attributes are set to the loaded values, and True is returned. Otherwise, we return False (or throw an exception).

Return type:

bool

save_state(client)[source]

This function is meant to facilitate saving state required to restart an FL process on the client side. This function will simply save all the attributes stated in ClientStateCheckpointer.snapshot_attrs. This function should only be called if a state_checkpointer exists in this module.

Parameters:

client (BasicClient) – The client object from which state will be saved.

Raises:

ValueError – Throws an error if this function is called, but no state checkpointer has been provided

Return type:

None