From c40ccce40bc01bfc682eca9c1cfa931e715b2706 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 15 May 2024 18:00:34 +0000 Subject: [PATCH 1/4] Add a new kernel for fusing the dequantization in fused-moe gemm --- .../layers/fused_moe/fused_moe.py | 253 +++++++++++++++++- .../layers/quantization/deepspeedfp.py | 29 +- vllm/model_executor/models/arctic.py | 20 +- 3 files changed, 282 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bb7938b3715be..08c900a30849f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -163,6 +163,150 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +@triton.jit +def fused_moe_kernel_fp8( + a_ptr, + b_ptr, + c_ptr, + scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + N, + K, + EM, + num_valid_tokens, + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + bf16_exponent_bits: tl.constexpr, + bf16_mantisa_bits: tl.constexpr, + fp8_exponent_bits: tl.constexpr, + fp8_mantisa_bits: tl.constexpr, + _sign_mask: tl.constexpr, + _exponent_mask: tl.constexpr, + _mantisa_mask: tl.constexpr, + quantization_group_size: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs_offset = off_experts * (stride_be // quantization_group_size) + offs_bn[None, :] * (stride_bn // quantization_group_size) + b_ptrs = b_ptr + off_experts * stride_be + offs_bn[None, :] * stride_bn + offs_k[:, None] * stride_bk + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + _exp_bias = (1 << (bf16_exponent_bits - 1)) - (1 << (fp8_exponent_bits - 1)) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + scale = tl.load(scale_ptr + b_ptrs_offset + ((k * BLOCK_SIZE_K) // quantization_group_size)) + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Dequantize weight (fp8 -> bf16) + dst_exponent = ((b & _exponent_mask) >> fp8_mantisa_bits).to(tl.uint16) + sign = ((b & _sign_mask) >> (fp8_mantisa_bits + fp8_exponent_bits)).to(tl.uint16) + dst_mantisa = (b & _mantisa_mask).to(tl.uint16) + b = ((sign << (bf16_exponent_bits + bf16_mantisa_bits)) | dst_mantisa << (bf16_mantisa_bits - fp8_mantisa_bits)) + dst_exponent = (dst_exponent + _exp_bias) + b = (b | (dst_exponent << bf16_mantisa_bits)).to(tl.uint16) + b = (b.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) + + accumulator += tl.dot(a, b) + + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -272,6 +416,66 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ) +def invoke_fused_moe_kernel_fp8(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, q_scales: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, top_k: int, + quantization_group_size: int, + config: Dict[str, Any]) -> None: + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) + + # currently, we are only supporting fp8(E4M3) -> bf16 + fp8_mantisa_bits = 3 + fp8_exponent_bits = 4 + bf16_exponent_bits = 8 + bf16_mantisa_bits = 7 + + # auxilary masks used for dequantizing the fp8 data + _sign_mask = 1 << (_mantisa_bits + _exponent_bits) + _mantisa_mask = (1 << _mantisa_bits) - 1 + _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits + + fused_moe_kernel_fp8[grid]( + A, + B, + C, + q_scales, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2], + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, + fp8_mantisa_bits=fp8_mantisa_bits, + bf16_exponent_bits=bf16_exponent_bits, + bf16_mantisa_bits=bf16_mantisa_bits, + fp8_exponent_bits=fp8_exponent_bits, + fp8_mantisa_bits=fp8_mantisa_bits, + _sign_mask=_sign_mask, + _mantisa_mask=_mantisa_mask, + _exponent_mask=_exponent_mask, + quantization_group_size=quantization_group_size, + **config, + ) + def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: device_name = torch.cuda.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" @@ -363,7 +567,9 @@ def fused_experts(hidden_states: torch.Tensor, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None): + a2_scale: Optional[torch.Tensor] = None, + quantization_group_size: Optional[int] = 256, + quantization_group_size2: Optional[int] = 256): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" @@ -404,7 +610,14 @@ def fused_experts(hidden_states: torch.Tensor, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1 } - + if w1.dtype == torch.uint8: + block_size = config['BLOCK_SIZE_K'] + config['BLOCK_SIZE_K'] = min(block_size, quantization_group_size) + if w2.dtype == torch.uint8: + config_2 = config + block_size = config['BLOCK_SIZE_K'] + config_2['BLOCK_SIZE_K'] = min(block_size, quantization_group_size2) + intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), device=hidden_states.device, dtype=hidden_states.dtype) @@ -420,7 +633,23 @@ def fused_experts(hidden_states: torch.Tensor, compute_type = (tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16) - invoke_fused_moe_kernel(hidden_states, + if w1.dtype == torch.uint8: + invoke_fused_moe_kernel_fp8( + hidden_states, + w1, + intermediate_cache1, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + quantization_group_size, + config) + else: + invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, a1_scale, @@ -438,7 +667,23 @@ def fused_experts(hidden_states: torch.Tensor, ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel(intermediate_cache2, + if w2.dtype == torch.uint8: + invoke_fused_moe_kernel_fp8( + intermediate_cache2, + w2, + intermediate_cache3, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + quantization_group_size2, + config1) + else: + invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, a2_scale, diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 31cdffbcf0ab9..f58df5c28a68d 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F - +import math from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -152,32 +152,39 @@ def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, raise ImportError("Please install deepspeed>=0.14.2 via " "`pip install deepspeed>=0.14.2` to use " "deepspeedfp quantizer.") from err - data = torch.empty(( - orig_shape.numel() // quant_config.group_size, - quant_config.group_size * quant_config.weight_bits // 8 + 4, - ), - dtype=torch.int8) + data = torch.empty(orig_shape, dtype=torch.uint8) self = torch.Tensor._make_subclass(cls, data, data.requires_grad) self.orig_shape = orig_shape self.quant_config = quant_config - self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size) + self.quant_config.group_size = max( + [ + 2**i for i in range(4, int(math.log2(quant_config.group_size)+1)) \ + if orig_shape[-1] % (2**i) == 0 + ] + ) + self.fp_quantizer = FP_Quantize(group_size=self.quant_config.group_size) self.fp_quantizer.orig_shape = orig_shape self.fp_quantizer.orig_dtype = params_dtype + return self def ds_quantize_(self, tensor: torch.Tensor): - assert tensor.device.type == "cuda" and tensor.dtype != torch.int8 + assert tensor.device.type == "cuda" and tensor.dtype != torch.uint8 return self.data.copy_( self.fp_quantizer.quantize( tensor.data, q_bits=self.quant_config.weight_bits, - )) + return_meta_tensor=True + )[0]) + + def quantization_scales(self): + return self.fp_quantizer.get_scales() def ds_dequantize(self, fp_out=None) -> torch.Tensor: """ Return a tensor containing the dequantized weights of this parameter. """ - assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + assert self.data.device.type == "cuda" and self.data.dtype == torch.uint8 return self.fp_quantizer.dequantize( self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits) @@ -186,7 +193,7 @@ def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: Return a tensor where only the weights at `indices` are dequantized (to save HBM -> SRAM bandwidth). """ - assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + assert self.data.device.type == "cuda" and self.data.dtype == torch.uint8 return self.fp_quantizer.selective_dequantize( self.data, indices, diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index cb99939cbb17a..c4ec7dfa9882a 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -93,7 +93,7 @@ def __init__(self, self.layer_id = layer_id self.top_k = config.num_experts_per_tok self.intermediate_size = config.intermediate_size // self.tp_size - + self.enable_dequantization_fusion = config.enable_dequantization_fusion self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 self.is_quant = isinstance(quant_config, DeepSpeedFPConfig) self.reduce_results = reduce_results @@ -173,7 +173,7 @@ def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: self.top_k, renormalize=do_normalize) # topk_ids: (num_tokens, k) - if self.is_quant: + if self.is_quant and (not self.enable_dequantization_fusion): if 2 * num_tokens <= self.num_experts: # If much fewer tokens than experts, use selective dequantize. ws_dequantized = self.ws.ds_selective_dequantize( @@ -192,11 +192,15 @@ def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = fused_experts( hidden_states, - ws_dequantized if self.is_quant else self.ws, - w2s_dequantized if self.is_quant else self.w2s, + ws_dequantized if (self.is_quant and not self.enable_dequantization_fusion) else self.ws, + w2s_dequantized if (self.is_quant and not self.enable_dequantization_fusion) else self.w2s, topk_weights, topk_ids, - inplace=True) + inplace=True, + w1_scales=self.ws.quantization_scales(), + w2_scales=self.w2s.quantization_scales(), + quantization_group_size=self.ws.fp_quantizer.group_size, + quantization_group_size2=self.w2s.fp_quantizer.group_size) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -405,6 +409,12 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, **kwargs) -> None: super().__init__() + + enable_dequantization_fusion = True + if 'enable_dequantization_fusion' in kwargs: + enable_dequantization_fusion = kwargs['enable_dequantization_fusion'] + config.enable_dequantization_fusion = enable_dequantization_fusion + self.config = config self.model = ArcticModel(config, cache_config, quant_config) self.vocab_size = config.vocab_size From d63f4e1094ab3d9f0e4d1f7bd6e9cfceaadea16b Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 15 May 2024 22:08:12 +0000 Subject: [PATCH 2/4] fixes & reduce memory pressure of loading expert params --- .../layers/fused_moe/fused_moe.py | 1 - .../layers/quantization/deepspeedfp.py | 3 +- vllm/model_executor/models/arctic.py | 53 +++++++++++++++---- 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 08c900a30849f..62b91ecc87072 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -464,7 +464,6 @@ def invoke_fused_moe_kernel_fp8(A: torch.Tensor, B: torch.Tensor, C: torch.Tenso MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, - fp8_mantisa_bits=fp8_mantisa_bits, bf16_exponent_bits=bf16_exponent_bits, bf16_mantisa_bits=bf16_mantisa_bits, fp8_exponent_bits=fp8_exponent_bits, diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index f58df5c28a68d..394dcfd2935dd 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -165,7 +165,8 @@ def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, self.fp_quantizer = FP_Quantize(group_size=self.quant_config.group_size) self.fp_quantizer.orig_shape = orig_shape self.fp_quantizer.orig_dtype = params_dtype - + self.fp_quantizer.scales = torch.empty(orig_shape.numel() // self.quant_config.group_size, 4, + dtype=torch.uint8, device=self.data.device) return self def ds_quantize_(self, tensor: torch.Tensor): diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index c4ec7dfa9882a..54bbf4c03bb7c 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -31,7 +31,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.arctic import ArcticConfig - +import gc logger = init_logger(__name__) @@ -145,22 +145,53 @@ def __init__(self, set_weight_attrs(self.w2s, { "weight_loader": self.weight_loader, }) + self.load_completion_w1_w3 = 0 + self.load_completion_w2 = 0 def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, expert_id: int): tp_rank = get_tensor_model_parallel_rank() - param_data = param.ds_dequantize() if self.is_quant else param.data + if not hasattr(param, 'shadow_data'): + param.shadow_data = [None] * self.num_experts shard_size = self.intermediate_size + if param.shadow_data[expert_id] is None: + param.shadow_data[expert_id] = torch.zeros( + (shard_size * 2 if (weight_name.endswith("w1.weight") or weight_name.endswith("w3.weight")) else loaded_weight.shape[0], + loaded_weight.shape[1] if (weight_name.endswith("w1.weight") or weight_name.endswith("w3.weight")) else shard_size), + dtype=loaded_weight.dtype, + device=loaded_weight.device) shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) if weight_name.endswith("w1.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + param.shadow_data[expert_id][0:shard_size, :] = loaded_weight[shard, :] + self.load_completion_w1_w3 += 1 if weight_name.endswith("w3.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] + param.shadow_data[expert_id][shard_size:, :] = loaded_weight[shard, :] + self.load_completion_w1_w3 += 1 if weight_name.endswith("w2.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] - if self.is_quant: - param.ds_quantize_(param_data) + param.shadow_data[expert_id][:, :shard_size] = loaded_weight[:, shard] + self.load_completion_w2 += 1 + if (self.load_completion_w2 == self.num_experts or self.load_completion_w1_w3 == self.num_experts * 2): + new_data = torch.stack(param.shadow_data).to(param.data.device) + + len_sd = len(param.shadow_data) + for _ in range(len_sd): + sd = param.shadow_data.pop() + del sd + del param.shadow_data + + if self.load_completion_w2 == self.num_experts: + self.load_completion_w2 = 0 + if self.load_completion_w1_w3 == self.num_experts * 2: + self.load_completion_w1_w3 = 0 + + if self.is_quant: + param.ds_quantize_(new_data) + del new_data + new_data = None + gc.collect() + torch.cuda.empty_cache() + else: + param_data.copy_(new_data) def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape @@ -197,8 +228,8 @@ def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: topk_weights, topk_ids, inplace=True, - w1_scales=self.ws.quantization_scales(), - w2_scales=self.w2s.quantization_scales(), + w1_scale=self.ws.quantization_scales(), + w2_scale=self.w2s.quantization_scales(), quantization_group_size=self.ws.fp_quantizer.group_size, quantization_group_size2=self.w2s.fp_quantizer.group_size) if self.reduce_results and self.tp_size > 1: @@ -508,6 +539,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) + # print(loaded_weight, param.data, param.data.shape) + # exit() break else: for param_name, weight_name, shard_id in mlp_params_mapping: From 78c38e5b7d0e629f94f5be1c9ba37888bbd52cfd Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 15 May 2024 22:28:02 +0000 Subject: [PATCH 3/4] minor fixes --- vllm/model_executor/layers/fused_moe/fused_moe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 62b91ecc87072..c1d83e92cd301 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -437,9 +437,9 @@ def invoke_fused_moe_kernel_fp8(A: torch.Tensor, B: torch.Tensor, C: torch.Tenso bf16_mantisa_bits = 7 # auxilary masks used for dequantizing the fp8 data - _sign_mask = 1 << (_mantisa_bits + _exponent_bits) - _mantisa_mask = (1 << _mantisa_bits) - 1 - _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits + _sign_mask = 1 << (fp8_mantisa_bits + fp8_exponent_bits) + _mantisa_mask = (1 << fp8_mantisa_bits) - 1 + _exponent_mask = ((1 << fp8_exponent_bits) - 1) << fp8_mantisa_bits fused_moe_kernel_fp8[grid]( A, @@ -680,7 +680,7 @@ def fused_experts(hidden_states: torch.Tensor, True, 1, quantization_group_size2, - config1) + config_2) else: invoke_fused_moe_kernel(intermediate_cache2, w2, From afa735c78ce1d0aa02a55354ba9229752f496edb Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Thu, 16 May 2024 01:45:41 +0000 Subject: [PATCH 4/4] fix the group size --- .../layers/quantization/deepspeedfp.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 394dcfd2935dd..8cfec74d574ab 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -8,7 +8,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs - +import gc class DeepSpeedFPConfig(QuantizationConfig): """Config for DeepSpeed FP quantizer. It supports fp6 and fp8. @@ -152,31 +152,42 @@ def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, raise ImportError("Please install deepspeed>=0.14.2 via " "`pip install deepspeed>=0.14.2` to use " "deepspeedfp quantizer.") from err - data = torch.empty(orig_shape, dtype=torch.uint8) + if len(orig_shape) == 3: + data = torch.empty(0, dtype=torch.uint8) + else: + data = torch.empty(orig_shape, dtype=torch.uint8) self = torch.Tensor._make_subclass(cls, data, data.requires_grad) self.orig_shape = orig_shape self.quant_config = quant_config - self.quant_config.group_size = max( + g_size = max( [ 2**i for i in range(4, int(math.log2(quant_config.group_size)+1)) \ if orig_shape[-1] % (2**i) == 0 ] ) - self.fp_quantizer = FP_Quantize(group_size=self.quant_config.group_size) + self.fp_quantizer = FP_Quantize(group_size=g_size) self.fp_quantizer.orig_shape = orig_shape self.fp_quantizer.orig_dtype = params_dtype - self.fp_quantizer.scales = torch.empty(orig_shape.numel() // self.quant_config.group_size, 4, + self.fp_quantizer.scales = torch.empty(orig_shape.numel() // g_size, 4, dtype=torch.uint8, device=self.data.device) return self def ds_quantize_(self, tensor: torch.Tensor): assert tensor.device.type == "cuda" and tensor.dtype != torch.uint8 - return self.data.copy_( - self.fp_quantizer.quantize( + prev_data = self.data + q_data, _ = self.fp_quantizer.quantize( tensor.data, q_bits=self.quant_config.weight_bits, return_meta_tensor=True - )[0]) + ) + del tensor + del prev_data + tensor = None + prev_data = None + gc.collect() + torch.cuda.empty_cache() + self.data = q_data + return self.data def quantization_scales(self): return self.fp_quantizer.get_scales()