Federated Learning

[MLSys 2020] FedProx - (1) 본문

Federated Learning/Papers

[MLSys 2020] FedProx - (1)

pseudope 2022. 7. 25. 03:00
728x90

논문 제목: Federated Optimization in Heterogeneous Networks

출처: https://arxiv.org/abs/1812.06127

 

 지난 포스트에서, 우리는 연합학습의 시초라고 이야기할 수 있는 FedSGD, FedAvg에 관하여 알아보았습니다. (이전 글 보기) 처음 제안하는 알고리즘이기 때문에 다소 미흡한 부분도 존재한다는 점을 포스트 말미에 잠시 언급하였는데, 두 번째 paper review에서는 그중 한 가지를 해결하고자 노력한 FedProx 알고리즘에 관하여 알아보려고 합니다.

 

1. 연구 배경

 

 FedSGD, FedAvg의 ablation study를 자세히 보면, model이 어느 정도 수렴하는 것 같다면, 마치 learning rate를 decay하는 것처럼 round 당 local computation 횟수를 ($E$를 줄이거나 $B$를 키우면서) 감소시키는 것이 의미있을 수 있다는 이야기를 하고 있습니다. (이전 포스트에서 이에 관하여 언급하였습니다.) 여기에서 한발짝 더 나아가서, FedProx의 저자들은 각 iteration마다, 그리고 각 device마다 local epoch 횟수를 조절하는 것이 필요하다고 주장하고 있으며, 이에 대한 근거로 heterogeneous한 system 구성을 들고 있습니다. (여기에서 heterogeneous하다는 것은, data의 구성이 device 별로 non-IID하다는 이야기가 아니라, device의 사양이 상이하다는 이야기입니다.) 즉, device마다 동일한 local epoch를 수행한다면, 느리게 학습이 진행되는 device를 끝까지 기다리거나, 그러한 device를 버리고 제한 시간 안에 학습이 완료된 device로부터만 학습 결과를 받거나, 혹은 아직 학습이 안 끝난 device의 경우 진행된 부분까지만 그 결과를 받아야 하는데, 셋 모두 이상적인 상황이 아니라는 것이 저자들의 생각입니다. 저자들은 해당 논문을 통해 왜 이러한 주장이 정당화될 수 있는지, 또 어떻게 이러한 aggregation 방식의 convergence를 보장할 수 있는지에 관하여 언급하고 있습니다.

 

2. Proximal Term / $\gamma$-inexact solution / $\gamma_k^t$-inexact solution

 

 저자들은 client 단의 학습 과정에서 local function $F_k(\cdot)$ 대신 proximal term $\frac {\mu} {2}||w - w^t||^2$이 추가된 objective function $h_k(w; w^t) = F_k(w) + \frac {\mu} {2}||w - w^t||^2$를 minimize할 것을 주장합니다. 이 proximal term은 local model $w^t$가 global model $w$로부터 과도하게 멀어지는 것을 자동적으로 막아줍니다. (proximal term의 기능이 한 가지 더 있는데, 이는 다음 포스트에서 확인할 수 있습니다.) 그리고 이러한 함수 $h_k$에 대해서 다음을 정의합니다.

 

 $\text{Definition 2}$ ($\gamma_k^t$-inexact solution)

 함수 $h(w; w_t) = F_k(w) + \frac {\mu} {2} ||w - w_t||^2$와 $\gamma \in [0, 1]$에 대하여, 만약 $w^*$가 $||\nabla h(w^*; w_t)|| \leq \gamma_k^t ||\nabla h(w_t; w_t)||$를 만족한다면, 이러한 $w^*$를 $\gamma_k^t$-inexact solution이라고 정의한다. 이때, $h_k$의 정의 상  $\nabla h(w; w_t) = \nabla F(w) + \mu (w - w_t)$이며, $\gamma_k^t$가 $0$에 가까울수록 accuracy가 더 높다는 것을 의미한다.

 

 다음은 $\gamma$가 unifrom할 경우의 정의이며, 저자들은 해당 case에서의 내용을 우선 전개한 뒤, $\gamma$가 variant한 경우(즉, 위의 정의에 해당하는 경우)로 그 주장을 확장하는 방식으로 논문을 작성하였습니다.

 

 $\text{Definition 1}$ ($\gamma$-inexact solution)

 함수 $h(w; w_0) = F_k(w) + \frac {\mu} {2} ||w - w_0||^2$와 $\gamma \in [0, 1]$에 대하여, 만약 $w^*$가 $||\nabla h(w^*; w_0)|| \leq \gamma ||\nabla h(w_0; w_0)||$를 만족한다면, 이러한 $w^*$를 $\gamma$-inexact solution이라고 정의한다. 이때, $h$의 정의 상  $\nabla h(w; w_0) = \nabla F(w) + \mu (w - w_0)$이며, $\gamma$가 $0$에 가까울수록 accuracy가 더 높다는 것을 의미한다.

 

3. FedProx

 저자들이 제안하는 FedProx 알고리즘은 FedAvg에서 아주 조금 수정된 것입니다. 따라서 큰 틀에서는 두 알고리즘이 차이를 보이지 않으며, 유사한 방식으로 작동됩니다. 하지만 $\text{Input}$ 부분을 보면, FedAvg의 $\eta$와 $E$ 대신 FedProx에는 $\mu$와 $\gamma$가 있는 것을 알 수 있는데, 앞서 살펴본 정의에 등장하는 그 $\mu$와 $\gamma$가 맞으며, 우리는 이 두 가지가 FedProx 알고리즘에서 어떠한 역할을 하는지 확인할 필요가 있습니다. 또한, FedAvg에서는 학습에 참여하는 모든 device들이 정해진 epoch 수 만큼 학습을 마친 후, 그 결과를 서버로 반환하여야 비로소 model aggregation이 진행되는데, FedProx에서는 각 device 별로, round 별로 정해진 $\gamma_k^t$ 값에 따라 $\gamma_k^t$-inexact solution인 $w_k^{t + 1}$을 찾으면 학습 결과를 서버로 반환합니다. 이 점이 결정적으로 두 알고리즘이 차이를 보이는 부분입니다.

 

 이어지는 포스트에서, 우리는 FedProx의 Convergence 증명 과정을 자세하게 확인해볼 것입니다. 그리고 그 과정에서  $\mu$와 $\gamma$의 정체에 관해서도 알아볼 것입니다. (다음 글 보기)

'Federated Learning > Papers' 카테고리의 다른 글

[MLSys 2020] FedProx - (4)  (0) 2022.08.13
[MLSys 2020] FedProx - (3)  (0) 2022.08.01
[MLSys 2020] FedProx - (2)  (0) 2022.07.30
[AISTATS 2017] FedSGD, FedAvg - (2)  (0) 2022.07.16
[AISTATS 2017] FedSGD, FedAvg - (1)  (0) 2022.07.11
Comments