[Paper Review] LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale
논문링크(Published at NeurIPS 2022)
https://arxiv.org/abs/2208.07339
Introduction
최근에 사용하는 Large Langugage Models(LLMs)는 추론을 위해 상단한 메모리를 요구한다. 이와 관련해서 LoRA(LoRA: Low-Rank Adaptation of Large Language Models) 기술이 나왔다. 파라미터의 사이지를 줄이는 기술 중 하나로 파라미터를 더 낮은 비트로 양자화하고 낮은 비트 정확도 행렬곱을 사용한다. 이 논문에서는 트랜스포머 라이브러리의 8비트 양자화 방법대해서 개발했다. 이 방법을 통해서 성능은 유지하며 이전보다 더 낮은 메모리 할당량을 갖게하려고 한다. 이전 기술은 350M 파라미터보다 적은 경우엔 충분히 좋았지만, 350M 파라미터보다 많아짐에 따라서 성능하락이 눈에 띄게 보였으며, 이에 대해 해결을 위해 노력했다.
이 논문에서 Int8 양자화 방식을 통해서 성능 하락을 초래하지 않는 것을 보여준다. 2가지 핵심 문제들을 해결한다.
- 1B가 넘는 모델에서 더 높은 양자화 정확도
- 6.7B(67억) 파라미터 규모부터 시작하여 모든 Transformer 레이어에서 나타나기 시작하는, 희소하지만 체계적인 큰 크기의 이상치 특성이 생겨날 경우, 이는 양자화 정밀도를 망가뜨리게 되므로 이를 명시적으로 표현
6.7B이상의 모델에 대해 성능 하락 없이 스케일링하기 위해, 추론동안에 은닉층의 특징 차원에서의 극한의 이상치의 생성에 대해 이해하는 것이 중요하다. 6.7B(67억) 규모에서는 시퀀스당 15만 개의 이상치가 발생하지만, 이는 Transformer 전체에서 단 6개의 특징 차원(feature dimension)에만 집중되어 있다. 이러한 이상치를 0으로 변환하는 순간 다음과 같은 현상이 발생한다.
- top-1 attention softmax 20% 이상 감소한다.
- Perplexity가 600~1000% 약화된다.
다음은 결국 모델이 텍스트를 예측하는데 훨씬 어려움을 겪는다. 이상치는 모델에서 추론을 하는데 매우 중요하므로 0으로 바꾸면 안되며, 다른 방식의 접근이 필요하다.
그래서 논문에서는 다음과 같은 방식을 사용한다.
- 99.9%의 보통값에 대해서는 8bit 행렬곱을 사용
- 0.1%의 이상치에 대해서는 16bit 행렬곱을 사용
이러한 방식을 통해서 성능하락이 줄어들지 않고 이상치에 대한 새로운 시각과 LLMs에 양자화기법을 사용할 수 있도록 한다.
다음은 사용예시이며 간단하게 설명해보려고 한다.
- 이상치가 아닌 값(위) : 이상치가 아닌 값들을 추출하고 양자화하여 int8로 변환하고 행렬곱을 하고 양자화를 복구한다.
- 이상치 값 : 이상치값들을 추출한 뒤 곱한다. 곱한 값을 이상치가 아닌 값과 더한다.
이를 통해 얻는 점은 다음과 같다.
- 8bit 연산을 통한 일반 값에 대해 효율성이 증가
- 16bit 연산을 통한 이상치에 대한 정밀도 증가 및 성능 하락 방지
Background
여기서 양자화 기법에 대해 간단하게 설명하고 각각에 장단점을 적어보려고 한다.
Absmax quantization
입력데이터를 절대값이 최대인 값으로 나눈 뒤 8비트 범위에 맞게 스케일링하여 정보를 최대한으로 보존하여 데이터 크기를 줄이는 방법
- $X_{f16}$, $W_{f16}$: 16비트 부동소수점 입력 행렬과 가중치 행렬.
- $X_{int8}$, $W_{int8}$: 8비트로 양자화된 입력 행렬과 가중치 행렬.
- $C_x$, $C_w$: 각각 입력 행렬($X_{f16}$)의 행별 최대값과 가중치 행렬($W_{f16}$)의 열별 최대값.
- $Out_{int32}$: 8비트 양자화된 행렬들의 곱셈 결과(Int32).
- $Out_{f16}$: 디양자화된 최종 출력(16비트 부동소수점).
- $\otimes$: 두 스케일링 상수의 외적(outer product).
계산방식(접은 글 클릭)
Zeropoint quantization
정규화와 제로포인트 이동을 통해 입력데이터가 8비트의 전체 범위를 활용하도록 하여 비대칭분포에서도 양자화 오류를 줄이고 효율성을 극대화 하는 방식
- $nd_{x_{f16}}$ : 이 수식은 정규화 계수를 계산하는 식
- $zp_{x_{i16}}$ : 이 수식은 제로 포인트(zero point) 값을 계산
- $\mathbf{X}_{i8}$ : 이 수식은 데이터를 양자화(quantization)하여 8비트 정수 값 $\mathbf{X}_{i8}$로 변환하는 과정
- $C_{i32} = \text{multiply}{i16}(A{zp_{a_{i16}}}, B_{zp_{b_{i16}}}) = (A_{i8} + zp_{a_{i16}})(B_{i8} + zp_{b_{i16}})$
- 이 수식은 두 양자화된 값 $A_{i8}$와 $B_{i8}$을 곱할 때 사용하는 계산식
- 양자화된 값은 정수로 표현되며, 제로 포인트(zero point) $zp_{a_{i16}}$ 및 $zp_{b_{i16}}$를 포함
- $C_{i32} = A_{i8}B_{i8} + A_{i8}zp_{b_{i16}} + B_{i8}zp_{a_{i16}} + zp_{a_{i16}}zp_{b_{i16}}$
- 이 확장은 주로 GPU나 TPU와 같이 $\text{multiply}_{i16}$ 명령어를 사용할 수 없는 환경에서 활용
Int8 Matrix Multiplication with 16bit Float and Outputs
여기서 8비트 행렬곱과 16비트 행렬곱을 이용한 결과에 대해 설명하고 있다.
- $\mathbf{X}_{f16}$: 16비트 부동소수점 입력 행렬 (Hidden states)
- $\mathbf{W}_{f16}$: 16비트 부동소수점 가중치 행렬
- $\mathbf{C}_{f16}$: 계산 결과인 16비트 부동소수점 출력 행렬
- $c_{x_{f16}}$, $c_{w_{f16}}$: 각각 입력 행렬과 가중치 행렬의 텐서 단위 스케일링 상수
- $\mathbf{C}_{i32}$: 중간 계산에서 32비트 정수로 변환된 출력 행렬
- 이 수식은 부동소수점 곱셈을 정수 곱셈으로 변환하여 효율성을 높이는 과정입니다. 정규화를 위해 $c_{x_{f16}}$, $c_{w_{f16}}$으로 나누어 복원
- $S_{f16}$: 스케일링 상수로, $c_{x_{f16}} c_{w_{f16}}$의 역수에 해당
- $A_{i8}, B_{i8}$: 양자화된 8비트 정수로 변환된 입력 값
- $Q(\cdot)$: 양자화 함수로, 부동소수점 입력을 양자화 (제로 포인트 또는 absmax 방식을 사용)
이 수식은 양자화된 8비트 정수행렬을 사용하여 효율적인 곱셈 연산을 수행하고, 부동소수점 값으로 복원하는 과정을 설명한다.
Int8 Matrix Multiplication at Scale
이전 사용 방식 및 문제점은 모든 값에 대해 동일한 양자화 정확도를 사용하다보니 매우 정확도가 낮앗다. 특히 벡터 단위의 양자화가 아니라 블록 방식 양자화에 의해 생긴 문제 방삭이다. 그래서 여기서는 값에 따라 양자화를 진행하는 벡터 단위 양자화 방식을 사용한다.
- 99.9%의 보통값에 대해서는 8bit 행렬곱을 사용
- 0.1%의 이상치에 대해서는 16bit 행렬곱을 사용
Vector-wise Quantization
- $\mathbf{C}_{f16}$: 부동소수점 행렬 곱셈 결과
- $\mathbf{C}_{i32}$: 양자화된 정수 행렬 곱셈의 중간 결과
- $S$: 스케일링 상수
- $Q(\cdot)$: 양자화 함수
- $\mathbf{A}{i8}, \mathbf{B}{i8}$: 양자화된 8비트 정수 행렬
이 방법은 기존의 행렬 곱셈 양자화 방식보다 더 세밀한 조정이 가능하며, 특히 딥러닝 모델에서 양자화의 성능을 최적화하는 데 유리하다.
The Core of LLM.int8(): Mixed-precision Decomposition
이상치 특징(outlier features)을 고정밀도로 계산하면서, 나머지 대부분의 값(99.9%)은 메모리 효율적인 8비트 곱셈을 사용하여 계산
- $\mathbf{X}{f16}^h \mathbf{W}{f16}^h$: 이상치 차원 $O$에 대해 16비트 고정밀도로 곱셈 수행.
- $\mathbf{X}{i8}^h \mathbf{W}{i8}^h$: 나머지 차원($h \notin O$)에 대해 메모리 효율적인 8비트 정수 곱셈 수행.
- $S_{f16}$: 8비트 입력과 가중치 행렬을 디양자화(dequantization)하기 위한 정규화(term) 스케일링 상수.
이상치 특징 정의:
- 입력 행렬 $\mathbf{X}_{f16} \in \mathbb{R}^{s \times h}$에서 이상치는 거의 모든 **시퀀스 차원 $s$**에 걸쳐 체계적으로 발생하지만, 특정 **히든 차원 $h$**에 제한됩니다.
- 따라서, 혼합 정밀도 분해에서는 이상치 특징 차원을 분리:
- 이상치 집합 $O = {i | i \in \mathbb{Z}, 0 \leq i \leq h}$는 특징/히든 차원 중, 크기가 임계값 $\alpha$ 이상인 차원만 포함합니다.
- 실험 결과 $\alpha = 6.0$을 사용하면 트랜스포머 성능 저하가 거의 0에 가깝게 감소함.
실험과 관련해서는 설명을 생략한다. 이 내용의 핵심은 양자화를 동적으로 하는데 있으며 이 글을 통해 양자화를 어떻게 진행하는지 알게되었으면 좋겠다.