fl4health.checkpointing.state_checkpointer module¶
- class ClientStateCheckpointer(checkpoint_dir, checkpoint_name=None, snapshot_attrs=None)[source]¶
Bases:
StateCheckpointer
- __init__(checkpoint_dir, checkpoint_name=None, snapshot_attrs=None)[source]¶
Class for saving and loading the state of a client’s attributes as specified in
snapshot_attrs
.- Parameters:
checkpoint_dir (Path) – Directory to which checkpoints are saved. This can be modified later with set_checkpoint_path
checkpoint_name (str | None, optional) – Name of the checkpoint to be saved. If None, but
checkpoint_dir
is set then a defaultcheckpoint_name
based on the underlying name of the client to be checkpointed will be set of the formf"client_{client.client_name}_state.pt"
. This can be changed later with set_checkpoint_path. Defaults to None.snapshot_attrs (dict[str, tuple[AbstractSnapshotter, Any]] | None, optional) – Attributes that we need to save in order to allow for restarting of training. If None, a sensible default set of attributes and their associated snapshotters for an FL client are set. Defaults to None.
- get_attribute(name)[source]¶
Get the attribute from the client.
- Parameters:
name (str) – Name of the attribute.
- Returns:
The attribute value.
- Return type:
Any
- maybe_load_client_state(client, attributes=None)[source]¶
Load the state into the client that is being provided.
- Parameters:
client (BasicClient) – Target client object into which state will be loaded
attributes (list[str] | None, optional) – List of attributes to load from the checkpoint. If None, all attributes specified in
snapshot_attrs
are loaded. Defaults to None.
- Returns:
True if a checkpoint is successfully loaded. False otherwise
- Return type:
- maybe_set_default_checkpoint_name()[source]¶
Potentially sets a default name for the checkpoint to be saved. If
checkpoint_dir
is set butcheckpoint_name
is None then a defaultcheckpoint_name
based on the underlying name of the client to be checkpointed will be set of the formf"client_{self.client.client_name}_state.pt"
.- Return type:
- save_client_state(client)[source]¶
Save the state of the client that is provided.
- Parameters:
client (BasicClient) – Client object with state to be saved.
- Return type:
- class NnUnetServerStateCheckpointer(checkpoint_dir, checkpoint_name=None)[source]¶
Bases:
ServerStateCheckpointer
- __init__(checkpoint_dir, checkpoint_name=None)[source]¶
Class for saving and loading the state of the server’s attributes based on the
snapshot_attrs
defined specifically for the nnUNet server.- Parameters:
checkpoint_dir (Path) – Directory to which checkpoints are saved. This can be modified later with set_checkpoint_path
checkpoint_name (str | None, optional) – Name of the checkpoint to be saved. If None, but
checkpoint_dir
is set then a defaultcheckpoint_name
based on the underlying name of the client to be checkpointed will be set of the formf"f"server_{self.server.server_name}_state.pt""
. This can be updated later with set_checkpoint_path. Defaults to None.
- class ServerStateCheckpointer(checkpoint_dir, checkpoint_name=None, snapshot_attrs=None)[source]¶
Bases:
StateCheckpointer
- __init__(checkpoint_dir, checkpoint_name=None, snapshot_attrs=None)[source]¶
Class for saving and loading the state of a server’s attributes as specified in
snapshot_attrs
.- Parameters:
checkpoint_dir (Path) – Directory to which checkpoints are saved. This can be modified later with set_checkpoint_path
checkpoint_name (str | None, optional) – Name of the checkpoint to be saved. If None, but
checkpoint_dir
is set then a defaultcheckpoint_name
based on the underlying name of the client to be checkpointed will be set of the formf"f"server_{self.server.server_name}_state.pt""
. This can be updated later with set_checkpoint_path. Defaults to None.snapshot_attrs (dict[str, tuple[AbstractSnapshotter, Any]] | None, optional) – Attributes that we need to save in order to allow for restarting of training. If None, a sensible default set of attributes and their associated snapshotters for an FL client are set. Defaults to None.
- get_attribute(name)[source]¶
Get the attribute from the server.
- Parameters:
name (str) – Name of the attribute.
- Returns:
The attribute value.
- Return type:
Any
- maybe_load_server_state(server, model, attributes=None)[source]¶
Load the state of the server from checkpoint.
- Parameters:
server (FlServer) – server into which the attributes will be loaded
nn.Module (model) – The model structure to be loaded as part of the server state.
attributes (list[str] | None) – List of attributes to load from the checkpoint. If None, all attributes specified in
snapshot_attrs
are loaded. Defaults to None.
- Returns:
Returns a model if a checkpoint exists to load from. Otherwise returns None
- Return type:
nn.Module | None
- maybe_set_default_checkpoint_name()[source]¶
Potentially sets a default name for the checkpoint to be saved. If
checkpoint_dir
is set butcheckpoint_name
is None then a defaultcheckpoint_name
based on the underlying name of the server to be checkpointed will be set of the formf"server_{self.server.server_name}_state.pt"
.- Return type:
- class StateCheckpointer(checkpoint_dir, checkpoint_name, snapshot_attrs)[source]¶
Bases:
ABC
- __init__(checkpoint_dir, checkpoint_name, snapshot_attrs)[source]¶
Class for saving and loading the state of the client or server attributes. Attributes are stored in a dictionary to assist saving and are loaded in a dictionary. Checkpointing can be done after client or server round to facilitate restarting federated training if interrupted, or during the client’s training loop to facilitate early stopping.
Server and client state checkpointers will save to disk in the provided directory. A default name for the state checkpoint will be derived if checkpoint name remains none at the time of saving.
- Parameters:
checkpoint_dir (Path) – Directory to which checkpoints are saved. This can be modified later with set_checkpoint_path
checkpoint_name (str) – Name of the checkpoint to be saved. If None at time of state saving, a default name will be given to the checkpoint. This can be changed later with set_checkpoint_path
snapshot_attrs (dict[str, tuple[AbstractSnapshotter, Any]]) – Attributes that we need to save in order to allow for restarting of training.
- add_to_snapshot_attr(name, snapshotter, input_type)[source]¶
Add new attribute to the default
snapshot_attrs
dictionary. For this, we need a snapshotter that provides functionality for loading and saving the state of the attribute based on the type of the attribute.- Parameters:
name (str) – Name of the attribute to be added.
snapshotter (AbstractSnapshotter) – Snapshotter object to be used for saving and loading the attribute.
input_type (type[T]) – Expected type of the attribute.
- Return type:
- checkpoint_exists()[source]¶
Check if a checkpoint exists at the checkpoint_path constructed as
checkpoint_dir
+checkpoint_name
.- Returns:
True if checkpoint exists, otherwise false.
- Return type:
- delete_from_snapshot_attr(name)[source]¶
Delete the attribute from the default
snapshot_attrs
dictionary. This is useful for removing attributes that are no longer needed or to avoid saving/loading them.
- abstract get_attribute(name)[source]¶
Get the attribute from the client or server.
- Parameters:
name (str) – Name of the attribute.
- Returns:
The attribute value.
- Return type:
Any
- load_checkpoint()[source]¶
Load and return the checkpoint stored in
checkpoint_dir
under thecheckpoint_name
if it exists. If it does not exist, an assertion error will be thrown.
- load_state(attributes=None)[source]¶
Load checkpointed state dictionary from the checkpoint, potentially restricting the attributes to load.
- save_checkpoint(checkpoint_dict)[source]¶
Save
checkpoint_dict
to checkpoint path defined based on checkpointer dir and checkpoint name.- Parameters:
checkpoint_dict (dict[str, Any]) – A dictionary with string keys and values of type Any representing the state to be saved.
- Raises:
e – Will throw an error if there is an issue saving the model.
Torch.save
seems to swallow errors in this context, so we explicitly surface the error with a try/except.- Return type:
- save_state()[source]¶
Create a snapshot of the state as defined in
self.snapshot_attrs
. It is saved atself.checkpoint_path
.- Return type: