Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new kernel for fusing the dequantization in fused-moe gemm #4841

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 248 additions & 4 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,150 @@ def fused_moe_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)


@triton.jit
def fused_moe_kernel_fp8(
a_ptr,
b_ptr,
c_ptr,
scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
N,
K,
EM,
num_valid_tokens,
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
bf16_exponent_bits: tl.constexpr,
bf16_mantisa_bits: tl.constexpr,
fp8_exponent_bits: tl.constexpr,
fp8_mantisa_bits: tl.constexpr,
_sign_mask: tl.constexpr,
_exponent_mask: tl.constexpr,
_mantisa_mask: tl.constexpr,
quantization_group_size: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.

Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens

offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)

offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs_offset = off_experts * (stride_be // quantization_group_size) + offs_bn[None, :] * (stride_bn // quantization_group_size)
b_ptrs = b_ptr + off_experts * stride_be + offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

_exp_bias = (1 << (bf16_exponent_bits - 1)) - (1 << (fp8_exponent_bits - 1))
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
scale = tl.load(scale_ptr + b_ptrs_offset + ((k * BLOCK_SIZE_K) // quantization_group_size))
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

# Dequantize weight (fp8 -> bf16)
dst_exponent = ((b & _exponent_mask) >> fp8_mantisa_bits).to(tl.uint16)
sign = ((b & _sign_mask) >> (fp8_mantisa_bits + fp8_exponent_bits)).to(tl.uint16)
dst_mantisa = (b & _mantisa_mask).to(tl.uint16)
b = ((sign << (bf16_exponent_bits + bf16_mantisa_bits)) | dst_mantisa << (bf16_mantisa_bits - fp8_mantisa_bits))
dst_exponent = (dst_exponent + _exp_bias)
b = (b | (dst_exponent << bf16_mantisa_bits)).to(tl.uint16)
b = (b.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16)

accumulator += tl.dot(a, b)


if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]

accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)

def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -272,6 +416,65 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
)


def invoke_fused_moe_kernel_fp8(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, q_scales: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
quantization_group_size: int,
config: Dict[str, Any]) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )

# currently, we are only supporting fp8(E4M3) -> bf16
fp8_mantisa_bits = 3
fp8_exponent_bits = 4
bf16_exponent_bits = 8
bf16_mantisa_bits = 7

# auxilary masks used for dequantizing the fp8 data
_sign_mask = 1 << (fp8_mantisa_bits + fp8_exponent_bits)
_mantisa_mask = (1 << fp8_mantisa_bits) - 1
_exponent_mask = ((1 << fp8_exponent_bits) - 1) << fp8_mantisa_bits

fused_moe_kernel_fp8[grid](
A,
B,
C,
q_scales,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,
bf16_exponent_bits=bf16_exponent_bits,
bf16_mantisa_bits=bf16_mantisa_bits,
fp8_exponent_bits=fp8_exponent_bits,
fp8_mantisa_bits=fp8_mantisa_bits,
_sign_mask=_sign_mask,
_mantisa_mask=_mantisa_mask,
_exponent_mask=_exponent_mask,
quantization_group_size=quantization_group_size,
**config,
)

def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
device_name = torch.cuda.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
Expand Down Expand Up @@ -363,7 +566,9 @@ def fused_experts(hidden_states: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None):
a2_scale: Optional[torch.Tensor] = None,
quantization_group_size: Optional[int] = 256,
quantization_group_size2: Optional[int] = 256):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
Expand Down Expand Up @@ -404,7 +609,14 @@ def fused_experts(hidden_states: torch.Tensor,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
}

if w1.dtype == torch.uint8:
block_size = config['BLOCK_SIZE_K']
config['BLOCK_SIZE_K'] = min(block_size, quantization_group_size)
if w2.dtype == torch.uint8:
config_2 = config
block_size = config['BLOCK_SIZE_K']
config_2['BLOCK_SIZE_K'] = min(block_size, quantization_group_size2)

intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
Expand All @@ -420,7 +632,23 @@ def fused_experts(hidden_states: torch.Tensor,
compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16)

invoke_fused_moe_kernel(hidden_states,
if w1.dtype == torch.uint8:
invoke_fused_moe_kernel_fp8(
hidden_states,
w1,
intermediate_cache1,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
quantization_group_size,
config)
else:
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1_scale,
Expand All @@ -438,7 +666,23 @@ def fused_experts(hidden_states: torch.Tensor,

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))

invoke_fused_moe_kernel(intermediate_cache2,
if w2.dtype == torch.uint8:
invoke_fused_moe_kernel_fp8(
intermediate_cache2,
w2,
intermediate_cache3,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
quantization_group_size2,
config_2)
else:
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
Expand Down
47 changes: 33 additions & 14 deletions vllm/model_executor/layers/quantization/deepspeedfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

import math
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs

import gc

class DeepSpeedFPConfig(QuantizationConfig):
"""Config for DeepSpeed FP quantizer. It supports fp6 and fp8.
Expand Down Expand Up @@ -152,32 +152,51 @@ def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
raise ImportError("Please install deepspeed>=0.14.2 via "
"`pip install deepspeed>=0.14.2` to use "
"deepspeedfp quantizer.") from err
data = torch.empty((
orig_shape.numel() // quant_config.group_size,
quant_config.group_size * quant_config.weight_bits // 8 + 4,
),
dtype=torch.int8)
if len(orig_shape) == 3:
data = torch.empty(0, dtype=torch.uint8)
else:
data = torch.empty(orig_shape, dtype=torch.uint8)
self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
self.orig_shape = orig_shape
self.quant_config = quant_config
self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size)
g_size = max(
[
2**i for i in range(4, int(math.log2(quant_config.group_size)+1)) \
if orig_shape[-1] % (2**i) == 0
]
)
self.fp_quantizer = FP_Quantize(group_size=g_size)
self.fp_quantizer.orig_shape = orig_shape
self.fp_quantizer.orig_dtype = params_dtype
self.fp_quantizer.scales = torch.empty(orig_shape.numel() // g_size, 4,
dtype=torch.uint8, device=self.data.device)
return self

def ds_quantize_(self, tensor: torch.Tensor):
assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
return self.data.copy_(
self.fp_quantizer.quantize(
assert tensor.device.type == "cuda" and tensor.dtype != torch.uint8
prev_data = self.data
q_data, _ = self.fp_quantizer.quantize(
tensor.data,
q_bits=self.quant_config.weight_bits,
))
return_meta_tensor=True
)
del tensor
del prev_data
tensor = None
prev_data = None
gc.collect()
torch.cuda.empty_cache()
self.data = q_data
return self.data

def quantization_scales(self):
return self.fp_quantizer.get_scales()

def ds_dequantize(self, fp_out=None) -> torch.Tensor:
"""
Return a tensor containing the dequantized weights of this parameter.
"""
assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
assert self.data.device.type == "cuda" and self.data.dtype == torch.uint8
return self.fp_quantizer.dequantize(
self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits)

Expand All @@ -186,7 +205,7 @@ def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor:
Return a tensor where only the weights at `indices` are dequantized
(to save HBM -> SRAM bandwidth).
"""
assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
assert self.data.device.type == "cuda" and self.data.dtype == torch.uint8
return self.fp_quantizer.selective_dequantize(
self.data,
indices,
Expand Down
Loading
Loading