본문 바로가기
논문

[BNN] Efficient and Scalable Bayesian Neural Nets with Rank-1 Factors

by 두재 2020. 9. 29.

아직 한국어 설명이 많지 않은 최신 논문들, 블루 오션을 타겟으로 한 포스팅입니다. 최신 논문에서 중점적으로 다루는 아이디어를 커버하려고 하고 기본적인 지식은 간략하게 다룰 예정입니다.

논문 정보

제목 : Efficient and Scalable Bayesian Neural Nets with Rank-1 Factors [1]

저자 : Michael W. Dusenberry 외

시기 : 14 Aug 2020

 

기초 지식

1. Bayesian Neural Network (BNN)

Bayesian Neural Network 에 대한 정말 디테일한 설명은 다른 웹사이트나 블로그에서 찾으실 수 있을 것입니다. 간략하게 핵심을 짚어보자면, weight가 하나의 숫자, 파라미터가 아니라 어떠한 확률 분포를 따른다는 것입니다. forward를 할 때에는 weight를 분포에서 sampling하는 것이고 때문에 매 시행마다 조금씩 결과가 다를 수 있습니다. 학습에 있어서는 weight 파라미터를 업데이트하는 것이 아니라, weight 파라미터의 분포를 조절할 수 있는 파라미터를 업데이트합니다.

예를 들어 만약 모든 weight가 정규분포를 따른다고 가정을 하면 (기존의 모델 같은 경우는 하나의 숫자, dirac-delta 함수를 따르는 것이었겠죠) 그 정규분포를 표현하기 위해서 평균표준 편차가 필요합니다. 즉 같은 모델 구조에서 만약 정규분포로 BNN을 모델링한다면 2배의 파라미터가 추가적으로 필요하겠죠.

여기서 저 weight의 파라미터가 어떠한 분포를 따를 것인지는 우리가 정해주어야 합니다. 더 정확히 말하자면 '평균 3, 표준 편차 1인 분포를 따르게 하자' 가 아니라 평균과 표준 편차는 학습을 통해서 알게 되겠지만 어쨌든 정규분포를 따르도록 하자 라는 가정이 있습니다. 여기서 정규 분포를 prior라고 부르며 이 prior는 우리가 알고 있는 간단한 정규분포를 선택합니다. 

BNN은 기본적으로 같은 구조를 사용하는 Deterministic한 모델에 비해서 파라미터의 개수가 몇 배가 들고, 확률 분포를 따르기 때문에 각 weight들이 매 시행마다 다르기 때문에 모델이 깊을 경우에는 학습이 잘 되지 않고, 성능도 잘 나오기 쉽지 않다고 알려져 있습니다. 물론 BNN이 가지고 있는 중요한 장점도 있습니다.

 

2. BatchEnsemble (Wen et al., 2020) [2]

Ensemble이라는 것을 아실 겁니다. Ensemble은 어떻게 만드냐에 따라 다를 수 있지만 일반적으로는, 같은 아키텍쳐를 가지는 여러 모델을 각각 학습 데이터의 순서, 파라미터 초기화 등을 다르게 하고 학습한 후 이 여러 모델의 결과를 평균을 내어 inference하는 방법입니다. 쉽게 생각하면 똑똑한 여러 명이서 의견을 낸 걸 적절히 합치면 더욱 좋은 결과가 나온다고 생각할 수 있습니다. Ensemble은 정말 널리 사용되고 있고, 웬만한 해커톤이나 챌린지처럼 아주 약간이라도 성능을 높여야 한다면 꼭 들어가는 방법 중 하나입니다.

그러나 Ensemble이 아주 약간의 성능 향상을 가져오지만, 문제는 시간과, 파라미터의 개수가 사용되는 모델의 총 개수에 선형적으로 비례하여 증가합니다. 최근 아키텍쳐들이 모두 커지고 층이 깊어지다보니 ensemble을 쓰지 않아도 학습이 버거운데, 추가적으로 ensemble을 쓰고자 한다면 시간과 computation이 정말 많이 요구됩니다. 다르게 말하면, parameter efficiency가 안 좋을 수 있다는 것이죠. 아래의 BatchEnsemble은 간단한 아이디어로 이를 해결합니다.

기존의 $W$ 가 weight matrix였고, n개의 모델로 ensemble을 쓴다면 $W_1, W_2, \cdots, W_n$ 만큼 생겨 n배로 늘어났다면, BatchEnsemble은 $W$라는 하나의 공통된 weight matrix을 공유하면서, 각각의 Ensemble 모델은 1행, 1열짜리 벡터를 추가적으로 가지고 위 그림, 아래의 수식을 통해서 각각의 weight matrix를 가집니다.

$$\overline{W_{i}}=W \circ F_{i}, \;where\; F_{i} = r_{i}{s_{i}}^{T}$$

1행, 1열짜리 벡터를 행렬곱을 하여 rank-1짜리 W와 크기가 똑같은 행렬을 만들 수 있고, Ensemble 전체가 공유하는 W 와 Ensemble 각자의 $r s^T$를 elementwise-multiplication (hadamard product)하여 새로운 $ W $를 가지게 됩니다. 이렇게 하면, ensemble이 늘어나더라도 1행, 1열짜리 파라미터만 추가되기 때문에 weight matrix가 늘어나는 것에 비해 parameter가 굉장히 적게 늘어납니다

이론적인 디테일이 더 있지만, 일단은 여기서 마치고 저 핵심 아이디어만 가지고 계시면 될 것 같습니다.

 

메인 논문 내용

그렇다면 이 포스트의 논문에서는 무엇을 했느냐? 바로 BNN과 BatchEnsemble을 합쳤습니다.

일단 BNN의 관점에서는 위의 weight matrix가 하나의 딱딱 떨어지는 숫자가 아니라 어떤 확률 분포를 따라야 할 것입니다. 그리고 위에서 설명드린 BatchEnsemble은 어쨌든 BNN을 사용하지 않은 parameter가 어떠한 숫자로 딱 나와있는 경우입니다.

이 두가지를 합쳐보자면, weigh matrix를 어떠한 확률 분포 prior를 따르게 만드는데 weight matrix가 BatchEnsemble내에서는 공통된 weight matrix와 $r$, $s$의 곱으로 이루어집니다. 

이 논문에서는 $ W $를 확률 분포를 따르도록 하기보다는 $r, s$를 확률 분포를 따르는 파라미터로 만들어 $ \overline{W_{i}} $가 확률 분포를 따르도록 만들었습니다. 일단 loss function을 보자면,

$$L = -\frac{N}{B} \sum _{b=1} ^{B} \mathbb{E}_{q(r)q(s)}[log p(y_{b} | x_{b}, W, r, s)] + KL(q(r)||p(r)) + KL(q(s)||p(s)) - log p(w) $$

와 같이 되는데, $ KL(q(r)||p(r)) + KL(q(s)||p(s)) - log p(w) $이 부분은 regularizer입니다. $-\frac{N}{B} \sum _{b=1} ^{B} \mathbb{E}_{q(r)q(s)}[log p(y_{b} | x_{b}, W, r, s)] $ 에서 보면 $x_b, W, r, s$가 있을 때 기댓값을 가장 높일 수 있는 $q(r)$과 $q(s)$를 찾는 것입니다. 

논문의 3페이지 Variational Inference를 보시면 학습을 어떤 식으로 진행했는지를 알 수 있는데, $r$ 과 $s$에 대해서는 EM 알고리즘을 통해서 posterior inference를 진행하였고, W는 point-estimate을 했다고 합니다. 즉, 논문에서는 shared weight $ W $는 BNN이 아닌 deterministic한 approach로 결정하되 $ r, s $ 가 확률 분포를 띄고 있기 때문에, BatchEnsemble의 수식에서 봤듯이 각 ensemble 모델의 $ W $는 확률 분포를 따르게 됩니다.  이 과정에서 Hierarchical Priors라는 표현이 사용됩니다. 

논문에는 다양한 Ablation study가 존재하는데 한 개만 가져와봤습니다. $r$ 과 $s$ 각각 하나씩만 prior를 따르게 했을 때와 두 개 모두 prior를 따르게 했을 경우가 나오고, 결과론적으로는 두 개 모두 prior를 쓰는 것이 성능이 잘 나옵니다.

Corrupted는 데이터셋에 아래와 같이 여러 가지 Corruption이 추가된 데이터셋에 대한 결과입니다. [3]

 

 

결과

실험은 CIFAR, ImageNet과 이들의 corrupted 버전에 대한 classification task와 MIMIC-III EHR mortality task라는 신호를 보고 binary classification으로 수행되었습니다. 그 중 ImageNet에 대해서 ResNet-50 기반으로 수행된 실험에 대해서 보여드리겠습니다.

Rank-1 BNN이 Deep Ensembles에 비해서 성능이 좋지 않은 metric도 있고 좋은 metric도 있는데, 중요한 점은 파라미터의 개수입니다. 논문에 따르면 Rank-1 BNN에서 10개의 ResNet-50 ensemble 모델을 사용하더라도 하나의 ResNet-50에 비해서 900%가 아니라 0.4%의 파라미터 수 증가밖에 없었다고 합니다. 이는 rank-1으로 ensemble을 구성하는 BatchEnsemble 덕분이겠죠. 파라미터가 그렇게 많이 늘지 않아서인지 이를 BNN으로 확장한 경우에서도 높은 성능을 보여줍니다.

 

결론

BatchEnsemble의 도움으로 굉장히 적은 파라미터의 증가로 굉장히 높은 성능을 가져올 수 있고, 이를 Hierarchical prior의 관점으로 BNN으로 확장시켜 BNN에서도 안정적으로 학습이 되고 좋은 성능을 보여줍니다.

 

 

참고문헌

[1] M. W. Dusenberry, G. Jerfel, Y. Wen, Y.-a. Ma, J. Snoek, K. Heller, B. Lakshminarayanan, and D. Tran. Efficient and scalable bayesian neural nets with rank-1 factors. arXiv preprint arXiv:2005.07186, 2020.

[2] Y. Wen, D. Tran, and J. Ba. Batchensemble: an alternative approach to efficient ensemble and lifelong learning. arXiv preprint arXiv:2002.06715, 2020.

[3] D. Hendrycks and T. Dietterich. Benchmarking neural network robustness to common corruptions and perturbations. In International Conference on Learning Representations, 2019.