Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix the bug where the Torch library is not correctly linked. #152

Merged
merged 2 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# --------------------------------------------------------------------------

cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(vptq_cuda_ops LANGUAGES C CXX CUDA)
project(vptq LANGUAGES C CXX CUDA)

# Prohibit in-source builds
if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_BINARY_DIR})
Expand Down
6 changes: 4 additions & 2 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the
# MIT License.
# --------------------------------------------------------------------------
set(TARGET "cuda_ops")
file(GLOB_RECURSE SOURCES "*.cu")
set(TARGET "vptq")
file(GLOB_RECURSE SOURCES "*.cu" "*.cc")
message(STATUS "Building ${TARGET} with ${SOURCES}")

cuda_add_library(${TARGET} SHARED ${SOURCES})

Expand All @@ -29,3 +30,4 @@ target_compile_options(
--use_fast_math
--generate-line-info>)
target_compile_features(${TARGET} PUBLIC cxx_std_17 cuda_std_17)
target_link_libraries(${TARGET} "${TORCH_LIBRARIES}")
4 changes: 2 additions & 2 deletions csrc/common.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
Expand Down Expand Up @@ -39,4 +39,4 @@ inline void gpuAssert(cudaError_t code, const char* file, int line) {
line);
TORCH_CHECK(false, cudaGetErrorString(code));
}
}
}
6 changes: 3 additions & 3 deletions csrc/dequant_impl_packed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ __global__ void DequantizeWithOutliers_PackIndice(
// @param weight_bias
// @return torch::Tensor
torch::Tensor launch_deqantize_outliers_cuda_packkernel(
const int* outf_x_inf, const torch::Tensor& q_indice,
const int64_t* outf_x_inf, const torch::Tensor& q_indice,
const torch::Tensor& centroids,
const c10::optional<torch::Tensor>& q_indice_residual,
const c10::optional<torch::Tensor>& residual_centroids,
Expand Down Expand Up @@ -534,7 +534,7 @@ torch::Tensor launch_deqantize_outliers_cuda_packkernel(
// @param bias
// @return torch::Tensor
torch::Tensor launch_gemv_outliers_cuda_packkernel(
const int out_features, const torch::Tensor& input,
const int64_t out_features, const torch::Tensor& input,
const torch::Tensor& q_indice, const torch::Tensor& centroids,
const c10::optional<torch::Tensor>& q_indice_residual,
const c10::optional<torch::Tensor>& residual_centroids,
Expand All @@ -544,7 +544,7 @@ torch::Tensor launch_gemv_outliers_cuda_packkernel(
const torch::Tensor& weight_bias,
const c10::optional<torch::Tensor>& bias) {
OptionalCUDAGuard cudaguard(input.device().index());
const int base_groupsize = centroids.size(-1);
const int64_t base_groupsize = centroids.size(-1);
int index_bits = log2(centroids.size(1));
int res_index_bits = residual_centroids.has_value()
? log2(residual_centroids.value().size(1))
Expand Down
67 changes: 44 additions & 23 deletions csrc/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
#include "common.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <torch/extension.h>
#include <torch/library.h>

#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
Expand All @@ -15,7 +17,7 @@
CHECK_CONTIGUOUS(x)

torch::Tensor launch_deqantize_outliers_cuda_packkernel(
const int* outf_x_inf, const torch::Tensor& q_indice,
const int64_t* outf_x_inf, const torch::Tensor& q_indice,
const torch::Tensor& centroids,
const c10::optional<torch::Tensor>& q_indice_residual,
const c10::optional<torch::Tensor>& residual_centroids,
Expand All @@ -25,7 +27,7 @@ torch::Tensor launch_deqantize_outliers_cuda_packkernel(
const torch::Tensor& weight_bias);

torch::Tensor launch_gemv_outliers_cuda_packkernel(
const int out_features, const torch::Tensor& input,
const int64_t out_features, const torch::Tensor& input,
const torch::Tensor& q_indice, const torch::Tensor& centroids,
const c10::optional<torch::Tensor>& q_indice_residual,
const c10::optional<torch::Tensor>& residual_centroids,
Expand All @@ -42,8 +44,8 @@ torch::Tensor dequant(const torch::Tensor& q_indice,
const c10::optional<torch::Tensor>& outliers_centroids,
const c10::optional<torch::Tensor>& invperm,
const torch::Tensor& weight_scale,
const torch::Tensor& weight_bias, int groupsize,
int in_features, int out_features) {
const torch::Tensor& weight_bias, int64_t groupsize,
int64_t in_features, int64_t out_features) {
auto dev_index = q_indice.device().index();

CHECK_INPUT(q_indice);
Expand Down Expand Up @@ -85,7 +87,7 @@ torch::Tensor dequant(const torch::Tensor& q_indice,

at::cuda::OptionalCUDAGuard guard(q_indice.device());
torch::Tensor output;
const int out_f_x_in_f[2] = {out_features, in_features};
const int64_t out_f_x_in_f[2] = {out_features, in_features};

output = launch_deqantize_outliers_cuda_packkernel(
out_f_x_in_f, q_indice, centroids, q_indice_residual, residual_centroids,
Expand All @@ -106,8 +108,9 @@ torch::Tensor wqA16Gemm(const torch::Tensor& input,
const c10::optional<torch::Tensor>& invperm,
const torch::Tensor& weight_scale,
const torch::Tensor& weight_bias,
const c10::optional<torch::Tensor>& bias, int groupsize,
int in_features, int out_features) {
const c10::optional<torch::Tensor>& bias,
int64_t groupsize, int64_t in_features,
int64_t out_features) {
CHECK_INPUT(q_indice);
CHECK_INPUT(input);
if (q_indice_residual.has_value()) {
Expand Down Expand Up @@ -155,22 +158,40 @@ torch::Tensor wqA16Gemm(const torch::Tensor& input,
return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dequant", &dequant,
R"DOC(Dequantize matrix weights to fp16.
function type:
const torch::Tensor& qweight,
const torch::Tensor& scales,
const torch::Tensor& qzeros,
Tensor g_idx, int groupsize, int bits, int in_features
)DOC");
TORCH_LIBRARY_IMPL(vptq, CUDA, m) {
m.impl("dequant", dequant);
m.impl("gemm", wqA16Gemm);
}

m.def("gemm", &wqA16Gemm,
R"DOC(Compute the gemm output, usually gemv.
function type:
const torch::Tensor& qweight,
const torch::Tensor& scales,
const torch::Tensor& qzeros,
tensor g_idx, int groupsize, int bits, int in_features
TORCH_LIBRARY(vptq, m) {
m.def(
R"DOC(dequant(Tensor q_indice,
Tensor centroids,
Tensor? q_indice_residual,
Tensor? residual_centroids,
Tensor? q_indice_outliers,
Tensor? outliers_centroids,
Tensor? invperm,
Tensor weight_scale,
Tensor weight_bias,
int groupsize,
int in_features,
int out_features) -> Tensor
)DOC");
m.def(
R"DOC(gemm(Tensor input,
Tensor q_indice,
Tensor centroids,
Tensor? q_indice_residual,
Tensor? residual_centroids,
Tensor? q_indice_outliers,
Tensor? outliers_centroids,
Tensor? invperm,
Tensor weight_scale,
Tensor weight_bias,
Tensor? bias,
int groupsize,
int in_features,
int out_features) -> Tensor
)DOC");
}
15 changes: 11 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,28 @@ def get_requirements():
class CMakeExtension(Extension):
""" specify the root folder of the CMake projects"""

def __init__(self, name="cuda_ops", cmake_lists_dir=".", **kwargs):
def __init__(self, name, cmake_lists_dir=".", **kwargs):
Extension.__init__(self, name, sources=[], **kwargs)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)


class CMakeBuildExt(build_ext):
"""launches the CMake build."""

def get_ext_filename(self, name):
return f"lib{name}.so"

def copy_extensions_to_source(self) -> None:
build_py = self.get_finalized_command("build_py")
for ext in self.extensions:
source_path = os.path.join(self.build_lib, "lib" + ext.name + ".so")
source_path = os.path.join(
self.build_lib, self.get_ext_filename(ext.name)
)
inplace_file, _ = self._get_inplace_equivalent(build_py, ext)

target_path = os.path.join(build_py.build_lib, "vptq", inplace_file)
target_path = os.path.join(
build_py.build_lib, "vptq", "ops", inplace_file
)

# Always copy, even if source is older than destination, to ensure
# that the right extensions for the current Python/platform are
Expand Down Expand Up @@ -169,7 +176,7 @@ def run(self):
version=get_version(),
description=description,
author="Wang Yang, Wen JiCheng",
ext_modules=[CMakeExtension()],
ext_modules=[CMakeExtension("vptq")],
cmdclass={
"build_ext": CMakeBuildExt,
"clean": Clean,
Expand Down
5 changes: 4 additions & 1 deletion vptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@

__version__ = importlib.metadata.version("vptq")

__all__ = ["AutoModelForCausalLM", "VQuantLinear"]
__all__ = [
"AutoModelForCausalLM",
"VQuantLinear",
]
5 changes: 4 additions & 1 deletion vptq/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@

from vptq.layers.model_base import AutoModelForCausalLM, VQuantLinear

__all__ = ["AutoModelForCausalLM", "VQuantLinear"]
__all__ = [
"AutoModelForCausalLM",
"VQuantLinear",
]
13 changes: 0 additions & 13 deletions vptq/layers/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# --------------------------------------------------------------------------

import glob
import importlib.util
from pathlib import Path

import accelerate
Expand Down Expand Up @@ -203,18 +202,6 @@ def from_pretrained(
preload_module_classes=["VQuantLinear"]
)

# check cuda kernel exist
if importlib.util.find_spec("vptq.cuda_ops") is not None:
pass
else:
print((
"!!! Warning !!!: CUDA kernels are not found, "
"please check CUDA and VPTQ installation."
))
print((
"!!! Warning !!!: Running on Torch implementations, "
"which is extremely slow."
))
model.eval()

torch.cuda.empty_cache()
Expand Down
38 changes: 27 additions & 11 deletions vptq/ops/quant_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,38 @@
# --------------------------------------------------------------------------

__all__ = [
'dequant',
'quant_gemm',
"dequant",
"quant_gemm",
]

import math
import os

import torch
from torch.nn import functional as F

# isort: off
# we need to import the CUDA kernels after importing torch
__cuda_ops_installed = True
try:
from vptq import cuda_ops
except ImportError:
__cuda_ops_installed = False

def _load_library(filename: str) -> bool:
"""Load a shared library from the given filename."""
try:
libdir = os.path.dirname(os.path.dirname(__file__))
torch.ops.load_library(os.path.join(libdir, filename))
print(f"Successfully loaded: '{filename}'")
return True
except Exception as error:
print((
f"{error}\n"
"!!! Warning !!!: CUDA kernels are not found, "
"please check CUDA and VPTQ installation."
))
print((
"!!! Warning !!!: Running on Torch implementations, "
"which is extremely slow."
))
return False


__cuda_ops_installed: bool = _load_library("libvptq.so")


def unpack_index_tensor(
Expand Down Expand Up @@ -226,7 +242,7 @@ def quant_gemm(
enable_norm = weight_scale is not None and weight_bias is not None

if (x.numel() // x.shape[-1] < 3) and __cuda_ops_installed:
out = cuda_ops.gemm(
out = torch.ops.vptq.gemm(
x,
indices,
centroids_,
Expand All @@ -245,7 +261,7 @@ def quant_gemm(
return out
else:
if __cuda_ops_installed:
weight = cuda_ops.dequant(
weight = torch.ops.vptq.dequant(
indices,
centroids_,
residual_indices,
Expand Down
Loading