-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsetup.py
116 lines (100 loc) · 3.56 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import platform
import sys
import pathlib
import torch
from setuptools import setup, find_packages
from torch.__config__ import parallel_info
from torch.utils.cpp_extension import (
CUDA_HOME,
BuildExtension,
CppExtension,
CUDAExtension,
)
__version__ = "0.1.0"
URL = "https://github.com/fschlatt/window_matmul"
WITH_CUDA = CUDA_HOME is not None
def get_extension():
setup_dir = pathlib.Path(".")
src_dir = setup_dir / "csrc" / "window_matmul"
src_files = []
src_files.extend(src_dir.rglob("*.cpp"))
src_files.extend(src_dir.rglob("*.cu"))
# remove generated 'hip' files, in case of rebuilds
src_files = [path for path in src_files if "hip" not in str(path)]
# remove cuda files if cuda not available
src_files = [path for path in src_files if path.parent.name != "cuda" or WITH_CUDA]
src_files = [str(path) for path in src_files]
define_macros = [("WITH_PYTHON", None)]
undef_macros = []
if sys.platform == "win32":
define_macros += [("torchscatter_EXPORTS", None)]
extra_compile_args = {"cxx": ["-O3"]}
if not os.name == "nt": # Not on Windows:
extra_compile_args["cxx"] += ["-Wno-sign-compare"]
extra_link_args = ["-s"]
info = parallel_info()
if (
"backend: OpenMP" in info
and "OpenMP not found" not in info
and sys.platform != "darwin"
):
extra_compile_args["cxx"] += ["-DAT_PARALLEL_OPENMP"]
if sys.platform == "win32":
extra_compile_args["cxx"] += ["/openmp"]
else:
extra_compile_args["cxx"] += ["-fopenmp"]
else:
print("Compiling without OpenMP...")
# Compile for mac arm64
if sys.platform == "darwin" and platform.machine() == "arm64":
extra_compile_args["cxx"] += ["-arch", "arm64"]
extra_link_args += ["-arch", "arm64"]
if WITH_CUDA:
define_macros += [("WITH_CUDA", None)]
nvcc_flags = os.getenv("NVCC_FLAGS", "")
nvcc_flags = [] if nvcc_flags == "" else nvcc_flags.split(" ")
nvcc_flags += ["-O3"]
if not torch.cuda.device_count():
nvcc_flags += ["-arch", "all-major"]
if torch.version.hip:
# USE_ROCM was added to later versions of PyTorch.
# Define here to support older PyTorch versions as well:
define_macros += [("USE_ROCM", None)]
undef_macros += ["__HIP_NO_HALF_CONVERSIONS__"]
else:
nvcc_flags += ["--expt-relaxed-constexpr"]
extra_compile_args["nvcc"] = nvcc_flags
Extension = CUDAExtension if WITH_CUDA else CppExtension
extension = Extension(
"window_matmul_kernel",
src_files,
include_dirs=[str(src_dir)],
define_macros=define_macros,
undef_macros=undef_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
return [extension]
# work-around hipify abs paths
include_package_data = True
if torch.cuda.is_available() and torch.version.hip:
include_package_data = False
setup(
name="window_matmul",
version=__version__,
description="PyTorch extension for windowed matrix multiplication",
author="Ferdinand Schlatt",
author_email="[email protected]",
url=URL,
keywords=["pytorch", "matmul", "window"],
python_requires=">=3.7",
ext_modules=get_extension(),
cmdclass={
"build_ext": BuildExtension.with_options(
no_python_abi_suffix=True, use_ninja=False
)
},
packages=find_packages(exclude="tests"),
include_package_data=include_package_data,
)