fl4health.model_bases.parallel_split_models module¶
- class ParallelFeatureJoinMode(value)[source]¶
Bases:
Enum
An enumeration.
- CONCATENATE = 'CONCATENATE'¶
- SUM = 'SUM'¶
- class ParallelSplitHeadModule(mode)[source]¶
Bases:
Module
,ABC
- __init__(mode)[source]¶
This is a head module to be used as part of
ParallelSplitModel
type models. This module is responsible for merging inputs from two parallel feature extractors and acting on those inputs to produce a prediction- Parameters:
mode (ParallelFeatureJoinMode) – This determines HOW the head module is meant to combine the features produced by the extraction modules. Currently, there are two modes, concatenation or summation of the inputs before producing a prediction.
- forward(first_tensor, second_tensor)[source]¶
Forward function for the head module of
ParallelSplitModels
. The inputs (first_tensor
,second_tensor
) are concatenated or added together depending on the mode specified inself.mode
The concatenation procedure is defined byparallel_output_join
. This concatenated or added together tensor is then passed through the forward function of the head module.- Parameters:
first_tensor (torch.Tensor) – Output from one parallel module
second_tensor (torch.Tensor) – Output from one parallel module
- Returns:
Output from the head module.
- Return type:
torch.Tensor
- abstract head_forward(input_tensor)[source]¶
Forward function for the head module.
- Parameters:
input_tensor (torch.Tensor) – Input tensor to be mapped
- Raises:
NotImplementedError – Must be implemented by any child class
- Returns:
Output of the head module from the given input.
- Return type:
torch.Tensor
- abstract parallel_output_join(local_tensor, global_tensor)[source]¶
Defines how the local and global feature tensors that are output by the preceding parallel feature extractors are meant to be joined together when the
self.mode
is set toParallelFeatureJoinMode.CONCATENATE
- Parameters:
local_tensor (torch.Tensor) – First tensor to be joined
global_tensor (torch.Tensor) – Second tensor to be joined
- Raises:
NotImplementedError – Any implementing child class must produce this method if it is to be used.
- Returns:
A single tensor with the two tensors joined together in some way
- Return type:
torch.Tensor
- class ParallelSplitModel(first_feature_extractor, second_feature_extractor, model_head)[source]¶
Bases:
Module
- __init__(first_feature_extractor, second_feature_extractor, model_head)[source]¶
This defines a model that has been split into two parallel feature extractors. The outputs of these feature extractors are merged together and mapped to a prediction by a
ParallelSplitHeadModule
. By default, no feature tensors are stored. Only a prediction tensor is produced.- Parameters:
first_feature_extractor (nn.Module) – First parallel feature extractor
second_feature_extractor (nn.Module) – Second parallel feature extractor
model_head (ParallelSplitHeadModule) – Module responsible for taking the outputs of the two feature extractors and using them to produce a prediction.
- forward(input)[source]¶
Composite forward function. The input tensor is first passed through the two parallel feature extractors and then finally through the head model. The outputs and joining mechanism defined in the head model need to be compatible with the head model input itself. This is left to the user to handle
- Parameters:
input (torch.Tensor) – Input tensor to be passed through the set of forwards.
- Returns:
Prediction tensor from the head model. These predictions are stored under the “prediction” key of the dictionary. The second feature dictionary is empty by default.
- Return type: