Split Federated Learning 回顾

我们假设 SFL 系统是一个 central server 和 MM 个 client 组成,记为 m{1,2,,m}m\in \{1,2,\ldots, m\}

在每一个 client 上,整个模型被划分成 client-side model fc(;θc)f_c(\cdot;\boldsymbol{\theta}_c) 和 server-side model fs(;θs)f_s(\cdot;\boldsymbol{\theta}_s) (毕竟背景是 split federated learning 嘛).这里的 θc/s\boldsymbol{\theta}_{c/s} 表示 client 和 server 的参数.整个模型的参数记为 θ=θc+θs\boldsymbol{\theta}=\boldsymbol{\theta}_c+\boldsymbol{\theta}_s

在 SFL 的体系下,我们从 client 本地的数据集 Dm\mathcal{D}_m 里采样 data point ξm=(xm,ym)Dm\xi_m=(\boldsymbol{x}_m, y_m) \in \mathcal{D}_m(这里的 xm\boldsymbol{x}_m 标记为粗体是表示 x\boldsymbol{x} 是张量).

接着执行 Forward Pass.首先 client-side model 计算 activation

zm=fc(xm;θc)\boldsymbol{z}_m=f_c(\boldsymbol{x}_m; \boldsymbol{\theta}_c)

然后 client 与 server 之间进行通信,传递 activation.后续的前向传播由 server-side model 完成:

y^m=fs(zm;θs)\hat{y}_m=f_s(\boldsymbol{z}_m; \boldsymbol{\theta}_s)

而后开始反向传播。先计算 loss

(θ;ξm):=(y^m,ym)\ell(\boldsymbol{\theta}; \xi_m) := \ell(\hat{y}_m, y_m)

对于 client-side model,我们定义其 local objective function 为

Lm(θ):=EξmDm[(θ;ξm)]\mathcal{L}_m(\boldsymbol{\theta}):=\mathbb{E}_{\xi_m \sim \mathcal{D}_m}\Big[\ell(\boldsymbol{\theta}; \xi_m)\Big]

所以全局需要优化的目标函数是

L(θ):=1Mm=1MLm(θ)\mathcal{L}(\boldsymbol{\theta}):=\frac{1}{M}\sum_{m=1}^M \mathcal{L}_m(\boldsymbol{\theta})


【理论推导】为什么 client 和 server 可以使用不同的优化方法

根据上文提到的设定,loss 的本质是

(fs(fc(xm;θc);θs),ym)\ell\bigg(f_s\Big(f_c(\boldsymbol{x}_m;\boldsymbol{\theta_c} ); \boldsymbol{\theta}_s\Big), y_m\bigg)

但这里涉及到 server-side model 与 client-side model 耦合的问题,直接对 client/server-side 采取不同的优化方法的话并不会得到理论支撑.我们用 variable lifting 的方法,把 client-side forward pass 的 activation 视为条件,在此基础上 optimize server-side model.这一点是 make sense 的,因为在 backprop 传导到 client-side 之前,都可以视为 barely training server-side model.

min(fs(zm;θs),ym)s.t.zm=fc(xm;θc)\begin{array}{rll} \min &\ell(f_s(\boldsymbol{z}_m; \boldsymbol{\theta}_s), y_m) \\ \text{s.t.} &\boldsymbol{z}_m=f_c(\boldsymbol{x}_m; \boldsymbol{\theta}_c) \end{array}

我们用拉格朗日乘数法转化为无约束优化问题:

min(fs(zm;θs),ym)+λ(fc(xm;θc)zm)\min \ell(f_s(\boldsymbol{z}_m; \boldsymbol{\theta}_s), y_m) + \boldsymbol{\lambda}^\top (f_c(\boldsymbol{x}_m; \boldsymbol{\theta}_c) - \boldsymbol{z}_m)