Federated Learning

[MLSys 2020] FedProx - (3) 본문

Federated Learning/Papers

[MLSys 2020] FedProx - (3)

pseudope 2022. 8. 1. 02:30
728x90

논문 제목: Federated Optimization in Heterogeneous Networks

출처: https://arxiv.org/abs/1812.06127

 

 이전 포스트에서, 우리는 non-convex한 Fk에 관한 FedProx의 convergence를 증명하였습니다. (이전 글 보기) 이번 포스트에서는 FedProx의 convergence rate를 구해보고, convex case와 γ가 device 별로, 또 round 별로 달라지는 case에 관해서도 살펴볼 것입니다.

 

6. non-convex case의 convergence rate

 

 Theorem 6 (Convergence rate: FedProx)

  ϵ>0이 주어졌을 때, B>Bϵ, μ, γ, 그리고 K에 대해서 Theorem 4의 assumption들이 FedProx algorithm의 매 iteration마다 성립한다고 가정하자. 그리고 f(w0)f=:Δ라고 정의하자. 이때, FedProx algorithm의 T:=O(Δρϵ)회 iteration이 진행되었다면, 1TT1t=0E[||f(wt)||2]ϵ이다.

 

 Theorem 6에 따르면, ϵ을 작게 잡을수록 수렴하는 데에 더 많은 iteration이 필요하고, 대신 더 정확한 결과가 도출된다는 것을 알 수 있습니다. 이는, overfitting이 바람직한 상황은 아니라는 것을 감안했을 때, 저자가 앞서 이야기한 "굳이 tight하게 ϵ을 잡을 필요 없다"는 주장과 일맥상통하는 부분입니다. 이를 증명하기 위해서는 추가적인 assumption이 한 가지 필요한데, Ek[||Fk(w)f(w)||2]σ2를 만족하는 σ가 존재한다, 즉, variation이 bounded되어 있다는 것입니다. 이는 앞서 살펴본 Assumption 1과 마찬가지로 크게 문제가 되는 가정은 아닙니다.

 

Corollary 10 (Bounded variance equivalance)

 Assumption 1이 성립하고, Ek[||Fk(w)f(w)||2]σ2를 만족하는 σ가 존재할 때, Bϵ1+σ2ϵ이다.

 

Proof

 Var[X]=E[X2]E[X]2임을 이용하면  Ek[||Fk(w)||2]||f(w)||2=Ek[||Fk(w)f(w)||2]σ2임을 알 수 있다.  이를 다시 정리하면 Ek[||Fk(w)||2]σ2+||f(w)||2로 나타낼 수 있으며, B의 정의에 따라서 Bϵ=Ek[||Fk(w)||2]||f(w)||2σ2+||f(w)||2||f(w)||2=1+σ2||f(w)||21+σ2ϵ이다.

 

Proof of Theorem 6

 Corollary 10의 결과를 Theorem 4ρ에 대입하면 알 수 있다.

 

7. convex case의 convergence

 

Corollary 7 (Convergence: Convex case)

 Theorem 4의 assumption이 모두 성립하고, Fk()들이 모두 convex하다고 가정하자. 더 나아가, 모든 kt에 대해서 γtk=0을 만족한다고 하자. (이는 곧 모든 local update가 exact하게 이루어짐을 뜻한다.) 이때, 만약 1B0.5K라면, 최적의 μμ6LB2이며, 따라서 ρ124LB이다.

 

Proof

 ρ=(1μγBμB(1+γ)2ˉμKLB(1+γ)ˉμμL(1+γ)2B22ˉμ2LB2(1+γ)2ˉμ2K(22K+2))임을 알고 있고, 가정에 따라 L=0, ˉμ=μ, γ=0이다. 따라서 식을 정리하면 ρ=(1μB2μKLBμ2LB22μ2LB2μ2K(22K+2))이다. 또한, B0.5K이므로 ρ(22)2μLB+(1+2)LB22μ2이다. 따라서,

 

ESt[f(wt+1)]f(wt)ρ||f(wt)||2f(wt)222μ||f(wt)||2+LB+(1+2)LB22μ2||f(wt)||2

 

이며, 더 나아가 \mu \approx 6LB^2로 잡을 경우, \mathbb{E}_{S_t} [f(w^{t+1})] \lessapprox f(w^t) - \frac {1} {24LB^2} ||\nabla f(w^t)||^2이다. \square

 

 convex case의 convergence는 최소 T := O(\frac {LB^2 \Delta} {\epsilon})회의 iteration 후에 진행되며, 이는 \text{Theorem 6}\text{Corollary 7}으로부터 자명합니다. 또한, bounded variance assumption에 의하여 O (\frac {L \Delta} {\epsilon} + \frac {L \Delta \sigma ^2} {\epsilon ^2})로도 표기할 수 있습니다. 이는 SGD의 시간복잡도와 정확히 일치하며, 저자들은 이러한 점을 근거로 해당 논문이 의미있다고 주장하고 있습니다. 물론, 이것이 FedProx가 distributed SGD보다 우위에 있다는 것을 의미하지는 않습니다. 그렇다기보다는, 연합학습 체계 중에서 convergence를 엄밀하게 증명해보인 첫 논문이라는 점, 그리고 그 속도가 SGD에 준한다는 점에서 해당 논문이 의미가 있다고 생각합니다.

 또한, \text{Corollary 7}을 통해서 알 수 있는 것이 한 가지 더 있는데, 비록 convex case로 한정지었지만, \mu를 어떻게 잡는지가 model의 convergence에 큰 영향을 준다는 것입니다. (물론, non-convex case에서도 똑같이 증명하기는 어렵겠지만 비슷한 결과가 나오리라 예상할 수 있고, 이 부분은 empirical test에서 확인하도록 하겠습니다.) 저는 처음에 이 부분을 명확하게 이해하기 어려웠는데, 분명 논문 초반부에서부터 계속해서 \gamma_k^t의 중요성을 언급하였음에도, 저자들은 \gamma가 아닌 \mu를 hyperparameter로 잡았기 때문입니다. (심지어, 이 \mu는 proximal term에서부터 정의도 되지 않은 채 쓰이고 있습니다.) 그렇다면, \gamma는 어디에서 사용되는 것이고, 또 hyperparameter가 아니라면 어디에서부터 주어지는 것일까요? 이에 관해서는 다음 포스트에서 ablation study와 함께 알아보도록 하고, 이에 관한 이야기를 하기 위해서 \gamma\gamma_k^t로 generalize하는 것으로 이번 포스트를 마무리짓겠습니다.

 

8. \gamma가 달라지는 경우의 convergence

 

\text{Corollary 9} (Convergence: variable \gamma's)

 local function F_k들이 모두 non-convex하고, L-Lipschitz smooth하다고 가정하자. 그리고 L_- > 0가 존재하여 \nabla^2 F_k \succeq -L_- \textbf{I}, \bar{\mu} := \mu - L_{-} > 0를 만족한다고 하자. w^t가 non-stationary한 solution이고 각 device의 local functions F_kB-dissimilar할 때 (즉, B(w^t) \leq B일 때), \mu, K, \gamma_k^t\rho^t := \left( \frac {1} {\mu} - \frac {\gamma^t B} {\mu} - \frac {B(1 + \gamma^t) \sqrt{2}} {\bar{\mu} \sqrt{K}} - \frac {LB(1 + \gamma^t)} {\bar{\mu} \mu} - \frac {L(1 + \gamma^t)^2 B^2} {2 \bar{\mu}^2} - \frac {LB^2 (1 + \gamma^t)^2} {\bar{\mu}^2 K} (2 \sqrt{2K} + 2) \right) > 0를 만족한다면, FedProx의 t번째 iteration에서 global objective function f의 expected decrease는 \mathbb{E}_{S_t} [f(w^{t + 1})] \leq f(w^t) - \rho^t ||\nabla f(w^t)||^2를 만족한다. (여기에서, S_tt번째 iteration에 참여하는 device들의 집합이고, \gamma^t := max_{k \in S_t} \; \gamma_k^t이다.)

 

\text{Proof}

 \gamma^t의 정의와 \text{Theorem 4}로부터 자명하다. \square

 

 이번 포스트에서는 FedProx의 convergence rate에 대한 증명과 convex case에서의 convergence 증명을 살펴보고, \gamma가 달라지는 case로 증명을 확장해보았습니다. 그리고 이 과정에서 hyperparameter \mu의 중요성도 확인해보았습니다. 다음 포스트에서는 도대체 \gamma가 무엇인지 (알아보자고 몇 번째 말하고 있는 것 같지만) 알아보고, 각종 ablation test를 살펴보겠습니다. (다음 글 보기)

'Federated Learning > Papers' 카테고리의 다른 글

[ICLR 2020] Convergence of FedAvg - (1)  (0) 2022.08.15
[MLSys 2020] FedProx - (4)  (0) 2022.08.13
[MLSys 2020] FedProx - (2)  (0) 2022.07.30
[MLSys 2020] FedProx - (1)  (0) 2022.07.25
[AISTATS 2017] FedSGD, FedAvg - (2)  (0) 2022.07.16
Comments