fl4health.model_bases.fedsimclr_base module¶
- class FedSimClrModel(encoder, projection_head=Identity(), prediction_head=None, pretrain=True)[source]¶
Bases:
Module
- __init__(encoder, projection_head=Identity(), prediction_head=None, pretrain=True)[source]¶
Model base to train SimCLR (https://arxiv.org/pdf/2002.05709) in a federated manner presented in (https://arxiv.org/pdf/2207.09158). Can be used in pretraining and optionally finetuning.
- Parameters:
encoder (nn.Module) – Encoder that extracts a feature vector. given an input sample.
projection_head (nn.Module) – Projection Head that maps output of encoder to final representation used in contrastive loss for pretraining stage. Defaults to identity transformation.
prediction_head (nn.Module | None) – Prediction head that maps output of encoder to prediction in the finetuning stage. Defaults to None.
pretrain (bool) – Determines whether or not to use the projection_head (True) or the prediction_head (False). Defaults to True.
- forward(input)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.