fl4health.checkpointing.checkpointer module

class BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name)[source]

Bases: FunctionTorchModuleCheckpointer

__init__(checkpoint_dir, checkpoint_name)[source]

This checkpointer only uses the loss value provided to the maybe_checkpoint function to determine whether a checkpoint should be save. We are always attempting to minimize the loss. So maximize is always set to false

Parameters:
  • checkpoint_dir (str) – Directory to which the model is saved. This directory should already exist. The checkpointer will not create it if it does not.

  • checkpoint_name (str) – Name of the checkpoint to be saved.

maybe_checkpoint(model, loss, metrics)[source]

This function will decide whether to checkpoint the provided model based on the loss argument. If the provided loss is better than any previous losses seen by this checkpointer, the model will be saved.

Parameters:
  • model (nn.Module) – Model that might be persisted if the scoring function determines it should be

  • loss (float) – Loss associated with the provided model. This value is used to determine whether to save the model or not.

  • metrics (dict[str, Scalar]) – Metrics associated with the provided model. Will not be used by this checkpointer.

Raises:

e – Will throw an error if there is an issue saving the model. Torch.save seems to swallow errors in this context, so we explicitly surface the error with a try/except.

Return type:

None

class FunctionTorchModuleCheckpointer(checkpoint_dir, checkpoint_name, checkpoint_score_function, maximize=False)[source]

Bases: TorchModuleCheckpointer

__init__(checkpoint_dir, checkpoint_name, checkpoint_score_function, maximize=False)[source]

A general torch checkpointer base class that allows for flexible definition of how to decide when to checkpoint based on the loss and metrics provided. The score function should compute a score from these values and maximize specifies whether we are hoping to maximize or minimize that score

Parameters:
  • checkpoint_dir (str) – Directory to which the model is saved. This directory should already exist. The checkpointer will not create it if it does not.

  • checkpoint_name (str) – Name of the checkpoint to be saved.

  • checkpoint_score_function (CheckpointScoreFunctionType) – Function taking in a loss value and dictionary of metrics and produces a score based on these.

  • maximize (bool, optional) – Specifies whether we’re trying to minimize or maximize the score produced by the scoring function. Defaults to False.

maybe_checkpoint(model, loss, metrics)[source]

Given the loss/metrics associated with the provided model, the checkpointer uses the scoring function to produce a score. This score will then be used to determine whether the model should be checkpointed or not.

Parameters:
  • model (nn.Module) – Model that might be persisted if the scoring function determines it should be

  • loss (float) – Loss associated with the provided model. Will potentially contribute to checkpointing decision, based on the score function.

  • metrics (dict[str, Scalar]) – Metrics associated with the provided model. Will potentially contribute to the checkpointing decision, based on the score function.

Raises:

e – Will throw an error if there is an issue saving the model. Torch.save seems to swallow errors in this context, so we explicitly surface the error with a try/except.

Return type:

None

class LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name)[source]

Bases: FunctionTorchModuleCheckpointer

__init__(checkpoint_dir, checkpoint_name)[source]

A checkpointer that always checkpoints the model, regardless of the loss/metrics provided. As such, the score function is essentially a dummy.

Parameters:
  • checkpoint_dir (str) – Directory to which the model is saved. This directory should already exist. The checkpointer will not create it if it does not.

  • checkpoint_name (str) – Name of the checkpoint to be saved.

maybe_checkpoint(model, loss, _)[source]

This function is essentially a pass through, as this class always checkpoints the provided model

Parameters:
  • model (nn.Module) – Model to be checkpointed whenever this function is called

  • loss (float) – Loss associated with the provided model. Will potentially contribute to checkpointing decision, based on the score function. NOT USED.

  • metrics (dict[str, Scalar]) – Metrics associated with the provided model. Will potentially contribute to the checkpointing decision, based on the score function. NOT USED.

Raises:

e – Will throw an error if there is an issue saving the model. Torch.save seems to swallow errors in this context, so we explicitly surface the error with a try/except.

Return type:

None

class PerRoundStateCheckpointer(checkpoint_dir)[source]

Bases: object

__init__(checkpoint_dir)[source]

Base class that provides a uniform interface for loading, saving and checking if checkpoints exists.

Parameters:
  • checkpoint_dir (Path) – Base directory to store checkpoints. This checkpoint directory MUST already exist.

  • checkpointer. (It will not be created by this state)

checkpoint_exists(checkpoint_name, **kwargs)[source]

Checks if a checkpoint exists at the checkpoint_dir constructed at initialization + checkpoint_name.

Returns:

Whether or not a checkpoint exists.

Return type:

bool

load_checkpoint(checkpoint_name)[source]

Loads and returns the checkpoint stored in checkpoint_dir under the provided name if it exists. If it doesn’t exist, an assertion error will be thrown.

Parameters:

checkpoint_name (str) – Name of the state checkpoint to be loaded.

Returns:

A dictionary representing the checkpointed state, as loaded by torch.load.

Return type:

dict[str, Any]

save_checkpoint(checkpoint_name, checkpoint_dict)[source]

Saves checkpoint_dict to checkpoint path form from this classes checkpointer dir and the provided checkpoint name.

Parameters:
  • checkpoint_name (str) – Name of the state checkpoint file.

  • checkpoint_dict (dict[str, Any]) – A dictionary with string keys and values of type Any representing the state to checkpoint.

Raises:

e – Will throw an error if there is an issue saving the model. Torch.save seems to swallow errors in this context, so we explicitly surface the error with a try/except.

Return type:

None

class TorchModuleCheckpointer(checkpoint_dir, checkpoint_name)[source]

Bases: ABC

__init__(checkpoint_dir, checkpoint_name)[source]

Basic abstract base class to handle checkpointing pytorch models. Models are saved with torch.save by default

Parameters:
  • checkpoint_dir (str) – Directory to which the model is saved. This directory should already exist. The checkpointer will not create it if it does not.

  • checkpoint_name (str) – Name of the checkpoint to be saved.

load_checkpoint(path_to_checkpoint=None)[source]

Checkpointer with the option to either specify a checkpoint path or fall back on the internal path of the checkpointer. The flexibility to specify a load path is useful, for example, if you are not overwriting checkpoints when saving and need to load a specific past checkpoint for whatever reason.

Parameters:

path_to_checkpoint (str | None, optional) – If provided, the checkpoint will be loaded from this path. If not specified, the checkpointer will load from self.checkpoint_path. Defaults to None.

Returns:

Returns a torch module loaded from the proper checkpoint path.

Return type:

nn.Module

abstract maybe_checkpoint(model, loss, metrics)[source]

Abstract method to be implemented by every TorchCheckpointer. Based on the loss and metrics provided it should determine whether to produce a checkpoint AND save it if applicable.

Parameters:
  • model (nn.Module) – Model to potentially save via the checkpointer

  • loss (float) – Computed loss associated with the model.

  • metrics (dict[str, float]) – Computed metrics associated with the model.

Raises:

NotImplementedError – Must be implemented by the checkpointer

Return type:

None