Federated Learning

[ICML 2019] Agnostic FL - (6) 본문

Federated Learning/Papers

[ICML 2019] Agnostic FL - (6)

pseudope 2022. 11. 17. 01:00
728x90

논문 제목: Agnostic Federated Learning

출처: https://arxiv.org/abs/1902.00146

 

 지난 포스트에서 $\text{STOCHASTIC-AFL}$ 알고리즘의 convergence를 증명하였습니다. (이전 글 보기) convergence의 bound가 $\sigma_w^2$와 $\sigma_\lambda^2$에 dependent하다는 것을 확인하였는데, 이번 포스트에서는 이 부분에 관하여 조금 더 자세하게 알아보겠습니다.

 

9. Stochastic Gradients

 

 앞서 정의한 바에 따르면, 임의의 $w \in \mathcal{W}$, $\lambda \in \Lambda$, $k \in [p]$에 대하여, 우리의 objective function $L (w, \lambda)$은 다음과 같습니다:

$$L (w, \lambda) := \sum_{k=1}^p \lambda_k L_k (w), \text{ where } L_k (w) := \frac {1} {m_k} \sum_{i=1}^{m_k} \ell (h(x_{k,i}), y_{k,i})$$

여기에서, 편의 상 $L_{k, i} (w) := \ell (h(x_{k,i}), y_{k,i})$로 notate하면, 다음과 같이 정리됩니다:

$$L (w, \lambda) = \sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{i=1}^{m_k} L_{k, i} (w) $$

따라서,

$$\nabla_w L (w, \lambda) = \sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{i=1}^{m_k} \nabla_w L_{k, i} (w),$$

$$\left[ \nabla_\lambda L (w, \lambda) \right]_k = \frac {1} {m_k} \sum_{i=1}^{m_k} L_{k, i} (w) = L_k (w)$$

임을 알 수 있습니다.

 

 우선, $\delta_\lambda L (w, \lambda)$에 대해서 살펴보겠습니다. 특이한 점이 있다면, $\nabla_w L (w, \lambda)$는$\lambda$에 independent하다는 것입니다. 따라서 $\delta_\lambda L (w, \lambda)$를 구성할 때에는 단순히 uniform하게 sampling을 할 것입니다. 이에 관한 pseudo code는 오른쪽과 같습니다. (특정한 한 가지의 client $K$에 대해서 정확히 한 개의 data $I_K$를 sampling합니다.) 그리고 다음 $\text{Lemma}$는 이러한 구성 하에서 $\sigma_\lambda^2$가 bounded된다는 것을 보장해줍니다.

 

※ 주의사항: 원 논문에 있는 $\text{Lemma 6}$ 증명이 이해가 잘 되지 않아서 따로 증명을 하였습니다. 그런데 그 결과 $p M^2$이 $(p - 1) M^2$이 되어 약간의 gap이 생겼습니다. 어찌됐든 $(p - 1) M^2 < p M^2$이므로 크게 문제될 부분은 아닌데, 저자들의 주장과 다소 차이가 생겼다는 점 참고 바랍니다. (혹시 아래의 증명 과정에서 문제되는 부분이 있다면 댓글 부탁드리겠습니다. 그리고 원래의 증명에 오류가 있었는지에 관해서 아시는 분도 댓글 부탁드립니다.)

 

$\text{Lemma 6}$

 

 $\delta_\lambda L (w, \lambda)$는 $\nabla_\lambda L (w, \lambda)$의 unbiased estimator이다. 더 나아가, 만약 loss function이 $M$으로 bounded되어 있다면, 다음이 성립한다:

$$\sigma_\lambda^2 := \max_{\substack{w \in \mathcal{W} \\ \lambda \in \Lambda}} \text{Var} (\delta_\lambda L (w, \lambda)) \leq p^2 M^2$$

 

$\text{Proof}$

 unbiasedness는 정의 상 자명하므로, variance 부분을 살펴보도록 하자. uniform하게 sampling을 진행할 것이기 때문에, 특정 client $k \in [p]$가 sampling될 확률은 $\frac {1} {p}$, 그렇지 못할 확률은 $\left( 1 - \frac {1} {p} \right)$이다. 따라서, $\text{Var} (X) = \mathbb{E} [X^2] - \left( \mathbb{E} [X] \right)^2$임을 이용하면, client $k$에 대한 variance를 다음과 같이 나타낼 수 있다:

\begin{align*} \text{Var} \left( \left[ \delta_\lambda L (w, \lambda) \right]_k \right) &= \mathbb{E} \left[ \left[ \delta_\lambda L (w, \lambda)^2 \right]_k \right] - \left( \mathbb{E} \left[ \left[ \delta_\lambda L (w, \lambda) \right]_k \right] \right)^2 \\&= (1 - \frac {1} {p}) \times 0^2 + \frac {1} {p} \times \left( \frac {1} {m_k} \sum_{i=1}^{m_k} p L_{k, i} (w) \right)^2 - \left( \mathbb{E} \left[ \left[ \delta_\lambda L (w, \lambda) \right]_k \right] \right)^2 \\&= p \times \left( \frac {1} {m_k} \sum_{i=1}^{m_k} L_{k, i} (w) \right)^2 - \left( \mathbb{E} \left[ \left[ \delta_\lambda L (w, \lambda) \right]_k \right] \right)^2 \\&= p \times \left( L_k (w) \right)^2 - \left( L_k (w) \right)^2 \\&\quad (\because \mathbb{E} [\delta_\lambda L (w, \lambda)] = \nabla_\lambda L (w, \lambda) \text{, where } \left[ \nabla_\lambda L (w, \lambda) \right]_k = \frac {1} {m_k} \sum_{i=1}^{m_k} L_{k, i} (w) = L_k (w)) \\&= (p - 1) \times \left( L_k (w) \right)^2 \\&\leq (p - 1) M^2 \; (\because L_k (w) \leq M \text{ by assumption}) \\&< p M^2   \end{align*}

 client는 총 p개가 있고 서로 independent하므로, $\text{Var} \left( \left[ \delta_\lambda L (w, \lambda) \right] \right) \leq p \times p M^2 = p^2 M^2$이다. $\square$

 

 보다시피  $\text{Var} \left[ \delta_\lambda L (w, \lambda) \right]$에 $p^2$ term이 있기 때문에, 학습에 참여하는 client의 수가 많을수록 variance가 커질 수밖에 없습니다. 저자들은 만약 variance가 커지는 것이 걱정된다면, 마치 (중앙에서 학습을 진행할 때의) mini-batch gradient descent처럼 각 client마다 1개씩 sampling한 것을 가지고 학습을 진행할 것을 제안합니다.

 

 다음으로, $\delta_w L (w, \lambda)$에 대해서 살펴보겠습니다. 아쉽게도, $\nabla_w L (w, \lambda)$는 $\lambda$에도 dependent하고 $w$에도 dependent하므로, $\nabla_\lambda L (w, \lambda)$보다 고려해야 할 사항이 많습니다. 저자들은 $\text{PerDomain}$ method와 $\text{Weighted}$ method, 이렇게 두 가지를 제안하고 있으며, 이에 관한 pseudo code는 오른쪽과 같습니다. 전자는 $\delta_\lambda L (w, \lambda)$를 구성하듯이 client 별로 unifrom하게 한 개씩 sampling하는 방식이고, 후자는 $\lambda$의 distribution에 따라 한 개의 client를 sampling한 후, 그 안에서 한 가지 sample을 uniform하게 선택하는 방식입니다.

 

$\text{Definition 7}$

 

 $w$에 관한 intra-domain variance $\sigma_I^2 (w)$와 outer-domain variance $\sigma_O^2 (w)$를 다음과 같이 정의한다:

$$\sigma_I^2 (w):= \max_{\substack{w \in \mathcal{W} \\ k \in [p]}} \frac {1} {m_k} \sum_{j=1}^{m_k} \left[ \nabla_w L_{k, j} (w) - \nabla_w L_k (w) \right]^2$$

$$\sigma_O^2 (w):= \max_{\substack{w \in \mathcal{W} \\ \lambda \in \Lambda}} \sum_{k=1}^p \lambda_k \left[ \nabla_w L_k (w) - \nabla_w L (w, \lambda) \right]^2$$

그리고 이때, sample 한 개에 대하여 loss와 $w$에 관한 gradient를 계산하는 것의 시간복잡도를 $U$로 denote한다.

 

 위의 $\text{Definition}$과 함께, 아래의 두 $\text{Lemma}$는 이러한 구성 하에서 $\sigma_w^2$가 bounded된다는 것을 보장해줍니다.

 

$\text{Lemma 8}$

 

 $\text{PerDomain}$ method를 이용하여 구한 stochastic gradient $\delta_w L (w, \lambda)$는$\nabla_w L (w, \lambda)$의 unbiased estimaor이며, $p U + \mathcal{O} (p \log m)$의 시간복잡도를 갖는다. 더 나아가, $\sigma_w^2 \leq R_\Lambda \sigma_I^2 (w)$이다.

 

$\text{Proof}$

 unbiasedness와 시간복잡도는 정의 상 자명하므로, variance 부분을 살펴보도록 하자. $\nabla_w L_{k, J_k} (w)$는 $\nabla_w L_k (w)$의 unbiased estiamate이므로, 다음이 성립한다:

\begin{align*} \text{Var} [\delta_w] &= \text{Var} \left( \sum_{k=1}^p \lambda_k \nabla_w L_{k, J_k} (w) \right) =  \sum_{k=1}^p \lambda_k^2 \text{Var} \left(  \nabla_w L_{k, J_k} (w) \right) \\&= \sum_{k=1}^p \lambda_k^2 \mathbb{E} \left[  \nabla_w L_{k, J_k} (w) - \mathbb{E} [\nabla_w L_{k, J_k} (w)] \right]^2 = \sum_{k=1}^p \lambda_k^2 \mathbb{E} \left[  \nabla_w L_{k, J_k} (w) - \nabla_w L_k (w) \right]^2 \\&\leq \sum_{k=1}^p \lambda_k^2 \sigma_I^2 (w) \leq R_\Lambda \sigma_I^2 (w) \; (\because \text{Properties 1}) \; \square \end{align*}

 

$\text{Lemma 9}$

 

 $\text{Weighted}$ method를 이용하여 구한 stochastic gradient $\delta_w L (w, \lambda)$는$\nabla_w L (w, \lambda)$의 unbiased estimaor이며, $U + \mathcal{O} (p + \log n)$의 시간복잡도를 갖는다. 더 나아가, $\sigma_w^2 \leq \sigma_I^2 (w) + \sigma_O^2 (w)$이다.

 

$\text{Proof}$

 unbiasedness와 시간복잡도는 정의 상 자명하므로, variance 부분만 살펴보도록 하자.

\begin{align*} \text{Var} (\delta_w) &= \sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{j=1}^{m_k} \left( \nabla_w L_{k, j} (w) - \nabla_w L (w, \lambda) \right)^2 \\&= \sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{j=1}^{m_k} \left( \nabla_w L_{k, j} (w) - \nabla_w L_k (w) + \nabla_w L_k (w) - \nabla_w L (w, \lambda) \right)^2 \\&= \sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{j=1}^{m_k} \left( \nabla_w L_{k, j} (w) - \nabla_w L_k (w) \right)^2 + \sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{j=1}^{m_k} \left( \nabla_w L_k (w) - \nabla_w L (w, \lambda) \right)^2 \\&\quad + 2 \underbrace{\sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{j=1}^{m_k} \left( \nabla_w L_{k, j} (w) - \nabla_w L_k (w) \right) \left( \nabla_w L_k (w) - \nabla_w L (w, \lambda) \right)}_{= \; 0} \\&= \sum_{k=1}^p \frac {\lambda_k} {m_k} \sum_{j=1}^{m_k} \left( \nabla_w L_{k, j} (w) - \nabla_w L_k (w) \right)^2 + \sum_{k=1}^p \lambda_k \left( L_k (w) - \nabla_w L (w, \lambda) \right)^2 \\&\leq \sigma_I^2 (w) + \sigma_O^2 (w) \; (\because \sum_{k=1}^p \lambda_k = 1) \; \square \end{align*}

 

9. $\text{PerDomain}$ Vs. $\text{Weighted}$

 

 정의 상 $R_\Lambda \leq 1$이기 때문에, $\text{Var} (\delta_w)$ term만 볼 경우 당연히 $\text{PerDomain}$ method가 더 유리할 것입니다. 하지만 시간복잡도를 고려하면, $\text{PerDomain}$ method에는 $p$ term이 있어서 ($p$가 충분히 작지 않은 이상) $\text{Weighted}$ method가 더 유리합니다. 따라서, 학습에 참여하는 client의 수가 그다지 많지 않은 상황에서는 $\text{PerDomain}$ method를 우선적으로 고려하되, 그 수가 커질 때에는 유불리를 잘 비교하여 판단할 필요가 있습니다.  

 

 한편, $p$가 지나치게 커진다면 두 method의 variance를 직접적으로 비교하는 것은 적절치 못할 것입니다. 그 이유는 $\text{PerDomain}$은 $p$개의 data를, $\text{Weighted}$는 $1$개의 data를 이용하여 계산하기 때문입니다.  따라서 저자들은 $p\text{ -Weighted}$ method를 정의하였는데, 이는 $\text{Weighted}$ method를 independent한 $p$개의 sample에 대해서 수행한 값의 평균입니다. 따라서 이때의 분산은 $\text{Var} \left( p\text{ -Weighted} \right) = \frac {\sigma_I^2 (w) + \sigma_O^2 (w)} {p}$이며, 정의 상 $R_\Lambda \leq \max_{\lambda \in \Lambda} ||\lambda||_2 = \max_{\lambda \in \Lambda} \sum_{k=1}^p \lambda_k^2 \leq \sum_{k=1}^p \left( \frac {1} {p} \right)^2 = \frac {1} {p}$이므로, $\text{Var} \left( \text{PerDomain} \right) \geq \frac {\sigma_I^2 (w)} {p}$입니다. 이렇게 놓고 봤을 때, 만약 $\lambda$ 값들이 비교적 일정하여서 $R_\Lambda \approx \frac {1} {p}$인 상황이라면, $\sigma_O^2 (w)$ term이 커질 것이기 때문에 $\text{PerDomain}$ method를 사용하는 것이 바람직할 것입니다. 반면, 만약 $\sigma_O^2 (w)$ term이 충분히 작다면, $\text{Weighted}$ method를 쓰는 것이 좋겠죠. 이것은 상황에 따라서 선택해야 하는 부분으로, 반드시 어느 method가 좋다고 이야기하기는 어려울 듯합니다.

 

 여기까지가 해당 paper에서 하고 싶었던 이야기들이고, 다음 포스트에서는 experiments를 살펴보도록 하겠습니다. $\text{OPTIMISTIC STOCHASTIC-AFL}$과 같은 variation 몇 가지가 논문의 appendix에 함께 실려 있는데, 이는 별도로 다루지 않을 계획입니다. 관심 있으신 분들은 해당 논문을 참조해주시기 바랍니다.

'Federated Learning > Papers' 카테고리의 다른 글

[ICLR 2020] q-FFL, q-FedAvg - (1)  (0) 2022.11.22
[ICML 2019] Agnostic FL - (7)  (1) 2022.11.18
[ICML 2019] Agnostic FL - (5)  (0) 2022.11.14
[ICML 2019] Agnostic FL - (4)  (0) 2022.11.08
[ICML 2019] Agnostic FL - (3)  (0) 2022.11.07
Comments