-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
86 lines (67 loc) · 2.3 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright the author(s) of DLK.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import os
from setuptools import find_packages, setup
def write_version_py():
with open(os.path.join("dlk", "version.txt")) as f:
version = f.read().strip()
# write version info to fairseq/version.py
with open(os.path.join("dlk", "version.py"), "w") as f:
f.write(f'__version__ = "{version}"\n')
return version
version = write_version_py()
cmdclass = {}
extensions = []
if ("CUDA_HOME" in os.environ) and int(os.environ.get("BUILD_CUDA", "0")) == 1:
try:
from torch.utils import cpp_extension
extensions = [
cpp_extension.CppExtension(
"dlk.ngram_repeat_block_cuda",
sources=[
"dlk/cuda/ngram_repeat_block_cuda.cpp",
"dlk/cuda/ngram_repeat_block_cuda_kernel.cu",
],
),
]
cmdclass = {"build_ext": cpp_extension.BuildExtension}
except:
pass
if "READTHEDOCS" in os.environ:
# don't build extensions when generating docs
extensions = []
if "build_ext" in cmdclass:
del cmdclass["build_ext"]
def package_files(directory):
paths = []
for path, directories, filenames in os.walk(directory):
for filename in filenames:
paths.append(os.path.join(".", path, filename))
return paths
added_files = []
added_files.extend(
package_files(os.path.join("dlk", "utils", "display", "label_colors"))
)
added_files.extend(package_files(os.path.join("dlk", "utils", "display", "fonts")))
with open("README.md", encoding="utf-8") as f:
readme = f.read()
with open("LICENSE", encoding="utf-8") as f:
license = f.read()
with open("requirements.txt", encoding="utf-8") as f:
requirements = f.read()
pkgs = [p for p in find_packages() if p.startswith("dlk")]
setup(
url="https://github.com/cstsunfu/dlk",
description="dlk: Deep Learning Kit",
long_description_content_type="text/markdown",
version=version,
ext_modules=extensions,
cmdclass=cmdclass,
package_data={"": added_files},
license=license,
include_package_data=True,
packages=pkgs,
install_requires=requirements.strip().split("\n"),
)