You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
from llmcompressor.transformers import SparseAutoModelForCausalLM
from transformers import AutoTokenizer
MODEL_ID = "./Llama-3.1-Nemotron-70B-Instruct-HF"
model = SparseAutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
from datasets import load_dataset
NUM_CALIBRATION_SAMPLES = 2048
MAX_SEQUENCE_LENGTH = 32768
# Load and preprocess the dataset
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
def preprocess(example):
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
ds = ds.map(preprocess)
def tokenize(sample):
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(tokenize, remove_columns=ds.column_names)
from llmcompressor.transformers import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
# Configure the quantization algorithms
recipe = [
SmoothQuantModifier(smoothing_strength=0.5),
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"], dampening_frac=0.1),
]
# Apply quantization
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Save the compressed model
SAVE_DIR = MODEL_ID + "/vllm-SMQ-W8A8-Dynamic-Per-Token-32k"
model.save_pretrained(SAVE_DIR, save_compressed=True)
print(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
Errors
2024-12-30T08:30:40.393240-0800 | compress_module | INFO - Compressing model.layers.3.model.layers.3.mlp.up_proj...
2024-12-30T08:30:42.984565-0800 | compress | METRIC - time 2.59
2024-12-30T08:30:42.986384-0800 | compress | METRIC - error 7.39
2024-12-30T08:30:42.986655-0800 | compress | METRIC - GPU 0 | usage: 33.10% | total memory: 85 GB
2024-12-30T08:30:42.986693-0800 | compress | METRIC - GPU 1 | usage: 23.62% | total memory: 85 GB
2024-12-30T08:30:42.986717-0800 | compress | METRIC - GPU 2 | usage: 23.62% | total memory: 85 GB
2024-12-30T08:30:42.986745-0800 | compress | METRIC - GPU 3 | usage: 23.62% | total memory: 85 GB
2024-12-30T08:30:42.986764-0800 | compress | METRIC - GPU 4 | usage: 23.62% | total memory: 85 GB
2024-12-30T08:30:42.986780-0800 | compress | METRIC - GPU 5 | usage: 23.62% | total memory: 85 GB
2024-12-30T08:30:42.986802-0800 | compress | METRIC - GPU 6 | usage: 23.62% | total memory: 85 GB
2024-12-30T08:30:42.986823-0800 | compress | METRIC - GPU 7 | usage: 15.54% | total memory: 85 GB
2024-12-30T08:30:42.986875-0800 | compress | METRIC - Compressed layer size: 469.848064 MB
2024-12-30T08:30:42.987115-0800 | compress_module | INFO - Compressing model.layers.3.model.layers.3.mlp.down_proj...
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py", line 176, in compress
self.H = torch.linalg.cholesky(self.H, upper=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._C._LinAlgError: linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 28364 is not positive-definite).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/code/smooth_quant.py", line 43, in <module>
oneshot(
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/transformers/finetune/text_generation.py", line 82, in oneshot
main(model_args, data_args, training_args)
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/transformers/finetune/text_generation.py", line 381, in main
stage_runner.one_shot()
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/transformers/finetune/runner.py", line 166, in one_shot
self.trainer.one_shot(calibration_data=calib_data, stage=stage)
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/transformers/finetune/session_mixin.py", line 440, in one_shot
apply(
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/core/session_functions.py", line 184, in apply
return active_session().apply(
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/core/session.py", line 212, in apply
self.initialize(**kwargs)
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/core/session.py", line 158, in initialize
mod_data = self._lifecycle.initialize(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/core/lifecycle.py", line 126, in initialize
data = mod.initialize(state=self.state, **extras)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/stage.py", line 124, in initialize
modifier.initialize(state, **kwargs)
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/modifier.py", line 118, in initialize
initialized = self.on_initialize(state=state, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/quantization/gptq/base.py", line 204, in on_initialize
self.apply_compression(calibration_dataloader)
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/quantization/gptq/base.py", line 302, in apply_compression
layer_compressor.compress()
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/utils/layer_compressor.py", line 177, in compress
self.layer.apply(compress_module)
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1029, in apply
module.apply(fn)
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1029, in apply
module.apply(fn)
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1030, in apply
fn(self)
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/utils/layer_compressor.py", line 174, in compress_module
module.compress(**self.args)
File "/usr/local/lib/python3.12/dist-packages/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py", line 179, in compress
raise ValueError(
ValueError: Failed to invert hessian due to numerical instability. Consider increasing GPTQModifier.dampening_frac, increasing the number of calibration samples, or shuffling the calibration dataset
Additional context
I have increased the number of datasets from 512 to 2048 and adjusted dampening_frac=0.1 but still have ValueError: Failed to invert hessian due to numerical instability. In addition, adjusting dampening_frac=0.1 will cause big accuracy problems. Is there any solution to these two problems? Thanks
The text was updated successfully, but these errors were encountered:
Describe the bug
Llama-3.1-Nemotron-70B-Instruct-HF W8A8 has ValueError: Failed to invert hessian due to numerical instability
Expected behavior
A normal W8A8 model can be obtained and the output accuracy is normal.
Environment
Include all relevant environment information:
To Reproduce
My code from vllm official https://docs.vllm.ai/en/latest/quantization/int8.html
Errors
Additional context
I have increased the number of datasets from 512 to 2048 and adjusted dampening_frac=0.1 but still have ValueError: Failed to invert hessian due to numerical instability. In addition, adjusting dampening_frac=0.1 will cause big accuracy problems. Is there any solution to these two problems? Thanks
The text was updated successfully, but these errors were encountered: