from abc import ABC, abstractmethod
from enum import Enum
import torch
import torch.nn as nn
[docs]
class ParallelFeatureJoinMode(Enum):
CONCATENATE = "CONCATENATE"
SUM = "SUM"
[docs]
class ParallelSplitHeadModule(nn.Module, ABC):
[docs]
def __init__(self, mode: ParallelFeatureJoinMode) -> None:
"""
This is a head module to be used as part of ParallelSplitModel type models. This module is responsible for
merging inputs from two parallel feature extractors and acting on those inputs to produce a prediction
Args:
mode (ParallelFeatureJoinMode): This determines HOW the head module is meant to combine the features
produced by the extraction modules. Currently, there are two modes, concatenation or summation of the
inputs before producing a prediction.
"""
super().__init__()
self.mode = mode
[docs]
@abstractmethod
def parallel_output_join(self, local_tensor: torch.Tensor, global_tensor: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
[docs]
@abstractmethod
def head_forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
[docs]
def forward(self, first_tensor: torch.Tensor, second_tensor: torch.Tensor) -> torch.Tensor:
head_input = (
self.parallel_output_join(first_tensor, second_tensor)
if self.mode == ParallelFeatureJoinMode.CONCATENATE
else torch.add(first_tensor, second_tensor)
)
return self.head_forward(head_input)
[docs]
class ParallelSplitModel(nn.Module):
[docs]
def __init__(
self,
first_feature_extractor: nn.Module,
second_feature_extractor: nn.Module,
model_head: ParallelSplitHeadModule,
) -> None:
"""
This defines a model that has been split into two parallel feature extractors. The outputs of these feature
extractors are merged together and mapped to a prediction by a ParallelSplitHeadModule. By default, no
feature tensors are stored. Only a prediction tensor is produced.
Args:
first_feature_extractor (nn.Module): First parallel feature extractor
second_feature_extractor (nn.Module): Second parallel feature extractor
model_head (ParallelSplitHeadModule): Module responsible for taking the outputs of the two feature
extractors and using them to produce a prediction.
"""
super().__init__()
self.first_feature_extractor = first_feature_extractor
self.second_feature_extractor = second_feature_extractor
self.model_head = model_head
[docs]
def forward(self, input: torch.Tensor) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
first_output = self.first_feature_extractor.forward(input)
second_output = self.second_feature_extractor.forward(input)
preds = {"prediction": self.model_head.forward(first_output, second_output)}
# No features are returned in the vanilla ParallelSplitModel implementation
return preds, {}