fl4health.model_bases.parallel_split_models module¶
- class ParallelFeatureJoinMode(value)[source]¶
Bases:
Enum
An enumeration.
- CONCATENATE = 'CONCATENATE'¶
- SUM = 'SUM'¶
- class ParallelSplitHeadModule(mode)[source]¶
Bases:
Module
,ABC
- __init__(mode)[source]¶
This is a head module to be used as part of ParallelSplitModel type models. This module is responsible for merging inputs from two parallel feature extractors and acting on those inputs to produce a prediction
- Parameters:
mode (ParallelFeatureJoinMode) – This determines HOW the head module is meant to combine the features produced by the extraction modules. Currently, there are two modes, concatenation or summation of the inputs before producing a prediction.
- forward(first_tensor, second_tensor)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class ParallelSplitModel(first_feature_extractor, second_feature_extractor, model_head)[source]¶
Bases:
Module
- __init__(first_feature_extractor, second_feature_extractor, model_head)[source]¶
This defines a model that has been split into two parallel feature extractors. The outputs of these feature extractors are merged together and mapped to a prediction by a ParallelSplitHeadModule. By default, no feature tensors are stored. Only a prediction tensor is produced.
- Parameters:
first_feature_extractor (nn.Module) – First parallel feature extractor
second_feature_extractor (nn.Module) – Second parallel feature extractor
model_head (ParallelSplitHeadModule) – Module responsible for taking the outputs of the two feature extractors and using them to produce a prediction.
- forward(input)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
tuple
[dict
[str
,Tensor
],dict
[str
,Tensor
]]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.