Federated Learning

[AISTATS 2017] FedSGD, FedAvg - (1) 본문

Federated Learning/Papers

[AISTATS 2017] FedSGD, FedAvg - (1)

pseudope 2022. 7. 11. 23:00
728x90

논문 제목: Communication-Efficient Learning of Deep Networks from Decentralized Data

출처: https://arxiv.org/abs/1602.05629

 

 첫 paper review 포스트의 대상은 연합학습을 처음으로 언급한 논문인 『Communication-Efficient Learning of Deep Networks from Decentralized Data』입니다. 해당 논문에서 언급된 FedSGD, FedAvg 알고리즘은 지금도 연합학습 시스템의 baseline 알고리즘으로 계속해서 각종 논문에 등장하고 있습니다. 비록 나온 지 조금 지났지만, 근본적인 논문이기 때문에 영향력이 상당하여 안 짚고 넘어갈 수 없었습니다. (2017년에 논문이 처음 등장했다니... 참으로 신생 분야이죠...?) 참고로, 해당 paper는 convergence analysis를 다루지 않습니다. 엄밀한 증명이 필요하신 분은 이 글을 확인해보시기 바랍니다.

 

0. Notation

 

 해당 논문뿐만 아니라 연합학습 분야의 논문 전반이 가진 특징인데, 정말 많은 notation이 등장합니다. 그리고 아쉽게도 논문 별로 notation이 통일되지 않은 경우가 많아 논문을 읽는 데에 어려움이 발생하기도 합니다. 해당 논문에서는 다음과 같은 notation을 사용하니 참고 바랍니다. 이후 다른 paper review에서는 해당 논문의 notation을 우선으로 할지, 이번 포스트의 notation을 우선으로 할지 아직 고민 중입니다. 혹시 좋은 의견이 있으시다면 댓글로 남겨주시기 바랍니다.

 

$K$: 전체 client의 수

$C$: 매 round(1회의 model update)마다 참여할 client의 fraction ($0 \leq C \leq 1$)

$E$: 매 round마다 local에서 학습할 epoch 수

$B$: 매 epoch마다 학습에 사용할 local dataset의 mini-batch size

$f(w)$: 우리가 구해야 할 model $w$의 loss

$f_{i}(w)$: 각 device 안의 model $w$의 loss

$\mathcal{P}_k$: client $k$가 가지고 있는 data의 indices, $n_k := |\mathcal{P}_k|$

(쉽게 말해서, 모든 device의 data를 합친 결과 총 $n$개의 data가 있을 때, $n = \sum_{k = 1}^{K} n_k$)

$u_k := \frac {En_k} {B}$ (한 round 당 이루어지는 local model의 update 횟수) 

$u := \frac {En} {KB}$ ($u_k$와 동일한 의미이지만, 특정 client $k$의 관점이 아닌, 일반화된 표현)

$\eta$: 고정된 learning rate

$g_k = \nabla F_k(w_t)$: client $k$의 $t$기 model인 $w_t$에 대한 average local gradient

 

1. FedSGD

 

 저자들은 현존하는 많은 optimizer들이 결국 SGD에서 시작된 것이므로, SGD에 기반한 연합학습 체계 구축은 자연스러운 수순이었다고 이야기합니다. 사실 FedSGD는 잠시 후에 살펴볼 FedAvg의 특수한 케이스이긴 하지만, FedSGD가 비교적 이해하기 쉬우므로 FedSGD를 우선적으로 확인하고 나서 FedAvg를 확인해보도록 하겠습니다.

 

 FedSGD는 매 round마다 전체 client 중 $C$ 만큼의 비율에 해당하는 client가 가진 모든 data를 한 번에 학습합니다. 즉, 참여하는 client들이 가진 dataset full-batch의 GD를 1회 계산하며, 위의 notation을 빌리면 $E = 1$, $B = \infty$인 case를 이야기하는 셈입니다. ($C$는 hyperparameter로 남는데, FedAvg 부분에서 이에 관하여 자세히 언급하겠습니다.) 일단 각 client $k$가 저마다의 $g_k = \nabla F_k(w_t)$를 구하는 과정이 끝나면, $\sum_{k = 1}^{K} \frac {n_k} {n} g_k = \nabla f(w_t)$이므로, 각 client로부터 학습 결과를 수집한 중앙 서버는 $w_{t + 1} \leftarrow w_t - \eta \sum_{k = 1}^{K} \frac {n_k} {n} g_k$와 같이 model update를 진행합니다. 다르게 표현하면, 임의의 $k \in K$에 대하여 $w_{t + 1} ^ k \leftarrow w_t - \eta g_k$로 local에서 update를 진행한 후, update의 결과인 $w_{t + 1} ^ k$ 들을 가지고 다시 한 번 global에서 $w_{t + 1} \leftarrow \sum_{k = 1}^{K} \frac {n_k} {n} w_{t + 1}^{k}$로 update를 진행한다고도 볼 수 있습니다.

 

2. FedAvg

 

 FedSGD는 매우 간단한 아이디어이지만 크게 흠잡을 것 없는 알고리즘입니다. 다만, 한 가지 아쉬운 점이 있다면, communication cost가 중요한 연합학습 체계에서 한 round에 1번의 local update만 반영한다는 것은 cost 낭비라는 점이죠. (FedSGD의 경우 이외에도 privacy preserving에 대한 이슈가 존재합니다만, 이러한 지적이 나온 시점이 해당 논문 발표 이후이기는 합니다.) 저자들은 이를 해결하기 위하여 다음과 같은 해결책을 제시합니다.

$w_t ^ k \leftarrow w_t - \eta g_k$로 local update를 여러 번 진행한 후 global update를 수행하면 되지 않을까?

 

 FedAvg 알고리즘은 오른쪽 pseudo code와 같이 구성되어 있습니다. $m \leftarrow max(C \cdot K, 1)$에서 $max(\cdot)$ function이 사용된 이유는, $0 \leq C \leq 1$이므로 $C = 0$인 경우가 존재할 수 있기 때문입니다. $0$개의 client가 학습을 진행한다는 것은 결국 이전의 state와 동일하다는 의미이기 때문에, 이를 막기 위해 최소한 $1$개의 client를 사용한다는 것을 보장하는 장치인 셈이죠.

 $C$라는 hyperparameter가 도입된 이유는 efficiency 때문입니다. cost 관점에서도 그렇고, experiments를 볼 때 일정 수 이상의 client가 학습에 관여했을 때 한계효용이 감소한다는 점에서도 적절한 $C$를 결정하는 것은 중요합니다.

 또 다른 hyperparameter $B$는 각 client 별 local dataset에 대한 mini-batch의 size이며, 1 batch에 관한 학습이 끝날 때마다 local update가 진행됩니다. 그리고 이 작업이 총 $E$ epoch만큼 반복된 뒤에, 학습 결과는 global update를 위하여 중앙 서버로 전송됩니다.

 

 이어지는 포스트에서, 우리는 해당 논문의 experiments를 확인한 뒤, 의의 및 한계에 관하여 알아볼 것입니다. (다음 글 보기)

'Federated Learning > Papers' 카테고리의 다른 글

[MLSys 2020] FedProx - (4)  (0) 2022.08.13
[MLSys 2020] FedProx - (3)  (0) 2022.08.01
[MLSys 2020] FedProx - (2)  (0) 2022.07.30
[MLSys 2020] FedProx - (1)  (0) 2022.07.25
[AISTATS 2017] FedSGD, FedAvg - (2)  (0) 2022.07.16
Comments