-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature]: Support Inflight quantization: load as 8bit quantization. #11655
Comments
PTAL https://docs.vllm.ai/en/stable/quantization/fp8.html#quick-start-with-online-dynamic-quantization
|
FP8 does not support GPUs with a compute capability below 8. Additionally, I am quantizing my fine-tuned sequence classification model online. |
There is no way to easily support 8bit inflight quantification.
_unquantized_generator only supports 4bit inflight quantification, very sad
|
The root cause is bitsandbytes seems that 4bit and 8bit use different algorithms and different interfaces bitsandbytes/functional.py only has quantize_4bit,no quantize_8bit https://github.com/bitsandbytes-foundation/bitsandbytes/blob/main/bitsandbytes/functional.py Papers, related resources & how to cite
┓( ´∀` )┏ What is the specific model of your graphics card? Does it support Int8 tensor cores? |
bitsandbytes/functional.py has quantize_blockwise, this is a function that quantifies 8 bits. def quantize_blockwise(
A: torch.Tensor,
code: Optional[torch.Tensor] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=4096,
nested=False,
) -> Tuple[torch.Tensor, QuantState]:
"""Quantize a tensor in blocks of values.
The input tensor is quantized by dividing it into blocks of `blocksize` values.
The the absolute maximum value within these blocks is calculated for scaling
the non-linear quantization.
Args:
A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes.
code (`torch.Tensor`, *optional*):
A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type.
For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561].
absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 4096.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
Raises:
ValueError: Raised when the input data type is not supported.
Returns:
`Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results.
- `torch.Tensor`: The quantized tensor.
- [`QuantState`]: The state object used to undo the quantization.
""" |
You can use hf_overrides
the code comes here:
We can try to implement _unquantized_8bit_generator Change the logic to this:
I'm not familiar with bnb, you can try to implement it _unquantized_8bit_generator |
This only seems to support pp, while inflight quantification supports tp. |
yes If you implement _unquantized_8bit_generator, it is not pre_quant, so it may support tp |
with set_default_torch_dtype(torch.float32):
# processed_weight, quant_state = quantize_4bit(
# loaded_weight,
# compress_statistics=True,
# quant_type="nf4",
# )
processed_weight, quant_state = quantize_blockwise(
loaded_weight,
nested=True,
) Still no. There's a problem here assert param_data.shape == loaded_weight.shape
param_data.shape=torch.Size([35389440, 1]), loaded_weight.shape=torch.Size([5120, 13824]) vllm/vllm/model_executor/layers/linear.py Line 1087 in e1a5c2f
|
5120* 13824 / 35389440 = 2.0 I guess it's because loaded_weight is 16bit and param_data is 8bit |
So it's not just about implementing an unquantized_8bit_generator? |
I don't know what's the difficulty of implementing an 8-bit quantizer? I think the missing part of vllm Bnb8Bit Quantizer can refer to HfQuantizer accelerate is not included in vllm requirements, vllm does not install accelerate default. Bnb8BitHfQuantizer uses accelerate. I don’t know if it is related to this library. View pull history,I'm sure vllm only supports read bnb 8bit pre-quantized model
I believe “Support Inflight quantization: load as 8bit quantization. “ not yet supported Is there any special difficulty in supporting 8bit bnb Inflight quantization? Or is it just that no one has tried it before? |
How about we implement 8bit Inflight quantization in another way?
https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_w8a8_int8
https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modifiers/README.md
I guess this is how “--quantization="fp8" is implemented |
🚀 The feature, motivation and pitch
VLLM supports 4bit inflight quantification, but does not support 8bit, 8bit speed is faster than 4bit, request support for support.
Alternatives
No response
Additional context
No response
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: