fl4health.preprocessing.warmed_up_module module

class WarmedUpModule(pretrained_model=None, pretrained_model_path=None, weights_mapping_path=None)[source]

Bases: object

This class is used to load a pretrained model into the target model.

__init__(pretrained_model=None, pretrained_model_path=None, weights_mapping_path=None)[source]

Initialize the WarmedUpModule with the pretrained model states and weights mapping dict.

Parameters:
  • pretrained_model (torch.nn.Module | None) – Pretrained model. This is mutually exclusive with pretrained_model_path.

  • pretrained_model_path (Path | None) – Path of the pretrained model. This is mutually exclusive with pretrained_model.

  • weights_mapping_dir (str | None, optional) – Path of to json file of the weights mapping dict.

  • same (If models are not exactly the)

  • pretrained (a weights mapping dict is needed to map the weights of the)

  • model. (model to the target)

get_matching_component(key)[source]

Get the matching component of the key from the weights mapping dictionary. Since the provided mapping can contain partial names of the keys, this function is used to split the key of the target model and match it with the partial key in the mapping, returning the complete name of the key in the pretrained model.

This allows users to provide one mapping for multiple states that share the same prefix. For example,if the mapping is {“model”: “global_model”} and the input key of the target model is “model.layer1.weight”,then the returned matching component is “global_model.layer1.weight”.

Parameters:

key (str) – Key to be matched in pretrained model.

Returns:

If no weights mapping dict is provided, returns the key. Otherwise, if the key is in the weights mapping dict, returns the matching component of the key. Otherwise, returns None.

Return type:

str | None

load_from_pretrained(model)[source]

Load the pretrained model into the target model.

Parameters:

model (torch.nn.Module) – target model.

Return type:

Module