Source code for flamby.strategies.scaffold
import warnings
from typing import List
import torch
from flamby.strategies.fed_avg import FedAvg
from flamby.strategies.utils import _Model
[docs]
class Scaffold(FedAvg):
"""SCAFFOLD Strategy class
SCAFFOLD is a stateful algorithm which modifies the local update steps of FedAvg
in order to provably correct for data heterogeneity across clients. If the data
on each client is very different, their local updates via FedAvg will move in
different directions. Each client maintains a 'correction' which estimates
this difference between client updates and global average update. This correction
is added to every local update on the client.
This is a more efficient implementation of Scaffold whose communication and
computation requirement exactly matches that of FedAvg.
The current implementation assumes that SGD is the local optimizer, and that all
clients participate every round.
References
----------
https://arxiv.org/abs/1910.06378
Parameters
----------
training_dataloaders : List
The list of training dataloaders from multiple training centers.
model : torch.nn.Module
An initialized torch model.
loss : torch.nn.modules.loss._Loss
The loss to minimize between the predictions of the model and the
ground truth.
optimizer_class : torch.optim.Optimizer
The class of the torch model optimizer to use at each step.
It has to be SGD.
learning_rate : float
The learning rate to be given to the clients optimizer_class.
num_updates : int
The number of updates to do on each client at each round.
nrounds : int
The number of communication rounds to do.
dp_target_epsilon: float
The target epsilon for (epsilon, delta)-differential
private guarantee. Defaults to None.
dp_target_delta: float
The target delta for (epsilon, delta)-differential
private guarantee. Defaults to None.
dp_max_grad_norm: float
The maximum L2 norm of per-sample gradients;
used to enforce differential privacy. Defaults to None.
server_learning_rate : float
The learning rate with which the server's updates are aggregated.
Defaults to 1.
log: bool
Whether or not to store logs in tensorboard. Defaults to False.
log_period: int
If log is True then log the loss every log_period batch updates.
Defauts to 100.
bits_counting_function : Union[callable, None]
A function making sure exchanges respect the rules, this function
can be obtained by decorating check_exchange_compliance in
flamby.utils. Should have the signature List[Tensor] -> int.
Defaults to None.
logdir: str
Where to store the logs. Defaulst to ./runs.
log_basename: str
The basename of the created logfile. Defaulst to scaffold.
"""
def __init__(
self,
training_dataloaders: List,
model: torch.nn.Module,
loss: torch.nn.modules.loss._Loss,
optimizer_class: torch.optim.Optimizer,
learning_rate: float,
num_updates: int,
nrounds: int,
dp_target_epsilon: float = None,
dp_target_delta: float = None,
dp_max_grad_norm: float = None,
server_learning_rate: float = 1,
log: bool = False,
log_period: int = 100,
bits_counting_function: callable = None,
logdir: str = "./runs",
log_basename: str = "scaffold",
):
"""Cf class docstring"""
assert (
optimizer_class == torch.optim.SGD
), "Only SGD for client optimizer with Scaffold"
super().__init__(
training_dataloaders,
model,
loss,
optimizer_class,
learning_rate,
num_updates,
nrounds,
dp_target_epsilon,
dp_target_delta,
dp_max_grad_norm,
log,
log_period,
bits_counting_function,
log_basename=log_basename,
logdir=logdir,
)
# Add a warning if user wants to make DP
if dp_target_epsilon is not None:
warnings.warn("Warning, the DP bounds passed are not valid for Scaffold.")
# initialize the previous state of each client
self.previous_client_state_list = [
_model._get_current_params() for _model in self.models_list
]
# initialize the corrections used by each client to 0s.
self.client_corrections_state_list = [
[torch.zeros_like(torch.from_numpy(p)) for p in _model._get_current_params()]
for _model in self.models_list
]
self.client_lr = learning_rate
self.server_lr = server_learning_rate
def _local_optimization(
self, _model: _Model, dataloader_with_memory, correction_state: List
):
"""Carry out the local optimization step.
Parameters
----------
_model: _Model
The model on the local device used by the optimization step.
dataloader_with_memory : dataloaderwithmemory
A dataloader that can be called infinitely using its get_samples()
method.
correction_state: List
Correction to be applied to the model state during every local update.
"""
_model._local_train_with_correction(
dataloader_with_memory, self.num_updates, correction_state
)
[docs]
def perform_round(self):
"""Does a single federated averaging round. The following steps will be
performed:
- each model will be trained locally for num_updates batches.
- the parameter updates will be collected and averaged. Averages will be
weighted by the number of samples in each client
- the averaged updates willl be used to update the local model
"""
local_updates = list()
new_client_state_list = list()
new_correction_state_list = list()
for (
_model,
dataloader_with_memory,
size,
_previous_client_state,
_prev_correction_state,
) in zip(
self.models_list,
self.training_dataloaders_with_memory,
self.training_sizes,
self.previous_client_state_list,
self.client_corrections_state_list,
):
# Update as correction += (server_state - previous_state) / lr*num_updates.
_server_state = _model._get_current_params()
_new_correction_state = [
c
+ torch.from_numpy(
(p - q) / (self.server_lr * self.client_lr * self.num_updates)
)
for c, p, q in zip(
_prev_correction_state, _server_state, _previous_client_state
)
]
new_correction_state_list.append(_new_correction_state)
del _previous_client_state
del _prev_correction_state
# Local Optimization
self._local_optimization(
_model, dataloader_with_memory, _new_correction_state
)
_local_next_state = _model._get_current_params()
# Scale local parameters by server_lr
_local_next_state = [
self.server_lr * new + (1 - self.server_lr) * old
for new, old in zip(_local_next_state, _server_state)
]
new_client_state_list.append(_local_next_state)
# Recovering updates
updates = [new - old for new, old in zip(_local_next_state, _server_state)]
del _local_next_state
# Reset local model
for p_new, p_old in zip(_model.model.parameters(), _server_state):
p_new.data = torch.from_numpy(p_old).to(p_new.device)
del _server_state
if self.bits_counting_function is not None:
self.bits_counting_function(updates)
local_updates.append({"updates": updates, "n_samples": size})
# update previous client states
self.previous_client_state_list = new_client_state_list
self.client_corrections_state_list = new_correction_state_list
# Aggregation step
aggregated_delta_weights = [
None for _ in range(len(local_updates[0]["updates"]))
]
for idx_weight in range(len(local_updates[0]["updates"])):
aggregated_delta_weights[idx_weight] = sum(
[
local_updates[idx_client]["updates"][idx_weight]
* local_updates[idx_client]["n_samples"]
for idx_client in range(self.num_clients)
]
)
aggregated_delta_weights[idx_weight] /= float(self.total_number_of_samples)
# Update models
for _model in self.models_list:
_model._update_params(aggregated_delta_weights)