Federated Learning

[ICML 2019] Agnostic FL - (6) 본문

Federated Learning/Papers

[ICML 2019] Agnostic FL - (6)

pseudope 2022. 11. 17. 01:00
728x90

논문 제목: Agnostic Federated Learning

출처: https://arxiv.org/abs/1902.00146

 

 지난 포스트에서 STOCHASTIC-AFL 알고리즘의 convergence를 증명하였습니다. (이전 글 보기) convergence의 bound가 σ2wσ2λ에 dependent하다는 것을 확인하였는데, 이번 포스트에서는 이 부분에 관하여 조금 더 자세하게 알아보겠습니다.

 

9. Stochastic Gradients

 

 앞서 정의한 바에 따르면, 임의의 wW, λΛ, k[p]에 대하여, 우리의 objective function L(w,λ)은 다음과 같습니다:

L(w,λ):=pk=1λkLk(w), where Lk(w):=1mkmki=1(h(xk,i),yk,i)

여기에서, 편의 상 Lk,i(w):=(h(xk,i),yk,i)로 notate하면, 다음과 같이 정리됩니다:

L(w,λ)=pk=1λkmkmki=1Lk,i(w)

따라서,

wL(w,λ)=pk=1λkmkmki=1wLk,i(w),

[λL(w,λ)]k=1mkmki=1Lk,i(w)=Lk(w)

임을 알 수 있습니다.

 

 우선, δλL(w,λ)에 대해서 살펴보겠습니다. 특이한 점이 있다면, wL(w,λ)λ에 independent하다는 것입니다. 따라서 δλL(w,λ)를 구성할 때에는 단순히 uniform하게 sampling을 할 것입니다. 이에 관한 pseudo code는 오른쪽과 같습니다. (특정한 한 가지의 client K에 대해서 정확히 한 개의 data IK를 sampling합니다.) 그리고 다음 Lemma는 이러한 구성 하에서 σ2λ가 bounded된다는 것을 보장해줍니다.

 

※ 주의사항: 원 논문에 있는 Lemma 6 증명이 이해가 잘 되지 않아서 따로 증명을 하였습니다. 그런데 그 결과 pM2(p1)M2이 되어 약간의 gap이 생겼습니다. 어찌됐든 (p1)M2<pM2이므로 크게 문제될 부분은 아닌데, 저자들의 주장과 다소 차이가 생겼다는 점 참고 바랍니다. (혹시 아래의 증명 과정에서 문제되는 부분이 있다면 댓글 부탁드리겠습니다. 그리고 원래의 증명에 오류가 있었는지에 관해서 아시는 분도 댓글 부탁드립니다.)

 

Lemma 6

 

 δλL(w,λ)λL(w,λ)의 unbiased estimator이다. 더 나아가, 만약 loss function이 M으로 bounded되어 있다면, 다음이 성립한다:

σ2λ:=max

 

\text{Proof}

 unbiasedness는 정의 상 자명하므로, variance 부분을 살펴보도록 하자. uniform하게 sampling을 진행할 것이기 때문에, 특정 client k \in [p]가 sampling될 확률은 \frac {1} {p}, 그렇지 못할 확률은 \left( 1 - \frac {1} {p} \right)이다. 따라서, \text{Var} (X) = \mathbb{E} [X^2] - \left( \mathbb{E} [X] \right)^2임을 이용하면, client k에 대한 variance를 다음과 같이 나타낼 수 있다:

\begin{align*} \text{Var} \left( \left[ \delta_\lambda L (w, \lambda) \right]_k \right) &= \mathbb{E} \left[ \left[ \delta_\lambda L (w, \lambda)^2 \right]_k \right] - \left( \mathbb{E} \left[ \left[ \delta_\lambda L (w, \lambda) \right]_k \right] \right)^2 \\&= (1 - \frac {1} {p}) \times 0^2 + \frac {1} {p} \times \left( \frac {1} {m_k} \sum_{i=1}^{m_k} p L_{k, i} (w) \right)^2 - \left( \mathbb{E} \left[ \left[ \delta_\lambda L (w, \lambda) \right]_k \right] \right)^2 \\&= p \times \left( \frac {1} {m_k} \sum_{i=1}^{m_k} L_{k, i} (w) \right)^2 - \left( \mathbb{E} \left[ \left[ \delta_\lambda L (w, \lambda) \right]_k \right] \right)^2 \\&= p \times \left( L_k (w) \right)^2 - \left( L_k (w) \right)^2 \\&\quad (\because \mathbb{E} [\delta_\lambda L (w, \lambda)] = \nabla_\lambda L (w, \lambda) \text{, where } \left[ \nabla_\lambda L (w, \lambda) \right]_k = \frac {1} {m_k} \sum_{i=1}^{m_k} L_{k, i} (w) = L_k (w)) \\&= (p - 1) \times \left( L_k (w) \right)^2 \\&\leq (p - 1) M^2 \; (\because L_k (w) \leq M \text{ by assumption}) \\&< p M^2   \end{align*}

 client는 총 p개가 있고 서로 independent하므로, \text{Var} \left( \left[ \delta_\lambda L (w, \lambda) \right] \right) \leq p \times p M^2 = p^2 M^2이다. \square

 

 보다시피  \text{Var} \left[ \delta_\lambda L (w, \lambda) \right]p^2 term이 있기 때문에, 학습에 참여하는 client의 수가 많을수록 variance가 커질 수밖에 없습니다. 저자들은 만약 variance가 커지는 것이 걱정된다면, 마치 (중앙에서 학습을 진행할 때의) mini-batch gradient descent처럼 각 client마다 1개씩 sampling한 것을 가지고 학습을 진행할 것을 제안합니다.

 

 다음으로, \delta_w L (w, \lambda)에 대해서 살펴보겠습니다. 아쉽게도, \nabla_w L (w, \lambda)\lambda에도 dependent하고 w에도 dependent하므로, \nabla_\lambda L (w, \lambda)보다 고려해야 할 사항이 많습니다. 저자들은 \text{PerDomain} method와 \text{Weighted} method, 이렇게 두 가지를 제안하고 있으며, 이에 관한 pseudo code는 오른쪽과 같습니다. 전자는 \delta_\lambda L (w, \lambda)를 구성하듯이 client 별로 unifrom하게 한 개씩 sampling하는 방식이고, 후자는 \lambda의 distribution에 따라 한 개의 client를 sampling한 후, 그 안에서 한 가지 sample을 uniform하게 선택하는 방식입니다.

 

\text{Definition 7}

 

 w에 관한 intra-domain variance \sigma_I^2 (w)와 outer-domain variance \sigma_O^2 (w)를 다음과 같이 정의한다:

\sigma_I^2 (w):= \max_{\substack{w \in \mathcal{W} \\ k \in [p]}} \frac {1} {m_k} \sum_{j=1}^{m_k} \left[ \nabla_w L_{k, j} (w) - \nabla_w L_k (w) \right]^2

\sigma_O^2 (w):= \max_{\substack{w \in \mathcal{W} \\ \lambda \in \Lambda}} \sum_{k=1}^p \lambda_k \left[ \nabla_w L_k (w) - \nabla_w L (w, \lambda) \right]^2

그리고 이때, sample 한 개에 대하여 loss와 w에 관한 gradient를 계산하는 것의 시간복잡도를 U로 denote한다.

 

 위의 \text{Definition}과 함께, 아래의 두 \text{Lemma}는 이러한 구성 하에서 \sigma_w^2가 bounded된다는 것을 보장해줍니다.

 

\text{Lemma 8}

 

 \text{PerDomain} method를 이용하여 구한 stochastic gradient \delta_w L (w, \lambda)\nabla_w L (w, \lambda)의 unbiased estimaor이며, p U + \mathcal{O} (p \log m)의 시간복잡도를 갖는다. 더 나아가, \sigma_w^2 \leq R_\Lambda \sigma_I^2 (w)이다.

 

\text{Proof}

 unbiasedness와 시간복잡도는 정의 상 자명하므로, variance 부분을 살펴보도록 하자. \nabla_w L_{k, J_k} (w)\nabla_w L_k (w)의 unbiased estiamate이므로, 다음이 성립한다:

\begin{align*} \text{Var} [\delta_w] &= \text{Var} \left( \sum_{k=1}^p \lambda_k \nabla_w L_{k, J_k} (w) \right) =  \sum_{k=1}^p \lambda_k^2 \text{Var} \left(  \nabla_w L_{k, J_k} (w) \right) \\&= \sum_{k=1}^p \lambda_k^2 \mathbb{E} \left[  \nabla_w L_{k, J_k} (w) - \mathbb{E} [\nabla_w L_{k, J_k} (w)] \right]^2 = \sum_{k=1}^p \lambda_k^2 \mathbb{E} \left[  \nabla_w L_{k, J_k} (w) - \nabla_w L_k (w) \right]^2 \\&\leq \sum_{k=1}^p \lambda_k^2 \sigma_I^2 (w) \leq R_\Lambda \sigma_I^2 (w) \; (\because \text{Properties 1}) \; \square \end{align*}

 

\text{Lemma 9}

 

 \text{Weighted} method를 이용하여 구한 stochastic gradient \delta_w L (w, \lambda)\nabla_w L (w, \lambda)의 unbiased estimaor이며, U + \mathcal{O} (p + \log n)의 시간복잡도를 갖는다. 더 나아가, \sigma_w^2 \leq \sigma_I^2 (w) + \sigma_O^2 (w)이다.

 

\text{Proof}

 unbiasedness와 시간복잡도는 정의 상 자명하므로, variance 부분만 살펴보도록 하자.

\begin{align*} \text{Var} (\delta_w) &= \sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{j=1}^{m_k} \left( \nabla_w L_{k, j} (w) - \nabla_w L (w, \lambda) \right)^2 \\&= \sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{j=1}^{m_k} \left( \nabla_w L_{k, j} (w) - \nabla_w L_k (w) + \nabla_w L_k (w) - \nabla_w L (w, \lambda) \right)^2 \\&= \sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{j=1}^{m_k} \left( \nabla_w L_{k, j} (w) - \nabla_w L_k (w) \right)^2 + \sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{j=1}^{m_k} \left( \nabla_w L_k (w) - \nabla_w L (w, \lambda) \right)^2 \\&\quad + 2 \underbrace{\sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{j=1}^{m_k} \left( \nabla_w L_{k, j} (w) - \nabla_w L_k (w) \right) \left( \nabla_w L_k (w) - \nabla_w L (w, \lambda) \right)}_{= \; 0} \\&= \sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{j=1}^{m_k} \left( \nabla_w L_{k, j} (w) - \nabla_w L_k (w) \right)^2 + \sum_{k=1}^p \lambda_k \left( L_k (w) - \nabla_w L (w, \lambda) \right)^2 \\&\leq \sigma_I^2 (w) + \sigma_O^2 (w) \; (\because \sum_{k=1}^p \lambda_k = 1) \; \square \end{align*}

 

9. \text{PerDomain} Vs. \text{Weighted}

 

 정의 상 R_\Lambda \leq 1이기 때문에, \text{Var} (\delta_w) term만 볼 경우 당연히 \text{PerDomain} method가 더 유리할 것입니다. 하지만 시간복잡도를 고려하면, \text{PerDomain} method에는 p term이 있어서 (p가 충분히 작지 않은 이상) \text{Weighted} method가 더 유리합니다. 따라서, 학습에 참여하는 client의 수가 그다지 많지 않은 상황에서는 \text{PerDomain} method를 우선적으로 고려하되, 그 수가 커질 때에는 유불리를 잘 비교하여 판단할 필요가 있습니다.  

 

 한편, p가 지나치게 커진다면 두 method의 variance를 직접적으로 비교하는 것은 적절치 못할 것입니다. 그 이유는 \text{PerDomain}p개의 data를, \text{Weighted}1개의 data를 이용하여 계산하기 때문입니다.  따라서 저자들은 p\text{ -Weighted} method를 정의하였는데, 이는 \text{Weighted} method를 independent한 p개의 sample에 대해서 수행한 값의 평균입니다. 따라서 이때의 분산은 \text{Var} \left( p\text{ -Weighted} \right) = \frac {\sigma_I^2 (w) + \sigma_O^2 (w)} {p}이며, 정의 상 R_\Lambda \leq \max_{\lambda \in \Lambda} ||\lambda||_2 = \max_{\lambda \in \Lambda} \sum_{k=1}^p \lambda_k^2 \leq \sum_{k=1}^p \left( \frac {1} {p} \right)^2 = \frac {1} {p}이므로, \text{Var} \left( \text{PerDomain} \right) \geq \frac {\sigma_I^2 (w)} {p}입니다. 이렇게 놓고 봤을 때, 만약 \lambda 값들이 비교적 일정하여서 R_\Lambda \approx \frac {1} {p}인 상황이라면, \sigma_O^2 (w) term이 커질 것이기 때문에 \text{PerDomain} method를 사용하는 것이 바람직할 것입니다. 반면, 만약 \sigma_O^2 (w) term이 충분히 작다면, \text{Weighted} method를 쓰는 것이 좋겠죠. 이것은 상황에 따라서 선택해야 하는 부분으로, 반드시 어느 method가 좋다고 이야기하기는 어려울 듯합니다.

 

 여기까지가 해당 paper에서 하고 싶었던 이야기들이고, 다음 포스트에서는 experiments를 살펴보도록 하겠습니다. \text{OPTIMISTIC STOCHASTIC-AFL}과 같은 variation 몇 가지가 논문의 appendix에 함께 실려 있는데, 이는 별도로 다루지 않을 계획입니다. 관심 있으신 분들은 해당 논문을 참조해주시기 바랍니다.

'Federated Learning > Papers' 카테고리의 다른 글

[ICLR 2020] q-FFL, q-FedAvg - (1)  (0) 2022.11.22
[ICML 2019] Agnostic FL - (7)  (1) 2022.11.18
[ICML 2019] Agnostic FL - (5)  (0) 2022.11.14
[ICML 2019] Agnostic FL - (4)  (0) 2022.11.08
[ICML 2019] Agnostic FL - (3)  (0) 2022.11.07