fl4health.checkpointing.checkpointer module¶
- class BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name)[source]¶
Bases:
FunctionTorchModuleCheckpointer
- __init__(checkpoint_dir, checkpoint_name)[source]¶
This checkpointer only uses the loss value provided to the maybe_checkpoint function to determine whether a checkpoint should be save. We are always attempting to minimize the loss. So maximize is always set to false
- maybe_checkpoint(model, loss, metrics)[source]¶
This function will decide whether to checkpoint the provided model based on the loss argument. If the provided loss is better than any previous losses seen by this checkpointer, the model will be saved.
- Parameters:
model (nn.Module) – Model that might be persisted if the scoring function determines it should be
loss (float) – Loss associated with the provided model. This value is used to determine whether to save the model or not.
metrics (dict[str, Scalar]) – Metrics associated with the provided model. Will not be used by this checkpointer.
- Raises:
e – Will throw an error if there is an issue saving the model. Torch.save seems to swallow errors in this context, so we explicitly surface the error with a try/except.
- Return type:
- class FunctionTorchModuleCheckpointer(checkpoint_dir, checkpoint_name, checkpoint_score_function, maximize=False)[source]¶
Bases:
TorchModuleCheckpointer
- __init__(checkpoint_dir, checkpoint_name, checkpoint_score_function, maximize=False)[source]¶
A general torch checkpointer base class that allows for flexible definition of how to decide when to checkpoint based on the loss and metrics provided. The score function should compute a score from these values and maximize specifies whether we are hoping to maximize or minimize that score
- Parameters:
checkpoint_dir (str) – Directory to which the model is saved. This directory should already exist. The checkpointer will not create it if it does not.
checkpoint_name (str) – Name of the checkpoint to be saved.
checkpoint_score_function (CheckpointScoreFunctionType) – Function taking in a loss value and dictionary of metrics and produces a score based on these.
maximize (bool, optional) – Specifies whether we’re trying to minimize or maximize the score produced by the scoring function. Defaults to False.
- maybe_checkpoint(model, loss, metrics)[source]¶
Given the loss/metrics associated with the provided model, the checkpointer uses the scoring function to produce a score. This score will then be used to determine whether the model should be checkpointed or not.
- Parameters:
model (nn.Module) – Model that might be persisted if the scoring function determines it should be
loss (float) – Loss associated with the provided model. Will potentially contribute to checkpointing decision, based on the score function.
metrics (dict[str, Scalar]) – Metrics associated with the provided model. Will potentially contribute to the checkpointing decision, based on the score function.
- Raises:
e – Will throw an error if there is an issue saving the model. Torch.save seems to swallow errors in this context, so we explicitly surface the error with a try/except.
- Return type:
- class LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name)[source]¶
Bases:
FunctionTorchModuleCheckpointer
- __init__(checkpoint_dir, checkpoint_name)[source]¶
A checkpointer that always checkpoints the model, regardless of the loss/metrics provided. As such, the score function is essentially a dummy.
- maybe_checkpoint(model, loss, _)[source]¶
This function is essentially a pass through, as this class always checkpoints the provided model
- Parameters:
model (nn.Module) – Model to be checkpointed whenever this function is called
loss (float) – Loss associated with the provided model. Will potentially contribute to checkpointing decision, based on the score function. NOT USED.
metrics (dict[str, Scalar]) – Metrics associated with the provided model. Will potentially contribute to the checkpointing decision, based on the score function. NOT USED.
- Raises:
e – Will throw an error if there is an issue saving the model. Torch.save seems to swallow errors in this context, so we explicitly surface the error with a try/except.
- Return type:
- class PerRoundStateCheckpointer(checkpoint_dir)[source]¶
Bases:
object
- __init__(checkpoint_dir)[source]¶
Base class that provides a uniform interface for loading, saving and checking if checkpoints exists.
- Parameters:
checkpoint_dir (Path) – Base directory to store checkpoints. This checkpoint directory MUST already exist.
checkpointer. (It will not be created by this state)
- checkpoint_exists(checkpoint_name, **kwargs)[source]¶
Checks if a checkpoint exists at the checkpoint_dir constructed at initialization + checkpoint_name.
- Returns:
Whether or not a checkpoint exists.
- Return type:
- load_checkpoint(checkpoint_name)[source]¶
Loads and returns the checkpoint stored in checkpoint_dir under the provided name if it exists. If it doesn’t exist, an assertion error will be thrown.
- save_checkpoint(checkpoint_name, checkpoint_dict)[source]¶
Saves checkpoint_dict to checkpoint path form from this classes checkpointer dir and the provided checkpoint name.
- Parameters:
- Raises:
e – Will throw an error if there is an issue saving the model. Torch.save seems to swallow errors in this context, so we explicitly surface the error with a try/except.
- Return type:
- class TorchModuleCheckpointer(checkpoint_dir, checkpoint_name)[source]¶
Bases:
ABC
- __init__(checkpoint_dir, checkpoint_name)[source]¶
Basic abstract base class to handle checkpointing pytorch models. Models are saved with torch.save by default
- load_checkpoint(path_to_checkpoint=None)[source]¶
Checkpointer with the option to either specify a checkpoint path or fall back on the internal path of the checkpointer. The flexibility to specify a load path is useful, for example, if you are not overwriting checkpoints when saving and need to load a specific past checkpoint for whatever reason.
- Parameters:
path_to_checkpoint (str | None, optional) – If provided, the checkpoint will be loaded from this path. If not specified, the checkpointer will load from self.checkpoint_path. Defaults to None.
- Returns:
Returns a torch module loaded from the proper checkpoint path.
- Return type:
nn.Module