fl4health.checkpointing.client_module module

class CheckpointMode(value)[source]

Bases: Enum

An enumeration.

POST_AGGREGATION = 'post_aggregation'
PRE_AGGREGATION = 'pre_aggregation'
class ClientCheckpointAndStateModule(pre_aggregation=None, post_aggregation=None, state_checkpointer=None)[source]

Bases: object

__init__(pre_aggregation=None, post_aggregation=None, state_checkpointer=None)[source]

This module is meant to hold up three major components that determine how clients handle model and state checkpointing, where state checkpointing is meant to allow clients to restart if FL training is interrupted. For model checkpointing, there are two distinct types.

The first type, if defined, is used to checkpoint local models BEFORE server-side aggregation, but after local training. NOTE: This is akin to “further fine-tuning” approaches for global models.

The second type, if defined, is used to checkpoint local models AFTER server-side aggregation, but before local training NOTE: This is the “traditional” mechanism for global models.

As a final note, for some methods, such as Ditto or MR-MTL, these checkpoints will actually be identical. That’s because the target model for these methods is never globally aggregated. That is, they remain local

Parameters:
  • pre_aggregation (ModelCheckpointers, optional) – If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their validation metrics/losses BEFORE server-side aggregation. Defaults to None.

  • post_aggregation (ModelCheckpointers, optional) – If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their validation metrics/losses AFTER server-side aggregation. Defaults to None.

  • state_checkpointer (PerRoundStateCheckpointer | None, optional) – If defined, this checkpointer is used to preserve client state (not just models), in the event one wants to restart federated training. Defaults to None.

maybe_checkpoint(model, loss, metrics, mode)[source]

Performs model checkpointing for a particular mode (either pre- or post-aggregation) if any checkpointers are provided for that particular mode in this module. If present, the various checkpointers will decide whether or not to checkpoint based on their internal criterion and the loss/metrics provided.

Parameters:
  • model (nn.Module) – The model that might be checkpointed by the checkpointers.

  • loss (float) – The metric value obtained by the provided model. Used by the checkpointer(s) to decide whether to checkpoint the model.

  • metrics (dict[str, Scalar]) – The metrics obtained by the provided model. Potentially used by checkpointer to decide whether to checkpoint the model.

  • mode (CheckpointMode) – Determines which of the types of checkpointers to use. Currently, the only modes available are pre- and post-aggregation.

Raises:

ValueError – Thrown if the model checkpointing mode is not recognized.

Return type:

None

maybe_load_state(state_checkpoint_name)[source]

This function facilitates loading of any pre-existing state (with the name state_checkpoint_name) in the directory of the state_checkpointer. If the state already exists at the proper path, the state is loaded and returned. If it doesn’t exist, we return None.

Parameters:

state_checkpoint_name (str) – Name of the state checkpoint file. The checkpointer itself will have a directory from which state will be loaded (if it exists).

Raises:

ValueError – Throws an error if this function is called, but no state checkpointer has been provided

Returns:

If the state checkpoint properly exists and is loaded correctly, this dictionary

carries that state. Otherwise, we return a None (or throw an exception).

Return type:

dict[str, Any] | None

save_state(state_checkpoint_name, state)[source]

This function is meant to facilitate saving state required to restart an FL process on the client side. This function will simply save whatever information is passed in the state variable using the file name in state_checkpoint_name. This function should only be called if a state_checkpointer exists in this module

Parameters:
  • state_checkpoint_name (str) – Name of the state checkpoint file. The checkpointer itself will have a directory to which state will be saved.

  • state (dict[str, Any]) – State to be saved so that training might be resumed on the client if federated training is interrupted. For example, this might contain things like optimizer states, learning rate scheduler states, etc.

Raises:

ValueError – Throws an error if this function is called, but no state checkpointer has been provided

Return type:

None