fl4health.model_bases.parallel_split_models module

class ParallelFeatureJoinMode(value)[source]

Bases: Enum

An enumeration.

CONCATENATE = 'CONCATENATE'
SUM = 'SUM'
class ParallelSplitHeadModule(mode)[source]

Bases: Module, ABC

__init__(mode)[source]

This is a head module to be used as part of ParallelSplitModel type models. This module is responsible for merging inputs from two parallel feature extractors and acting on those inputs to produce a prediction

Parameters:

mode (ParallelFeatureJoinMode) – This determines HOW the head module is meant to combine the features produced by the extraction modules. Currently, there are two modes, concatenation or summation of the inputs before producing a prediction.

forward(first_tensor, second_tensor)[source]

Forward function for the head module of ParallelSplitModels. The inputs (first_tensor, second_tensor) are concatenated or added together depending on the mode specified in self.mode The concatenation procedure is defined by parallel_output_join. This concatenated or added together tensor is then passed through the forward function of the head module.

Parameters:
  • first_tensor (torch.Tensor) – Output from one parallel module

  • second_tensor (torch.Tensor) – Output from one parallel module

Returns:

Output from the head module.

Return type:

torch.Tensor

abstract head_forward(input_tensor)[source]

Forward function for the head module.

Parameters:

input_tensor (torch.Tensor) – Input tensor to be mapped

Raises:

NotImplementedError – Must be implemented by any child class

Returns:

Output of the head module from the given input.

Return type:

torch.Tensor

abstract parallel_output_join(local_tensor, global_tensor)[source]

Defines how the local and global feature tensors that are output by the preceding parallel feature extractors are meant to be joined together when the self.mode is set to ParallelFeatureJoinMode.CONCATENATE

Parameters:
  • local_tensor (torch.Tensor) – First tensor to be joined

  • global_tensor (torch.Tensor) – Second tensor to be joined

Raises:

NotImplementedError – Any implementing child class must produce this method if it is to be used.

Returns:

A single tensor with the two tensors joined together in some way

Return type:

torch.Tensor

class ParallelSplitModel(first_feature_extractor, second_feature_extractor, model_head)[source]

Bases: Module

__init__(first_feature_extractor, second_feature_extractor, model_head)[source]

This defines a model that has been split into two parallel feature extractors. The outputs of these feature extractors are merged together and mapped to a prediction by a ParallelSplitHeadModule. By default, no feature tensors are stored. Only a prediction tensor is produced.

Parameters:
  • first_feature_extractor (nn.Module) – First parallel feature extractor

  • second_feature_extractor (nn.Module) – Second parallel feature extractor

  • model_head (ParallelSplitHeadModule) – Module responsible for taking the outputs of the two feature extractors and using them to produce a prediction.

forward(input)[source]

Composite forward function. The input tensor is first passed through the two parallel feature extractors and then finally through the head model. The outputs and joining mechanism defined in the head model need to be compatible with the head model input itself. This is left to the user to handle

Parameters:

input (torch.Tensor) – Input tensor to be passed through the set of forwards.

Returns:

Prediction tensor from the head model. These predictions are stored under the “prediction” key of the dictionary. The second feature dictionary is empty by default.

Return type:

tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]