RyanCheung Ignition sequence start

联邦学习 Federated Learning


联邦学习

联邦学习需要解决的问题是:如何在不上传数据的情况下,利用边缘设备的算力,对模型进行训练并共享

联邦学习和传统分布式学习的区别在于:

  1. 用户节点作为worker对数据进行计算,不需要将数据分发到服务器或其他节点;
  2. 用户节点worker并不稳定,数据可能呈现非独立同分布(non-IID),数据分布可能在不同节点具有偏移性
  3. 用户节点worker和服务器server的通信代价远大于计算代价。

Communication-Efficiency

提高通信效率的核心理念在于“多做计算少做通信”

Privacy

需要保证用于训练的用户数据在用户本地被训练,并且保证不能通过训练模型的梯度或参数逆向推理

FedAVG

FedAVG是在边缘节点Worker将本地数据经过 1~5 epoch 训练后,将更新的权重发送到Server进行加权平均后,重新发送回边缘节点Worker。

目标函数(最小化经验损失):$ \min{\sum_{k=1}^K\frac{n_k}{n}F_k(w)}$, where $F_k(w)=\frac{1}{n_k}L_k(w)$


Server executes:

​ initialize $w_0$

​ for each round $t = 1, 2, . . .$ do

​ $m ← max(C · K, 1)$​ // a fixed set of $K$​ clients, a random fraction $C$​ of clients is selected

​ $St ← (random\;set\;of\;m\;clients)$

​ for each client $k ∈ St$ in parallel do

​ $w_{t+1}^k ← ClientUpdate(k, w_t)$​

​ $w_{t+1}← \sum_{k=1}^K(\frac{n_k}{n}w_{t+1}^k)$​​

**ClientUpdate($k$​​, $w$​​): **

​ $B ← (split\;P_k\;into\;batches\;of\;size\;B)$​

​ for each local epoch $i$ from $1$ to $E$ do

​ for batch $b ∈ B$ do

​ $w ← w − η·g(w,b)$​

​ return $w$ to server


FedProx

FedProx是对FedAVG的补充,其为了使边缘节点Worker更新不要太远离初始Global Model,减少Non-IID的影响。

目标函数(最小化经验损失):$ \min{\sum_{k=1}^K\frac{n_k}{n}F_k(w)}$​, where $F_k(w)=\frac{1}{n_k}L_k(w)+\frac{\mu}{2}   w-w_t   ^2$​

Server executes:

​ initialize $w_0$

​ for each round $t = 1, 2, . . .$ do

​ $m ← max(C · K, 1)$​ // a fixed set of $K$​ clients, a random fraction $C$​ of clients is selected

​ $St ← (random\;set\;of\;m\;clients)$

​ for each client $k ∈ St$ in parallel do

​ $w_{t+1}^k ← ClientUpdate(k, w_t)$​

​ $w_{t+1}← \sum_{k=1}^K(\frac{n_k}{n}w_{t+1}^k)$​​

**ClientUpdate($k$​​, $w$​​): **

​ $B ← (split\;P_k\;into\;batches\;of\;size\;B)$​

​ for each local epoch $i$ from $1$ to $E$ do

​ for batch $b ∈ B$ do

​ $w = argmin(\frac{1}{n_k}L’_k(w)+\frac{\mu}{2}   w-w_t   ^2)$​

​ return $w$ to server


FedCurv

FedCurv联邦学习连续学习EWC算法的适用场景。对于第$k$​个节点的$t$​个任务,其损失函数可以定义为:

$L_{t,k}(\theta) = L’k(\theta) + \lambda\sum{j\in{K/k}}(\theta-\hat{\theta}{t-1,j})^TF{t-1,j}(\theta-\hat{\theta}_{t-1,j})$​​​


Server executes:

​ for each round $t = 1, 2, . . .$ do

​ $m ← max(C · K, 1)$​ // a fixed set of $K$​ clients, a random fraction $C$​ of clients is selected

​ $St ← (random\;set\;of\;m\;clients)$​

​ $\theta_{t}← \sum_{k=1}^K(\frac{n_k}{n}\theta_{t-1}^k)$​​

​ for each client $k ∈ St$ in parallel do

​ $\theta_{t+1}^k,F^k_{t+1} ← ClientUpdate(k, \theta_t,F_t)$​​​​​​​​​​​

**ClientUpdate($k$​​​​​, $\theta$​​, $F$​​​​​​): **

​ $B ← (split\;P_k\;into\;batches\;of\;size\;B)$​

​ for each local epoch $i$ from $1$ to $E$ do

​ for batch $b ∈ B$ do

​ $\theta = argmin(L’(\theta) + \lambda\sum_{j\in{K/k}}(\theta-\hat{\theta}{t-1,j})^TF{t-1,j}(\theta-\hat{\theta}_{t-1,j}))$​

​ return $\theta_{t+1}^k,F^k_{t+1}$ to server


FedCL

FedCL是对FedCurv的改进。将FedCurv中对每个模型参数的重要性权重$F_t^k$的计算放在Server上进行,通过Proxy dataset进行估计,从而减少至少50%的通信代价。但它的问题也在于需要存有一部分Dataset在Server上用于计算$F_t^k$​,对于隐私和传输敏感的场景具有局限性。

$L_{t,k}(\theta) = L’k(\theta) + \lambda\sum{j\in{K/k}}F_{t,k}(\theta-\hat{\theta}_{t-1,j})^2)$​​​​, where $F=\frac{1}{ D_{proxy} }\sum_{x_k,y_k\in{D_{proxcy}}} \frac{\partial{L(\theta_g(x_k),y_k)}}{\partial{\theta_g}} ^2$​​​​

Server executes:

​ initialize $\theta_0$​

​ for each round $t = 1, 2, . . .$​ do

​ $F←\frac{1}{ D_{proxy} }\sum_{x_k,y_k\in{D_{proxcy}}} \frac{\partial{L(\theta_g(x_k),y_k)}}{\partial{\theta_g}} ^2$

​ $m ← max(C · K, 1)$​ // a fixed set of $K$​ clients, a random fraction $C$​ of clients is selected

​ $St ← (random\;set\;of\;m\;clients)$

​ for each client $k ∈ St$​ in parallel do

​ if $t\;mod\;N == 0$ do

​ $\theta_{t+1}^k ← ClientUpdate(\theta_t, F)$​

​ else

​ $\theta_{t+1}^k ← ClientUpdate(\theta_t)$​

​ $\theta_{t+1}← \sum_{k=1}^K(\frac{n_k}{n}\theta_{t+1}^k)$​​

**ClientUpdate($\theta$, $F$): **

​ if $F$ is not received then

​ $F←Identity\;Metricx$

​ $B ← (split\;P_k\;into\;batches\;of\;size\;B)$​

​ for each local epoch $i$ from $1$ to $E$ do

​ for batch $b ∈ B$ do

​ $\theta = argmin(L’k(\theta) + \lambda\sum{j\in{K/k}}F_{t,k}(\theta-\hat{\theta}_{t-1,j})^2))$

​ return $\theta_{t+1}^k,F^k_{t+1}$ to server