[docs]def__init__(self,pre_aggregation:ModelCheckpointers=None,post_aggregation:ModelCheckpointers=None,state_checkpointer:PerRoundStateCheckpointer|None=None,)->None:""" This module is meant to hold up three to major components that determine how clients handle model and state checkpointing, where state checkpointing is meant to allow clients to restart if FL training is interrupted. For model checkpointing, there are two distinct types. - The first type, if defined, is used to checkpoint local models **BEFORE** server-side aggregation, but after local training. **NOTE**: This is akin to "further fine-tuning" approaches for global models. - The second type, if defined, is used to checkpoint local models **AFTER** server-side aggregation, but before local training **NOTE**: This is the "traditional" mechanism for global models. As a final note, for some methods, such as Ditto or MR-MTL, these checkpoints will actually be identical. That's because the target model for these methods is never globally aggregated. That is, they remain local Args: pre_aggregation (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their validation metrics/losses **BEFORE** server-side aggregation. Defaults to None. post_aggregation (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their validation metrics/losses **AFTER** server-side aggregation. Defaults to None. state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer is used to preserve client state (not just models), in the event one wants to restart federated training. Defaults to None. """self.pre_aggregation=([pre_aggregation]ifisinstance(pre_aggregation,TorchModuleCheckpointer)elsepre_aggregation)self.post_aggregation=([post_aggregation]ifisinstance(post_aggregation,TorchModuleCheckpointer)elsepost_aggregation)self._check_if_shared_checkpoint_names()self.state_checkpointer=state_checkpointer
def_check_if_shared_checkpoint_names(self)->None:""" This function checks whether there is overlap in the paths to which the checkpointers of this module are supposed to write. This is to ensure that there isn't any accidental overwriting of checkpoints that is unintended. Raises: ValueError: If any of the pre- or post-aggregation model checkpointer paths are not unique. """pre_aggregation_paths=([checkpointer.checkpoint_pathforcheckpointerinself.pre_aggregation]ifself.pre_aggregationelse[])post_aggregation_paths=([checkpointer.checkpoint_pathforcheckpointerinself.post_aggregation]ifself.post_aggregationelse[])all_paths=pre_aggregation_paths+post_aggregation_pathsunique_paths=set(all_paths)iflen(unique_paths)!=len(all_paths):formatted_all_paths="\n".join(all_paths)raiseValueError("The paths of all of your checkpointers should be unique otherwise overwrites are possible and data "f"will be lost. The current paths are:\n{formatted_all_paths}")
[docs]defmaybe_checkpoint(self,model:nn.Module,loss:float,metrics:dict[str,Scalar],mode:CheckpointMode)->None:""" Performs model checkpointing for a particular mode (either pre- or post-aggregation) if any checkpointers are provided for that particular mode in this module. If present, the various checkpointers will decide whether or not to checkpoint based on their internal criterion and the loss/metrics provided. Args: model (nn.Module): The model that might be checkpointed by the checkpointers. loss (float): The metric value obtained by the provided model. Used by the checkpointer(s) to decide whether to checkpoint the model. metrics (dict[str, Scalar]): The metrics obtained by the provided model. Potentially used by checkpointer to decide whether to checkpoint the model. mode (CheckpointMode): Determines which of the types of checkpointers to use. Currently, the only modes available are pre- and post-aggregation. Raises: ValueError: Thrown if the model checkpointing mode is not recognized. """ifmode==CheckpointMode.PRE_AGGREGATION:ifself.pre_aggregationisnotNone:forcheckpointerinself.pre_aggregation:checkpointer.maybe_checkpoint(model,loss,metrics)else:log(INFO,"No Pre-aggregation checkpoint specified. Skipping.")elifmode==CheckpointMode.POST_AGGREGATION:ifself.post_aggregationisnotNone:forcheckpointerinself.post_aggregation:checkpointer.maybe_checkpoint(model,loss,metrics)else:log(INFO,"No Post-aggregation checkpoint specified. Skipping.")else:raiseValueError(f"Unrecognized mode for checkpointing: {str(mode)}")
[docs]defsave_state(self,state_checkpoint_name:str,state:dict[str,Any])->None:""" This function is meant to facilitate saving state required to restart an FL process on the client side. This function will simply save whatever information is passed in the state variable using the file name in ``state_checkpoint_name``. This function should only be called if a ``state_checkpointer`` exists in this module Args: state_checkpoint_name (str): Name of the state checkpoint file. The checkpointer itself will have a directory to which state will be saved. state (dict[str, Any]): State to be saved so that training might be resumed on the client if federated training is interrupted. For example, this might contain things like optimizer states, learning rate scheduler states, etc. Raises: ValueError: Throws an error if this function is called, but no state checkpointer has been provided """ifself.state_checkpointerisnotNone:self.state_checkpointer.save_checkpoint(state_checkpoint_name,state)else:raiseValueError("Attempting to save state but no state checkpointer is specified")
[docs]defmaybe_load_state(self,state_checkpoint_name:str)->dict[str,Any]|None:""" This function facilitates loading of any pre-existing state (with the name ``state_checkpoint_name``) in the directory of the ``state_checkpointer``. If the state already exists at the proper path, the state is loaded and returned. If it doesn't exist, we return None. Args: state_checkpoint_name (str): Name of the state checkpoint file. The checkpointer itself will have a directory from which state will be loaded (if it exists). Raises: ValueError: Throws an error if this function is called, but no state checkpointer has been provided Returns: dict[str, Any] | None: If the state checkpoint properly exists and is loaded correctly, this dictionary carries that state. Otherwise, we return a None (or throw an exception). """ifself.state_checkpointerisnotNone:ifself.state_checkpointer.checkpoint_exists(state_checkpoint_name):returnself.state_checkpointer.load_checkpoint(state_checkpoint_name)else:log(INFO,"State checkpointer is defined but no state checkpoint exists.")returnNoneelse:raiseValueError("Attempting to load state, but no state checkpointer is specified")