fl4health.model_bases.ensemble_base module¶
- class EnsembleAggregationMode(value)[source]¶
Bases:
Enum
An enumeration.
- AVERAGE = 'AVERAGE'¶
- VOTE = 'VOTE'¶
- class EnsembleModel(ensemble_models, aggregation_mode=EnsembleAggregationMode.AVERAGE)[source]¶
Bases:
Module
- __init__(ensemble_models, aggregation_mode=EnsembleAggregationMode.AVERAGE)[source]¶
Class that acts a wrapper to an ensemble of models to be trained in federated manner with support for both voting and averaging prediction of individual models.
- Parameters:
ensemble_models (dict[str, nn.Module]) – A dictionary of models that make up the ensemble.
aggregation_mode (EnsembleAggregationMode | None) – The mode in which to aggregate the predictions of individual models.
- ensemble_average(preds_list)[source]¶
Produces the aggregated prediction of the ensemble via averaging.
- Parameters:
preds_list (list[torch.Tensor]) – A list of predictions of the models in the ensemble.
- Returns:
The average prediction of the ensemble.
- Return type:
torch.Tensor
- ensemble_vote(preds_list)[source]¶
Produces the aggregated prediction of the ensemble via voting. Expects predictions to be in a format where the 0 axis represents the sample index and the -1 axis represents the class dimension.
- Parameters:
preds_list (list[torch.Tensor]) – A list of predictions of the models in the ensemble.
- Returns:
The vote prediction of the ensemble.
- Return type:
torch.Tensor