03/2023 - Phil
Mainly summarizing how LLM.int8()
works and some remarks on doing more with less.
Problem:
fp32
is 80GBs.Remedy:
Considerations:
Consider a matrix \(\mathbf{X}_{f16}\in\mathbb{R}^{s\times h}, \mathbf{W}_{f16}\in\mathbb{h\times o}\) (in fp16 ). Then $$\frac{\mathbf{X_{f16}}}{ |
\mathbf{X}_{f16} | _\infty}\(has range\)[-1,1]\((note that\) | \cdot | _\infty=\max{ | X_{ij} | :i\in{0,1,\ldots, s-1}, j\in{0,1,\ldots,h-1}}$$). So |
has range \([-127, 127].\) Therefore we can approximate \(\mathbf{X}_{f16}\approx \frac{\mathbf{X}_{i8}}{s_{xf16}}.\)
Now we want to be able to perform matmul with int8:
\[\begin{align*} \mathbf{X_{f16}}\mathbf{W_{f16}}&\approx \frac{\mathbf{X_{i8}\mathbf{W_{i8}}}}{s_{xf16}{s_{wf16}}}\\ &=\frac{\mathbf{C_{i32}}}{s_{xf16}{s_{wf16}}}. \end{align*}\]This means we convert our fp16
matrices to int8
, multiply them, then divide by fp16.
![]() |
---|
Absmax quantization |
This is similar to Absmax but now instead of mapping to \([-127, 127]\) we map to \([0, 255]\) by shifting our values so that that our minimum value starts at zero. By doing this we use the full range of values from \([0, 255].\) With absmax quantization we are simply rescailing, and if, say, most of our values are positive we won’t be using much from \([-127, 0),\) leading to more quantization error.
Zero point quantization is as follows:
\[\begin{align*} nd_{x16}&=\frac{255}{\max\mathbf{X_{f16}}-\min\mathbf{X_{f16}}}\\ zp_{xi16}&=\text{round}[\mathbf{X_{f16}}\min\mathbf{X_{f16}}]\\ \mathbf{X_{i8}}&=\text{round}[nd_x\mathbf{X_{f16}}],\\ \mathbf{C_{i32}}&=(\mathbf{X_{i8}}+zp_{xi16})(\mathbf{W_{i8}}+zp_{wi16})\\ \mathbf{X_{f16}}\mathbf{W_{f16}}&\approx\frac{\mathbf{C_{i32}}}{nd_{x16}nd_{w16}}. \end{align*}\]Note that our GPU needs to be able to accumulate multiplication into int16
int int32
for \(C_{i32},\) otherwise we need to expand and multiply each term separately.
![]() |
---|
Zero point quantization |
When we have a massive matrix \(\mathbf{X_{f16}}\) scailing by the min or max is going to destroy a lot of information due to outliers. The basic remedy here is to selectively apply our quantization to different parts of the matrix.
One straight forward way is to quantize the rows of \(\mathbf{X_{f16}}\) and the columns of \(\mathbf{W_{f16}}\) separately. Then we end up with vectors \(\mathbf{c_{xf16}}\in\mathbb{R}^s, \mathbf{c_{wf16}}\in\mathbf{R}^o,\) which hold the scailing factors. Then after computing these vectors our quantized multiplication becomes \(\mathbf{X_{f16}}\mathbf{W_{f16}}\approx\frac{1}{\mathbf{c_{xf16}}\otimes\mathbf{c_{wf16}}}\cdot\mathbf{C_{i32}}\) where \(\otimes\) is the outer product.
Doing this is going to be a lot better than naive zeropoint or absmax quantization. However, there’s still large performance degradation.
There are a few things that are good to know about the outliers in the model:
Using these facts one remedy is to use mixed-precision decomposition. As most of the outliers are in a small number of columns of \(\mathbf{X_{f16}}\), we can separate those columns out and multiply them separately with the outlier rows of \(W_{f16}\) using normal fp16
matmul and multiply the rest using our quantization methods. After decomposing, we can merge the decomposed matrices back into our output matrix. Since the dimensions affected are not larger than 7 for a 13B parameter model, 99.9% of our values are 8-bit.
Using this decomposition method the Dettmers paper is able to achieve 32bit float perplexity on int8
:
![]() |
If your model is in PyTorch (and you have a CUDA GPU), one can probably use the bitsandbytes
wrapper pretty straightforwardly. It’s as simple as replacing torch.nn.Linear(...)
with bnb.nn.Linear8bitLt(...)
with threshold=k
set.
A lot of people might be CPU memory bound still as we still need to load the original model into DRAM, so one can create a swapfile to offload any extra memory into disk. I’m guessing one could skip some of this CPU memory usage in the future if they already precomputed the conversion and prestored the int8
and scaling factors too.
Note that huggingface also already integrates with bitsandbytes
pretty conveniently.
We can use even fewer bits if we want. Empirically it seems that going down to int4
does not damage performance much, although we start seeing more degradation once we use int3
. I would not be surprised if many commercial models are using int4
quantization.
Currently LLM.int8()
does not support int4
quantization although it might soon. I am planning on working to rush out a naive form of int4
(potentially slow kernel but saves on memory).
There are a few reasons why GPU int4
isn’t as widely available to developers currently (although I suspect it might be soon):
int8
or bool
.int4
operations. Therefore we may need to move in the kernel level to support this.There isn’t theoretically anything about int4
that makes it that much harder than int8
. There are just slightly more implementation details as fewer people are comfortable working on the kernel level and the NVIDIA documentation isn’t super transparent.
Basic plan of attack:
int4
tensor operations. This is a warp level operation and it seems we need to pass in matrices of size 8x4 int32
to use it. The 4 ints column wise are interpreted as 8 int4
s. Then we need to incorporate the warp level operation into an actual GEMM kernel to be able to multiply matrices.
int4
values in PyTorch we pack the values in the way the int4
WMMA wants us to–we pack values in an int32
tensor by bitshifting values together.int4
kernel -> dequantize kernel output and return fp16
output.fp16
is going to be 26GBs, but an RTX 4090 (around $2000) is only 24GB of VRAM. With LLM.int8()
one can make minor modifications to the code and now the entire model can fit into VRAM.