Source code for fl4health.model_bases.fedsimclr_base

from __future__ import annotations

from pathlib import Path

import torch
import torch.nn as nn


[docs] class FedSimClrModel(nn.Module):
[docs] def __init__( self, encoder: nn.Module, projection_head: nn.Module = nn.Identity(), prediction_head: nn.Module | None = None, pretrain: bool = True, ) -> None: """ Model base to train SimCLR (https://arxiv.org/pdf/2002.05709) in a federated manner presented in (https://arxiv.org/pdf/2207.09158). Can be used in pretraining and optionally finetuning. Args: encoder (nn.Module): Encoder that extracts a feature vector. given an input sample. projection_head (nn.Module): Projection Head that maps output of encoder to final representation used in contrastive loss for pretraining stage. Defaults to identity transformation. prediction_head (nn.Module | None): Prediction head that maps output of encoder to prediction in the finetuning stage. Defaults to None. pretrain (bool): Determines whether or not to use the ``projection_head`` (True) or the ``prediction_head`` (False). Defaults to True. """ super().__init__() assert not ( prediction_head is None and not pretrain ), "Model with pretrain==False must have prediction head (ie not None)" self.encoder = encoder self.projection_head = projection_head self.prediction_head = prediction_head self.pretrain = pretrain
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor: """ Passes the input tensor through the encoder module. If we're in the pretraining phase, the output of the encoder is flattened/projected for similarity computations. If we're fine-tuning the model, these latent features are passes through the provided prediction head. Args: input (torch.Tensor): Input to be mapped to either latent features or a final prediction depending on the training phase. Returns: torch.Tensor: The output from either the ``projection_head`` module if pre-training or the ``prediction_head`` if fine-tuning. """ features = self.encoder(input) if self.pretrain: return self.projection_head(features) else: assert ( self.prediction_head is not None ), "Model with pretrain==False must have prediction_head (ie not None)" return self.prediction_head(features)
[docs] @staticmethod def load_pretrained_model(model_path: Path) -> FedSimClrModel: """ Given a path, this function loads a model from the path, assuming was of type ``FedSimClrModel``. The proper components are then routed to form a new model with the pre-existing weights. **NOTE**: Loaded models automatically set ``pretrain`` to False Args: model_path (Path): Path to a ``FedSimClrModel`` object saved using ``torch.save`` Returns: FedSimClrModel: A model with pre-existing weights loaded and ``pretrain`` set to False """ prev_model = torch.load(model_path) ssl_model = FedSimClrModel( encoder=prev_model.encoder, projection_head=prev_model.projection_head, prediction_head=prev_model.prediction_head, pretrain=False, ) return ssl_model