fl4health.model_bases.fedsimclr_base module¶
- class FedSimClrModel(encoder, projection_head=Identity(), prediction_head=None, pretrain=True)[source]¶
Bases:
Module
- __init__(encoder, projection_head=Identity(), prediction_head=None, pretrain=True)[source]¶
Model base to train SimCLR (https://arxiv.org/pdf/2002.05709) in a federated manner presented in (https://arxiv.org/pdf/2207.09158). Can be used in pretraining and optionally finetuning.
- Parameters:
encoder (nn.Module) – Encoder that extracts a feature vector. given an input sample.
projection_head (nn.Module) – Projection Head that maps output of encoder to final representation used in contrastive loss for pretraining stage. Defaults to identity transformation.
prediction_head (nn.Module | None) – Prediction head that maps output of encoder to prediction in the finetuning stage. Defaults to None.
pretrain (bool) – Determines whether or not to use the
projection_head
(True) or theprediction_head
(False). Defaults to True.
- forward(input)[source]¶
Passes the input tensor through the encoder module. If we’re in the pretraining phase, the output of the encoder is flattened/projected for similarity computations. If we’re fine-tuning the model, these latent features are passes through the provided prediction head.
- Parameters:
input (torch.Tensor) – Input to be mapped to either latent features or a final prediction depending on the training phase.
- Returns:
The output from either the
projection_head
module if pre-training or theprediction_head
if fine-tuning.- Return type:
torch.Tensor
- static load_pretrained_model(model_path)[source]¶
Given a path, this function loads a model from the path, assuming was of type
FedSimClrModel
. The proper components are then routed to form a new model with the pre-existing weights.NOTE: Loaded models automatically set
pretrain
to False- Parameters:
model_path (Path) – Path to a
FedSimClrModel
object saved usingtorch.save
- Returns:
A model with pre-existing weights loaded and
pretrain
set to False- Return type: