Split Federated Learning 回顾
我们假设 SFL 系统是一个 central server 和 M 个 client 组成,记为 m∈{1,2,…,m}
在每一个 client 上,整个模型被划分成 client-side model fc(⋅;θc) 和 server-side model fs(⋅;θs) (毕竟背景是 split federated learning 嘛).这里的 θc/s 表示 client 和 server 的参数.整个模型的参数记为 θ=θc+θs
在 SFL 的体系下,我们从 client 本地的数据集 Dm 里采样 data point ξm=(xm,ym)∈Dm(这里的 xm 标记为粗体是表示 x 是张量).
接着执行 Forward Pass.首先 client-side model 计算 activation
zm=fc(xm;θc)
然后 client 与 server 之间进行通信,传递 activation.后续的前向传播由 server-side model 完成:
y^m=fs(zm;θs)
而后开始反向传播。先计算 loss
ℓ(θ;ξm):=ℓ(y^m,ym)
对于 client-side model,我们定义其 local objective function 为
Lm(θ):=Eξm∼Dm[ℓ(θ;ξm)]
所以全局需要优化的目标函数是
L(θ):=M1m=1∑MLm(θ)
【理论推导】为什么 client 和 server 可以使用不同的优化方法
根据上文提到的设定,loss 的本质是
ℓ(fs(fc(xm;θc);θs),ym)
但这里涉及到 server-side model 与 client-side model 耦合的问题,直接对 client/server-side 采取不同的优化方法的话并不会得到理论支撑.我们用 variable lifting 的方法,把 client-side forward pass 的 activation 视为条件,在此基础上 optimize server-side model.这一点是 make sense 的,因为在 backprop 传导到 client-side 之前,都可以视为 barely training server-side model.
mins.t.ℓ(fs(zm;θs),ym)zm=fc(xm;θc)
我们用拉格朗日乘数法转化为无约束优化问题:
minℓ(fs(zm;θs),ym)+λ⊤(fc(xm;θc)−zm)