-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_models.py
134 lines (98 loc) · 3.93 KB
/
load_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import time
import torch
from recognize_anything.ram.models import ram_plus
from recognize_anything.ram.models import ram
from recognize_anything.ram.models import tag2text
from recognize_anything.ram import inference_ram
from recognize_anything.ram import inference_tag2text
from recognize_anything.ram import get_transform
def load_ram_plus(image_size):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device: ", device)
transform = get_transform(image_size=image_size)
model = ram_plus(pretrained="/data/pretrained/ram_plus_swin_large_14m.pth",
image_size=image_size,
vit='swin_l')
model.eval()
model = model.to(device)
print("Loaded ram_plus_swin_large_14m.pth")
def inference(image):
start_time = time.perf_counter()
transformed = transform(image).unsqueeze(0).to(device)
result = inference_ram(transformed, model)
print(f"processed image in {time.perf_counter() - start_time:0.4f}s")
return {
"english": result[0].split(" | "),
"chinese": result[1].split(" | ")
}
return {
"device": device,
"model": model,
"inference": inference,
"transform": transform
}
def load_ram(image_size):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device: ", device)
transform = get_transform(image_size=image_size)
model = ram(pretrained="/data/pretrained/ram_swin_large_14m.pth",
image_size=image_size,
vit='swin_l')
model.eval()
model = model.to(device)
print("Loaded ram_swin_large_14m.pth")
def inference(image):
start_time = time.perf_counter()
transformed = transform(image).unsqueeze(0).to(device)
result = inference_ram(transformed, model)
print(f"processed image in {time.perf_counter() - start_time:0.4f}s")
return {
"english": result[0].split(" | "),
"chinese": result[1].split(" | ")
}
return {
"device": device,
"model": model,
"inference": inference,
"transform": transform
}
def load_tag2text(image_size, threshold=0.68, delete_tag_index=None):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device: ", device)
transform = get_transform(image_size=image_size)
# delete some tags that may disturb captioning
# 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
model = tag2text(pretrained="/data/pretrained/tag2text_swin_14m.pth",
image_size=image_size,
vit='swin_b',
delete_tag_index=(delete_tag_index or [127, 2961, 3351, 3265, 3338, 3355, 3359]))
model.threshold = threshold
model.eval()
model = model.to(device)
print("Loaded tag2text_swin_14m.pth")
def inference(image):
start_time = time.perf_counter()
transformed = transform(image).unsqueeze(0).to(device)
result = inference_tag2text(transformed, model)
print(f"processed image in {time.perf_counter() - start_time:0.4f}s")
return {
"model_tags": result[0] and result[0].split(" | "),
"user_tags": result[1] and result[1].split(" | "),
"image_caption": result[2]
}
return {
"device": device,
"model": model,
"inference": inference,
"transform": transform
}
def load_model(model_name, image_size, threshold, delete_tag_index):
print("Loading model: ", model_name)
match model_name:
case "ram_plus":
return load_ram_plus(image_size)
case "ram":
return load_ram(image_size)
case "tag2text":
# TODO: pass threshold / delete_tag_index parameters
return load_tag2text(image_size, threshold, delete_tag_index)