Federated Learning

[ICLR 2020] Convergence of FedAvg - (1) 본문

Federated Learning/Papers

[ICLR 2020] Convergence of FedAvg - (1)

pseudope 2022. 8. 15. 00:00
728x90

논문 제목: On the Convergence of FedAvg on Non-IID Data

출처: https://arxiv.org/abs/1907.02189

 

 지난 포스트에서, 우리는 model aggregation method의 convergence를 엄밀하게 증명한 첫 논문인 FedProx에 관하여 알아보았습니다. (이전 글 보기) heterogeneous한 구성에서 FedProx가 잘 작동한다는 것은 알 수 있었지만, FedAvg의 convergence에 관한 명확한 이야기를 들어볼 수 없었다는 점은 조금 아쉽기도 했습니다. 이번에 살펴볼 논문이 이에 대한 대답을 줄 수 있을 것 같은데, 한 번 확인해보도록 하겠습니다. 해당 논문은 (FedProx와 더불어) convergence analysis의 표본이 되는 논문으로, 관련 분야를 공부하신다면 꼭 유심히 살펴보시는 것을 권장드립니다.

 

0. Notation

 

 ※ 주의: FedAvg 논문에서는 $E$를 epoch 수로 사용하였는데, 해당 논문에서는 $E$를 1회의 global update를 위해 요구되는 local update의 수로 사용합니다. 즉, mini-batch 단위입니다. "$t$기"라는 표현도 mini-batch 단위로 계산합니다.

 

 $\xi_t^k$: $t$기에 $k^{th}$ device의 local dataset으로부터 uniformly random하게 추출된 dataset. (즉, mini-batch 1개)

 $w_t^k$: $t$기에 $k^{th}$ device가 가지고 있는 model parameter.

 $\eta_t$: $t$기의 learning rate. (혹은 step size)

 $v_{t+1}^k := w_t^k - \eta_t \nabla F_k(w_t^k, \xi_t^k)$ (즉, 1회의 local update 결과)

 $\mathcal{I}_E$: global synchronization step(즉, model aggregation round)의 집합. 다시 말해, $\mathcal{I}_E := \{nE\; | \; n = 1, 2, \dots \}$.

 $N$: device의 총 갯수.

 $K$: 해당 round에서 학습에 참여하는 client 수의 threshold. ($1 \leq K \leq N$)

 $\mathcal{S}_t$: $t$기에 참여하는 client들의 집합. (따라서, $|\mathcal{S}_t| = K$)

 $T$: 총 local update 횟수. (따라서, $\frac {T} {E}$는 communication round의 수)

 $n_k$: $k^{th}$ device가 지닌 data의 갯수.

 $p_k := \frac {n_k} {\sum_{k = 1}^N n_k} \geq 0$ (따라서, $\sum_{k = 1}^{N} p_k = 1$)

 $x_{k, j}$: $k^{th}$ device의 $j^{th}$ data.

 $\Gamma := F^* - \sum_{k=1}^N p_k F_k^*$ (Stochastic Heterogeniety를 나타내는 지표로, Non-IID case에서는 $\Gamma > 0$)

 

1. FedAvg 복습

 

 우선, 해당 논문의 notation에 따라서 FedAvg 알고리즘을 다시 정리해보겠습니다. 우리가 원하는 것은 global objective function $F$를 minimize하는 것, 즉, $\min_w \; \{ F(w) := \sum_{k = 1}^N p_k F_k(w) \}$입니다. 여기에서 $F_k$는 각 device마다 가지고 있는 local objective function이며, $F_k(w) := \frac {1} {n_k} \sum_{j=1}^{n_k} \ell (w; x_{k, j})$로 정의됩니다. (단, $\ell$은 각 device 고유의 loss function.)

 이러한 구성에서, $t$기의 학습을 마친 후, $t + 1$기로 넘어가기 위해 model을 update하는 과정을 살펴봅시다. 만약 $t + 1$기가 communication round라면, 즉, $t+1 \in \mathcal{I}_E$라면, $w_{t+1}^k$은 각 local model의 update 결과인 $v_{t+1}$의 aggregation이 될 것이고, 만약 그렇지 않다면 $w_{t+1}^k$가 곧 $v_{t+1}^k$일 것입니다. 이를 수식으로 정리하면 다음과 같습니다.

$$v_{t+1}^k = w_t^k - \eta_t \nabla F_k(w_t^k, \xi_t^k)$$

\begin{equation} w_{t + 1}^k = \begin{cases} v_{t+1}^k & \text{if $t+1 \notin \mathcal{I}_E$}\\ \sum_{k=1}^N p_k v_{t+1}^k & \text{if $t+1 \in \mathcal{I}_E$} \end{cases} \end{equation}

 다만, 위 식은 모든 device가 학습에 참여한다는 가정 하에 성립하는 것이며, 우리는 FedAvg에서 straggler의 학습 결과가 버려진다는 것을 알고 있습니다. 즉, 위 식을 그대로 따른다면 $w_{t+E} \leftarrow \sum_{k=1}^N p_k w_{t+E}^k$로 update가 진행되어야 하지만, partial device participation을 고려해야 하므로 $w_{t+E} \leftarrow \frac {N} {K} \sum_{k \in \mathcal{S}_t} p_k w_{t+E}^k$로 update가 진행됩니다. (확률적으로 가중치의 합이 $1$이 되도록 하기 위하여 $\frac {N} {K}$를 곱해준 것입니다.) 해당 논문에서는 full device participation case(즉, $K = N$)를 우선 증명한 후, partial case(즉, $K < N$)로 그 증명을 확장하는 서술 방식을 사용하고 있습니다.

 

 Sampling과 Averaging 부분이 기존 paper와 약간 다른데, 위의 table이 기존에 제시된 방법들과 저자들이 제시하는 방법 간의 차이를 보여줍니다. (위에서부터 순서대로 FedAvg, FedProx, 그리고 현재 보고 있는 논문 순입니다.) 우선 Sampling 부분을 살펴보면, FedProx에서 제시한 FedAvg의 경우 각 device마다 선택될 확률이 별도로 존재하기 때문에 $p$라는 별도의 parameter가 있으며, 따라서 sampling with replacement가 허용됩니다. 그러므로 이 경우에는 $\mathcal{S}_t$를 set이 아닌 family로 정의합니다. 그 외의 논문에서는 특정 device가 선택될 확률은 동일하다고 보고 있으며, sampling without replacement이기 때문에 $\mathcal{S}_t$는 set입니다. 다음으로 Averaging 부분을 살펴보면, 원 논문에서는 해당 round의 학습에 참여하지 않은 client에 대해서도 계산을 진행하였지만 이후의 논문들에서는 학습에 관여한 client에 대해서만 계산하는 모습을 볼 수 있습니다. 아래의 두 논문은 Sampling 과정에서 각 device가 선택될 확률이 일정한지 혹은 별도로 존재하는지에 따라 표현 방식이 다소 상이하지만, 근본적으로는 같은 방법을 사용하고 있습니다.

 

2. Assumptions - (1)

 

  임의의 model parameter $v, w$에 대하여, local device들의 함수 $F_1, F_2, \dots, F_N$이 다음 네 가지의 assumption을 만족한다고 가정합시다. $\text{Assumption 1 ~ 3}$은 FedProx 증명 과정에서도 사용한 가정이며, $\text{Assumption 4}$의 경우 $\text{Assumption 3}$을 사용한다는 전제 하에 크게 문제될 것이 없는 가정입니다.

 

$\text{Assumption 1}$ [$L$-smoothness]

 $F_1, F_2, \dots, F_N$은 모두 $L$-smooth하다. 즉,  $F_k(v) - F_k(w) \leq (v - w)^T \nabla F_k(w) + \frac {L} {2} ||v - w||_2^2$이다.

 

$\text{Assumption 2}$ [$\mu$-strong convexity]

 $F_1, F_2, \dots, F_N$은 모두 $\mu$-strong convex하다. 즉, $F_k(v) - F_k(w) \geq (v - w)^T \nabla F_k(w) + \frac {\mu} {2} ||v - w||_2^2$이다.

 

$\text{Assumption 3}$ [Bounded Variance]

 각 device가 가지는 stochastic gradeint의 variance는 bounded되어 있다. 즉, $\mathbb{E} [||\nabla F_k (w_t^k, \xi_t^k) - \nabla F_k (w_t^k)||^2] \leq \sigma_k^2$를 만족하는 $\sigma_k$가 존재한다.

 

$\text{Assumption 4}$ [Uniformly Bounded Squared Expectation]

 각 device가 가지는 stochastic gradeint의 expectation의 제곱은 uniformly bounded되어 있다. 즉, $k$에 상관 없이 $\mathbb{E} [||\nabla F_k (w_t^k, \xi_t^k) ||^2] \leq G^2$를  만족하는 $G$가 존재한다.

 

 이어지는 포스트에서, 우리는 full device participation case에 대한 FedAvg의 convergence analysis를 확인해볼 것입니다. 증명 과정이 다소 길어서, 두세 편으로 나누어 게재할 예정입니다. (여담이지만, FedProx의 경우 증명을 한 번에 하고 있어서 읽기 다소 번거로웠는데, 해당 논문의 경우 여러 Lemma로 증명 과정을 나누어 놓아서 비교적 읽기 편했습니다.)

'Federated Learning > Papers' 카테고리의 다른 글

[ICLR 2020] Convergence of FedAvg - (3)  (0) 2022.08.19
[ICLR 2020] Convergence of FedAvg - (2)  (0) 2022.08.16
[MLSys 2020] FedProx - (4)  (0) 2022.08.13
[MLSys 2020] FedProx - (3)  (0) 2022.08.01
[MLSys 2020] FedProx - (2)  (0) 2022.07.30
Comments