Source code for fl4health.clients.instance_level_dp_client

from collections.abc import Sequence
from pathlib import Path

import torch
from flwr.common.typing import Config
from opacus import PrivacyEngine

from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule
from fl4health.clients.basic_client import BasicClient
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import Metric
from fl4health.utils.privacy_utilities import privacy_validate_and_fix_modules


[docs] class InstanceLevelDpClient(BasicClient):
[docs] def __init__( self, data_path: Path, metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, client_name: str | None = None, ) -> None: """ Client for Instance/Record level Differentially Private Federated Averaging Args: data_path (Path): path to the data to be used to load the data for client-side training metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model device (torch.device): Device indicator for where to send the model, batches, labels etc. Often "cpu" or "cuda" loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to ``LossMeterType.AVERAGE``. checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses ``tqdm``. Defaults to False client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, progress_bar=progress_bar, client_name=client_name, ) self.clipping_bound: float self.noise_multiplier: float
[docs] def setup_client(self, config: Config) -> None: """ Performs the same flow as BasicClient to setup a client. This functionality straps on a processing of two configuration variables ``self.clipping_bound`` and ``self.noise_multiplier``. The last step is to do some processing of the model and optimizers with Opacus to make them DP compatible and to setup the privacy engine used for privacy accounting. This is done with the ``setup_opacus_objects`` function. Args: config (Config): Configurations sent by the server to allow for customization of this functions behavior. """ # Ensure that clipping bound and noise multiplier is present in config # Set attributes to be used when setting DP training self.clipping_bound = narrow_dict_type(config, "clipping_bound", float) self.noise_multiplier = narrow_dict_type(config, "noise_multiplier", float) # Do basic client setup super().setup_client(config) # Configure DP training self.setup_opacus_objects(config)
[docs] def setup_opacus_objects(self, config: Config) -> None: """ Validates and potentially fixes the PyTorch model of the client to be compatible with Opacus and privacy mechanisms, sets up the privacy engine of Opacus using the model, optimizer, dataloaders etc. for DP training Args: config (Config): Configurations sent by the server to allow for customization of this functions behavior. """ # Validate that the model layers are compatible with privacy mechanisms in Opacus and try to replace the layers # with compatible ones if necessary. self.model, reinitialize_optimizer = privacy_validate_and_fix_modules(self.model) # If we have fixed the model by changing out layers (and therefore parameters), we need to update the optimizer # parameters to coincide with this fixed model. **NOTE:** It is not done in make_private! if reinitialize_optimizer: self.set_optimizer(config) # Create DP training objects privacy_engine = PrivacyEngine() # NOTE: that Opacus make private is NOT idempotent self.model, optimizer, self.train_loader = privacy_engine.make_private( module=self.model, optimizer=self.optimizers["global"], data_loader=self.train_loader, noise_multiplier=self.noise_multiplier, max_grad_norm=self.clipping_bound, clipping="flat", ) self.optimizers = {"global": optimizer}