from abc import ABC, abstractmethod
from enum import Enum
import torch
import torch.nn as nn
[docs]
class ParallelFeatureJoinMode(Enum):
CONCATENATE = "CONCATENATE"
SUM = "SUM"
[docs]
class ParallelSplitHeadModule(nn.Module, ABC):
[docs]
def __init__(self, mode: ParallelFeatureJoinMode) -> None:
"""
This is a head module to be used as part of ``ParallelSplitModel`` type models. This module is responsible for
merging inputs from two parallel feature extractors and acting on those inputs to produce a prediction
Args:
mode (ParallelFeatureJoinMode): This determines **HOW** the head module is meant to combine the features
produced by the extraction modules. Currently, there are two modes, concatenation or summation of the
inputs before producing a prediction.
"""
super().__init__()
self.mode = mode
[docs]
@abstractmethod
def parallel_output_join(self, local_tensor: torch.Tensor, global_tensor: torch.Tensor) -> torch.Tensor:
"""
Defines how the local and global feature tensors that are output by the preceding parallel feature extractors
are meant to be joined together when the ``self.mode`` is set to ``ParallelFeatureJoinMode.CONCATENATE``
Args:
local_tensor (torch.Tensor): First tensor to be joined
global_tensor (torch.Tensor): Second tensor to be joined
Raises:
NotImplementedError: Any implementing child class must produce this method if it is to be used.
Returns:
torch.Tensor: A single tensor with the two tensors joined together in some way
"""
raise NotImplementedError
[docs]
@abstractmethod
def head_forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""
Forward function for the head module.
Args:
input_tensor (torch.Tensor): Input tensor to be mapped
Raises:
NotImplementedError: Must be implemented by any child class
Returns:
torch.Tensor: Output of the head module from the given input.
"""
raise NotImplementedError
[docs]
def forward(self, first_tensor: torch.Tensor, second_tensor: torch.Tensor) -> torch.Tensor:
"""
Forward function for the head module of ``ParallelSplitModels``. The inputs (``first_tensor``,
``second_tensor``) are concatenated or added together depending on the mode specified in ``self.mode``
The concatenation procedure is defined by ``parallel_output_join``. This concatenated or added together
tensor is then passed through the forward function of the head module.
Args:
first_tensor (torch.Tensor): Output from one parallel module
second_tensor (torch.Tensor): Output from one parallel module
Returns:
torch.Tensor: Output from the head module.
"""
head_input = (
self.parallel_output_join(first_tensor, second_tensor)
if self.mode == ParallelFeatureJoinMode.CONCATENATE
else torch.add(first_tensor, second_tensor)
)
return self.head_forward(head_input)
[docs]
class ParallelSplitModel(nn.Module):
[docs]
def __init__(
self,
first_feature_extractor: nn.Module,
second_feature_extractor: nn.Module,
model_head: ParallelSplitHeadModule,
) -> None:
"""
This defines a model that has been split into two parallel feature extractors. The outputs of these feature
extractors are merged together and mapped to a prediction by a ``ParallelSplitHeadModule``. By default, no
feature tensors are stored. Only a prediction tensor is produced.
Args:
first_feature_extractor (nn.Module): First parallel feature extractor
second_feature_extractor (nn.Module): Second parallel feature extractor
model_head (ParallelSplitHeadModule): Module responsible for taking the outputs of the two feature
extractors and using them to produce a prediction.
"""
super().__init__()
self.first_feature_extractor = first_feature_extractor
self.second_feature_extractor = second_feature_extractor
self.model_head = model_head
[docs]
def forward(self, input: torch.Tensor) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
"""
Composite forward function. The input tensor is first passed through the two parallel feature extractors and
then finally through the head model. The outputs and joining mechanism defined in the head model need to
be compatible with the head model input itself. This is left to the user to handle
Args:
input (torch.Tensor): Input tensor to be passed through the set of forwards.
Returns:
tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: Prediction tensor from the head model. These
predictions are stored under the "prediction" key of the dictionary. The second feature dictionary is
empty by default.
"""
first_output = self.first_feature_extractor.forward(input)
second_output = self.second_feature_extractor.forward(input)
preds = {"prediction": self.model_head.forward(first_output, second_output)}
# No features are returned in the vanilla ParallelSplitModel implementation
return preds, {}