Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Support Inflight quantization: load as 8bit quantization. #11655

Open
1 task done
ShelterWFF opened this issue Dec 31, 2024 · 14 comments
Open
1 task done

[Feature]: Support Inflight quantization: load as 8bit quantization. #11655

ShelterWFF opened this issue Dec 31, 2024 · 14 comments

Comments

@ShelterWFF
Copy link

🚀 The feature, motivation and pitch

VLLM supports 4bit inflight quantification, but does not support 8bit, 8bit speed is faster than 4bit, request support for support.

Alternatives

No response

Additional context

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@noooop
Copy link
Contributor

noooop commented Jan 1, 2025

PTAL

https://docs.vllm.ai/en/stable/quantization/fp8.html#quick-start-with-online-dynamic-quantization

Dynamic quantization of an original precision BF16/FP16 model to FP8 can be achieved with vLLM without any calibration data required. You can enable the feature by specifying --quantization="fp8" in the command line or setting quantization="fp8" in the LLM constructor.

@Bryce1010
Copy link
Contributor

@ShelterWFF
Copy link
Author

PTAL

https://docs.vllm.ai/en/stable/quantization/fp8.html#quick-start-with-online-dynamic-quantization

使用 vLLM 可以实现将原始精密 BF16/FP16 模型动态量化为 FP8,而无需任何校准数据。您可以通过在命令行中指定 --quantization=“fp8” 或在 LLM 构造函数中设置 quantization=“fp8” 来启用该功能。

FP8 does not support GPUs with a compute capability below 8. Additionally, I am quantizing my fine-tuned sequence classification model online.

@noooop
Copy link
Contributor

noooop commented Jan 3, 2025

There is no way to easily support 8bit inflight quantification.

https://github.com/vllm-project/vllm/blob/fd3a62a122fcbc9331d000b325e72687629ef1bd/vllm/model_executor/model_loader/loader.py#L804C1-L815C79

        if pre_quant:
            if load_8bit:              # load_8bit only supports the calibration version, not inflight quantification
                return self._quantized_8bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
            else:
                return self._quantized_4bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict

        return self._unquantized_generator(hf_weights_files, use_safetensors,
                                           quant_state_dict), quant_state_dict

_unquantized_generator only supports 4bit inflight quantification, very sad

https://github.com/vllm-project/vllm/blob/fd3a62a122fcbc9331d000b325e72687629ef1bd/vllm/model_executor/model_loader/loader.py#L970C1-L975C22

                with set_default_torch_dtype(torch.float32):
                    processed_weight, quant_state = quantize_4bit(
                        loaded_weight,
                        compress_statistics=True,
                        quant_type="nf4",
                    )

@noooop
Copy link
Contributor

noooop commented Jan 3, 2025

The root cause is bitsandbytes seems that 4bit and 8bit use different algorithms and different interfaces

bitsandbytes/functional.py only has quantize_4bit,no quantize_8bit

https://github.com/bitsandbytes-foundation/bitsandbytes/blob/main/bitsandbytes/functional.py

Papers, related resources & how to cite

  • The case for 4-bit precision: k-bit Inference Scaling Laws (Dec 2022)
  • LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale (Nov 2022)
  • 8-bit Optimizers via Block-wise Quantization (Oct 2021)

┓( ´∀` )┏

@ShelterWFF

What is the specific model of your graphics card? Does it support Int8 tensor cores?

@ShelterWFF
Copy link
Author

The root cause is bitsandbytes seems that 4bit and 8bit use different algorithms and different interfaces

bitsandbytes/functional.py only has quantize_4bit,no quantize_8bit

https://github.com/bitsandbytes-foundation/bitsandbytes/blob/main/bitsandbytes/functional.py

Papers, related resources & how to cite

  • The case for 4-bit precision: k-bit Inference Scaling Laws (Dec 2022)
  • LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale (Nov 2022)
  • 8-bit Optimizers via Block-wise Quantization (Oct 2021)

┓( ´∀` )┏

@ShelterWFF

What is the specific model of your graphics card? Does it support Int8 tensor cores?

bitsandbytes/functional.py has quantize_blockwise, this is a function that quantifies 8 bits.

def quantize_blockwise(
    A: torch.Tensor,
    code: Optional[torch.Tensor] = None,
    absmax: Optional[torch.Tensor] = None,
    out: Optional[torch.Tensor] = None,
    blocksize=4096,
    nested=False,
) -> Tuple[torch.Tensor, QuantState]:
    """Quantize a tensor in blocks of values.

    The input tensor is quantized by dividing it into blocks of `blocksize` values.
    The the absolute maximum value within these blocks is calculated for scaling
    the non-linear quantization.

    Args:
        A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes.
        code (`torch.Tensor`, *optional*):
            A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type.
            For more details, see  (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561].
        absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
        out (`torch.Tensor`, *optional*): A tensor to use to store the result.
        blocksize (`int`, *optional*):
            The size of the blocks. Defaults to 4096.
            Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
        nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.

    Raises:
        ValueError: Raised when the input data type is not supported.

    Returns:
        `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results.
        - `torch.Tensor`: The quantized tensor.
        - [`QuantState`]: The state object used to undo the quantization.
    """

@noooop
Copy link
Contributor

noooop commented Jan 3, 2025

You can use hf_overrides

from vllm import LLM
import torch
model_id = "huggyllama/llama-7b"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes", hf_overrides={"quantization_config": {"load_in_8bit": True, "quant_method": "bitsandbytes"}})

the code comes here:

https://github.com/vllm-project/vllm/blob/fd3a62a122fcbc9331d000b325e72687629ef1bd/vllm/model_executor/model_loader/loader.py#L804C1-L815C79

        if pre_quant:
            if load_8bit:              # <- here: load_8bit only supports the calibration version, not inflight quantification
                return self._quantized_8bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
            else:
                return self._quantized_4bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict

        return self._unquantized_generator(hf_weights_files, use_safetensors,
                                           quant_state_dict), quant_state_dict

We can try to implement _unquantized_8bit_generator

Change the logic to this:

        if pre_quant:
            if load_8bit:              # <- here: load_8bit only supports the calibration version, not inflight quantification
                return self._quantized_8bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
            else:
                return self._quantized_4bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
         else:
             if load_8bit:
                  return self._unquantized_8bit_generator(hf_weights_files, use_safetensors,
                                           quant_state_dict), quant_state_dict
             else:
                 return self._unquantized_generator(hf_weights_files, use_safetensors,
                                           quant_state_dict), quant_state_dict

I'm not familiar with bnb, you can try to implement it _unquantized_8bit_generator

@ShelterWFF
Copy link
Author

You can use hf_overrides

from vllm import LLM
import torch
model_id = "huggyllama/llama-7b"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes", hf_overrides={"quantization_config": {"load_in_8bit": True, "quant_method": "bitsandbytes"}})

the code comes here:

https://github.com/vllm-project/vllm/blob/fd3a62a122fcbc9331d000b325e72687629ef1bd/vllm/model_executor/model_loader/loader.py#L804C1-L815C79

        if pre_quant:
            if load_8bit:              # <- here: load_8bit only supports the calibration version, not inflight quantification
                return self._quantized_8bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
            else:
                return self._quantized_4bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict

        return self._unquantized_generator(hf_weights_files, use_safetensors,
                                           quant_state_dict), quant_state_dict

We can try to implement _unquantized_8bit_generator

Change the logic to this:

        if pre_quant:
            if load_8bit:              # <- here: load_8bit only supports the calibration version, not inflight quantification
                return self._quantized_8bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
            else:
                return self._quantized_4bit_generator(
                    hf_weights_files, use_safetensors,
                    quant_state_dict), quant_state_dict
         else:
             if load_8bit:
                  return self._unquantized_8bit_generator(hf_weights_files, use_safetensors,
                                           quant_state_dict), quant_state_dict
             else:
                 return self._unquantized_generator(hf_weights_files, use_safetensors,
                                           quant_state_dict), quant_state_dict

I'm not familiar with bnb, you can try to implement it _unquantized_8bit_generator

This only seems to support pp, while inflight quantification supports tp.

@noooop
Copy link
Contributor

noooop commented Jan 3, 2025

This only seems to support pp, while inflight quantification supports tp.

yes

https://github.com/vllm-project/vllm/blob/e1a5c2f0a123835558b1b1c9895181161527c55e/vllm/model_executor/model_loader/loader.py#L1068C1-L1073C39

If you implement _unquantized_8bit_generator, it is not pre_quant, so it may support tp

@ShelterWFF
Copy link
Author

ShelterWFF commented Jan 5, 2025

This only seems to support pp, while inflight quantification supports tp.

yes

https://github.com/vllm-project/vllm/blob/e1a5c2f0a123835558b1b1c9895181161527c55e/vllm/model_executor/model_loader/loader.py#L1068C1-L1073C39

If you implement _unquantized_8bit_generator, it is not pre_quant, so it may support tp

                with set_default_torch_dtype(torch.float32):
                    # processed_weight, quant_state = quantize_4bit(
                    #     loaded_weight,
                    #     compress_statistics=True,
                    #     quant_type="nf4",
                    # )
                    processed_weight, quant_state = quantize_blockwise(
                        loaded_weight,
                        nested=True,
                    )

Still no. There's a problem here

assert param_data.shape == loaded_weight.shape 
param_data.shape=torch.Size([35389440, 1]), loaded_weight.shape=torch.Size([5120, 13824])

assert param_data.shape == loaded_weight.shape

@noooop
Copy link
Contributor

noooop commented Jan 6, 2025

5120* 13824 / 35389440 = 2.0

I guess it's because loaded_weight is 16bit and param_data is 8bit

@ShelterWFF
Copy link
Author

5120* 13824 / 35389440 = 2.0

I guess it's because loaded_weight is 16bit and param_data is 8bit

So it's not just about implementing an unquantized_8bit_generator?

@noooop
Copy link
Contributor

noooop commented Jan 6, 2025

  • Is the load_in_8bit parameter in vllm quantization_config? yes
  • Is bnb Linear 8bit apply implemented? yes
  • 16bit weight to 8bit (Quantizer) ? missing?

I don't know what's the difficulty of implementing an 8-bit quantizer?

I think the missing part of vllm Bnb8Bit Quantizer can refer to HfQuantizer

https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_bnb_8bit.py

accelerate is not included in vllm requirements, vllm does not install accelerate default.

Bnb8BitHfQuantizer uses accelerate. I don’t know if it is related to this library.

View pull history,I'm sure vllm only supports read bnb 8bit pre-quantized model

refer to

Currently, vLLM only supports 4-bit for in-flight quantization, see: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader/loader.py#L997.
vLLM's load_in_4bit/load_in_8bit arg are used for pre-quantization(at least for now). In general, they are obtained from the model's configuration file,see: https://huggingface.co/openbmb/MiniCPM-V-2_6-int4/blob/main/config.json#L28

I believe “Support Inflight quantization: load as 8bit quantization. “ not yet supported

@jeejeelee

Is there any special difficulty in supporting 8bit bnb Inflight quantization? Or is it just that no one has tried it before?

@noooop
Copy link
Contributor

noooop commented Jan 6, 2025

@ShelterWFF

How about we implement 8bit Inflight quantization in another way?

  1. See if INT8 W8A8 kernel can be used in your device

INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper).

https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_w8a8_int8

  1. use "One-shot algorithm"

https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modifiers/README.md

  1. put the two pieces together

I guess this is how “--quantization="fp8" is implemented

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants