fl4health.model_bases.fedsimclr_base module

class FedSimClrModel(encoder, projection_head=Identity(), prediction_head=None, pretrain=True)[source]

Bases: Module

__init__(encoder, projection_head=Identity(), prediction_head=None, pretrain=True)[source]

Model base to train SimCLR (https://arxiv.org/pdf/2002.05709) in a federated manner presented in (https://arxiv.org/pdf/2207.09158). Can be used in pretraining and optionally finetuning.

Parameters:
  • encoder (nn.Module) – Encoder that extracts a feature vector. given an input sample.

  • projection_head (nn.Module) – Projection Head that maps output of encoder to final representation used in contrastive loss for pretraining stage. Defaults to identity transformation.

  • prediction_head (nn.Module | None) – Prediction head that maps output of encoder to prediction in the finetuning stage. Defaults to None.

  • pretrain (bool) – Determines whether or not to use the projection_head (True) or the prediction_head (False). Defaults to True.

forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

static load_pretrained_model(model_path)[source]
Return type:

FedSimClrModel