Skip to content

Commit

Permalink
patch
Browse files Browse the repository at this point in the history
  • Loading branch information
HanGuo97 committed Aug 2, 2024
1 parent d01ff81 commit a33f6fc
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 31 deletions.
2 changes: 1 addition & 1 deletion flute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from . import _C
from . import ops

__version__ = "0.0.5"
__version__ = "0.0.6"

QGEMM_SIMPLE_TYPE = Callable[
[
Expand Down
39 changes: 23 additions & 16 deletions flute/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from accelerate.hooks import (
ModelHook,
add_hook_to_module)
from typing import Optional
from typing import Optional, Dict

import flute
import flute.utils
Expand All @@ -36,17 +36,22 @@ def get_accelerate_hook(name: str, module: torch.nn.Module, allow: bool) -> Opti
# 2/4
@torch.no_grad()
def prepare_model_flute(
name: str,
module: torch.nn.Module,
num_bits: int,
group_size: int,
fake: bool,
handle_hooks: bool = False,
custom_scales_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None:

warnings.warn(f"Quantization always happen on 1st GPU")

def _replace_linear(_module: torch.nn.Module) -> None:
for name, child in _module.named_children():
def _replace_linear(_name: str, _module: torch.nn.Module) -> None:
for child_name, child in _module.named_children():

child_full_name = f"{_name}.{child_name}"

if isinstance(child, torch.nn.Linear):

if child.weight.dtype not in [torch.float16, torch.bfloat16]:
Expand Down Expand Up @@ -78,11 +83,11 @@ def _replace_linear(_module: torch.nn.Module) -> None:
raise NotImplementedError

# the replacement will remove the accelerate hooks
maybe_hook = get_accelerate_hook(name, child, allow=True)
maybe_hook = get_accelerate_hook(child_name, child, allow=True)

setattr(
_module,
name,
child_name,
FluteLinear(
in_features=child.in_features,
out_features=child.out_features,
Expand All @@ -92,22 +97,22 @@ def _replace_linear(_module: torch.nn.Module) -> None:
device=child.weight.device,
dtype=child.weight.dtype))

template_id = flute.TEMPLATE_TUNED_WITHOUT_M_CONFIGS[(
flute.NUM_SMS,
num_bits,
group_size,
child.out_features, # N
child.in_features)] # K
new_child = getattr(_module, name)
if custom_scales_dict is not None:
custom_scales = custom_scales_dict[child_full_name]
else:
custom_scales = None

_, _Q, scales, qmap = flute.nf_utils.nf_quantize(
W=child.weight.to(device="cuda"),
num_bits=num_bits,
group_size=group_size)
group_size=group_size,
custom_scales=custom_scales)
Q = flute.utils.pack(
_Q.T.contiguous(),
num_bits=num_bits,
template_ids=[template_id])
group_size=group_size)

new_child = getattr(_module, child_name)
scales = scales.view(new_child.scales.shape)
scales = scales.to(dtype=new_child.scales.dtype)
qmap = qmap.to(dtype=new_child.tables.dtype)
Expand All @@ -126,9 +131,9 @@ def _replace_linear(_module: torch.nn.Module) -> None:
hook=maybe_hook)

else:
_replace_linear(child)
_replace_linear(child_full_name, child)

_replace_linear(module)
_replace_linear(name, module)


class FluteLinear(torch.nn.Module):
Expand Down Expand Up @@ -221,6 +226,7 @@ def quantize_hf_model(

if isinstance(model, (LlamaForCausalLM, Gemma2ForCausalLM)):
prepare_model_flute(
name="model.model.layers",
module=model.model.layers,
num_bits=num_bits,
group_size=group_size,
Expand All @@ -233,6 +239,7 @@ def quantize_hf_model(

# save the config
config = {
"version": flute.__version__,
"num_sms": flute.NUM_SMS,
"num_bits": num_bits,
"group_size": group_size,
Expand Down
8 changes: 1 addition & 7 deletions flute/integrations/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,16 +311,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
Q_unpacked = torch.cat(Qs_unpacked, dim=0)

# re-pack the tensors
template_id = flute.TEMPLATE_TUNED_WITHOUT_M_CONFIGS[(
flute.NUM_SMS,
layer.num_bits,
layer.group_size,
Q_unpacked.shape[0], # N
Q_unpacked.shape[1])] # K
Q_repacked = flute.utils.pack(
Q_unpacked.T.contiguous().to(device="cpu"),
num_bits=layer.num_bits,
template_ids=[template_id]).to(device=layer.weight.device)
group_size=layer.group_size).to(device=layer.weight.device)

if not all([
Q_repacked.shape == layer.weight.shape,
Expand Down
4 changes: 3 additions & 1 deletion flute/nf_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from typing import Tuple
from typing import Tuple, Optional

DTYPE = torch.float32

Expand Down Expand Up @@ -53,10 +53,12 @@ def nf_quantize(
W: torch.Tensor,
num_bits: int,
group_size: int,
custom_scales: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
values, pivots = get_values_pivots(num_bits, False)
W_dequantized, W_quantized, absmax = manual_nf4(
W,
absmax=custom_scales,
bits=num_bits,
blocksize=group_size,
values=values,
Expand Down
15 changes: 9 additions & 6 deletions flute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,15 @@ def pack(
raise ValueError("Either `group_size` or `template_ids` must be provided")

K, N = W.shape
template_id = TEMPLATE_TUNED_WITHOUT_M_CONFIGS[(
NUM_SMS,
num_bits,
group_size,
N, K)]
template_ids = [template_id]
template_ids = []
for dtype in [torch.float16, torch.bfloat16]:
template_id = TEMPLATE_TUNED_WITHOUT_M_CONFIGS[(
NUM_SMS,
num_bits,
group_size,
N, K,
str(dtype))]
template_ids.append(template_id)

# the packing is specialized to `tile_P`, which could
# be different for different templates. We check that
Expand Down
20 changes: 20 additions & 0 deletions tests/kernel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import click
import torch
import argparse
Expand Down Expand Up @@ -161,9 +162,28 @@ def run_tests(num: int) -> None:
identity=False)


def test_configs() -> None:
for k0 in flute.TEMPLATE_TUNED_WITHOUT_M_CONFIGS.keys():
if k0[-1] == "torch.bfloat16":
continue
if k0[-1] != "torch.float16":
raise ValueError
k1 = list(copy.deepcopy(k0))
k1[-1] = "torch.bfloat16"
k1 = tuple(k1)

tid0 = flute.TEMPLATE_TUNED_WITHOUT_M_CONFIGS[k0]
tid1 = flute.TEMPLATE_TUNED_WITHOUT_M_CONFIGS[k1]
c0 = flute.utils.get_template_config(num_bits=k0[1], template_id=tid0)
c1 = flute.utils.get_template_config(num_bits=k1[1], template_id=tid1)
if c0["tileP"] != c1["tileP"]:
raise ValueError


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num", type=int, default=10)
args = parser.parse_args()

test_configs()
run_tests(num=args.num)

0 comments on commit a33f6fc

Please sign in to comment.