전체 개념 논문 구현 NLP PyTorch Transformer Recommender Graph Vision GAN MWP Automata SSM Mamba

State-Space Models towards an Ideal Language Model: From S4 to Mamba-2 (1)

2024.06.15
논문 NLP SSM Mamba

Transformer 는 여전히 강력하지만, 아키텍처 자체에 내재한 한계점들이 존재한다. 특히 이전 포스팅에서도 지적했던 recurrence 문제는 Transformer의 강력함과 한계를 동시에 보여준다고 볼 수 있다. 그동안 수많은 모델들이 이러한 Transformer 를 대체하고자 하였지만, 아직까지 완전하게 그 우위를 입증한 모델은 없다.

분명 그 이유에는 아키텍처 자체의 강력함도 상당 부분을 차지하겠지만, 그만큼 이 Transformer라는 모델에 최적화된 학습 방법론이 많이 연구되었기 때문도 있을 것이다. Mamba가 실제로 이러한 Transformer를 대체할 수 있는지의 여부를 떠나, 그리고 리뷰어 퀄리티에 대한 논란을 떠나, 결국 ICLR 2024에서 좋은 평가를 받지 못한 이유도 바로 이와 같이 Transformer에 최적화된 실험 조건하에서 새로운 모델의 잠재력을 완전히 실험으로서 보여주기가 어려웠기 때문이라고 생각한다.

최근 AI분야의 논문 경쟁은 확실히 이러한 딜레마를 갖게 만든다. 이론 연구와 달리 AI는 아무것도 완벽하게 증명할수가 없고 그 단서를 실험적으로 입증해야 하는데, 현실적으로 모든 것을 실험하고 보여주기는 어렵다. 그 중에서도 특히 언어는 이러한 증명과 가장 동떨어져 있다. 애초에 우리 언어의 분포가 어떻게 이루어져있는지에 대한 조건을 전혀 알 수가 없기 때문이다.

결국 커다란 패러다임의 변화를 위해서는 새로운 코어 이론의 제시뿐만 아니라, 이를 보조하는 다양한 후속 연구들이 필요하다. 하지만 리뷰어의 입장에서는 매년 제시되는 수만개의 새로운 연구들 중에서 실험적으로 가장 잘 드러난 연구들에 대해 좋은 평가를 줄 수밖에 없을 것이다.

물론 저자들은 최대한 많은 것들을 보여주고자 하였다. 하지만 모든 것을 입증할 수는 없었다. 딥 러닝이 그러했듯, 결국 패러다임의 변화는 결국 많은 사람들의 꾸준한 노력과 시간이 동반되어야 할 것이다.

하지만, 결국 다음과 같은 의문이 든다.

그렇다면 논문의 완성도 여부를 떠나서, Mamba 시리즈는 언어 모델로서 정말로 Transformer를 넘어설 수 있는 것일까?

위 질문에 대답하고자 이 글을 작성하게 되었다. 며칠 전, Mamba-2 논문이 arXiv에 공개되는 것을 보고 이 포스팅을 작성할 생각이 들었는데, 원래는 Mamba와 Mamba-2에 대해서만 다루려고 했으나, 설명을 적다보니 결국 분량이 길어지게 되었다.

Mamba는 Albert GuTri Dao 두 교수님의 공동 저작품이다. 저자들을 이야기를 하는 이유는, Mamba에는 State Space Model (SSM) 라는 다소 생소한 개념이 등장하는데, 바로 이 분들이 주축이 되어 지난 몇년간 연구해온 새로운 방향의 모형이다.

결국 Mamba에 대한 설명을 위해서는 이 SSM을 짚고 넘어가는 수밖에 없어서 각 연구들에 대해 필요한 관점만 간단하게 살펴보도록 한다. 이 포스팅에서 다룰 논문들은 다음과 같다.

  1. S4: Gu et al. Efficiently Modeling Long Sequences with Structured State Spaces. (2021). ICLR 2022. arXiv
  2. H3: Fu et al. Hungry Hungry Hippos: Towards Language Modeling with State Space Models. (2022). ICLR 2023. arXiv
  3. Mamba: Gu, Albert, and Tri Dao. Mamba: Linear-Time Sequence Modeling with Selective State Spaces. (2023). arXiv
  4. Mamba-2: Dao, Tri, and Albert Gu. Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. (2024). arXiv

먼저 이 SSM 기반의 두 모델 S4와 H3에 대해 간단히 살펴본 후, 본격적으로 Mamba와 Mamba-2에 대해 이야기할 것이다.

Structured State-Space Sequence (S4)

Structured State Space Sequence (S4) 모델은 저자들이 SSM을 본격적으로 sequence modeling에 도입한 첫 모델이다. 여기에 사용한 모델은 사실 거의 이전 논문 Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers (NeurIPS 2021)의 Linear State-Space Layer (LSSL)에 기반하고 있는데, 여기에서는 따로 구분 없이 Mamba를 위해 필요한 S4의 내용들만 간단히 짚어보도록 한다.

모델 구조

S4 모델에 대해 간략히 요약하자면 다음과 같은 블럭을 가지는 모델이다.

\[h_t = \overline{\mathbf{H}} h_{t-1} + W_x x_t \\ y_t = W_yh_t .\]

직관적인 이해의 편의를 위해 notation은 원래 논문과는 많이 다르게 작성하였다.

$h_t$는 현재 time-step $t$에서의 hidden state, $x_t$는 $t$에서의 입력 벡터, $y_t$는 $t$에서의 출력벡터를 의미하고, $\overline{\mathbf{H}}$는 HiPPO Matrix 라는, 저자들이 이전 연구에서 제시했던 행렬을 활용하여 초기화시킨 weight 인데, 간단히 말하면 멀리 떨어진 state도 잘 기억하도록 가중치값이 계산된 행렬이다. 각 블럭간에는 pre-normalization과, residual connection이 적용된다.

위 식을 보면 식이 RNN과 거의 흡사해 보인다. 하지만 HiPPO Matrix를 사용한다는 점, 그리고 activation function을 사용하지 않았다는 점이 바로 차별화의 핵심이라고 볼 수 있다.

이 블럭의 원형이 바로 기존 제어이론에서 continuous-time state-space를 정의하는 모델에서 파생되었기에 SSM이라고 부른다. 저자들은 이 SSM 중 Linear Time-Invariant (LTI) 한 기본 모델을 가지고 continous 형태를 discrete 하게 변형시켜 사용한다. 이 정의에 대해서는 state-space representation 위키를 참조하면 된다.

LSSL 논문 제목과 같이 이 SSM 블럭은 recurrent, convolutional, continuous-time 3가지 관점으로 해석할 수 있는 것이 특징인데, 이 중 recurrentcontinuous-time 관점은 위의 SSM 식과, 이를 discretize 하기 전의 원형 식이 각각 나타내고 있는 형태라고 이해하면 된다.

여기에서 주목해야 하는 것은 하는 것은 convolution 연산을 통해 병렬연산을 만들어낸 것인데, 이 부분이 바로 이 모델의 가장 독창적인 특징이라고 생각한다. 기존 RNN에서 불가능했던 병렬 계산이 activation을 제거하면서 가능하게 되었는데, 구체적으로 아래와 같다.

참고로 이후에 나온 Encoding Recurrence into Transformers (ICLR 2023) 논문 또한 유사한 관점에서 RNN을 linear activation을 적용하여 병렬 연산을 유도하였는데, 서로 인용은 없었지만 개인적으로 Mamba와 위 논문의 RSA 블럭이 접근하는 관점이 꽤 유사한 측면이 있어 보였다.

Convolution for LTI-SSM

Recurrent 모델은 이전 time-step 의 결과를 구해야 다음의 state를 계산할 수 있고, 이는 결국 sequence 길이가 길어질수록 학습이 상당히 느려지게 된다. 반면에 LSSL은 activation을 제거하면서 각 time-step 에 들어올 입력을 미리 알고 있다는 전제 하에, 전체 sequence에 대한 state를 한번에 구해낼 수 있다. 이는 근본적으로 이 모델이 Linear Time-Invariant (LTI) SSM, 즉 time-step $t$에 관계없이 항상 동일한 weight를 적용하는 형태를 전제했기 때문이다.

학습시에는 모델은 전체 입력을 미리 알고있으므로, 이 convolution 연산을 수행할 수 있다. 물론 완벽히 recurrent 방식처럼 무한한 길이에 대해 수행할 수 있는 것은 아니고, 대신 학습을 위해 모델을 fixed window size 의 convolution network로 근사한다고 이해해야 한다.

예를 들어 window size가 3인 경우, 다음과 같이 크기 3짜리 convolution kernel $\mathbf K_3$을 만들 수가 있는데,

\[\mathbf K_3 = [W_y\overline{\mathbf{H}}^2W_x, W_y\overline{\mathbf{H}}W_x, W_yW_x] ,\]

이로부터 $y_3$는 다음과 같이 계산될 수 있다.

\[y_3 = \mathbf K_3 \cdot [x_3, x_2, x_1] = \mathbf K_3X_{:3} .\]

학습 시에는 모든 입력 $x_1, x_2, x_3$ 을 이미 알고있기 때문에, 이렇게 전체 크기의 $\overline{\mathbf K}$를 미리 구해놓으면 병렬연산을 통해 모든 $y_t$에 대한 값을 한번에 계산할 수 있는 것이다.

물론 여기에는 이 커널 벡터를 한번에 계산한다는 것이 전제가 된다. $\overline{\mathbf{H}}$가 Hippo로 고정된 값일 경우, 이 값을 미리 계산해놓으면 되지만, S4에서는 $\overline{\mathbf{H}}$도 결국 학습되기 때문에 sequence의 길이 $n$만큼의 $\overline{\mathbf{H}}^n$ 를 매번 계산해야 하는 문제가 있고, 실제 학습시에는 이 논문에서 새로 제시한 알고리즘을 통해 근사시킨 $\overline{\mathbf K}$를 사용한다.

$\overline{\mathbf{H}}$를 고정시킨 경우와 학습시킨 경우의 차이는 저자의 LSSL 논문에서 자세히 비교하고 있는데, 결국 이 근사 알고리즘은 $\overline{\mathbf{H}}$가 Hippo Matrix 라는 것을 전제로 유도되었기 때문에 학습을 통해 업데이트된 행렬에 대한 근사와 일치한다는 보장은 없다. 이 커널 근사 방식은 저자들의 매 연구마다 업데이트되고 있다.

의의

이 S4모델의 저자들이 모델의 연산효율성을 상당히 중요하게 생각하고 있는 것이 보이는데, 저자 중 한명인 Tri Dao가 FlashAttention 논문의 주저자인 것을 생각해보면 저자들이 추구하는 방향이 좀 더 쉽게 이해가 된다.

저자들은 결국 학습과 추론 두 방향에서 효과적인 모델을 만들어내고자 했는데, 그 결과 이와 같은 3가지 모드의 전환이 이 모델의 재미있는 포인트이다. 학습 시에는 모든 입력을 미리 알고 있으니 병렬연산을 수행하고, 추론시에는 다음 토큰만 예측하면 되니 전체 state를 볼 필요 없이 이전 state만으로 다음 결과를 예측하는 것이다. 쉽게 이야기하면, 동일한 weight 를 가지고 학습모드와 추론모드에 맞게 변신한 모델이 작동하는 것이다. 심지어 연속 데이터에서는 continuous-time 모드로도 동작시킬 수도 있다.

이 S4는 원래 언어뿐만 아니라 다양한 sequential 데이터에 다재다능하게 적용할 수 있다는 것이 특징이지만, 여기에서는 언어모델 관점에 대해서만 살펴보면 된다. 여기에서 주목할만한 성과는 바로 Long Range Arena (LRA) 벤치마크에서의 압도적인 성능이다. 특히 이 LRA 성능은 논문이 인정받은 주요 기여점 중 하나라고 할 수 있는데, 일단 이 점을 염두하고 일단 아래 논문을 살펴보도록 한다.

Hungry Hungry Hippos (H3)

H3 모델은 Tri Dao의 Hungry Hungry Hippos: Towards Language Modeling with State (ICLR 2023) 논문에서의 모델로, Mamba의 기본 조각이 되는 모델이다. Hippo와 Mamba까지, 정말 저자들의 작명 스타일이 매우 일관적인 것이 느껴진다.

H3은 S4를 좀 더 언어모델로 특화시켜 발전시키면서 그 구조가 많이 변화하였다. 그 중 가장 큰 특징은 S4와 달리, Transformer처럼 다음 토큰을 추론할 때 전체 입력을 살펴본다는 것이다. 하지만 이로 인해 S4만큼의 추론 효율성은 없어지게 된다. 그렇다면, H3은 효율성을 포기하고 무엇을 얻고자 했을까?

모델 구조

H3 모델은 Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention (ICML 2020) 논문의 Linear Attention을 SSM 블록을 통해 구현한 모델이라고 할 수 있다.

Linear Attention은 기존의 Attention 값을 계산할 때 사용하는 $\mathrm{softmax}(QK^\top)V$ 내부의 $\exp({q\cdot k})$ 부분을 일종의 각 토큰간 similarity 를 표현하는 activation 함수로 보고, 이 유사도를 어떤 non-linear row-wise 함수 $\phi:\R^d\to \R^d$를 통해 $\phi(q)\cdot\phi(k)$ 형태로 근사시킨다.

길이가 $n$인 입력을 projection 시킨 $Q, K, V \in \R^{n\times d}$에 대한 $i$번째 토큰의 출력 feature $y_i$를 구해야 하는 경우, causal mask가 포함된 linear self-attention은 다음과 같은 식으로 표현된다.

\[S_{i} = S_{i-1} + \phi(k_i)v_i^\top \in \R^{d\times d},\\ z_i = z_{i-1} + \phi(k_i) \in \R^d,\\ y_i = \frac{\phi(q_i)^\top S_{i}}{\phi(q_i)\cdot z_i} .\]

설명의 편의를 위하여 점화식 형태로만 표현하였다. Linear Attention은 이렇게 점화식 형태로 attention을 계산하기 때문에 논문의 제목처럼 Transformer를 RNN 형태로 표현할 수 있게 되고, 계산 복잡도가 $\mathcal O(n)$이 된다.

여기에서 핵심은 $S_i$인데, 결국 query $Q$와의 attention을 구하기 위한 이전 sequence의 상태를 표현한 행렬이라고 볼 수 있다. 여기서 분모는 결국 단일 값이 되기 때문에, 일종의 normalization term으로 보고 생략해도 무방하다.

H3에서는 이 Linear Attention의 SSM을 사용하여 다음과 같이 $\phi$ 함수를 대체한다.

\[\mathrm{H3}(X)=Q \circ \mathrm{SSM_{diag}}(\mathrm{SSM_{shift}}(K)\circ V) .\]

식을 보면 오히려 큰 골격은 Transformer 레이어이고, 그 내부 구조에서 SSM을 활용하는 형태이다. 위 Linear Attention의 식과 비교해보면 $z_i$를 생략하고 $S_i$는 $\mathrm{SSM_{diag}}(\cdot)$로, $S_i$안의 $\phi(\cdot)$는 $\mathrm{SSM_{shift}}(\cdot)$로 대체하였다.

여기에서 두 $\mathrm{SSM_{(diag, shift)}}$은 $\overline{\mathbf{H}}$의 초기값이 다른 SSM 레이어로, $\mathbf H^\mathrm{(diag)}$는 S4의 HiPPO Matrix의 diagonal 버전1이고, shift 는 이전 토큰만 보는 행렬 ${\mathbf H}^\mathrm{(shift)}_{i,j}=\mathbf{1}(i-1=j)$로 대체한 것이다. 즉, $\mathbf H^\mathrm{(diag)}$는 전체 맥락에 대한 정보, $\mathbf H^\mathrm{(shift)}$는 이전 상태만을 참고하려는 SSM이라고 볼 수 있다.

논문에서는 이 모델 결과를 정의해 놓고 해석하는 것에만 초점을 두지, 왜 이런 방식으로 설계했는지에 대한 설명은 충분하지 않다. 특히, $\phi(k_i)$는 현재 상태를 반영하는데 반해, 왜 이 부분이 이전 state를 반영하는 $\mathrm{SSM_{shift}}$가 되어야 하는 지는 불분명하다.
다만 저자들의 이렇게 두 종류의 SSM을 둔 이유는, 결국 많은 시도를 했지만 $\mathrm{SSM_{diag}}$ 만으로는 모델이 충분한 성능을 내지 못했기 때문일 것으로 볼 수 있다. 아래의 실험 결과에서도 이야기하겠지만, H3는 아직 완전히 좋다고 이야기하기는 어려운 결과를 보여준다.

결국 이 H3도 Transformer와 같이 전체 sequence가 여러 단계의 레이어 블록을 쌓는 형태가 되는데, Transformer처럼 각각의 블록 역시 multi-head로 계산될 뿐만 아니라, Attention 이후에는 FFN (혹은 MLP) 레이어도 똑같이 포함된다. 이렇게 계산되는 H3의 계산 복잡도는 $\mathcal O(n\log n)$이다. 논문에서는 역시나 저자들 답게 계산 효율을 위해서 FlashConv 라는 새로운 IO 병목 우회 로직도 제시하였지만, 이에 대한 설명은 생략한다.

의의

논문에서는 H3를 SSM의 발전된 계보로 보고 기존 S4에서 잘 하지 못했던 태스크, 특히 Transformer가 강점을 보이는 in-context learning에 관한 태스크들을 잘한다고 이야기하였지만, 이는 어찌보면 당연하다. S4는 분명 Transformer와는 전혀 다른 방향의 모델이었지만, H3는 사실 Transformer에서 attention 계산만 변형시킨 수많은 모델 중 하나에 가깝기 때문이다. 심지어 논문에서의 대부분 실험에서 보여주는 Hybrid 모델은, 결국 일부 레이어는 그냥 Transformer Layer를 사용한 모델이다. 즉, H3는 SSM의 진화형태보다는 Transformer 구조에 SSM을 활용한 모델로서 이해하는 것이 더 명확할 것이다.

물론 그렇다고 H3의 의미가 적다고 생각하지는 않는다. 동일한 성능을 내는 언어 모델을 $\mathcal O(n\log n)$으로 계산효율을 단축시킬 수 있다면, 그것 또한 엄청난 개선이기 때문이다. 저자들은 이를 보여주기 위해 심지어 대규모의 corpus로, 2.7B 사이즈까지 직접 pre-train까지 시키고 GPT와 비교를 하였다.

LLM의 시대가 온 후, 기업단위가 아니고서는 이러한 연구를, 실험조차 하기 쉽지 않은 것을 생각하면 많은 공이 들어간 연구라고 생각한다. 물론 H3이 그만큼 하드웨어 디테일까지 고려한 설계였기 때문일 수도 있지만, 그래도 개인 혹은 연구실 단위로 이러한 실험을 한 것만으로도 쉽지 않은 작업인 것은 분명하다.

그러나, 이렇게 pre-train 시킨 것에 비하면 그렇게 다양한 결과를 보여주지 못하였고, 작은 모델들을 가지고 fine-tuning 등 다른 방식으로 먼저 접근해보았으면 어땠을까에 대한 아쉬움이 남는다. Perplexity 비교는 너무나 볼 수 있는 것이 제한적이고, zero-shot 결과는 분명 성과는 있지만, 그 성능 편차가 GPT에 비해 매우 심하다. 저자들이 리뷰 이후에 추가한 LRA 실험에서 역시 결국 S4D의 성능에 미치지 못한 것을 보아도, SSM의 장점을 다 살렸다고 보기는 어렵다.

결론적으로 H3은 계산효율은 물론 성능에서도 꽤 강점이 드러나는 부분도 있었다. 하지만 그 구조의 타당성도, 실험적인 다양성도, 모델의 특징에 대한 더 자세한 고찰도, 무언가 더 보여줄 수 있을 것 같았지만 많이 아쉬움이 남는 설명으로 마무리된 연구이다.

이제 Mamba에 대해 이야기할 수 있도록 모든 기초 연구를 다루었다. S4는 SSM의 아주 기본적인 모델로서 다루었고, H3은 이 SSM을 Transformer 내부에서 작동하는 역할로 활용하였다. 이 과정에서도 물론 생략된 논문들도 있지만, 대부분 거의 큰 틀에서 중복되는 내용들이 많기에 굳이 따로 다루지 않았고, SSM의 언어모델 활용성을 기준으로 특징적인 연구들만 선정하여 (최대한) 간단히 알아보았다. 이를 바탕으로 본격적으로 Mamba와 Mamba-2에 대해 이야기할텐데, 분량이 많이 길어져 다음 포스팅에서 이어서 진행한다.

  1. Gu et al. On the parameterization and initialization of diagonal state space models. (NeurIPS 2022). link