Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Llama-3.1-Nemotron-70B-Instruct-HF W8A8 has ValueError: Failed to invert hessian due to numerical instability #1019

Open
fan-niu opened this issue Dec 31, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@fan-niu
Copy link

fan-niu commented Dec 31, 2024

Describe the bug
Llama-3.1-Nemotron-70B-Instruct-HF W8A8 has ValueError: Failed to invert hessian due to numerical instability

Expected behavior
A normal W8A8 model can be obtained and the output accuracy is normal.

Environment
Include all relevant environment information:

Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.12.7 (main, Oct  1 2024, 08:52:12) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-124-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 550.90.07
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

vllm                              0.6.6.post1
llmcompressor                     0.3.1
accelerate                        1.1.1
torch                             2.5.1
transformers                      4.46.2
xformers                          0.0.28.post3

nvidia-cublas-cu12                12.4.5.8
nvidia-cuda-cupti-cu12            12.4.127
nvidia-cuda-nvrtc-cu12            12.4.127
nvidia-cuda-runtime-cu12          12.4.127
nvidia-cudnn-cu12                 9.1.0.70
nvidia-cufft-cu12                 11.2.1.3
nvidia-curand-cu12                10.3.5.147
nvidia-cusolver-cu12              11.6.1.9
nvidia-cusparse-cu12              12.3.1.170
nvidia-ml-py                      12.560.30
nvidia-nccl-cu12                  2.21.5
nvidia-nvjitlink-cu12             12.4.127
nvidia-nvtx-cu12                  12.4.127

To Reproduce
My code from vllm official https://docs.vllm.ai/en/latest/quantization/int8.html

from llmcompressor.transformers import SparseAutoModelForCausalLM
from transformers import AutoTokenizer
MODEL_ID = "./Llama-3.1-Nemotron-70B-Instruct-HF"
model = SparseAutoModelForCausalLM.from_pretrained(
    MODEL_ID, device_map="auto", torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

from datasets import load_dataset

NUM_CALIBRATION_SAMPLES = 2048
MAX_SEQUENCE_LENGTH = 32768

# Load and preprocess the dataset
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))

def preprocess(example):
    return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
ds = ds.map(preprocess)

def tokenize(sample):
    return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(tokenize, remove_columns=ds.column_names)

from llmcompressor.transformers import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier

# Configure the quantization algorithms
recipe = [
    SmoothQuantModifier(smoothing_strength=0.5),
    GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"], dampening_frac=0.1),
]

# Apply quantization
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Save the compressed model
SAVE_DIR = MODEL_ID + "/vllm-SMQ-W8A8-Dynamic-Per-Token-32k"
model.save_pretrained(SAVE_DIR, save_compressed=True)
print(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

Errors

2024-12-30T08:30:40.393240-0800 | compress_module | INFO - Compressing model.layers.3.model.layers.3.mlp.up_proj...
2024-12-30T08:30:42.984565-0800 | compress | METRIC - time 2.59
2024-12-30T08:30:42.986384-0800 | compress | METRIC - error 7.39
2024-12-30T08:30:42.986655-0800 | compress | METRIC - GPU 0 | usage: 33.10% | total memory: 85 GB
2024-12-30T08:30:42.986693-0800 | compress | METRIC - GPU 1 | usage: 23.62% | total memory: 85 GB
2024-12-30T08:30:42.986717-0800 | compress | METRIC - GPU 2 | usage: 23.62% | total memory: 85 GB
2024-12-30T08:30:42.986745-0800 | compress | METRIC - GPU 3 | usage: 23.62% | total memory: 85 GB
2024-12-30T08:30:42.986764-0800 | compress | METRIC - GPU 4 | usage: 23.62% | total memory: 85 GB
2024-12-30T08:30:42.986780-0800 | compress | METRIC - GPU 5 | usage: 23.62% | total memory: 85 GB
2024-12-30T08:30:42.986802-0800 | compress | METRIC - GPU 6 | usage: 23.62% | total memory: 85 GB
2024-12-30T08:30:42.986823-0800 | compress | METRIC - GPU 7 | usage: 15.54% | total memory: 85 GB
2024-12-30T08:30:42.986875-0800 | compress | METRIC - Compressed layer size: 469.848064 MB
2024-12-30T08:30:42.987115-0800 | compress_module | INFO - Compressing model.layers.3.model.layers.3.mlp.down_proj...
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py", line 176, in compress
    self.H = torch.linalg.cholesky(self.H, upper=True)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._C._LinAlgError: linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 28364 is not positive-definite).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/code/smooth_quant.py", line 43, in <module>
    oneshot(
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/transformers/finetune/text_generation.py", line 82, in oneshot
    main(model_args, data_args, training_args)
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/transformers/finetune/text_generation.py", line 381, in main
    stage_runner.one_shot()
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/transformers/finetune/runner.py", line 166, in one_shot
    self.trainer.one_shot(calibration_data=calib_data, stage=stage)
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/transformers/finetune/session_mixin.py", line 440, in one_shot
    apply(
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/core/session_functions.py", line 184, in apply
    return active_session().apply(
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/core/session.py", line 212, in apply
    self.initialize(**kwargs)
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/core/session.py", line 158, in initialize
    mod_data = self._lifecycle.initialize(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/core/lifecycle.py", line 126, in initialize
    data = mod.initialize(state=self.state, **extras)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/stage.py", line 124, in initialize
    modifier.initialize(state, **kwargs)
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/modifier.py", line 118, in initialize
    initialized = self.on_initialize(state=state, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/quantization/gptq/base.py", line 204, in on_initialize
    self.apply_compression(calibration_dataloader)
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/quantization/gptq/base.py", line 302, in apply_compression
    layer_compressor.compress()
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/utils/layer_compressor.py", line 177, in compress
    self.layer.apply(compress_module)
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1029, in apply
    module.apply(fn)
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1029, in apply
    module.apply(fn)
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1030, in apply
    fn(self)
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/utils/layer_compressor.py", line 174, in compress_module
    module.compress(**self.args)
  File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py", line 179, in compress
    raise ValueError(
ValueError: Failed to invert hessian due to numerical instability. Consider increasing GPTQModifier.dampening_frac, increasing the number of calibration samples, or shuffling the calibration dataset

Additional context
I have increased the number of datasets from 512 to 2048 and adjusted dampening_frac=0.1 but still have ValueError: Failed to invert hessian due to numerical instability. In addition, adjusting dampening_frac=0.1 will cause big accuracy problems. Is there any solution to these two problems? Thanks

@fan-niu fan-niu added the bug Something isn't working label Dec 31, 2024
@dsikka dsikka self-assigned this Jan 1, 2025
@fan-niu
Copy link
Author

fan-niu commented Jan 2, 2025

@dsikka @robertgshaw2-neuralmagic @kylesayrs hi, could you help take a look into this issue? Thanks a lot !

@dsikka
Copy link
Collaborator

dsikka commented Jan 3, 2025

Hi @fan-niu
Could you try one of the suggested changes listed here, to see if they might help?

  1. increasing GPTQModifier.dampening_frac
  2. increasing the number of calibration samples
  3. shuffling the calibration dataset

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants