Introduction

Welcome to AI Pocket References: Federated Learning (FL) Collection. This compilation encapsulates core concepts as well as advanced methods for implementing FL — one of the main techniques for building AI models in a decentralized setting.

Be sure to check out our other collections of AI Pocket References!

Core Concepts in Federated Learning

Suggest an Edit

Reading time: 1 min

In this chapter, we'll introduce several of the fundamental concepts for understanding federated learning (FL). We begin by discussing some of the different flavors of FL and why they constitute distinct subdomains each with their own applications, challenges, and research literature. Next, we briefly discuss three of the most important building blocks associated with FL pipelines: Clients, Servers, and Aggregation.


Contributors:

The Different Flavors of Federated Learning

Suggest an Edit

Reading time: 6 min

Machine learning (ML) models are most commonly trained on a centralized pool of data, meaning that all training data is accessible to a single training process. Federated learning (FL) is used to train ML models on decentralized data, such that data is compartmentalized. The sites at which the data is held and trained are typically referred to as clients. Training data is most often decentralized when it cannot or should not be moved from its location. This might be the case for various reasons, including privacy regulations, security concerns, or resource constraints. Many industries are subject to strict privacy laws, compliance requirements, or data handling requirements, among other important considerations. As such, data centralization is often infeasible or ill-advised. On the other hand, it is well known that access to larger quantities of representative training data often leads to better ML models.1 Thus, in spite of the potential challenges associated with decentralized training, there is significant incentive to facilitate distributed model training.

There are many different flavors of FL. Covering the full set of variations is beyond the scope of these references. However, this reference will cover a few of the major types considered in practice.

Decentralized Datasets

Horizontal Vs. Vertical FL

One of the primary distinctions in FL methodologies is whether one is aiming to perform Horizontal or Vertical FL. The choice of methodological framework here is primarily driven by the kind of training data that exists and why you are doing FL in the first place.

Horizontal FL: More Data, Same Features

In Horizontal FL, it is assumed that models will be trained on a unified set of features and targets. That is, across the distributed datasets, each training point has the same set of features with the same set of interpretations, pre-processing steps, and ranges of potential values, for example. The goal in Horizontal FL is to facilitate access to additional data points during the training of a model. For more details, see Horizontal FL.

Horizontal FL
Feature spaces are shared between clients, enabling access to more unique training data points.

Vertical FL: More Features, Same Generators

While Horizontal FL is concerned with accessing more data points during training, Vertical FL aims to add additional predictive features to improve model predictions. In Vertical FL, there is a shared target or set of targets to be predicted across distributed datasets and it is assumed that all datasets share a non-empty intersection of "data generators" that can be "linked" in some way. For example, the "data generators" might be individual customers of different retailers. Two retailers, might want to collaboratively train a customer segmentation model to improve predictions for their shared customer base. Each retailer has unique information about the customer from their interactions that, when combined, might improve prediction performance.

Vertical FL
"Data generators" are shared between clients with unique features.

To produce a useful distributed training dataset in Vertical FL, datasets are privately "aligned" such that only the intersection of "data generators" are considered in training. In most cases, the datasets are ordered to ensure that disparate features are meaningfully aligned by the underlying generator. Depending on the properties of the datasets, fewer individual data points may be available for training, but hopefully they have been enriched with additional important features. For more details, see Vertical FL.

Cross-Device Vs. Cross-Silo FL

An important distinction between standard ML training and decentralized model training is the presence of multiple, and potentially diverse, compute environments. Leaving aside settings with the possibility of significant resource disparities across data hosting environments, there are still many things to consider that influence the kinds of FL techniques to use. There are two main categories with general, but not firm, separating characteristics: Cross-Silo FL and Cross-Device FL. In the table below, key distinctions between the two types of FL are summarized.

TypeCross-SiloCross-Device
# of ParticipantsSmall- to medium-sized pool of clientsLarge pool of participants
ComputeModerate to large computeLimited compute resources
Dataset SizeModerate to large datasetsTypically small datasets
ReliabilityStable connection and participationPotentially unreliable participants

A quintessential example of a cross-device setting is training a model using data housed on different cell-phones. There are potentially millions of devices participating in training, each with limited computing resources. At any given time, a phone must be switched off or disconnected from the internet. Alternatively, cross-silo settings might arise in training a model between companies or institutions, such as banks or hospitals. They likely have larger datasets at each site and access to more computational resources. There will be fewer participants in training, but they are more likely to reliably contribute to the training system.

Knowing which category of FL one is operating in helps inform design decisions and FL component choices. For example, the model being trained may need to be below a certain size or the memory/compute needs of an FL technique might be prohibitive. A good example of the latter is Ditto, which requires larger compute resources than many other methods.

One Model or a Model Zoo

The final distinction that is highlighted here is whether the model architecture to be trained is the same (homogeneous) across disparate sites or if it differs (heterogeneous). In many settings, the goal is to train a homogeneous model architecture across FL participants. In the context of Horizontal FL, this implies that each client has an identical copy of the architecture with shared feature and label dimensions, as in the figure below.

Homogeneous Architectures
Each client participating in Horizontal FL typically trains the same architecture.

Alternatively, there are FL techniques which aim to federally train collections of heterogeneous architectures across clients.2 That is, each participant in the FL system might be training a different model architecture. Such a setting may arise, for example, if participants would like to benefit from the expanded training data pool offered through Horizontal FL, but want to train their own, proprietary model architecture, rather than a shared model design across all clients. As another example, perhaps certain participants, facing compute constraints, aim to train a model of more manageable size given the resources at their disposal.

Homogeneous Architectures
Model heterogeneous FL attempts to wrangle a zoo of model architectures across participants.

The primary focus of the current pocket references will consider the homogeneous architecture setting. However, there is significant research across each of the different flavors of FL discussed above.


Contributors:

The Role of Clients in Federated Learning

Suggest an Edit

Reading time: 2 min

As discussed in The Different Flavors of Federated Learning, FL is a collection of methods that aim to facilitate training ML models on decentralized training datasets. The entities that house these datasets are often referred to as clients. Any procedures that involve working directly with raw data are typically the responsibility of the clients participating in the FL systems. In addition, clients are only privy to their own local datasets and generally receive no raw data from other participants.

Some FL methods consider the use of related public or synthetic data, potentially modeled after local client data. However, there are often caveats to each of these settings. The former setting is restricted by the assumed existence of relevant public data and the level of "relatedness" can have notable implications in the FL process. In the latter setting, data synthesis has privacy implications that might undermine the goal of keeping data separate in the first place.

Because each client is canonically the only one with access to the data stored in its dataset, they are predominantly responsible for model training, through some mechanism, on their local data. In Horizontal FL, this often manifests as performing some form of gradient-based optimization targeting a local loss function incorporating local data. In Vertical FL, partial forward passes and gradients are constructed based on information from the partial (local) features in each client.

Client
Visualization of some assets for FL clients.

The figure above is a simplified illustration of the various resources housed within an FL client. Each of these components needs to be considered to ensure that federated training proceeds smoothly. For example, given the size of the model to be trained and the desired training settings like batch size, will the client have enough memory to perform backpropagation? Will the training iterations complete in a reasonable amount of time? Is the network bandwidth going to be sufficient to facilitate efficient communication with other components of the FL system?

In subsequent chapters, we'll discuss the exact role clients play in FL, and how they interact with other components of the FL system.


Contributors:

Servers and FL Orchestration

Suggest an Edit

Reading time: 3 min

In many FL workflows a server plays a vital role in orchestration of client behavior, coordinating communication, facilitating information exchange, and synchronizing training results across clients participating in the FL system. In many settings, for example, the server is responsible for

  1. Selecting clients to participate in federated training.
  2. Gathering the results of their local training processes.
  3. Combining these results into a single result for further federated training.
  4. Requesting model evaluations.
  5. Monitoring performance and model checkpointing.

While the server provides significant value through orchestration, it typically bears a reduced computational responsibility. That is, its processes are often less resource intensive compared to those of the FL clients. As such, it can be hosted in environments with lower compute or even collocated with clients. However, there are FL methods that also perform compute intensive procedures on the server-side of the FL system. The trade-offs associated with these methods should play a part in any system design choices.

The figure below provides a simple illustrative example of one role that a server typically plays in a Horizontal FL system. That is, after each client has trained a model on their local training data, they send the model weights to the server to be combined, in some way, into a single new set of weights, which are sent back to the clients for further training.

Exchanging of Weights
Among many other roles, an FL server may receive and combine model weights from FL clients.

A fundamental tenant of FL is that raw data never leaves the local repositories of each client. As such, FL servers never receive raw training data from participating clients. However, the exchange of information is not necessarily restricted to model weights. For example, in adaptive forms of FedProx,1 the server is responsible for adjusting the proximal loss weight used in client training in response to a global view of the loss landscape across participating clients. This requires transmitting such adjustments to the clients at the appropriate times.

In subsequent chapters, the exact role that the server plays in various forms of FL will be discussed in detail. In each setting, the compute burden of the server may vary and the role it plays may differ quite significantly. When deciding on an FL approach, the role of the server is also an important design consideration. For example, in certain settings, the server may not need to reside on a separate machine from the clients. In certain setups, one of the clients may also play the role of the server. In others, that responsibility may rotate between clients.


Contributors:

Aggregation Strategies

Suggest an Edit

Reading time: 2 min

In FL workflows, servers are responsible for a number of crucial components, as discussed in Servers and FL Orchestration. One of these roles is that of aggregation and synchronization of the results of distributed client training processes. This is most prominent in Horizontal FL, where the server is responsible for executing, among other things, an aggregation strategy.

In most Horizontal FL algorithms, there is a concept of a server round wherein each decentralized client trains a model (or models) using local training data. After local training has concluded, each client sends the model weights back to the server. These model weights are combined into a single set of weights using an aggregation strategy. One of the earliest forms of such a strategy, and still one of the most widely used, is FedAvg.1 In FedAvg, client model weights are combined using a weighted averaging scheme. More details on this strategy can be found in FedAvg.

Other forms of FL, beyond Horizontal, incorporate aggregation strategies in various forms. For example, in Vertical FL, the clients must synthesize partial gradient information received from other clients in the system in order to properly perform gradient descent for their local model split in SplitNN algorithms.2 This process, however, isn't necessarily the responsibility of an FL server. Nevertheless, aggregation strategies are most prominently featured and the subject of significant research in Horizontal FL frameworks. As is seen in the sections of Horizontal Federated Learning, many variations and extensions of FedAvg have been proposed to improve convergence, deal with data heterogeneity challenges, stabilize training dynamics, and produce better models. We'll dive into many of these advances in subsequent chapters.


Contributors:

Horizontal Federated Learning

Suggest an Edit

Reading time: 3 min

As outlined in The Different Flavors of Federated Learning, Horizontal FL considers the setting where \(i=1, \ldots, N\) clients each hold a distributed training dataset, \(D_{i}\), on their local compute environment. Each of the datasets share the same feature and label spaces. The goal of Horizontal FL is to train a high-performing model (or models) using all of the training data, \(\{D_i\}_{i=1}^N\), residing on each of the clients in the system.

Horizontal FL
Feature spaces are shared between clients, enabling access to more unique training data points.

In an Horizontal FL system, some fundamental elements are generally present. In most cases, communication and computation between the server and clients is broken into iterations known as server rounds. Typically, the number of such rounds is simply specified as a hyper-parameter, \(T > 0\). During each round, the server chooses a subset of all possible clients of size \(m \leq N \) to participate in that round. Note that one may choose to include all clients or a proper subset thereof. These clients perform some kind of training using their local datasets and send the results of that training back to the server. The contents of these "training results" varies depending on the method used, but often include the model parameters after local training.

After receiving the training results for the clients participating in the round, the server performs some kind of aggregation, combining the training results together. These combined results are returned to the clients for the next round of training. In most cases, the results are communicated to all clients, rather than just the subset that participated in the round.

This process skeleton is summarized in the algorithm below. The specifics of how each of the high-level steps outlined in the algorithms function depends on the exact Horizontal FL algorithm being used. There are also variations of such algorithms that modify or add to the basic framework below.

Horizontal FL Algorithm Outline

This section of the book is organized as follows:

Each of the chapters covers a different aspect of Horizontal FL and provides deeper details on the inner workings of the various algorithms. In Vanilla FL, the foundational Horizontal FL algorithms are discussed. In Robust Global FL, extensions to these foundational algorithms are detailed. Such extensions aim to improve things like convergence and robustness to heterogeneous data challenges common in FL applications while still producing a single generalizable model. Finally, Personalized FL discusses methods for robust and effective methods for training individual models per client that still benefit from the global perspective of other clients. The end result is a set of models individually optimized to perform well on each clients unique distributions.


Contributors:

Foundational FL Techniques

Suggest an Edit

Reading time: 1 min

In this section, FedSGD and FedAvg are detailed, both of which were first proposed in [1]. These methods fall under the category of Horizontal FL. Before detailing how each method works, let's first establish some notation that will be shared in describing both methods. First assume that there are \(N\) clients in the FL pool, each with a unique local training dataset, \(D_i\). Let

$$ D = \bigcup\limits_{k=1}^{N} D_k, $$

and denote \(\vert D \vert = n\). The end goal is to train a model parameterized by weights \(\mathbf{w}\) using all data in \(D\). Further, let \(\ell(\mathbf{w})\) be a loss function depending on \(\mathbf{w}\).

In standard FL, we aim to train a model by minimizing the loss over the dataset \(D\) of total size \(n\). This is written

$$ \begin{align*} \min_{\mathbf{w} \in \mathbf{R}^d} \ell(\mathbf{w}), \qquad \text{where} \qquad & \ell(\mathbf{w}) = \frac{1}{n} \sum_{i=1}^n \ell_i(\mathbf{w}), \end{align*} $$

and \(\ell_i(\mathbf{w})\) is the loss with respect to the \(i^{\text{th}}\) sample in the training dataset. Note that we have implicitly denoted the dimensionality of the model weights, in the equation above, as \(d\).


Contributors:

FedSGD

Suggest an Edit

Reading time: 4 min

Given the Horizontal FL setup, the general idea of FedSGD1 is fairly straightforward.

  1. During each server round, participating clients compute a gradient based on their local loss function, using the current model weights, \(\mathbf{w}\), applied to their local dataset. These gradients are sent to the server.
  2. The server uses the client gradients to update the weights of a model.
  3. The server sends the updated model weights back to the clients, who proceed to compute a new gradient based on their data.

The math

Leveraging the notation set out in the previous section (Foundational FL Techniques), denote by \(P_k\) the indices of samples from client \(k\) in the total dataset \(D\), and denote \(n_k = \vert P_k \vert\). Then we can write the loss over the entire dataset as

$$ \begin{align*} \ell(\mathbf{w}) = \frac{1}{n} \sum_{k=1}^{N} \sum_{i \in P_k} \ell_i(\mathbf{w}), \end{align*} $$

recalling that \(\ell_i(\mathbf{w})\) is the loss function with respect to the \(i^{\text{th}}\) sample. In this equation it is the \(i^{\text{th}}\) sample drawn from the dataset \(D_k\).

For a server round \(t\) and current set of model weights, \(\mathbf{w}_t\), consider selected a subset, \(C_t\), of \(m \leq N\) clients from which to compute a weight update. The loss over all data points held by the clients in \(C_t\) is written

$$ \begin{align*} \ell_t(\mathbf{w}_t) = \frac{1}{n_s} \sum_{k \in C_t} \sum_{i \in P_k} \ell_i(\mathbf{w}_t), \end{align*} $$

where \(n_s = \displaystyle{\sum_{k \in C_t} n_k}\). For a given client \(k\), define

$$ \begin{align*} L_k(\mathbf{w_t}) = \frac{1}{n_k} \sum_{i \in P_k} \ell_i(\mathbf{w_t}). \end{align*} $$

Note that \(L_k(\mathbf{w}_t)\) is simply the local loss for client \(k\) across all data points in its dataset, \(D_k\). Then

$$ \begin{align*} \ell_t(\mathbf{w}_t) = \sum_{k \in C_t} \frac{n_k}{n_s} L_k(\mathbf{w}_t). \end{align*} $$

Observe that the global loss, \(\ell_t(\mathbf{w}_t)\), over the selected subset of clients, \(C_t\), is now written as a linearly weighted combination of the local losses of each client.

As the gradient is a linear operator, the gradient of the global loss is

$$ \begin{align*} \nabla \ell_t(\mathbf{w}_t) = \sum_{k \in C_t} \frac{n_k}{n_s} \nabla L_k(\mathbf{w}_t). \end{align*} $$

This implies that model weights \(\mathbf{w}_t\) are updated as

$$ \begin{align} \mathbf{w}_{t+1} = \mathbf{w}_t - \eta \sum_{k \in C_t} \frac{n_k}{n_s} \nabla L_k(\mathbf{w}_t), \tag{1} \end{align} $$

for some learning rate \(\eta > 0\). Because \(\nabla L_k(\mathbf{w}_t)\) is the gradient of the local loss function of client \(k\), this update is simply a linearly weighted combination of local gradients, which can be computed locally by each client.

FedSGD is just Large Batch SGD

One of the important properties of FedSGD is that the update in Equation (1) is mathematically equivalent to

$$ \begin{align*} \mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla \ell_t(\mathbf{w}_t), \end{align*} $$

where \(\ell_t\) denotes the loss function over all data in each of the clients in \(C_t\). As such, FedSGD, in spite of leveraging gradients computed in a distributed fashion on each individual client, is equivalent to performing centralized batch SGD, where the batch is of size \(n_s = \displaystyle{\sum_{k \in C_t} n_k}\).

Among other implications, this means that convergence theory associated with standard batch SGD is directly applicable to the FedSGD procedure.

The algorithm

We are now in a position to fill in the details of the general Horizontal FL algorithm presented in Horizontal Federated Learning. The full workflow of FedSGD is summarized in the algorithm below. Inputs are \(N\), the number of clients, \(T\), the number of server rounds to perform, \(\eta\) the learning rate, and \(\mathbf{w}\), the initial weights for the model to be trained. After the final server round is complete, each client receives the final model as described by the weights \(\mathbf{w}_T\).

FedSGD Algorithm

Communication overhead

The FedSGD algorithm has several benefits, including the mathematical equivalence discussed above. However, it has at least one significant drawback. Communication between the clients and server occurs for every SGD step. That is, for each model update, participating clients are required to communicate their gradients, and the server must send updated model weights back. In most settings, latency associated with communication between clients and servers will be significantly higher than that of computing the local gradients or performing the weight updates. As such communication overhead become a significant bottleneck and materially slows training. Reducing communication costs is the driving motivation behind the FedAvg approach.


Contributors:

FedAvg

Suggest an Edit

Reading time: 5 min

The FedAvg algorithm1 builds on the same principles of FedSGD, but aims to reduce the communication costs incurred by the FedSGD approach. Recall that the major shortcoming of FedSGD was that it required clients to send local gradients for every training step in order to perform model updates. The FedAvg algorithm attempts to reduce this overhead by pushing additional computation onto the clients.

The math

Assume a fixed learning rate of \(\eta > 0\), and denote

$$ \begin{align} \mathbf{w}_{t+1}^k = \mathbf{w}_t - \eta \nabla L_k(\mathbf{w}_t). \tag{1} \end{align} $$

Note that \(\mathbf{w}_t - \eta \nabla L_k(\mathbf{w}_t)\) is just a local (full) gradient step on client \(k\). That is, \(\nabla L_k(\mathbf{w}_t)\) is the gradient with respect to all training data on client \(k\). So the weights \(\mathbf{w}_{t+1}^k\) represent a new model using only the data of client \(k\) to update the weights, \(\mathbf{w}_t\). Then we can rewrite the server update in FedSGD in terms of \(\mathbf{w}_{t+1}^k\) with a little algebra as

$$ \begin{align} \mathbf{w}_{t+1}&= \mathbf{w}_t - \eta \sum_{k \in C_t} \frac{n_k}{n_s} \nabla L_k(\mathbf{w}_t) \\ &= \sum_{k \in C_t} \frac{n_k}{n_s} \mathbf{w}_t - \eta \sum_{k \in C_t} \frac{n_k}{n_s} \nabla L_k(\mathbf{w}_t) \\ &= \sum_{k \in C_t} \frac{n_k}{n_s} \mathbf{w}_{t+1}^k. \tag{2} \end{align} $$

The final line of Equation (2) implies that the updated weights, \(\mathbf{w}_{t+1}\), in FedSGD can be rewritten as the linearly weighted average of local weight updates performed by the clients themselves. That is, \(\mathbf{w}_{t+1}\) is just a weighted average of locally updated weights, where the weights are the proportion of data points on each client (\(n_k\)) relative the the size of all data points used to compute the update (\(n_s\)).

With this in hand, we can push responsibility for updating model weights onto the clients participating in a round of FL training. Only model weights are communicated back and forth, and the server need only average the locally updated weights to obtain the new model. This procedure remains mathematically equivalent to centralized large batch SGD, as is the case for FedSGD. The bad news is that we haven't saved any communication yet. This still relies on communicating the updated weights after each step and the dimensionality of the model weights and their gradient is equal. So what can we do?

Rather than a full, local-gradient step on each client, as expressed in Equation (1), we can run multiple local batch SGD updates. For client \(k\), let \(B\) be a set of batches drawn from \(P_k\), the collection of training data points on client \(k\). For \(b \in B\), perform local updates of the form

$$ \begin{align*} \mathbf{w}^k = \mathbf{w}^k - \eta \frac{1}{\vert b \vert} \sum_{i \in b} \nabla \ell_i (\mathbf{w}^k). \end{align*} $$

This allows for each client to perform multiple local batch SGD updates to the model weights. As in standard ML training, these updates can be performed for a certain number epochs, iterating through each client's local data. Only after completing such iterations are the updated weights communicated to the server for aggregation using the same formula in Equation (2) on the server side. In this manner, we have decoupled model updates from communication with the server and are free to communicate as frequently or infrequently as we choose.

The algorithm

With the new approach proposed in the previous section, the full FedAvg algorithm may be summarized in the algorithm below. Inputs to the algorithm are:

  • \(N\): The number of clients.
  • \(T\): The number of server rounds to perform.
  • \(\eta\): The learning rate to be used by each client.
  • \(n_b\): The batch size to be used for each local gradient step.
  • \(\mathbf{w}\): The initial weights for the model to be trained.
  • \(E\): The number of epochs for each client to perform.

After the final server round is complete, each client receives the final model as described by the weights \(\mathbf{w}_T\).

FedAvg Algorithm

Note that, in the algorithm above, the local updates are performed with standard batch SGD. There is nothing stopping us from using a different training procedure on the client side. For example, one might instead perform such updates using an AdamW optimizer.2 As with standard ML training, the type of optimizer that works best is problem dependent.

A broken equivalence can have consequences

Both theoretically and experimentally, FedAvg is a strong algorithm. The modifications to FedSGD can be used to substantially drive down communication costs while retaining many of the benefits of FedSGD in practice. Since its introduction, the FedAvg algorithm has been widely used to make ML model training on decentralized datasets a reality. However, the modifications that make FedAvg more communication efficient compared with FedSGD also break the mathematical equivalence to global large-batch SGD enjoyed by FedSGD.

When the training data spread across clients is identically and independently distributed (i.e. drawn independently from the same distribution), this loss of equivalence is generally less consequential. On the other hand, when client data distributions become more heterogeneous, the lack of true equivalence materially impacts the convergence properties of FedAvg and can lead to suboptimal performance. As such, many approaches have since been proposed to improve upon FedAvg while maintaining its desirable qualities, like communication efficiency.


Contributors: