[docs]def__init__(self,additional_losses:dict[str,torch.Tensor]|None=None)->None:""" An abstract class to store the losses Args: additional_losses (dict[str, torch.Tensor] | None): Optional dictionary of additional losses. """self.additional_losses=additional_lossesifadditional_losseselse{}
[docs]defas_dict(self)->dict[str,float]:""" Produces a dictionary representation of the object with all of the losses. Returns: dict[str, float]: A dictionary with the additional losses if they exist. """loss_dict:dict[str,float]={}ifself.additional_lossesisnotNone:forkey,valinself.additional_losses.items():loss_dict[key]=float(val.item())returnloss_dict
[docs]@staticmethod@abstractmethoddefaggregate(loss_meter:LossMeter)->Losses:""" Aggregates the losses in the given ``LossMeter`` into an instance of Losses Args: loss_meter (LossMeter): The loss meter object with the collected losses. Raises: NotImplementedError: To be implemented by child classes. """raiseNotImplementedError
[docs]def__init__(self,checkpoint:torch.Tensor,additional_losses:dict[str,torch.Tensor]|None=None)->None:""" A class to store the checkpoint and ``additional_losses`` of a model along with a method to return a dictionary representation. Args: checkpoint (torch.Tensor): The loss used to checkpoint model (if checkpointing is enabled). additional_losses (dict[str, torch.Tensor] | None): Optional dictionary of additional losses. """super().__init__(additional_losses)self.checkpoint=checkpoint
[docs]defas_dict(self)->dict[str,float]:""" Produces a dictionary representation of the object with all of the losses. Returns: dict[str, float]: A dictionary with the checkpoint loss, plus each one of the keys in additional losses if they exist. """loss_dict=super().as_dict()loss_dict["checkpoint"]=float(self.checkpoint.item())returnloss_dict
[docs]@staticmethoddefaggregate(loss_meter:LossMeter[EvaluationLosses])->EvaluationLosses:""" Aggregates the losses in the given ``LossMeter`` into An instance of ``EvaluationLosses`` Args: loss_meter (LossMeter[EvaluationLosses]): The loss meter object with the collected evaluation losses. Returns: EvaluationLosses: An instance of ``EvaluationLosses`` with the aggregated losses. """checkpoint_loss=torch.sum(torch.FloatTensor([losses.checkpointforlossesinloss_meter.losses_list])# type: ignore)ifloss_meter.loss_meter_type==LossMeterType.AVERAGE:checkpoint_loss/=len(loss_meter.losses_list)additional_losses_list=[losses.additional_lossesforlossesinloss_meter.losses_list]additional_losses_dict=LossMeter.aggregate_losses_dict(additional_losses_list,loss_meter.loss_meter_type)returnEvaluationLosses(checkpoint=checkpoint_loss,additional_losses=additional_losses_dict)
[docs]def__init__(self,backward:torch.Tensor|dict[str,torch.Tensor],additional_losses:dict[str,torch.Tensor]|None=None,)->None:""" A class to store the backward and ``additional_losses`` of a model along with a method to return a dictionary representation. Args: backward (torch.Tensor | dict[str, torch.Tensor]): The backward loss or losses to optimize. In the normal case, backward is a Tensor corresponding to the loss of a model. In the case of an ``ensemble_model``, backward is dictionary of losses. additional_losses (dict[str, torch.Tensor] | None): Optional dictionary of additional losses. """super().__init__(additional_losses)self.backward=backwardifisinstance(backward,dict)else{"backward":backward}
[docs]defas_dict(self)->dict[str,float]:""" Produces a dictionary representation of the object with all of the losses. Returns: dict[str, float]: A dictionary where each key represents one of the backward losses, plus additional losses if they exist. """loss_dict=super().as_dict()backward={key:float(loss.item())forkey,lossinself.backward.items()}loss_dict.update(backward)returnloss_dict
[docs]@staticmethoddefaggregate(loss_meter:LossMeter[TrainingLosses])->TrainingLosses:""" Aggregates the losses in the given ``LossMeter`` into An instance of ``TrainingLosses`` Args: loss_meter (LossMeter[TrainingLosses]): The loss meter object with the collected training losses. Returns: TrainingLosses: An instance of ``TrainingLosses`` with the aggregated losses. """additional_losses_list=[losses.additional_lossesforlossesinloss_meter.losses_list]additional_losses_dict=LossMeter.aggregate_losses_dict(additional_losses_list,loss_meter.loss_meter_type)backward_losses_list=[losses.backwardforlossesinloss_meter.losses_list]# type: ignoreiflen(backward_losses_list)>0andisinstance(backward_losses_list[0],dict):# if backward losses is a dictionary, aggregate the dictionary keysbackward_losses_dict=LossMeter.aggregate_losses_dict(backward_losses_list,loss_meter.loss_meter_type)returnTrainingLosses(backward=backward_losses_dict,additional_losses=additional_losses_dict)# otherwise, calculate the average tensorbackward_losses=torch.sum(torch.FloatTensor(backward_losses_list))ifloss_meter.loss_meter_type==LossMeterType.AVERAGE:backward_losses/=len(loss_meter.losses_list)returnTrainingLosses(backward=backward_losses,additional_losses=additional_losses_dict)
[docs]def__init__(self,loss_meter_type:LossMeterType,losses_type:type[LossesType])->None:""" A meter to store a list of losses. Args: loss_meter_type (LossMeterType): The type of this loss meter losses_type (type[Losses]): The type of the loss that will be stored. Should be one of the subclasses of Losses """self.losses_list:list[LossesType]=[]self.loss_meter_type=loss_meter_typeself.losses_type=losses_type
[docs]defupdate(self,losses:LossesType)->None:""" Appends loss to list of losses. Args: losses (LossesType): A losses object with checkpoint, backward and additional losses. """self.losses_list.append(losses)
[docs]defclear(self)->None:""" Resets the meter by re-initializing ``losses_list`` to be empty """self.losses_list=[]
[docs]defcompute(self)->LossesType:""" Computes the aggregation of current list of losses if non-empty. Returns: LossesType: New Losses object with the aggregation of losses in ``losses_list``. """assertlen(self.losses_list)>0returnself.losses_type.aggregate(self)# type: ignore
[docs]@staticmethoddefaggregate_losses_dict(loss_list:list[dict[str,torch.Tensor]],loss_meter_type:LossMeterType,)->dict[str,torch.Tensor]:""" Aggregates a list of losses dictionaries into a single dictionary according to the loss meter aggregation type Args: loss_list (list[dict[str, torch.Tensor]]): A list of loss dictionaries loss_meter_type (LossMeterType): The type of the loss meter to perform the aggregation Returns: dict[str, torch.Tensor]: A single dictionary with the aggregated losses according to the given loss meter type """# We don't know the keys of the dict (backward or additional losses) beforehand. We don't obtain them# from the first entry because losses can have different keys. We get list of all the keys from# all the losses.loss_keys=set(keyforloss_dict_inloss_listforkeyinloss_dict_.keys())loss_dict:dict[str,torch.Tensor]={}forkeyinloss_keys:ifloss_meter_type==LossMeterType.AVERAGE:loss=torch.mean(torch.FloatTensor([loss[key]forlossinloss_listifkeyinloss]))else:loss=torch.sum(torch.FloatTensor([loss[key]forlossinloss_listifkeyinloss]))loss_dict[key]=lossreturnloss_dict