-
Notifications
You must be signed in to change notification settings - Fork 3
/
setup.py
71 lines (64 loc) · 2.39 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, _find_cuda_home
import torch
import platform
CPU_COUNT = os.cpu_count()
generator_flag = []
torch_dir = torch.__path__[0]
cc_flag = []
def find_cublas_headers():
home = _find_cuda_home()
if home is None:
raise EnvironmentError("CUDA environment not found, ensure that you have CUDA toolkit installed locally, and have added it to your environment variables as CUDA_HOME=/path/to/cuda-12.x")
if platform.system() == "Windows":
cublas_include = os.path.join(home, "include")
cublas_libs = os.path.join(home, "lib", "x64")
else:
cublas_include = os.path.join(home, "include")
cublas_libs = os.path.join(home, "lib64")
return cublas_include, cublas_libs
def append_nvcc_threads(nvcc_extra_args):
nvcc_threads = os.getenv("NVCC_THREADS") or f"{min(CPU_COUNT, 8)}"
return nvcc_extra_args + ["--threads", nvcc_threads]
setup(
name="cublas_ops",
version="0.0.5",
ext_modules=[
CUDAExtension(
"cublas_ops_ext",
[
"src/cublas_hgemm.cpp",
"src/cublas_hgemm_kernel.cu",
"src/cublas_hgemm_batched_kernel.cu",
"src/cublaslt_hgemm_kernel.cu",
"src/cublaslt_hgemm_batched_kernel.cu",
"src/simt_hgemv.cu",
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"nvcc": append_nvcc_threads(
[
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
+ generator_flag
+ cc_flag
),
},
libraries=["cublas","cublasLt"],
include_dirs=[*find_cublas_headers()],
),
],
packages=find_packages(
exclude=[".misc", "__pycache__", ".vscode", "cublas_ops.egg-info"]
),
cmdclass={"build_ext": BuildExtension},
)