fl4health.model_bases.ensemble_base module

class EnsembleAggregationMode(value)[source]

Bases: Enum

An enumeration.

AVERAGE = 'AVERAGE'
VOTE = 'VOTE'
class EnsembleModel(ensemble_models, aggregation_mode=EnsembleAggregationMode.AVERAGE)[source]

Bases: Module

__init__(ensemble_models, aggregation_mode=EnsembleAggregationMode.AVERAGE)[source]

Class that acts a wrapper to an ensemble of models to be trained in federated manner with support for both voting and averaging prediction of individual models.

Parameters:
  • ensemble_models (dict[str, nn.Module]) – A dictionary of models that make up the ensemble.

  • aggregation_mode (EnsembleAggregationMode | None) – The mode in which to aggregate the predictions of individual models.

ensemble_average(preds_list)[source]

Produces the aggregated prediction of the ensemble via averaging.

Parameters:

preds_list (list[torch.Tensor]) – A list of predictions of the models in the ensemble.

Returns:

The average prediction of the ensemble.

Return type:

torch.Tensor

ensemble_vote(preds_list)[source]

Produces the aggregated prediction of the ensemble via voting. Expects predictions to be in a format where the 0 axis represents the sample index and the -1 axis represents the class dimension.

Parameters:

preds_list (list[torch.Tensor]) – A list of predictions of the models in the ensemble.

Returns:

The vote prediction of the ensemble.

Return type:

torch.Tensor

forward(input)[source]

Produce the predictions of the ensemble models given input data.

Parameters:

input (torch.Tensor) – A batch of input data.

Returns:

A dictionary of predictions of the individual ensemble models

as well as prediction of the ensemble as a whole.

Return type:

dict[str, torch.Tensor]