fl4health.model_bases.parallel_split_models module

class ParallelFeatureJoinMode(value)[source]

Bases: Enum

An enumeration.

CONCATENATE = 'CONCATENATE'
SUM = 'SUM'
class ParallelSplitHeadModule(mode)[source]

Bases: Module, ABC

__init__(mode)[source]

This is a head module to be used as part of ParallelSplitModel type models. This module is responsible for merging inputs from two parallel feature extractors and acting on those inputs to produce a prediction

Parameters:

mode (ParallelFeatureJoinMode) – This determines HOW the head module is meant to combine the features produced by the extraction modules. Currently, there are two modes, concatenation or summation of the inputs before producing a prediction.

forward(first_tensor, second_tensor)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

abstract head_forward(input_tensor)[source]
Return type:

Tensor

abstract parallel_output_join(local_tensor, global_tensor)[source]
Return type:

Tensor

class ParallelSplitModel(first_feature_extractor, second_feature_extractor, model_head)[source]

Bases: Module

__init__(first_feature_extractor, second_feature_extractor, model_head)[source]

This defines a model that has been split into two parallel feature extractors. The outputs of these feature extractors are merged together and mapped to a prediction by a ParallelSplitHeadModule. By default, no feature tensors are stored. Only a prediction tensor is produced.

Parameters:
  • first_feature_extractor (nn.Module) – First parallel feature extractor

  • second_feature_extractor (nn.Module) – Second parallel feature extractor

  • model_head (ParallelSplitHeadModule) – Module responsible for taking the outputs of the two feature extractors and using them to produce a prediction.

forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: tuple[dict[str, Tensor], dict[str, Tensor]]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.