fl4health.checkpointing.state_checkpointer module

class ClientStateCheckpointer(checkpoint_dir, checkpoint_name=None, snapshot_attrs=None)[source]

Bases: StateCheckpointer

__init__(checkpoint_dir, checkpoint_name=None, snapshot_attrs=None)[source]

Class for saving and loading the state of a client’s attributes as specified in snapshot_attrs.

Parameters:
  • checkpoint_dir (Path) – Directory to which checkpoints are saved. This can be modified later with set_checkpoint_path

  • checkpoint_name (str | None, optional) – Name of the checkpoint to be saved. If None, but checkpoint_dir is set then a default checkpoint_name based on the underlying name of the client to be checkpointed will be set of the form f"client_{client.client_name}_state.pt". This can be changed later with set_checkpoint_path. Defaults to None.

  • snapshot_attrs (dict[str, tuple[AbstractSnapshotter, Any]] | None, optional) – Attributes that we need to save in order to allow for restarting of training. If None, a sensible default set of attributes and their associated snapshotters for an FL client are set. Defaults to None.

get_attribute(name)[source]

Get the attribute from the client.

Parameters:

name (str) – Name of the attribute.

Returns:

The attribute value.

Return type:

Any

maybe_load_client_state(client, attributes=None)[source]

Load the state into the client that is being provided.

Parameters:
  • client (BasicClient) – Target client object into which state will be loaded

  • attributes (list[str] | None, optional) – List of attributes to load from the checkpoint. If None, all attributes specified in snapshot_attrs are loaded. Defaults to None.

Returns:

True if a checkpoint is successfully loaded. False otherwise

Return type:

bool

maybe_set_default_checkpoint_name()[source]

Potentially sets a default name for the checkpoint to be saved. If checkpoint_dir is set but checkpoint_name is None then a default checkpoint_name based on the underlying name of the client to be checkpointed will be set of the form f"client_{self.client.client_name}_state.pt".

Return type:

None

save_client_state(client)[source]

Save the state of the client that is provided.

Parameters:

client (BasicClient) – Client object with state to be saved.

Return type:

None

set_attribute(name, value)[source]

Set the attribute on the client.

Parameters:
  • name (str) – Name of the attribute.

  • value (Any) – Value to set for the attribute.

Return type:

None

class NnUnetServerStateCheckpointer(checkpoint_dir, checkpoint_name=None)[source]

Bases: ServerStateCheckpointer

__init__(checkpoint_dir, checkpoint_name=None)[source]

Class for saving and loading the state of the server’s attributes based on the snapshot_attrs defined specifically for the nnUNet server.

Parameters:
  • checkpoint_dir (Path) – Directory to which checkpoints are saved. This can be modified later with set_checkpoint_path

  • checkpoint_name (str | None, optional) – Name of the checkpoint to be saved. If None, but checkpoint_dir is set then a default checkpoint_name based on the underlying name of the client to be checkpointed will be set of the form f"f"server_{self.server.server_name}_state.pt"". This can be updated later with set_checkpoint_path. Defaults to None.

class ServerStateCheckpointer(checkpoint_dir, checkpoint_name=None, snapshot_attrs=None)[source]

Bases: StateCheckpointer

__init__(checkpoint_dir, checkpoint_name=None, snapshot_attrs=None)[source]

Class for saving and loading the state of a server’s attributes as specified in snapshot_attrs.

Parameters:
  • checkpoint_dir (Path) – Directory to which checkpoints are saved. This can be modified later with set_checkpoint_path

  • checkpoint_name (str | None, optional) – Name of the checkpoint to be saved. If None, but checkpoint_dir is set then a default checkpoint_name based on the underlying name of the client to be checkpointed will be set of the form f"f"server_{self.server.server_name}_state.pt"". This can be updated later with set_checkpoint_path. Defaults to None.

  • snapshot_attrs (dict[str, tuple[AbstractSnapshotter, Any]] | None, optional) – Attributes that we need to save in order to allow for restarting of training. If None, a sensible default set of attributes and their associated snapshotters for an FL client are set. Defaults to None.

get_attribute(name)[source]

Get the attribute from the server.

Parameters:

name (str) – Name of the attribute.

Returns:

The attribute value.

Return type:

Any

maybe_load_server_state(server, model, attributes=None)[source]

Load the state of the server from checkpoint.

Parameters:
  • server (FlServer) – server into which the attributes will be loaded

  • nn.Module (model) – The model structure to be loaded as part of the server state.

  • attributes (list[str] | None) – List of attributes to load from the checkpoint. If None, all attributes specified in snapshot_attrs are loaded. Defaults to None.

Returns:

Returns a model if a checkpoint exists to load from. Otherwise returns None

Return type:

nn.Module | None

maybe_set_default_checkpoint_name()[source]

Potentially sets a default name for the checkpoint to be saved. If checkpoint_dir is set but checkpoint_name is None then a default checkpoint_name based on the underlying name of the server to be checkpointed will be set of the form f"server_{self.server.server_name}_state.pt".

Return type:

None

save_server_state(server, model)[source]

Save the state of the server, including a torch model, which is not a required component of the server class.

Parameters:
  • server (FlServer) – Server with state to be saved

  • model (nn.Module) – The model to be saved as part of the server state.

Return type:

None

set_attribute(name, value)[source]

Set the attribute on the server.

Parameters:
  • name (str) – Name of the attribute.

  • value (Any) – Value to set for the attribute.

Return type:

None

class StateCheckpointer(checkpoint_dir, checkpoint_name, snapshot_attrs)[source]

Bases: ABC

__init__(checkpoint_dir, checkpoint_name, snapshot_attrs)[source]

Class for saving and loading the state of the client or server attributes. Attributes are stored in a dictionary to assist saving and are loaded in a dictionary. Checkpointing can be done after client or server round to facilitate restarting federated training if interrupted, or during the client’s training loop to facilitate early stopping.

Server and client state checkpointers will save to disk in the provided directory. A default name for the state checkpoint will be derived if checkpoint name remains none at the time of saving.

Parameters:
  • checkpoint_dir (Path) – Directory to which checkpoints are saved. This can be modified later with set_checkpoint_path

  • checkpoint_name (str) – Name of the checkpoint to be saved. If None at time of state saving, a default name will be given to the checkpoint. This can be changed later with set_checkpoint_path

  • snapshot_attrs (dict[str, tuple[AbstractSnapshotter, Any]]) – Attributes that we need to save in order to allow for restarting of training.

add_to_snapshot_attr(name, snapshotter, input_type)[source]

Add new attribute to the default snapshot_attrs dictionary. For this, we need a snapshotter that provides functionality for loading and saving the state of the attribute based on the type of the attribute.

Parameters:
  • name (str) – Name of the attribute to be added.

  • snapshotter (AbstractSnapshotter) – Snapshotter object to be used for saving and loading the attribute.

  • input_type (type[T]) – Expected type of the attribute.

Return type:

None

checkpoint_exists()[source]

Check if a checkpoint exists at the checkpoint_path constructed as checkpoint_dir + checkpoint_name.

Returns:

True if checkpoint exists, otherwise false.

Return type:

bool

delete_from_snapshot_attr(name)[source]

Delete the attribute from the default snapshot_attrs dictionary. This is useful for removing attributes that are no longer needed or to avoid saving/loading them.

Parameters:

name (str) – Name of the attribute to be removed from the snapshot_attrs dictionary.

Return type:

None

abstract get_attribute(name)[source]

Get the attribute from the client or server.

Parameters:

name (str) – Name of the attribute.

Returns:

The attribute value.

Return type:

Any

load_checkpoint()[source]

Load and return the checkpoint stored in checkpoint_dir under the checkpoint_name if it exists. If it does not exist, an assertion error will be thrown.

Returns:

A dictionary representing the checkpointed state, as loaded by torch.load.

Return type:

dict[str, Any]

load_state(attributes=None)[source]

Load checkpointed state dictionary from the checkpoint, potentially restricting the attributes to load.

Parameters:

attributes (list[str] | None) – List of attributes to load from the checkpoint. If None, all attributes specified in snapshot_attrs are loaded. Defaults to None.

Return type:

None

save_checkpoint(checkpoint_dict)[source]

Save checkpoint_dict to checkpoint path defined based on checkpointer dir and checkpoint name.

Parameters:

checkpoint_dict (dict[str, Any]) – A dictionary with string keys and values of type Any representing the state to be saved.

Raises:

e – Will throw an error if there is an issue saving the model. Torch.save seems to swallow errors in this context, so we explicitly surface the error with a try/except.

Return type:

None

save_state()[source]

Create a snapshot of the state as defined in self.snapshot_attrs. It is saved at self.checkpoint_path.

Return type:

None

abstract set_attribute(name, value)[source]

Set the attribute on the client or server.

Parameters:
  • name (str) – Name of the attribute.

  • value (Any) – Value to set for the attribute.

Return type:

None

set_checkpoint_path(checkpoint_dir, checkpoint_name)[source]

Set or update the checkpoint path based on the provided checkpoint name and directory.

Parameters:
  • checkpoint_dir (Path) – The directory where the checkpoint will be saved.

  • checkpoint_name (str) – The name of the checkpoint file.

Return type:

None