Source code for fl4health.utils.privacy_utilities

from logging import INFO, WARNING
from typing import Any

import torch.nn as nn
from flwr.common.logger import log
from opacus import GradSampleModule
from opacus.grad_sample.utils import wrap_model
from opacus.validators import ModuleValidator


[docs] def privacy_validate_and_fix_modules(model: nn.Module) -> tuple[nn.Module, bool]: """ This function runs Opacus model validation to ensure that the provided models layers are compatible with the privacy mechanisms in Opacus. The function attempts to use Opacus to replace any incompatible layers if possible. For example BatchNormalization layers are not "DP Compliant" and will be replaced by compliant layers such as GroupNormalization with this function. Note that this uses the default "fix" functionality in Opacus. For more custom options, defining your own setup_opacus_objects function is required. Args: model (nn.Module): The model to be validated and potentially modified to be Opacus compliant. Returns: tuple[nn.Module, bool]: Returns a (possibly) modified pytorch model and a boolean indicating whether a reinitialization of any optimizers associated with the model will be required. Reinitialization of the optimizer parameters is required, for example, when the model layers are modified, yielding a mismatch in the optimizer parameters and the new model parameters. """ errors = ModuleValidator.validate(model, strict=False) reinitialize_optimizer = len(errors) > 0 # Due to a bug in Opacus, it's possible that we need to run multiple rounds fo module validator fix for # complex nested models to fully replace all layers within a model (for example, in the Fed-IXI model) while len(errors) != 0: for error in errors: opacus_warning = ( "Opacus has found layers within your model that do not comply with DP training. " "These layers will automatically be replaced with DP compliant layers. " "If you would like to perform this replacement yourself, please adjust your model manually." ) log(WARNING, f"{opacus_warning}") log(WARNING, f"Opacus error: {error}") model = ModuleValidator.fix(model) errors = ModuleValidator.validate(model, strict=False) # If we made changes to the underlying model, we may need to reinitialize an optimizer return model, reinitialize_optimizer
[docs] def convert_model_to_opacus_model( model: nn.Module, grad_sample_mode: str = "hooks", *args: Any, **kwargs: Any ) -> GradSampleModule: """ This function converts a standard pytorch model to an Opacus GradSampleModule, which Opacus uses to perform efficient DP-SGD operations. It uses the wrap_model functionality and mimics its defaults. Args: model (nn.Module): Pytorch model to be converted to an Opacus GradSampleModule grad_sample_mode (str, optional): This determines how Opacus performs the conversion under the hood. The standard mechanism is indicated by "hooks" but other approaches may be necessary depending on how the pytorch module is defined. Defaults to "hooks". Returns: GradSampleModule: The Opacus wrapped GradSampleModule """ if isinstance(model, GradSampleModule): log(INFO, f"Provided model is already of type {type(model)}, skipping conversion to Opacus model type") return model return wrap_model(model, grad_sample_mode=grad_sample_mode, *args, **kwargs)
[docs] def map_model_to_opacus_model( model: nn.Module, grad_sample_mode: str = "hooks", *args: Any, **kwargs: Any ) -> GradSampleModule: """ Performs an validation and modifications necessary to make the provided pytorch model "Opacus Compliant" via the call to privacy_validate_and_fix_modules. The resulting model is then converted to an Opacus GradSampleModule via convert_model_to_opacus_model. Args: model (nn.Module): Pytorch model to be converted to an Opacus compliant GradSampleModule grad_sample_mode (str, optional): This determines how Opacus performs the conversion under the hood. The standard mechanism is indicated by "hooks" but other approaches may be necessary depending on how the pytorch module is defined. Defaults to "hooks". Returns: GradSampleModule: The Opacus-compliant, wrapped GradSampleModule """ model, _ = privacy_validate_and_fix_modules(model) return convert_model_to_opacus_model(model, grad_sample_mode, *args, **kwargs)