from abc import ABC, abstractmethod
from math import ceil
from fl4health.privacy.moments_accountant import (
FixedSamplingWithoutReplacement,
MomentsAccountant,
PoissonSampling,
SamplingStrategy,
)
[docs]
class FlInstanceLevelAccountant:
"""
This accountant should be used when applying FL and measuring instance-level privacy
NOTE: This class assumes that all sampling is done via Poisson sampling (client and data point level).
Further it assumes that the sampling ratio of clients and noise multiplier are fixed throughout training
"""
[docs]
def __init__(
self,
client_sampling_rate: float,
noise_multiplier: float,
epochs_per_round: int,
client_batch_sizes: list[int],
client_dataset_sizes: list[int],
moment_orders: list[float] | None = None,
) -> None:
"""
client_sampling_rate: probability that each client will be included in a round
noise_multiplier: multiplier of noise std. dev. on clipping bound
epochs_per_round: number of epochs each client will complete per server round
client_batch_sizes: batch size per client, if a single value it is assumed to be constant across clients
client_dataset_sizes: size of full dataset on a client, if a single value it is assumed to be constant
across clients.
"""
self.noise_multiplier = noise_multiplier
self.epochs_per_round = epochs_per_round
assert len(client_batch_sizes) == len(client_dataset_sizes)
self.num_batches_per_client = self._calculate_num_batches(client_batch_sizes, client_dataset_sizes)
client_batch_ratios = self._calculate_batch_ratios(client_batch_sizes, client_dataset_sizes)
self.sampling_strategies_per_client = [
PoissonSampling(client_sampling_rate * client_batch_ratio) for client_batch_ratio in client_batch_ratios
]
self.accountant = MomentsAccountant(moment_orders)
def _calculate_batch_ratios(self, client_batch_sizes: list[int], client_dataset_sizes: list[int]) -> list[float]:
return [batch / dataset for batch, dataset in zip(client_batch_sizes, client_dataset_sizes)]
def _calculate_num_batches(self, client_batch_sizes: list[int], client_dataset_sizes: list[int]) -> list[int]:
return [ceil(dataset / batch) for batch, dataset in zip(client_batch_sizes, client_dataset_sizes)]
[docs]
def get_epsilon(self, server_updates: int, delta: float) -> float:
"""server_updates: number of central server updates performed"""
epsilons = []
for num_batch, sampling_strategy in zip(self.num_batches_per_client, self.sampling_strategies_per_client):
# Round up because privacy loss is monotonic wrt total_updates
total_updates = ceil(server_updates * self.epochs_per_round * num_batch)
epsilon = self.accountant.get_epsilon(sampling_strategy, self.noise_multiplier, total_updates, delta)
epsilons.append(epsilon)
return max(epsilons)
[docs]
def get_delta(self, server_updates: int, epsilon: float) -> float:
"""server_updates: number of central server updates performed"""
deltas = []
for num_batch, sampling_strategy in zip(self.num_batches_per_client, self.sampling_strategies_per_client):
# Round up because privacy loss is monotonic wrt total_updates
total_updates = ceil(server_updates * self.epochs_per_round * num_batch)
delta = self.accountant.get_delta(sampling_strategy, self.noise_multiplier, total_updates, epsilon)
deltas.append(delta)
return max(deltas)
[docs]
class ClientLevelAccountant(ABC):
def __init__(self, noise_multiplier: float | list[float], moment_orders: list[float] | None = None) -> None:
self.noise_multiplier = noise_multiplier
self.accountant = MomentsAccountant(moment_orders)
[docs]
@abstractmethod
def get_epsilon(self, server_updates: int | list[int], delta: float) -> float:
pass
[docs]
@abstractmethod
def get_delta(self, server_updates: int | list[int], epsilon: float) -> float:
pass
def _validate_server_updates(self, server_updates: int | list[int]) -> None:
if isinstance(server_updates, list):
assert isinstance(self.noise_multiplier, list)
assert len(server_updates) == len(self.noise_multiplier)
else:
assert isinstance(self.noise_multiplier, float)
[docs]
class FlClientLevelAccountantPoissonSampling(ClientLevelAccountant):
"""
This accountant should be used when applying FL with Poisson client sampling and measuring client-level privacy
"""
[docs]
def __init__(
self,
client_sampling_rate: float | list[float],
noise_multiplier: float | list[float],
moment_orders: list[float] | None = None,
) -> None:
"""
client_sampling_rate: probability that each client will be included in a round
noise_multiplier: multiplier of noise std. dev. on clipping bound
NOTE: The above values can be lists, where they are treated as sequences of training with the respective
parameters
"""
super().__init__(noise_multiplier, moment_orders)
self.sampling_strategy: SamplingStrategy | list[PoissonSampling]
if isinstance(client_sampling_rate, list):
self.sampling_strategy = [PoissonSampling(q) for q in client_sampling_rate]
else:
self.sampling_strategy = PoissonSampling(client_sampling_rate)
[docs]
def get_epsilon(self, server_updates: int | list[int], delta: float) -> float:
"""server_updates: number of central server updates performed"""
self._validate_server_updates(server_updates)
return self.accountant.get_epsilon(self.sampling_strategy, self.noise_multiplier, server_updates, delta)
[docs]
def get_delta(self, server_updates: int | list[int], epsilon: float) -> float:
"""server_updates: number of central server updates performed"""
self._validate_server_updates(server_updates)
return self.accountant.get_delta(self.sampling_strategy, self.noise_multiplier, server_updates, epsilon)
[docs]
class FlClientLevelAccountantFixedSamplingNoReplacement(ClientLevelAccountant):
"""
This accountant should be used when applying FL with Fixed Sampling
with No Replacement and measuring client-level privacy
"""
[docs]
def __init__(
self,
n_total_clients: int,
n_clients_sampled: int | list[int],
noise_multiplier: float | list[float],
moment_orders: list[float] | None = None,
) -> None:
"""
n_total_clients: total number of clients to be sampled from
n_clients_sampled: number of clients sampled in a given round
noise_multiplier: multiplier of noise std. dev. on clipping bound
NOTE: The above values can be lists, where they are treated as sequences of training with the respective
parameters
"""
super().__init__(noise_multiplier, moment_orders)
self.sampling_strategy: SamplingStrategy | list[FixedSamplingWithoutReplacement]
if isinstance(n_clients_sampled, list):
self.sampling_strategy = [
FixedSamplingWithoutReplacement(n_total_clients, n_clients) for n_clients in n_clients_sampled
]
else:
self.sampling_strategy = FixedSamplingWithoutReplacement(n_total_clients, n_clients_sampled)
[docs]
def get_epsilon(self, server_updates: int | list[int], delta: float) -> float:
"""server_updates: number of central server updates performed"""
self._validate_server_updates(server_updates)
return self.accountant.get_epsilon(self.sampling_strategy, self.noise_multiplier, server_updates, delta)
[docs]
def get_delta(self, server_updates: int | list[int], epsilon: float) -> float:
"""server_updates: number of central server updates performed"""
self._validate_server_updates(server_updates)
return self.accountant.get_delta(self.sampling_strategy, self.noise_multiplier, server_updates, epsilon)