fl4health.checkpointing.checkpointer module¶
- class BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name)[source]¶
Bases:
FunctionTorchModuleCheckpointer
- __init__(checkpoint_dir, checkpoint_name)[source]¶
This checkpointer only uses the loss value provided to the
maybe_checkpoint
function to determine whether a checkpoint should be save. We are always attempting to minimize the loss. So maximize is always set to false.
- maybe_checkpoint(model, loss, metrics)[source]¶
This function will decide whether to checkpoint the provided model based on the loss argument. If the provided loss is better than any previous losses seen by this checkpointer, the model will be saved.
- Parameters:
model (nn.Module) – Model that might be persisted if the scoring function determines it should be.
loss (float) – Loss associated with the provided model. This value is used to determine whether to save the model or not.
metrics (dict[str, Scalar]) – Metrics associated with the provided model. Will not be used by this checkpointer.
- Raises:
e – Will throw an error if there is an issue saving the model.
Torch.save
seems to swallow errors in this context, so we explicitly surface the error with a try/except.- Return type:
- class FunctionTorchModuleCheckpointer(checkpoint_dir, checkpoint_name, checkpoint_score_function, maximize=False)[source]¶
Bases:
TorchModuleCheckpointer
- __init__(checkpoint_dir, checkpoint_name, checkpoint_score_function, maximize=False)[source]¶
A general torch checkpointer base class that allows for flexible definition of how to decide when to checkpoint based on the loss and metrics provided. The score function should compute a score from these values and maximize specifies whether we are hoping to maximize or minimize that score.
- Parameters:
checkpoint_dir (str) – Directory to which the model is saved. This directory should already exist. The checkpointer will not create it if it does not.
checkpoint_name (str) – Name of the checkpoint to be saved.
checkpoint_score_function (CheckpointScoreFunctionType) – Function taking in a loss value and dictionary of metrics and produces a score based on these.
maximize (bool, optional) – Specifies whether we’re trying to minimize or maximize the score produced by the scoring function. Defaults to False.
- maybe_checkpoint(model, loss, metrics)[source]¶
Given the loss/metrics associated with the provided model, the checkpointer uses the scoring function to produce a score. This score will then be used to determine whether the model should be checkpointed or not.
- Parameters:
model (nn.Module) – Model that might be persisted if the scoring function determines it should be.
loss (float) – Loss associated with the provided model. Will potentially contribute to checkpointing decision, based on the score function.
metrics (dict[str, Scalar]) – Metrics associated with the provided model. Will potentially contribute to the checkpointing decision, based on the score function.
- Raises:
e – Will throw an error if there is an issue saving the model.
Torch.save
seems to swallow errors in this context, so we explicitly surface the error with a try/except.- Return type:
- class LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name)[source]¶
Bases:
FunctionTorchModuleCheckpointer
- __init__(checkpoint_dir, checkpoint_name)[source]¶
A checkpointer that always checkpoints the model, regardless of the loss/metrics provided. As such, the score function is essentially a dummy.
- maybe_checkpoint(model, loss, _)[source]¶
This function is essentially a pass through, as this class always checkpoints the provided model
- Parameters:
model (nn.Module) – Model to be checkpointed whenever this function is called
loss (float) – Loss associated with the provided model. Will potentially contribute to checkpointing decision, based on the score function. NOT USED.
metrics (dict[str, Scalar]) – Metrics associated with the provided model. Will potentially contribute to the checkpointing decision, based on the score function. NOT USED.
- Raises:
e – Will throw an error if there is an issue saving the model.
Torch.save
seems to swallow errors in this context, so we explicitly surface the error with a try/except.- Return type:
- class PerRoundStateCheckpointer(checkpoint_dir)[source]¶
Bases:
object
- __init__(checkpoint_dir)[source]¶
Base class that provides a uniform interface for loading, saving and checking if checkpoints exists.
- Parameters:
checkpoint_dir (Path) – Base directory to store checkpoints. This checkpoint directory MUST already exist. It will not be created by this state checkpointer.
- checkpoint_exists(checkpoint_name, **kwargs)[source]¶
Checks if a checkpoint exists at the
checkpoint_dir
constructed at initialization +checkpoint_name
.- Parameters:
checkpoint_name (str) – Name of checkpoint for existence test. Directory of checkpoint is held internally as state by the checkpointer
- Raises:
ValueError – Previously this function supported sending a path, but now requires
checkpoint_name
. Will raise an error ischeckpoint_path
provided.- Returns:
True if checkpoint exists, otherwise false.
- Return type:
- load_checkpoint(checkpoint_name)[source]¶
Loads and returns the checkpoint stored in
checkpoint_dir
under the provided name if it exists. If it does not exist, an assertion error will be thrown.
- save_checkpoint(checkpoint_name, checkpoint_dict)[source]¶
Saves
checkpoint_dict
to checkpoint path form from this classes checkpointer dir and the provided checkpoint name.- Parameters:
- Raises:
e – Will throw an error if there is an issue saving the model.
Torch.save
seems to swallow errors in this context, so we explicitly surface the error with a try/except.- Return type:
- class TorchModuleCheckpointer(checkpoint_dir, checkpoint_name)[source]¶
Bases:
ABC
- __init__(checkpoint_dir, checkpoint_name)[source]¶
Basic abstract base class to handle checkpointing pytorch models. Models are saved with
torch.save
by default.
- load_checkpoint(path_to_checkpoint=None)[source]¶
Checkpointer with the option to either specify a checkpoint path or fall back on the internal path of the checkpointer. The flexibility to specify a load path is useful, for example, if you are not overwriting checkpoints when saving and need to load a specific past checkpoint for whatever reason.
- Parameters:
path_to_checkpoint (str | None, optional) – If provided, the checkpoint will be loaded from this path. If not specified, the checkpointer will load from
self.checkpoint_path
. Defaults to None.- Returns:
Returns a torch module loaded from the proper checkpoint path.
- Return type:
nn.Module