[docs]def__init__(self,data_path:Path,metrics:Sequence[Metric],device:torch.device,loss_meter_type:LossMeterType=LossMeterType.AVERAGE,checkpoint_and_state_module:ClientCheckpointAndStateModule|None=None,reporters:Sequence[BaseReporter]|None=None,progress_bar:bool=False,client_name:str|None=None,)->None:""" Client that clips updates being sent to the server where noise is added. Used to obtain Client Level Differential Privacy in FL setting. Args: data_path (Path): path to the data to be used to load the data for client-side training metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model device (torch.device): Device indicator for where to send the model, batches, labels etc. Often "cpu" or "cuda" loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to ``LossMeterType.AVERAGE``. checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses ``tqdm``. Defaults to False client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """super().__init__(data_path=data_path,metrics=metrics,device=device,loss_meter_type=loss_meter_type,checkpoint_and_state_module=checkpoint_and_state_module,reporters=reporters,progress_bar=progress_bar,client_name=client_name,)self.parameter_exchanger:FullParameterExchangerWithPacking[float]self.clipping_bound:float|None=Noneself.adaptive_clipping:bool|None=None
[docs]defcalculate_parameters_norm(self,parameters:NDArrays)->float:""" Given a set of parameters, compute the l2-norm of the parameters. This is a matrix norm: squared sum of all of the weights Args: parameters (NDArrays): Tensor to measure with the norm Returns: float: Squared sum of all values in the NDArrays """layer_inner_products=[pow(linalg.norm(layer_weights),2)forlayer_weightsinparameters]# network Frobenius normreturnpow(sum(layer_inner_products),0.5)
[docs]defclip_parameters(self,parameters:NDArrays)->tuple[NDArrays,float]:""" Performs "flat clipping" on the parameters according to .. math:: \\text{parameters} \\cdot \\min \\left(1, \\frac{C}{\\Vert \\text{parameters} \\Vert_2} \\right) Args: parameters (NDArrays): Parameters to clip Returns: tuple[NDArrays, float]: Clipped parameters and the associated clipping bit indicating whether the norm was below ``self.clipping_bound``. If ``self.adaptive_clipping`` is false, this bit is always 0.0 """assertself.clipping_boundisnotNoneassertself.adaptive_clippingisnotNone# performs flat clipping (i.e. parameters * min(1, C/||parameters||_2))network_frobenius_norm=self.calculate_parameters_norm(parameters)log(INFO,f"Update norm: {network_frobenius_norm}, Clipping Bound: {self.clipping_bound}")ifnetwork_frobenius_norm<=self.clipping_bound:# if we're not adaptively clipping then don't send true clipping bit info as this would potentially leak# informationclipping_bit=1.0ifself.adaptive_clippingelse0.0returnparameters,clipping_bitclip_scalar=min(1.0,self.clipping_bound/network_frobenius_norm)# parameters and clipping bitreturn[layer_weights*clip_scalarforlayer_weightsinparameters],0.0
[docs]defcompute_weight_update_and_clip(self,parameters:NDArrays)->tuple[NDArrays,float]:""" Compute the weight delta (i.e. new weights - old weights) and clip according to `self.clipping_bound` Args: parameters (NDArrays): Updated parameters to compute the delta from and clip thereafter Returns: tuple[NDArrays, float]: Clipped weighted updates (weight deltas) and the associated clipping bit """assertself.initial_weightsisnotNoneassertlen(parameters)==len(self.initial_weights)weight_update:NDArrays=[new_layer_weights-old_layer_weightsforold_layer_weights,new_layer_weightsinzip(self.initial_weights,parameters)]# return clipped parameters and clipping bitreturnself.clip_parameters(weight_update)
[docs]defget_parameters(self,config:Config)->NDArrays:""" This function performs clipping through ``compute_weight_update_and_clip`` and stores the clipping bit as the last entry in the NDArrays """ifnotself.initialized:log(INFO,"Setting up client and providing full model parameters to the server for initialization")# If initialized==False, the server is requesting model parameters from which to initialize all other# clients. As such get_parameters is being called before fit or evaluate, so we must call# setup_client first.self.setup_client(config)# Need all parameters even if normally exchanging partialreturnFullParameterExchanger().push_parameters(self.model,config=config)else:assertself.modelisnotNoneandself.parameter_exchangerisnotNonemodel_weights=self.parameter_exchanger.push_parameters(self.model,config=config)clipped_weight_update,clipping_bit=self.compute_weight_update_and_clip(model_weights)returnself.parameter_exchanger.pack_parameters(clipped_weight_update,clipping_bit)
[docs]defset_parameters(self,parameters:NDArrays,config:Config,fitting_round:bool)->None:""" This function assumes that the parameters being passed contain model parameters followed by the last entry of the list being the new clipping bound. They are unpacked for the clients to use in training. If it is called in the first fitting round, we assume the full model is being initialized and use the ``FullParameterExchanger()`` to set all model weights. Args: parameters (NDArrays): Parameters have information about model state to be added to the relevant client model and also the clipping bound. config (Config): The config is sent by the FL server to allow for customization in the function if desired. fitting_round (bool): Boolean that indicates whether the current federated learning round is a fitting round or an evaluation round. This is used to help determine which parameter exchange should be used for pulling parameters. A full parameter exchanger is used if the current federated learning round is the very first fitting round. """assertself.modelisnotNoneandself.parameter_exchangerisnotNone# The last entry in the parameters list is assumed to be a clipping bound (even if we're evaluating)server_model_parameters,clipping_bound=self.parameter_exchanger.unpack_parameters(parameters)self.clipping_bound=clipping_boundcurrent_server_round=narrow_dict_type(config,"current_server_round",int)ifcurrent_server_round==1andfitting_round:# Initialize all model weights as this is the first time things have been setself.initialize_all_model_weights(server_model_parameters,config)# Extract only the initial weights that we care about clipping and exchangingself.initial_weights=self.parameter_exchanger.push_parameters(self.model,config=config)else:# Store the starting parameters without clipping bound before client optimization stepsself.initial_weights=server_model_parameters# Inject the server model parameters into the client modelself.parameter_exchanger.pull_parameters(server_model_parameters,self.model,config)