fl4health.model_bases.fedsimclr_base module

class FedSimClrModel(encoder, projection_head=Identity(), prediction_head=None, pretrain=True)[source]

Bases: Module

__init__(encoder, projection_head=Identity(), prediction_head=None, pretrain=True)[source]

Model base to train SimCLR (https://arxiv.org/pdf/2002.05709) in a federated manner presented in (https://arxiv.org/pdf/2207.09158). Can be used in pretraining and optionally finetuning.

Parameters:
  • encoder (nn.Module) – Encoder that extracts a feature vector. given an input sample.

  • projection_head (nn.Module) – Projection Head that maps output of encoder to final representation used in contrastive loss for pretraining stage. Defaults to identity transformation.

  • prediction_head (nn.Module | None) – Prediction head that maps output of encoder to prediction in the finetuning stage. Defaults to None.

  • pretrain (bool) – Determines whether or not to use the projection_head (True) or the prediction_head (False). Defaults to True.

forward(input)[source]

Passes the input tensor through the encoder module. If we’re in the pretraining phase, the output of the encoder is flattened/projected for similarity computations. If we’re fine-tuning the model, these latent features are passes through the provided prediction head.

Parameters:

input (torch.Tensor) – Input to be mapped to either latent features or a final prediction depending on the training phase.

Returns:

The output from either the projection_head module if pre-training or the prediction_head if fine-tuning.

Return type:

torch.Tensor

static load_pretrained_model(model_path)[source]

Given a path, this function loads a model from the path, assuming was of type FedSimClrModel. The proper components are then routed to form a new model with the pre-existing weights.

NOTE: Loaded models automatically set pretrain to False

Parameters:

model_path (Path) – Path to a FedSimClrModel object saved using torch.save

Returns:

A model with pre-existing weights loaded and pretrain set to False

Return type:

FedSimClrModel