[docs]deffold_loss_dict_into_metrics(metrics:dict[str,Scalar],loss_dict:dict[str,float],logging_mode:LoggingMode)->None:# Prefixing the loss value keys with the mode from which they are generatediflogging_modeisLoggingMode.VALIDATION:metrics.update({f"{MetricPrefix.VAL_PREFIX.value}{key}":loss_valforkey,loss_valinloss_dict.items()})else:metrics.update({f"{MetricPrefix.TEST_PREFIX.value}{key}":loss_valforkey,loss_valinloss_dict.items()})
[docs]defset_pack_losses_with_val_metrics(config:Config)->bool:try:pack_losses_with_val_metrics=narrow_dict_type(config,"pack_losses_with_val_metrics",bool)exceptValueError:pack_losses_with_val_metrics=Falseifpack_losses_with_val_metrics:log(INFO,"As specified in the config, all validation losses will be packed into validation metrics")returnpack_losses_with_val_metrics
[docs]defmove_data_to_device(data:T,device:torch.device)->T:""" Moves data to the target device. Args: data (T): The data to move to self.device. Can be a ``TorchInputType`` or a ``TorchTargetType`` device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or 'cuda' Raises: TypeError: Raised if data is not one of the types specified by ``TorchInputType`` or ``TorchTargetType`` Returns: T: The data argument except now it's been moved to ``self.device`` """# Currently we expect both inputs and targets to be either tensors# or dictionaries of tensorsifisinstance(data,torch.Tensor):returndata.to(device)elifisinstance(data,dict):return{key:value.to(device)forkey,valueindata.items()}else:raiseTypeError("data must be of type torch.Tensor or dict[str, torch.Tensor]. If definition of TorchInputType or ""TorchTargetType has changed this method might need to be updated or split into two.")
[docs]defcheck_if_batch_is_empty_and_verify_input(input:TorchInputType)->bool:""" This function checks whether the provided batch (input) is empty. If the input is a dictionary of inputs, it first verifies that the length of all inputs is the same, then checks if they are non-empty. **NOTE:** This function assumes the input is **BATCH FIRST** Args: input (TorchInputType): Input batch. input can be of type ``torch.Tensor`` or ``dict[str, torch.Tensor]``, and in the latter case, the batch is considered to be empty if all tensors in the dictionary have length zero. Raises: TypeError: Raised if input is not of type ``torch.Tensor`` or ``dict[str, torch.Tensor]``. ValueError: Raised if input has type ``dict[str, torch.Tensor]`` and not all tensors within the dictionary have the same size. Returns: bool: True if input is an empty batch. """ifisinstance(input,torch.Tensor):returnlen(input)==0elifisinstance(input,dict):input_iter=iter(input.items())_,first_val=next(input_iter)first_val_len=len(first_val)ifnotall(len(val)==first_val_lenfor_,valininput_iter):raiseValueError("Not all tensors in the dictionary have the same size.")else:returnfirst_val_len==0else:raiseTypeError("Input must be of type torch.Tensor or dict[str, torch.Tensor].")
[docs]defclone_and_freeze_model(model:nn.Module)->nn.Module:""" Creates a clone of the model with frozen weights to be used in loss calculations so the original model is preserved in its current state. Args: model (nn.Module): Model to clone and freeze Returns: nn.Module: Cloned and frozen model """cloned_model=copy.deepcopy(model)forparamincloned_model.parameters():param.requires_grad=Falsecloned_model.eval()returncloned_model
[docs]defmaybe_progress_bar(iterable:Iterable,display_progress_bar:bool)->Iterable:""" Used to print progress bars during client training and validation. If ``self.progress_bar`` is false, just returns the original input iterable without modifying it. Args: iterable (Iterable): The iterable to wrap Returns: Iterable: an iterator which acts exactly like the original iterable, but prints a dynamically updating progress bar every time a value is requested. Or the original iterable if ``self.progress_bar`` is False """ifnotdisplay_progress_bar:returniterableelse:# We can use the flwr console handler to format progress barframe=currentframe()lineno=0ifframeisNoneelsegetframeinfo(frame).linenorecord=LogRecord(name=LOGGER_NAME,pathname=os.path.abspath(os.getcwd()),lineno=lineno,#args={},exc_info=None,level=INFO,msg="{l_bar}{bar}{r_bar}",)format=console_handler.format(record)# Create a clean looking tqdm instance that matches the flwr loggingkwargs:Any={"leave":True,"ascii":" >=","unit":"steps","dynamic_ncols":True,"bar_format":format,}returntqdm(iterable,**kwargs)
[docs]defprocess_and_check_validation_steps(config:Config,val_loader:DataLoader)->int|None:if"num_validation_steps"inconfig:log(INFO,"num_validation_steps specified in config. Only a subset of batches will be processed from the validation ""set during evaluation. If num_validation_steps is greater than the number of batches in the validation ""dataloader, datapoints may be evaluated twice",)num_validation_steps=narrow_dict_type(config,"num_validation_steps",int)assertnum_validation_steps>0,"num_validation_steps must not be 0"val_dataloader_len=len(val_loader)assertval_dataloader_len>0,"Dataloader must have length greater than 0."ifnum_validation_steps>val_dataloader_len:log(WARNING,f"num_validation_steps: {num_validation_steps} is larger than the length of the "f"validation dataloader: {val_dataloader_len}",)returnnum_validation_stepselse:returnNone