-
Notifications
You must be signed in to change notification settings - Fork 34
/
setup.py
77 lines (65 loc) · 1.96 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python3
import os
import sys
from functools import lru_cache
from subprocess import DEVNULL, call
import torch
from setuptools import setup
from torch.utils import cpp_extension
@lru_cache(None)
def cuda_toolkit_available():
# https://github.com/idiap/fast-transformers/blob/master/setup.py
try:
call(["nvcc"], stdout=DEVNULL, stderr=DEVNULL)
return True
except FileNotFoundError:
return False
def compile_args():
args = ["-fopenmp", "-ffast-math"]
if sys.platform == "darwin":
return ["-Xpreprocessor"] + args
return args
def ext_modules():
extensions = [
cpp_extension.CppExtension(
"torchsort.isotonic_cpu",
sources=["torchsort/isotonic_cpu.cpp"],
extra_compile_args=compile_args(),
),
]
if cuda_toolkit_available():
extensions.append(
cpp_extension.CUDAExtension(
"torchsort.isotonic_cuda",
sources=["torchsort/isotonic_cuda.cu"],
),
)
return extensions
with open("README.md") as f:
long_description = f.read()
setup(
name="torchsort",
version="0.1.9" + os.getenv("TORCHSORT_VERSION_SUFFIX", ""),
description="Differentiable sorting and ranking in PyTorch",
author="Teddy Koker",
url="https://github.com/teddykoker/torchsort",
license="Apache",
long_description=long_description,
long_description_content_type="text/markdown",
packages=["torchsort"],
classifiers=[
"Programming Language :: Python :: 3",
],
install_requires=["torch"],
python_requires=">=3.7",
extras_require={
"testing": [
"pytest",
# "torch",
"fast_soft_sort @ git+https://github.com/google-research/fast-soft-sort.git@6a52ce79869ab16e1e0f39149a84f50f8ad648c5",
],
},
ext_modules=ext_modules(),
cmdclass={"build_ext": cpp_extension.BuildExtension},
include_package_data=True,
)