전체 개념 논문 구현 NLP PyTorch Transformer Recommender Graph Vision GAN

ELECTRA: Pre-Training Text Encoders as Discriminators Rather Than Generators

2021.09.27
논문 NLP Transformer

이 논문은 개인적으로 정말 재미있게 읽은 논문이어서 포스팅을 작성하게 되었다. 최근 NLP에는 BERT 이후로 이를 개선한 ALBERT, RoBERTa, XLNet 등 많은 pre-training 모델이 등장하였다. BERT의 핵심적인 학습은 바로 Masked Language Modeling(MLM)인데, ALBERT를 제외하고 위 논문들은 대부분 이 MLM을 개선시켜서 성능을 향상시키고자 했지만, 크게 MLM의 형태에서 벗어나지는 않았다.

ELECTRA는 위 논문들과는 다르게 GAN으로 MLM을 대체하는 아이디어를 제시했다. 단순히 GAN을 이용한다는 발상까지야 해볼 수 있지만, 이 발상으로 기존 모델들 이상의 성능까지 끌어올린 점이 정말 대단하다고 생각한다. GAN 포스팅에서도 언급했지만, GAN 자체가 classification보다 훨씬 학습 수렴이 힘든데, vision 분야에 비해서 NLP에서는 비교적 GAN에 대해서 연구가 많이 되지는 않았기 때문이다.

ELECTRA에서는 GAN을 어떻게 활용하려고 했는지, 그리고 GAN을 학습과 수렴시키기 위해서 어떤 테크닉들을 사용했는지에 대한 관점에서 이 논문을 살펴보도록 하겠다.

Masked Language Modeling (MLM)

BERT의 성능에 가장 핵심적인 역할을 하는 학습이 바로 MLM이다. 큰 corpus 데이터셋을 가공하지 않고도 self-supervised learning을 통해서 토큰에 대한 임베딩을 학습시킬 수 있는 방식이다.

이 MLM 과정을 어떤 논문은 denoising autoencoder 라고 표현하기도 한다. 입력에 노이즈가 들어간 것을 원래대로 복구해내는 autoencoder라고 볼 수 있는 것이다.

이 논문에서 지적하는 MLM의 단점은 바로 [MASK] 토큰이다. Pre-training 단계에서는 노이즈를 주기 위해 입력에서 원래 토큰 대신에 [MASK] 토큰으로 가린다. 하지만, 실제 학습된 BERT 모델을 활용할 때는 입력에 [MASK]토큰이 들어올 일이 없다. 학습단계에서만 사용하는 vocabulary가 계속 유지되고 있는 것이다.

BERT논문에서도 이런 문제점을 인지하고 있다. 그래서 [MASK]토큰 대신에 일부를 랜덤토큰을 적용하는 등의 시도를 한다. 하지만 랜덤토큰은 결국에 원래 문장과 너무 상관없는 토큰이 들어오기 때문에, 알아채기가 쉽다.

그래서 이 논문에서는 Generator를 통해서 [MASK] 토큰도 아니고, 랜덤 토큰도 아니고, 그럴듯한 가짜 토큰을 생성한다. 기존의 [MASK] 토큰이나 랜덤 토큰보다도 더 난이도있는 입력을 줄 수 있는 것이다. 그러면서도 더 이상 [MASK]라는 가상의 토큰을 사용하지 않게 된다.

이 아이디어가 논문의 전부다. 새로운 발상이지만 아주 간단하다. 여기서 중요한 것은 GAN을 학습시키기 위해 어떤 테크닉들을 사용했냐에 대한 것이다. 위에서도 언급한 것처럼, GAN을 이용한다는 아이디어를 실제 결과로 도출해내기가 쉽지 않기 때문이다. 결과적으로 논문의 학습 방법들은 학습 속도면에서도 기존의 RoBERTa나 XLNet보다 훨씬 빨리 수렴에 도달하게 된다. 아래에서 논문에서 사용한 테크닉들에 대해 하나씩 알아보도록 한다.

ELECTRA

ELECTRA의 모델 구조는 아주 간단하다. 아래 그림이 전부다.

Electra model

여기서 ELECTRA 모델은 Discriminator 부분에 해당되고, Generator 모델은 이 Discriminator를 학습시키기 위한 입력을 생성하는 모델이라고 이해하면 된다. 각 모델에 대한 구체적인 내용은 다음과 같다.

Generator

Generator는 기존의 MLM처럼 [MASK] 토큰을 일반적인 단어 토큰으로 바꾼다. 그 결과가 바로 ELECTRA 모델을 학습시키는 입력이 되는 것이다. 이 output vocabulary는 input vocabulary와 다르게 [MASK] 토큰이 포함되어있지 않다는 것이 포인트다.1 이를 식으로 다음과 같이 표현할 수 있다.

\[p_G(x_t\vert x)=\frac{\mathbb{exp}(e(x_t)^Th_G(x)_t)}{\sum_{x'}\mathbb{exp}(e(x')^Th_G(x)_t)}\]

논문에서는 위와 같이 표현되었지만, 결국 softmax 식이라는 것을 알아볼 수 있다.

여기에 적용되는 토큰들에 대한 기호의 의미는 다음과 같다.

  • $\mathbf x=[x_1,x_2,…,x_n]$ : 최초의 입력 토큰.
  • $\mathbf m=[m_1,…m_k]$ : 마스크를 적용할 랜덤한 포지션 값.
    • 논문에서는 마스크의 개수 $k=[0.15n]$, 즉 전체 시퀀스의 15%를 마스크로 적용하였다.
  • $\mathbf{x^{masked}}=\rm REPLACE(\mathbf x,\mathbf m,[MASK])$ : 마스크가 씌워진 토큰. 이 값이 generator의 입력이 된다.
  • $\mathbf{x^{corrupt}}=\rm REPLACE(\mathbf x,\mathbf m,\mathbf{\hat x})$ : Generator로부터 변형된 출력. 이 값이 discriminator의 입력이 된다.
    • 여기서 $\hat x_i\sim p_G(x_i\vert\mathbf{x^{mask}})\ \mathrm{for}\ i\in m$, 즉 각 마스크토큰에 대해서 generator에서 추론된 값

Discriminator

ELECTRA 모델은 이 입력을 가지고 학습을 한다. MLM과 다르게, 이 학습의 output은 각각의 토큰에 대한 original/replaced 여부를 판별하는 것이다. 따라서 output sequence의 각 토큰마다 sigmoid 형태의 output이 되는 것이다. 이를 식으로 표현하면 다음과 같다.

\[D(\mathbf x,t)=\mathrm{sigmoid}(w^Th_D(\mathbf x)_t)\]

Loss

GAN 구조의 수렴에서 가장 중요한 부분이 바로 loss다. 이 논문에서는 generator와 disciminator 각각 별도의 loss를 구성한다. 이 것이 바로 GAN과의 가장 큰 차이점이고, 논문에서 학습 수렴을 위해 사용한 첫 번째 테크닉이다. 각각의 loss에 대해 구체적으로 소개한다.

Generator Loss

먼저 generator는 다음과 같은 loss 식을 사용한다.

\[\mathcal L_\mathrm{MLM}(\mathbf x,\theta_G)=\mathbb E\left( \sum_{i\in \mathbf m}-log\ p_G(x_i\vert\mathbf{ x^{masked}})\right)\]

위 식을 보면, 일반적인 GAN과는 전혀 다른 것이 보인다. 바로 discriminator 없이 단독으로 학습한다는 점이다.

GAN은 기본적으로 adversarial, 즉 적대적인 training이 기본이다. 따라서 일반적으로는 generator와 discriminator를 합쳐서 loss를 구성한다. 그리고 generator는 정답을 맞추는 것이 목적이 아니라, discriminator가 구분할 수 없는 output을 생성하는 것을 목적으로 가져야 한다.

하지만 이 논문의 appendix를 보면, 이 generator 학습에 대해서 이와 같은 방식을 적용할 수 없는 이유에 대해서 discrete sampling 때문에 backpropagation 자체가 일어나기 어렵다고 설명하고 있다. 따라서 논문에서는 이 부분을 대체하기 위해 여러가지 adversarial 실험뿐만 아니라, RL까지 도입하여 여러 loss들을 비교하였다. 여기에서 소개하는 loss가 바로 이 여러가지 방법들 중에서 최적의 결과인 것이다.

Discriminator Loss

ELECTRA의 discriminator loss는 각 토큰별로 real과 fake를 구분한다는 점이 일반적인 GAN과 다르다. 이는 Vision과 다른 NLP 데이터셋의 특성을 활용한 것이다. 이와 같이 discriminator loss에서는 NLP 데이터셋의 특징을 많이 활용한 트릭들을 사용했다. 토큰별로 나뉜다는 점을 제외하면 discriminator loss의 식은 다음과 같이 GAN에서 많이 사용하는 BCE(Binary Cross Entropy) loss를 적용하였다.

\[\mathcal L_\mathrm{Disc}(x,\theta_D)=\mathbb E\left(\sum_t^n \begin{cases} -log(D(\mathbf{x^{corrupt}},t)) & (x_t^{corrupt}=x_t) \\ -log(1-D(\mathbf{x^{corrupt}},t)) & (x_t^{corrupt}\neq x_t) \end{cases} \right)\]

여기서 GAN과의 또 하나의 큰 차이점이 보인다. 바로 fake 여부를 비교하는 기준이다. GAN은 정답이 실제 정답인지, 아니면 generator가 만들어낸 정답인지에 따라서 real과 fake를 구분해야 한다. 하지만 ELECTRA에서는 각 토큰별로 정답을 판별하고, generator가 만들어낸 정답이라도 실제 정답과 같으면, 즉 $x_t^{corrupt}=x_t$ 면 real로 구분하도록 한다.

이 부분은 Vision과 NLP의 차이점때문에 나타난다고 이해하면 쉽다. 이미지는 애초에 "비슷한" 그림은 생성하기 쉽지만, 픽셀의 RGB까지 동일한 경우 자체는 없다고 봐도 무방하다. 그렇기 때문에 generator에서 생성했는지의 여부만으로도 real과 fake를 결정지어도 되는 것이다. 하지만 토큰이라는 속성은 동일한 경우가 발생하기 쉽다. 그리고 동일한 토큰이 될 경우에는 real인지 generator인지 구분할 수 없는게 당연하기 때문에 이와 같이 정답과의 토큰 비교가 필요하다.

학습

위의 모델 부분에서도 설명했듯이, GAN 구조를 잘 학습시키기 위해서 특히 loss에서 많은 시도들을 통한 트릭들이 적용되었다. 이러한 loss뿐만 아니라 학습을 위한 hyperparameter 셋팅도 그만큼 중요하고, 여러 트릭들이 사용된다.

Smaller Generators

이 부분이 바로 논문에서 사용한 가장 중요한 트릭 중 하나라고 생각한다. Generator 모델은 BERT와 같은 MLM을 학습하는데, generator의 성능이 너무 좋아지면 입력에서 마스크들이 다 정답으로 바뀌게 된다. 입력의 대부분이 정답이 들어오면, 오히려 discriminator가 학습할 수 있는 입력들이 적어져서 학습이 어렵게 되는 것이다.

최적의 Discriminator를 만들어내기 위해서 논문에서는 generator 모델의 크기를 줄여서 밸런스를 조절하고자 한 것이다. 따라서 적합한 generator의 hidden size를 찾는 실험을 했고, 결과는 아래와 같았다.

generator size

이 결과에서도 나타나듯이, generator의 사이즈가 discriminator와 같거나 그 이상으로 좋아지게 되면 오히려 discriminator가 제대로 학습을 못해서 성능이 떨어지게 된다. Generator의 성능을 hidden size를 통해 조절하는 방법이 적절했던 것이다.

Weight Sharing

여기서 말하는 weight sharing은 generator와 discriminator간의 weight를 공유한다는 것이다. Generator와 discriminator는 위에서 언급한 것과 같이 크기가 다르다. 따라서 논문에서는 이렇게 다른 사이즈의 경우는 embedding에 대해서만 weight sharing을 하는 것이 가장 효율적이었다고 한다.

하지만 개인적으로 이 weight sharing부분은 의문이 든다. 임베딩에 대해서 weight sharing을 적용하게 되면, discriminator도 [MASK] 토큰에 대한 입력을 받아야 한다. 이 논문에서 지적하던 불필요한 [MASK] 토큰에 대해서 학습을 피한다는 장점이 사라진다는 것이다.

또한, 겨우 embedding weight만 sharing하는 것은 결국 대부분의 weight가 별도로 구성된다는 것이기 때문에, 이 weight sharing으로 학습에 큰 차이가 발생하기는 쉽지 않다. 따라서 정말 ELECTRA 모델의 성능에 이 weight sharing이 효과적이었는지에 대해서는 확인이 필요한 것이다.

하지만 논문에서는 이 weight sharing에 대한 실험결과가 따로 정리되어 있지는 않다. 논문에서 언급한 것은, Generator와 Discriminator의 크기가 같은 경우에 대한 성능이다. 그러나 이 경우는 결국 위의 결과처럼 Smaller Generator가 더 성능이 좋기 때문에, 사용하지 않는다. 따라서 weight sharing이 큰 의미를 갖는다고 얘기하기는 어렵다.

실험

ELECTRA 모델로 파라매터 크기 대비 학습 속도가 많이 향상되었고, GLUE나 SQuAD결과 또한 기존보다 나아졌다. 이 실험결과에 대해서는 따로 설명하지 않는다.

개인적으로 이 논문에서 눈여겨볼만한 실험으로, 마스크를 학습시키는 셋팅이 재밌어서 이에 대한 실험 내용만 소개하도록 한다. 다음과 같이 4가지 종류의 discriminator loss를 설정해서 마스크 학습에 대해 비교를 진행하는 실험이다.

  • ELECTRA

    논문에서 적용한 원래 모델을 말한다.

  • ELECTRA 15%

    ELECTRA와 동일한 셋팅에서 discriminator loss를 마스크 토큰에 해당되는 output에 대해서만 정확하게 판별했는지를 측정하는 것이다.

  • Replace MLM

    기존 BERT에서 MLM loss와 동일하다. BERT의 마스크 대신에 랜덤토큰이 적용된 MLM이라고 생각하면 된다.

  • All-Tokens MLM

    위의 Replace MLM을 모든 토큰에 적용한 버전. 즉, ELECTRA 15%와 원래 ELECTRA의 관계와 같다.

이 4가지의 결과는 당연하게도 다음과 같다.

모델 ELECTRA All-Tokens MLM Replace MLM ELECTRA 15% BERT
GLUE score 85.0 84.3 82.4 82.4 82.2

이 결과로부터 알 수 있는 사실은 다음과 같다.

  • 직관적으로도 당연히 모든 토큰에 대해서 정상적으로 real/fake를 구분하는 것이 효과적임을 알 수 있다.
  • All-Tokens MLM이 결과가 꽤 좋다. GAN구조의 여부보다, 전체토큰에 대한 loss를 설정하는 것만으로도 성능이 훨씬 좋아진다는 것이다.
  • ELECTRA 모델의 의의도 여기서 찾을 수 있다. 단순히 모든 토큰을 학습했기 때문에 결과가 좋은 것이 아니라, All-Tokens MLM과 비교했을 때도 성능향상이 있는 것으로 GAN구조가 더 효과적이라는 것을 설명할 수 있게 된다.

결론

ELECTRA 논문은 전체적으로 새로운 방법을 많이 시도했는데도, 많은 부분들을 깔끔하게 다듬고 결과를 낸 논문이라고 생각한다.

먼저, GAN의 수렴을 위해서 정말 여러가지 시도들을 했고, 성공까지 잘 이끌어냈다는 것이 느껴진다. 논문의 내용에서만 봐도 다양한 실험들이 존재하는데, 그 뒤에는 잘 안된 훨씬 많은 수렴 시도들이 있었을 것이 느껴진다.

그리고 기존 RoBERTa, XLNet등 경쟁 논문들이 많은데, 그 가운데서 논문의 당위성을 효과적으로 입증해냈다. 단순히 성능으로뿐만 아니라, GAN구조를 사용하는 방식의 효과를 입증하기 위해 정말 다양한 아이디어의 실험을 진행한 것이 느껴졌다.

논문의 아쉬운 부분들도 존재한다. 위에서도 언급했던 weight sharing에 대해 미흡한 점도 아쉬운 부분 중 하나이다.

하지만 가장 아쉬운 점은 GAN을 아직 100% 이용하지는 못했다는 것이다. Vision 분야 GAN에서는 generator의 diversity를 높이는 연구가 많고, 정말 다양한 이미지들을 많이 생성한다.

이에 비해 ELECTRA의 generator는 그 종류가 매우 한정되어 있고, 타겟 또한 original label이기 때문에, generator와 discriminator가 유기적으로 학습하는 원래의 GAN에 비해서 그 절반밖에 이용하지 못했다고 할 수 있다. 물론 이는 data augmentation에 더 가깝기 때문에, 다른 NLP data augmentation 논문들에서 다뤄지고 있다.

그래서 위 부분은 논문의 문제점보다는 그냥 아쉬운 점이고, 이와는 별개로 전체적으로 GAN을 활용한 아이디어들, 그리고 이를 풀어나가는 과정들에서 정말 독특한 발상들을 조리있게 풀어나가는 것이 재미있는 한 편의 논문이라고 생각한다.

  1. 단, 구현 상에서 이를 적용하냐의 여부는 다른 문제다. 구현에서 굳이 이를 제외하지 않더라도, 모델을 학습할 때 활용되지 않아서 임베딩에 영향을 끼치지 않는 것이다.