Source code for fl4health.client_managers.poisson_sampling_manager
from logging import WARNING
import numpy as np
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy
from flwr.server.criterion import Criterion
from fl4health.client_managers.base_sampling_manager import BaseFractionSamplingManager
[docs]
class PoissonSamplingClientManager(BaseFractionSamplingManager):
"""Overrides the Simple Client Manager to Provide Poisson Sampling for Clients rather than
fixed without replacement sampling"""
def _poisson_sample(self, sampling_probability: float, available_cids: list[str]) -> list[str]:
poisson_trials = np.random.binomial(1, sampling_probability, len(available_cids))
poisson_mask = poisson_trials.astype(dtype=bool)
return list(np.array(available_cids)[poisson_mask])
[docs]
def sample_fraction(
self,
sample_fraction: float,
min_num_clients: int | None = None,
criterion: Criterion | None = None,
) -> list[ClientProxy]:
"""Poisson Sampling of Flower ClientProxy instances with a probability determine by sample_fraction."""
available_cids = self.wait_and_filter(min_num_clients, criterion)
n_available_cids = len(available_cids)
expected_clients_selected = sample_fraction * n_available_cids
if expected_clients_selected < 1:
log(
WARNING,
f"Sample fraction of {round(sample_fraction, 3)} from {n_available_cids} clients results "
f"in expected value of {round(expected_clients_selected, 3)} selected.",
)
sampled_cids = self._poisson_sample(sample_fraction, available_cids)
return [self.clients[cid] for cid in sampled_cids]