Quickstart¶
Installation¶
First, we need to install the fl4health
package. The easiest and recommended
way to do this is via pip
.
pip install fl4health
A simple FL task¶
With federated learning, the model is trained collaboratively by a set of
distributed nodes called clients
. This collaboration is facilitated by another
node, namely the server
node. To setup an FL task we need to define our Client
as well as our Server
in the scripts client.py
and server.py
, respectively.
client.py
¶
from pathlib import Path
import flwr as fl
import torch
import torch.nn as nn
from flwr.common.typing import Config
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from examples.models.cnn_model import Net
from fl4health.clients.basic_client import BasicClient
from fl4health.utils.load_data import load_cifar10_data
from fl4health.utils.metrics import Accuracy
class CifarClient(BasicClient):
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size=64)
return train_loader, val_loader
def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()
def get_optimizer(self, config: Config) -> Optimizer:
return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
def get_model(self, config: Config) -> nn.Module:
return Net().to(self.device)
def main(dataset_path: str) -> None:
client = CifarClient(data_path=Path(dataset_path), metrics=[Accuracy("accuracy")], device=torch.device("cpu"))
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
server.py
¶
from functools import partial
import flwr as fl
from flwr.common.typing import Config
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg
from examples.models.cnn_model import Net
from fl4health.servers.base_server import FlServer
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters
def fit_config(current_server_round: int) -> Config:
return {"local_epochs": 3, "batch_size": 64, "current_server_round": current_server_round}
def main() -> None:
fit_config_fn = partial(fit_config)
model = Net()
strategy = FedAvg(
min_fit_clients=2,
min_evaluate_clients=2,
# Server waits for min_available_clients before starting FL rounds
min_available_clients=2,
on_fit_config_fn=fit_config_fn,
# We use the same fit config function, as nothing changes for eval
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_all_model_parameters(model),
)
server = FlServer(SimpleClientManager(), {}, strategy)
fl.server.start_server(
server=server,
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=20),
)
Running the FL task¶
Now that we have our server and clients defined, we can run the FL system!
Starting Server¶
The next step is to start the server by running
python -m examples.basic_example.server
Starting Clients¶
Once the server has started and logged “FL starting,” the next step, in separate terminals, is to start the two clients. This is done by simply running (remembering to activate your environment)
python -m examples.basic_example.client /path/to/data
NOTE: The argument dataset_path
has two functions, depending on whether the dataset exists locally or not. If
the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be
automatically downloaded to the path specified and used in the run.
After both clients have been started federated learning should commence.