From c3f6744e00056b8ae0505d73586c8583f816b8d6 Mon Sep 17 00:00:00 2001 From: "lcy.seso" Date: Mon, 30 Dec 2024 08:56:40 +0000 Subject: [PATCH 1/2] Fix the bug where the Torch library is not correctly linked. --- CMakeLists.txt | 2 +- csrc/CMakeLists.txt | 6 ++-- csrc/common.h | 4 +-- csrc/dequant_impl_packed.cu | 6 ++-- csrc/ops.cc | 67 ++++++++++++++++++++++++------------- setup.py | 11 ++++-- vptq/__init__.py | 5 ++- vptq/layers/__init__.py | 5 ++- vptq/layers/model_base.py | 13 ------- vptq/ops/quant_gemm.py | 36 +++++++++++++++----- 10 files changed, 97 insertions(+), 58 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 119b656..4e1ca9b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- cmake_minimum_required(VERSION 3.18 FATAL_ERROR) -project(vptq_cuda_ops LANGUAGES C CXX CUDA) +project(vptq LANGUAGES C CXX CUDA) # Prohibit in-source builds if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_BINARY_DIR}) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 16f9ca9..b58d128 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -2,8 +2,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the # MIT License. # -------------------------------------------------------------------------- -set(TARGET "cuda_ops") -file(GLOB_RECURSE SOURCES "*.cu") +set(TARGET "vptq") +file(GLOB_RECURSE SOURCES "*.cu" "*.cc") +message(STATUS "Building ${TARGET} with ${SOURCES}") cuda_add_library(${TARGET} SHARED ${SOURCES}) @@ -29,3 +30,4 @@ target_compile_options( --use_fast_math --generate-line-info>) target_compile_features(${TARGET} PUBLIC cxx_std_17 cuda_std_17) +target_link_libraries(${TARGET} "${TORCH_LIBRARIES}") diff --git a/csrc/common.h b/csrc/common.h index a2c58f1..cd8b9ba 100644 --- a/csrc/common.h +++ b/csrc/common.h @@ -1,6 +1,6 @@ - // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once #include #include @@ -39,4 +39,4 @@ inline void gpuAssert(cudaError_t code, const char* file, int line) { line); TORCH_CHECK(false, cudaGetErrorString(code)); } -} \ No newline at end of file +} diff --git a/csrc/dequant_impl_packed.cu b/csrc/dequant_impl_packed.cu index 5263cff..0194959 100644 --- a/csrc/dequant_impl_packed.cu +++ b/csrc/dequant_impl_packed.cu @@ -325,7 +325,7 @@ __global__ void DequantizeWithOutliers_PackIndice( // @param weight_bias // @return torch::Tensor torch::Tensor launch_deqantize_outliers_cuda_packkernel( - const int* outf_x_inf, const torch::Tensor& q_indice, + const int64_t* outf_x_inf, const torch::Tensor& q_indice, const torch::Tensor& centroids, const c10::optional& q_indice_residual, const c10::optional& residual_centroids, @@ -534,7 +534,7 @@ torch::Tensor launch_deqantize_outliers_cuda_packkernel( // @param bias // @return torch::Tensor torch::Tensor launch_gemv_outliers_cuda_packkernel( - const int out_features, const torch::Tensor& input, + const int64_t out_features, const torch::Tensor& input, const torch::Tensor& q_indice, const torch::Tensor& centroids, const c10::optional& q_indice_residual, const c10::optional& residual_centroids, @@ -544,7 +544,7 @@ torch::Tensor launch_gemv_outliers_cuda_packkernel( const torch::Tensor& weight_bias, const c10::optional& bias) { OptionalCUDAGuard cudaguard(input.device().index()); - const int base_groupsize = centroids.size(-1); + const int64_t base_groupsize = centroids.size(-1); int index_bits = log2(centroids.size(1)); int res_index_bits = residual_centroids.has_value() ? log2(residual_centroids.value().size(1)) diff --git a/csrc/ops.cc b/csrc/ops.cc index e010308..5202f48 100644 --- a/csrc/ops.cc +++ b/csrc/ops.cc @@ -4,7 +4,9 @@ #include "common.h" #include #include + #include +#include #define CHECK_CUDA(x) \ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") @@ -15,7 +17,7 @@ CHECK_CONTIGUOUS(x) torch::Tensor launch_deqantize_outliers_cuda_packkernel( - const int* outf_x_inf, const torch::Tensor& q_indice, + const int64_t* outf_x_inf, const torch::Tensor& q_indice, const torch::Tensor& centroids, const c10::optional& q_indice_residual, const c10::optional& residual_centroids, @@ -25,7 +27,7 @@ torch::Tensor launch_deqantize_outliers_cuda_packkernel( const torch::Tensor& weight_bias); torch::Tensor launch_gemv_outliers_cuda_packkernel( - const int out_features, const torch::Tensor& input, + const int64_t out_features, const torch::Tensor& input, const torch::Tensor& q_indice, const torch::Tensor& centroids, const c10::optional& q_indice_residual, const c10::optional& residual_centroids, @@ -42,8 +44,8 @@ torch::Tensor dequant(const torch::Tensor& q_indice, const c10::optional& outliers_centroids, const c10::optional& invperm, const torch::Tensor& weight_scale, - const torch::Tensor& weight_bias, int groupsize, - int in_features, int out_features) { + const torch::Tensor& weight_bias, int64_t groupsize, + int64_t in_features, int64_t out_features) { auto dev_index = q_indice.device().index(); CHECK_INPUT(q_indice); @@ -85,7 +87,7 @@ torch::Tensor dequant(const torch::Tensor& q_indice, at::cuda::OptionalCUDAGuard guard(q_indice.device()); torch::Tensor output; - const int out_f_x_in_f[2] = {out_features, in_features}; + const int64_t out_f_x_in_f[2] = {out_features, in_features}; output = launch_deqantize_outliers_cuda_packkernel( out_f_x_in_f, q_indice, centroids, q_indice_residual, residual_centroids, @@ -106,8 +108,9 @@ torch::Tensor wqA16Gemm(const torch::Tensor& input, const c10::optional& invperm, const torch::Tensor& weight_scale, const torch::Tensor& weight_bias, - const c10::optional& bias, int groupsize, - int in_features, int out_features) { + const c10::optional& bias, + int64_t groupsize, int64_t in_features, + int64_t out_features) { CHECK_INPUT(q_indice); CHECK_INPUT(input); if (q_indice_residual.has_value()) { @@ -155,22 +158,40 @@ torch::Tensor wqA16Gemm(const torch::Tensor& input, return output; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("dequant", &dequant, - R"DOC(Dequantize matrix weights to fp16. -function type: -const torch::Tensor& qweight, -const torch::Tensor& scales, -const torch::Tensor& qzeros, -Tensor g_idx, int groupsize, int bits, int in_features -)DOC"); +TORCH_LIBRARY_IMPL(vptq, CUDA, m) { + m.impl("dequant", dequant); + m.impl("gemm", wqA16Gemm); +} - m.def("gemm", &wqA16Gemm, - R"DOC(Compute the gemm output, usually gemv. -function type: -const torch::Tensor& qweight, -const torch::Tensor& scales, -const torch::Tensor& qzeros, -tensor g_idx, int groupsize, int bits, int in_features +TORCH_LIBRARY(vptq, m) { + m.def( + R"DOC(dequant(Tensor q_indice, + Tensor centroids, + Tensor? q_indice_residual, + Tensor? residual_centroids, + Tensor? q_indice_outliers, + Tensor? outliers_centroids, + Tensor? invperm, + Tensor weight_scale, + Tensor weight_bias, + int groupsize, + int in_features, + int out_features) -> Tensor +)DOC"); + m.def( + R"DOC(gemm(Tensor input, + Tensor q_indice, + Tensor centroids, + Tensor? q_indice_residual, + Tensor? residual_centroids, + Tensor? q_indice_outliers, + Tensor? outliers_centroids, + Tensor? invperm, + Tensor weight_scale, + Tensor weight_bias, + Tensor? bias, + int groupsize, + int in_features, + int out_features) -> Tensor )DOC"); } diff --git a/setup.py b/setup.py index 1ea628b..ed2a380 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ def get_requirements(): class CMakeExtension(Extension): """ specify the root folder of the CMake projects""" - def __init__(self, name="cuda_ops", cmake_lists_dir=".", **kwargs): + def __init__(self, name, cmake_lists_dir=".", **kwargs): Extension.__init__(self, name, sources=[], **kwargs) self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) @@ -40,10 +40,15 @@ def __init__(self, name="cuda_ops", cmake_lists_dir=".", **kwargs): class CMakeBuildExt(build_ext): """launches the CMake build.""" + def get_ext_filename(self, name): + return f"lib{name}.so" + def copy_extensions_to_source(self) -> None: build_py = self.get_finalized_command("build_py") for ext in self.extensions: - source_path = os.path.join(self.build_lib, "lib" + ext.name + ".so") + source_path = os.path.join( + self.build_lib, self.get_ext_filename(ext.name) + ) inplace_file, _ = self._get_inplace_equivalent(build_py, ext) target_path = os.path.join(build_py.build_lib, "vptq", inplace_file) @@ -169,7 +174,7 @@ def run(self): version=get_version(), description=description, author="Wang Yang, Wen JiCheng", - ext_modules=[CMakeExtension()], + ext_modules=[CMakeExtension("vptq")], cmdclass={ "build_ext": CMakeBuildExt, "clean": Clean, diff --git a/vptq/__init__.py b/vptq/__init__.py index b378176..471c9f6 100644 --- a/vptq/__init__.py +++ b/vptq/__init__.py @@ -9,4 +9,7 @@ __version__ = importlib.metadata.version("vptq") -__all__ = ["AutoModelForCausalLM", "VQuantLinear"] +__all__ = [ + "AutoModelForCausalLM", + "VQuantLinear", +] diff --git a/vptq/layers/__init__.py b/vptq/layers/__init__.py index dc4a281..04bfa53 100644 --- a/vptq/layers/__init__.py +++ b/vptq/layers/__init__.py @@ -5,4 +5,7 @@ from vptq.layers.model_base import AutoModelForCausalLM, VQuantLinear -__all__ = ["AutoModelForCausalLM", "VQuantLinear"] +__all__ = [ + "AutoModelForCausalLM", + "VQuantLinear", +] diff --git a/vptq/layers/model_base.py b/vptq/layers/model_base.py index f7a2cd7..1e97040 100644 --- a/vptq/layers/model_base.py +++ b/vptq/layers/model_base.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import glob -import importlib.util from pathlib import Path import accelerate @@ -203,18 +202,6 @@ def from_pretrained( preload_module_classes=["VQuantLinear"] ) - # check cuda kernel exist - if importlib.util.find_spec("vptq.cuda_ops") is not None: - pass - else: - print(( - "!!! Warning !!!: CUDA kernels are not found, " - "please check CUDA and VPTQ installation." - )) - print(( - "!!! Warning !!!: Running on Torch implementations, " - "which is extremely slow." - )) model.eval() torch.cuda.empty_cache() diff --git a/vptq/ops/quant_gemm.py b/vptq/ops/quant_gemm.py index 8279280..845a21c 100644 --- a/vptq/ops/quant_gemm.py +++ b/vptq/ops/quant_gemm.py @@ -4,22 +4,40 @@ # -------------------------------------------------------------------------- __all__ = [ - 'dequant', - 'quant_gemm', + "dequant", + "quant_gemm", ] import math +import os import torch from torch.nn import functional as F + +def _load_library(filename: str) -> bool: + """Load a shared library from the given filename.""" + try: + libdir = os.path.dirname(os.path.dirname(__file__)) + torch.ops.load_library(os.path.join(libdir, filename)) + print(f"Successfully loaded: '{filename}'") + return True + except Exception as error: + print(( + f"{error}\n" + "!!! Warning !!!: CUDA kernels are not found, " + "please check CUDA and VPTQ installation." + )) + print(( + "!!! Warning !!!: Running on Torch implementations, " + "which is extremely slow." + )) + return False + + # isort: off # we need to import the CUDA kernels after importing torch -__cuda_ops_installed = True -try: - from vptq import cuda_ops -except ImportError: - __cuda_ops_installed = False +__cuda_ops_installed: bool = _load_library("libvptq.so") def unpack_index_tensor( @@ -226,7 +244,7 @@ def quant_gemm( enable_norm = weight_scale is not None and weight_bias is not None if (x.numel() // x.shape[-1] < 3) and __cuda_ops_installed: - out = cuda_ops.gemm( + out = torch.ops.vptq.gemm( x, indices, centroids_, @@ -245,7 +263,7 @@ def quant_gemm( return out else: if __cuda_ops_installed: - weight = cuda_ops.dequant( + weight = torch.ops.vptq.dequant( indices, centroids_, residual_indices, From d330554145c45951ef06de90b4d3d5ea964fb405 Mon Sep 17 00:00:00 2001 From: "lcy.seso" Date: Mon, 30 Dec 2024 08:56:40 +0000 Subject: [PATCH 2/2] Fix the bug where the Torch library is not correctly linked. --- setup.py | 4 +++- vptq/ops/quant_gemm.py | 2 -- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index ed2a380..942624e 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,9 @@ def copy_extensions_to_source(self) -> None: ) inplace_file, _ = self._get_inplace_equivalent(build_py, ext) - target_path = os.path.join(build_py.build_lib, "vptq", inplace_file) + target_path = os.path.join( + build_py.build_lib, "vptq", "ops", inplace_file + ) # Always copy, even if source is older than destination, to ensure # that the right extensions for the current Python/platform are diff --git a/vptq/ops/quant_gemm.py b/vptq/ops/quant_gemm.py index 845a21c..ccd7cf6 100644 --- a/vptq/ops/quant_gemm.py +++ b/vptq/ops/quant_gemm.py @@ -35,8 +35,6 @@ def _load_library(filename: str) -> bool: return False -# isort: off -# we need to import the CUDA kernels after importing torch __cuda_ops_installed: bool = _load_library("libvptq.so")