forked from turboderp-org/exllamav2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsetup.py
96 lines (84 loc) · 2.83 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from setuptools import setup, Extension
from torch.utils import cpp_extension
from torch import version as torch_version
import os
extension_name = "exllamav2_ext"
verbose = False
ext_debug = False
precompile = 'EXLLAMA_NOCOMPILE' not in os.environ
windows = (os.name == "nt")
extra_cflags = ["/Ox"] if windows else ["-O3"]
if ext_debug:
extra_cflags += ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]
extra_cuda_cflags = ["-lineinfo", '-O3']
if torch_version.hip:
extra_cuda_cflags += ["-DHIPBLAS_USE_HIP_HALF"]
extra_compile_args = {
"cxx": extra_cflags,
"nvcc": extra_cuda_cflags,
}
setup_kwargs = {
"ext_modules": [
cpp_extension.CUDAExtension(
extension_name,
[
"exllamav2/exllamav2_ext/ext.cpp",
"exllamav2/exllamav2_ext/cuda/h_gemm.cu",
"exllamav2/exllamav2_ext/cuda/lora.cu",
"exllamav2/exllamav2_ext/cuda/pack_tensor.cu",
"exllamav2/exllamav2_ext/cuda/quantize.cu",
"exllamav2/exllamav2_ext/cuda/q_matrix.cu",
"exllamav2/exllamav2_ext/cuda/q_attn.cu",
"exllamav2/exllamav2_ext/cuda/q_mlp.cu",
"exllamav2/exllamav2_ext/cuda/q_gemm.cu",
"exllamav2/exllamav2_ext/cuda/rms_norm.cu",
"exllamav2/exllamav2_ext/cuda/rope.cu",
"exllamav2/exllamav2_ext/cuda/cache.cu",
"exllamav2/exllamav2_ext/cpp/quantize_func.cpp",
"exllamav2/exllamav2_ext/cpp/sampling.cpp",
"exllamav2/exllamav2_ext/cuda/quip/quiptools.cu",
"exllamav2/exllamav2_ext/cuda/quip/quiptools_e8p_gemv.cu"
],
extra_compile_args=extra_compile_args,
libraries=["cublas"] if windows else [],
)],
"cmdclass": {"build_ext": cpp_extension.BuildExtension}
} if precompile else {}
version_py = {}
with open("exllamav2/version.py", encoding = "utf8") as fp:
exec(fp.read(), version_py)
version = version_py["__version__"]
print("Version:", version)
# version = "0.0.5"
setup(
name = "exllamav2",
version = version,
packages = [
"exllamav2",
"exllamav2.generator",
"exllamav2.quip",
# "exllamav2.generator.filters",
# "exllamav2.server",
# "exllamav2.exllamav2_ext",
# "exllamav2.exllamav2_ext.cpp",
# "exllamav2.exllamav2_ext.cuda",
# "exllamav2.exllamav2_ext.cuda.quant",
],
url = "https://github.com/turboderp/exllamav2",
license = "MIT",
author = "turboderp",
install_requires = [
"pandas",
"ninja",
"fastparquet",
"torch>=2.0.1",
"safetensors>=0.3.2",
"sentencepiece>=0.1.97",
"pygments",
"websockets",
"regex"
],
include_package_data = True,
verbose = verbose,
**setup_kwargs,
)