fl4health.preprocessing.warmed_up_module module¶
- class WarmedUpModule(pretrained_model=None, pretrained_model_path=None, weights_mapping_path=None)[source]¶
Bases:
object
- __init__(pretrained_model=None, pretrained_model_path=None, weights_mapping_path=None)[source]¶
This class is used to load a pretrained model into the target model.
Initialize the
WarmedUpModule
with the pretrained model states and weights mapping dict.- Parameters:
pretrained_model (torch.nn.Module | None) – Pretrained model. This is mutually exclusive with
pretrained_model_path
.pretrained_model_path (Path | None) – Path of the pretrained model. This is mutually exclusive with
pretrained_model
.weights_mapping_dir (str | None, optional) – Path of to json file of the weights mapping dict. If models are not exactly the same, a weights mapping dict is needed to map the weights of the pretrained model to the target model.
- get_matching_component(key)[source]¶
Get the matching component of the key from the weights mapping dictionary. Since the provided mapping can contain partial names of the keys, this function is used to split the key of the target model and match it with the partial key in the mapping, returning the complete name of the key in the pretrained model.
This allows users to provide one mapping for multiple states that share the same prefix. For example, if the mapping is
{"model": "global_model"}
and the input key of the target model ismodel.layer1.weight
, then the returned matching component isglobal_model.layer1.weight
.