联邦学习 Federated Learning
联邦学习
联邦学习需要解决的问题是:如何在不上传数据的情况下,利用边缘设备的算力,对模型进行训练并共享。
联邦学习和传统分布式学习的区别在于:
- 用户节点作为worker对数据进行计算,不需要将数据分发到服务器或其他节点;
- 用户节点worker并不稳定,数据可能呈现非独立同分布(non-IID),数据分布可能在不同节点具有偏移性;
- 用户节点worker和服务器server的通信代价远大于计算代价。
Communication-Efficiency
提高通信效率的核心理念在于“多做计算少做通信”。
Privacy
需要保证用于训练的用户数据在用户本地被训练,并且保证不能通过训练模型的梯度或参数逆向推理。
FedAVG
FedAVG是在边缘节点Worker将本地数据经过 1~5 epoch 训练后,将更新的权重发送到Server进行加权平均后,重新发送回边缘节点Worker。
目标函数(最小化经验损失):$ \min{\sum_{k=1}^K\frac{n_k}{n}F_k(w)}$, where $F_k(w)=\frac{1}{n_k}L_k(w)$
Server executes:
initialize $w_0$
for each round $t = 1, 2, . . .$ do
$m ← max(C · K, 1)$ // a fixed set of $K$ clients, a random fraction $C$ of clients is selected
$St ← (random\;set\;of\;m\;clients)$
for each client $k ∈ St$ in parallel do
$w_{t+1}^k ← ClientUpdate(k, w_t)$
$w_{t+1}← \sum_{k=1}^K(\frac{n_k}{n}w_{t+1}^k)$
**ClientUpdate($k$, $w$): **
$B ← (split\;P_k\;into\;batches\;of\;size\;B)$
for each local epoch $i$ from $1$ to $E$ do
for batch $b ∈ B$ do
$w ← w − η·g(w,b)$
return $w$ to server
FedProx
FedProx是对FedAVG的补充,其为了使边缘节点Worker更新不要太远离初始Global Model,减少Non-IID的影响。
| 目标函数(最小化经验损失):$ \min{\sum_{k=1}^K\frac{n_k}{n}F_k(w)}$, where $F_k(w)=\frac{1}{n_k}L_k(w)+\frac{\mu}{2} | w-w_t | ^2$ |
Server executes:
initialize $w_0$
for each round $t = 1, 2, . . .$ do
$m ← max(C · K, 1)$ // a fixed set of $K$ clients, a random fraction $C$ of clients is selected
$St ← (random\;set\;of\;m\;clients)$
for each client $k ∈ St$ in parallel do
$w_{t+1}^k ← ClientUpdate(k, w_t)$
$w_{t+1}← \sum_{k=1}^K(\frac{n_k}{n}w_{t+1}^k)$
**ClientUpdate($k$, $w$): **
$B ← (split\;P_k\;into\;batches\;of\;size\;B)$
for each local epoch $i$ from $1$ to $E$ do
for batch $b ∈ B$ do
| $w = argmin(\frac{1}{n_k}L’_k(w)+\frac{\mu}{2} | w-w_t | ^2)$ |
return $w$ to server
FedCurv
FedCurv是联邦学习对连续学习EWC算法的适用场景。对于第$k$个节点的$t$个任务,其损失函数可以定义为:
$L_{t,k}(\theta) = L’k(\theta) + \lambda\sum{j\in{K/k}}(\theta-\hat{\theta}{t-1,j})^TF{t-1,j}(\theta-\hat{\theta}_{t-1,j})$
Server executes:
for each round $t = 1, 2, . . .$ do
$m ← max(C · K, 1)$ // a fixed set of $K$ clients, a random fraction $C$ of clients is selected
$St ← (random\;set\;of\;m\;clients)$
$\theta_{t}← \sum_{k=1}^K(\frac{n_k}{n}\theta_{t-1}^k)$
for each client $k ∈ St$ in parallel do
$\theta_{t+1}^k,F^k_{t+1} ← ClientUpdate(k, \theta_t,F_t)$
**ClientUpdate($k$, $\theta$, $F$): **
$B ← (split\;P_k\;into\;batches\;of\;size\;B)$
for each local epoch $i$ from $1$ to $E$ do
for batch $b ∈ B$ do
$\theta = argmin(L’(\theta) + \lambda\sum_{j\in{K/k}}(\theta-\hat{\theta}{t-1,j})^TF{t-1,j}(\theta-\hat{\theta}_{t-1,j}))$
return $\theta_{t+1}^k,F^k_{t+1}$ to server
FedCL
FedCL是对FedCurv的改进。将FedCurv中对每个模型参数的重要性权重$F_t^k$的计算放在Server上进行,通过Proxy dataset进行估计,从而减少至少50%的通信代价。但它的问题也在于需要存有一部分Dataset在Server上用于计算$F_t^k$,对于隐私和传输敏感的场景具有局限性。
| $L_{t,k}(\theta) = L’k(\theta) + \lambda\sum{j\in{K/k}}F_{t,k}(\theta-\hat{\theta}_{t-1,j})^2)$, where $F=\frac{1}{ | D_{proxy} | }\sum_{x_k,y_k\in{D_{proxcy}}} | \frac{\partial{L(\theta_g(x_k),y_k)}}{\partial{\theta_g}} | ^2$ |
Server executes:
initialize $\theta_0$
for each round $t = 1, 2, . . .$ do
| $F←\frac{1}{ | D_{proxy} | }\sum_{x_k,y_k\in{D_{proxcy}}} | \frac{\partial{L(\theta_g(x_k),y_k)}}{\partial{\theta_g}} | ^2$ |
$m ← max(C · K, 1)$ // a fixed set of $K$ clients, a random fraction $C$ of clients is selected
$St ← (random\;set\;of\;m\;clients)$
for each client $k ∈ St$ in parallel do
if $t\;mod\;N == 0$ do
$\theta_{t+1}^k ← ClientUpdate(\theta_t, F)$
else
$\theta_{t+1}^k ← ClientUpdate(\theta_t)$
$\theta_{t+1}← \sum_{k=1}^K(\frac{n_k}{n}\theta_{t+1}^k)$
**ClientUpdate($\theta$, $F$): **
if $F$ is not received then
$F←Identity\;Metricx$
$B ← (split\;P_k\;into\;batches\;of\;size\;B)$
for each local epoch $i$ from $1$ to $E$ do
for batch $b ∈ B$ do
$\theta = argmin(L’k(\theta) + \lambda\sum{j\in{K/k}}F_{t,k}(\theta-\hat{\theta}_{t-1,j})^2))$
return $\theta_{t+1}^k,F^k_{t+1}$ to server