fl4health.utils.losses module¶
- class EvaluationLosses(checkpoint, additional_losses=None)[source]¶
Bases:
Losses
- __init__(checkpoint, additional_losses=None)[source]¶
A class to store the checkpoint and additional_losses of a model along with a method to return a dictionary representation.
- static aggregate(loss_meter)[source]¶
Aggregates the losses in the given LossMeter into an instance of EvaluationLosses
- Parameters:
loss_meter (LossMeter[EvaluationLosses]) – The loss meter object with the collected evaluation losses.
- Returns:
An instance of EvaluationLosses with the aggregated losses.
- Return type:
- class LossMeter(loss_meter_type, losses_type)[source]¶
Bases:
Generic
[LossesType
]- __init__(loss_meter_type, losses_type)[source]¶
A meter to store a list of losses.
- Parameters:
loss_meter_type (LossMeterType) – The type of this loss meter
losses_type (type[Losses]) – The type of the loss that will be stored. Should be one of the subclasses of Losses
- static aggregate_losses_dict(loss_list, loss_meter_type)[source]¶
Aggregates a list of losses dictionaries into a single dictionary according to the loss meter aggregation type
- Parameters:
loss_list (list[dict[str, torch.Tensor]]) – A list of loss dictionaries
loss_meter_type (LossMeterType) – The type of the loss meter to perform the aggregation
- Returns:
- A single dictionary with the aggregated losses according to the given loss
meter type
- Return type:
- class LossMeterType(value)[source]¶
Bases:
Enum
An enumeration.
- ACCUMULATION = 'ACCUMULATION'¶
- AVERAGE = 'AVERAGE'¶
- class Losses(additional_losses=None)[source]¶
Bases:
ABC
- abstract static aggregate(loss_meter)[source]¶
Aggregates the losses in the given LossMeter into an instance of Losses
- Parameters:
loss_meter (LossMeter) – The loss meter object with the collected losses.
- Raises:
NotImplementedError – To be implemented by child classes.
- Return type:
- class TrainingLosses(backward, additional_losses=None)[source]¶
Bases:
Losses
- __init__(backward, additional_losses=None)[source]¶
A class to store the backward and additional_losses of a model along with a method to return a dictionary representation.
- Parameters:
backward (torch.Tensor | dict[str, torch.Tensor]) – The backward loss or losses to optimize. In the normal case, backward is a Tensor corresponding to the loss of a model. In the case of an ensemble_model, backward is dictionary of losses.
additional_losses (dict[str, torch.Tensor] | None) – Optional dictionary of additional losses.
- static aggregate(loss_meter)[source]¶
Aggregates the losses in the given LossMeter into an instance of TrainingLosses
- Parameters:
loss_meter (LossMeter[TrainingLosses]) – The loss meter object with the collected training losses.
- Returns:
An instance of TrainingLosses with the aggregated losses.
- Return type: