-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathinstall.py
146 lines (130 loc) · 5.26 KB
/
install.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import subprocess
import os, sys
from typing import Any
import pkg_resources
from tqdm import tqdm
import urllib.request
from packaging import version as pv
try:
from modules.paths_internal import models_path
except:
try:
from modules.paths import models_path
except:
models_path = os.path.abspath("models")
BASE_PATH = os.path.dirname(os.path.realpath(__file__))
req_file = os.path.join(BASE_PATH, "requirements.txt")
models_dir = os.path.join(models_path, "insightface")
model_url = "https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/inswapper_128.onnx"
model_name = os.path.basename(model_url)
model_path = os.path.join(models_dir, model_name)
def pip_install(*args):
subprocess.run([sys.executable, "-m", "pip", "install", *args])
def pip_uninstall(*args):
subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", *args])
def is_installed (
package: str, version: str | None = None, strict: bool = True
):
has_package = None
try:
has_package = pkg_resources.get_distribution(package)
if has_package is not None:
installed_version = has_package.version
if (installed_version != version and strict == True) or (pv.parse(installed_version) < pv.parse(version) and strict == False):
return False
else:
return True
else:
return False
except Exception as e:
print(f"Error: {e}")
return False
def download(url, path):
request = urllib.request.urlopen(url)
total = int(request.headers.get('Content-Length', 0))
with tqdm(total=total, desc='Downloading...', unit='B', unit_scale=True, unit_divisor=1024) as progress:
urllib.request.urlretrieve(url, path, reporthook=lambda count, block_size, total_size: progress.update(block_size))
if not os.path.exists(models_dir):
os.makedirs(models_dir)
if not os.path.exists(model_path):
download(model_url, model_path)
# print("ReActor preheating...", end=' ')
last_device = None
first_run = False
available_devices = ["CPU", "CUDA"]
try:
last_device_log = os.path.join(BASE_PATH, "last_device.txt")
with open(last_device_log) as f:
last_device = f.readline().strip()
if last_device not in available_devices:
last_device = None
except:
last_device = "CPU"
first_run = True
with open(os.path.join(BASE_PATH, "last_device.txt"), "w") as txt:
txt.write(last_device)
with open(req_file) as file:
install_count = 0
ort = "onnxruntime-gpu"
import torch
cuda_version = None
try:
if torch.cuda.is_available():
cuda_version = torch.version.cuda
print(f"CUDA {cuda_version}")
if first_run or last_device is None:
last_device = "CUDA"
elif torch.backends.mps.is_available() or hasattr(torch,'dml') or hasattr(torch,'privateuseone'):
ort = "onnxruntime"
# to prevent errors when ORT-GPU is installed but we want ORT instead:
if first_run:
pip_uninstall("onnxruntime", "onnxruntime-gpu")
# just in case:
if last_device == "CUDA" or last_device is None:
last_device = "CPU"
else:
if last_device == "CUDA" or last_device is None:
last_device = "CPU"
with open(os.path.join(BASE_PATH, "last_device.txt"), "w") as txt:
txt.write(last_device)
if cuda_version is not None:
if float(cuda_version)>=12: # CU12.x
extra_index_url = "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/"
else: # CU11.8
extra_index_url = "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-11/pypi/simple"
if not is_installed(ort,"1.17.1",True):
install_count += 1
ort = "onnxruntime-gpu==1.17.1"
pip_uninstall("onnxruntime", "onnxruntime-gpu")
pip_install(ort,"--extra-index-url",extra_index_url)
elif not is_installed(ort,"1.18.1",False):
install_count += 1
pip_install(ort, "-U")
except Exception as e:
print(e)
print(f"\nERROR: Failed to install {ort} - ReActor won't start")
raise e
# print(f"Device: {last_device}")
strict = True
for package in file:
package_version = None
try:
package = package.strip()
if "==" in package:
package_version = package.split('==')[1]
elif ">=" in package:
package_version = package.split('>=')[1]
strict = False
if not is_installed(package,package_version,strict):
install_count += 1
pip_install(package)
except Exception as e:
print(e)
print(f"\nERROR: Failed to install {package} - ReActor won't start")
raise e
if install_count > 0:
print(f"""
+---------------------------------+
--- PLEASE, RESTART the Server! ---
+---------------------------------+
""")