-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
126 lines (98 loc) · 5.19 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import math
import torch
import logging
from PIL import Image, ImageDraw, ImageFont
def compute_ca_loss(attn_maps_mid, attn_maps_up, bboxes, object_positions):
loss = 0
object_number = len(bboxes)
if object_number == 0:
return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()
for attn_map_integrated in attn_maps_mid:
attn_map = attn_map_integrated
b, i, j = attn_map.shape
H = W = int(math.sqrt(i))
for obj_idx in range(object_number):
obj_loss = 0
mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W))
for obj_box in bboxes[obj_idx]:
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
mask[y_min: y_max, x_min: x_max] = 1
for obj_position in object_positions[obj_idx]:
ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(
dim=-1)
obj_loss += torch.mean((1 - activation_value) ** 2)
loss += (obj_loss / len(object_positions[obj_idx]))
for attn_map_integrated in attn_maps_up[0]:
attn_map = attn_map_integrated
#
b, i, j = attn_map.shape
H = W = int(math.sqrt(i))
for obj_idx in range(object_number):
obj_loss = 0
mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W))
for obj_box in bboxes[obj_idx]:
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
mask[y_min: y_max, x_min: x_max] = 1
for obj_position in object_positions[obj_idx]:
ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
# ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W)
activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(
dim=-1)
obj_loss += torch.mean((1 - activation_value) ** 2)
loss += (obj_loss / len(object_positions[obj_idx]))
loss = loss / (object_number * (len(attn_maps_up[0]) + len(attn_maps_mid)))
return loss
def Phrase2idx(prompt, phrases):
"""找到 phrases 中每个短语(由分号分隔)内的单词在 prompt 文本中首次出现的位置(索引),并将这些索引作为列表返回"""
phrases = [x.strip() for x in phrases.split(';')]
prompt_list = prompt.strip('.').split(' ')
object_positions = []
for obj in phrases:
obj_position = []
for word in obj.split(' '):
obj_first_index = prompt_list.index(word) + 1
obj_position.append(obj_first_index)
object_positions.append(obj_position)
return object_positions
def draw_box(pil_img, bboxes, phrases, save_path):
draw = ImageDraw.Draw(pil_img)
font = ImageFont.truetype('./FreeMono.ttf', 25)
phrases = [x.strip() for x in phrases.split(';')]
for obj_bboxes, phrase in zip(bboxes, phrases):
for obj_bbox in obj_bboxes:
x_0, y_0, x_1, y_1 = obj_bbox[0], obj_bbox[1], obj_bbox[2], obj_bbox[3]
draw.rectangle([int(x_0 * 512), int(y_0 * 512), int(x_1 * 512), int(y_1 * 512)], outline='red', width=5)
draw.text((int(x_0 * 512) + 5, int(y_0 * 512) + 5), phrase, font=font, fill=(255, 0, 0))
pil_img.save(save_path)
def setup_logger(save_path, logger_name):
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Create a file handler to write logs to a file
file_handler = logging.FileHandler(os.path.join(save_path, f"{logger_name}.log"))
file_handler.setLevel(logging.INFO)
# Create a formatter to format log messages
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Set the formatter for the file handler
file_handler.setFormatter(formatter)
# Add the file handler to the logger
logger.addHandler(file_handler)
return logger
def load_text_inversion(text_encoder, tokenizer, placeholder_token, embedding_ckp_path):
num_added_tokens = tokenizer.add_tokens(placeholder_token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
learned_embedding = torch.load(embedding_ckp_path)
token_embeds[placeholder_token_id] = learned_embedding[placeholder_token]
return text_encoder, tokenizer