Source code for fl4health.model_bases.fedsimclr_base

from __future__ import annotations

from pathlib import Path

import torch
import torch.nn as nn


[docs] class FedSimClrModel(nn.Module):
[docs] def __init__( self, encoder: nn.Module, projection_head: nn.Module = nn.Identity(), prediction_head: nn.Module | None = None, pretrain: bool = True, ) -> None: """ Model base to train SimCLR (https://arxiv.org/pdf/2002.05709) in a federated manner presented in (https://arxiv.org/pdf/2207.09158). Can be used in pretraining and optionally finetuning. Args: encoder (nn.Module): Encoder that extracts a feature vector. given an input sample. projection_head (nn.Module): Projection Head that maps output of encoder to final representation used in contrastive loss for pretraining stage. Defaults to identity transformation. prediction_head (nn.Module | None): Prediction head that maps output of encoder to prediction in the finetuning stage. Defaults to None. pretrain (bool): Determines whether or not to use the projection_head (True) or the prediction_head (False). Defaults to True. """ super().__init__() assert not ( prediction_head is None and not pretrain ), "Model with pretrain==False must have prediction head (ie not None)" self.encoder = encoder self.projection_head = projection_head self.prediction_head = prediction_head self.pretrain = pretrain
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor: features = self.encoder(input) if self.pretrain: return self.projection_head(features) else: assert ( self.prediction_head is not None ), "Model with pretrain==False must have prediction_head (ie not None)" return self.prediction_head(features)
[docs] @staticmethod def load_pretrained_model(model_path: Path) -> FedSimClrModel: prev_model = torch.load(model_path) ssl_model = FedSimClrModel( encoder=prev_model.encoder, projection_head=prev_model.projection_head, prediction_head=prev_model.prediction_head, pretrain=False, ) return ssl_model