Skip to content

Commit

Permalink
change mmdialog dataloder
Browse files Browse the repository at this point in the history
  • Loading branch information
KzZheng committed Mar 19, 2024
1 parent 21308c1 commit 2121c74
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 58 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Since our model is trained with two stages **(Stage 1: Unimodal Alignment Stage,

| Stage 1: CC3M | Stage 2: VIST | Stage 2: MMDialog |
:------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:
[Download](https://drive.google.com/file/d/1y-VUXubIzFe0iq5_CJUaE3HKhlrdn4n2/view?usp=sharing) | [Download](https://drive.google.com/file/d/1rjTsKwF8_pqcNLbdZdurqZLSpKoo2K9F/view?usp=drive_link) | [Download](https://drive.google.com/file/d/1ehyX8Ykn1pbU5J8yM47catSswId0m5FZ/view?usp=drive_link)
[Download](https://drive.google.com/file/d/1y-VUXubIzFe0iq5_CJUaE3HKhlrdn4n2/view?usp=sharing) | [Download](https://drive.google.com/file/d/1rjTsKwF8_pqcNLbdZdurqZLSpKoo2K9F/view?usp=drive_link) | [Download](https://drive.google.com/file/d/1uo0LU-X11F1FIPC2h62s4Uzl6rBSAoQH/view?usp=sharing)

Stage 2 needs the pretrained weights in Stage 1, so always download Stage 1 weights first.

Expand Down
161 changes: 104 additions & 57 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,72 +312,118 @@ def __init__(self, data_path: str, input_processor=None, output_vis_processor=No
self.input_processor = input_processor
self.output_vis_processor = output_vis_processor
self.output_img_id = input_processor.tokenizer.convert_tokens_to_ids(ALL_IMG_TOKENS[0])
eos_token = input_processor.tokenizer.eos_token
# eos_token = input_processor.tokenizer.eos_token
self.load_preprocessed_image_features = False

system_prompt="Give the following images in <Img>ImageContent</Img> format. "\
"You will be able to see the images once I provide it to you. Please generate conversation with appropriate image."
"You will be able to see the images once I provide it to you. Please generate conversations with appropriate images."

error_image_ids = ['-3872362534310063124', '3713630103994725971']

self.sources, self.targets, self.input_image_path, self.output_image_path = [], [], [], []
self.caption, self.task_names = [], []
data_folder = os.path.dirname(data_path)
with open(data_path, 'r') as f:
all_data = f.readlines()
for data in tqdm(all_data):
data = json.loads(data)
data_num = data['conversation_id']
conversation = data['conversation']
if len(conversation)<2:
continue
history_prompt = system_prompt
history_images = []
for i, conv in enumerate(conversation):
turn = conv['turn']
turn_text = turn[0]['__TEXT__']
if len(turn)==1:
turn_image_path = None
else:
turn_image_path = os.path.join(data_folder, f"{turn[1]['__MEDIA__']}.jpg")
if not os.path.exists(turn_image_path):
# print(f'Cannot Find: {turn_image_path}')
all_sources, all_targets = [], []

preprocessed_data_pkl = data_path.replace('.txt', '.pkl')

if os.path.exists(preprocessed_data_pkl) and not self.test:
print("Loading saved data...")
all_data = torch.load(preprocessed_data_pkl)
self.sources = all_data['sources']
self.targets = all_data['targets']
self.input_image_path = all_data['input_image_path']
self.output_image_path = all_data['output_image_path']
self.caption = all_data['caption']
self.task_names = all_data['task_names']
del all_data
else:
data_folder = os.path.dirname(data_path)
with open(data_path, 'r') as f:
all_data = f.readlines()

for data in tqdm(all_data):
data = json.loads(data)
data_id = data['conversation_id']
conversation = data['conversation']
if len(conversation)<2:
continue
history_prompt = []
user1_counting = 0
remove_idx = 0
for i, conv in enumerate(conversation):
turn = conv['turn']
turn_text = turn[0]['__TEXT__']
if len(turn_text)==0 or turn_text.endswith('.jpg'):
break
if len(turn_text.split(' '))>20 and not test:
break
if i%2==0:
source = history_prompt + "###Human:"
else:
source = history_prompt + "###Assistant:"
if not self.test:
tokened_text = input_processor.tokenizer(turn_text, return_tensors="pt", add_special_tokens=False).input_ids
if len(tokened_text[0])>50:
break

if i>0:
self.sources.append(source)
if turn_image_path is not None:
target = f"{turn_text} {ALL_IMG_TOKENS_STR} ###"
else:
target = f"{turn_text} ###"
self.targets.append(target)
self.caption.append(None)
self.task_names.append(f'mmdialog{data_num}_{i}')
if len(history_images):
self.input_image_path.append(copy.deepcopy(history_images))
if len(turn)==1:
turn_image_path = None
else:
self.input_image_path.append([None])
self.output_image_path.append(turn_image_path)
turn_image_path = os.path.join(data_folder, f"{turn[1]['__MEDIA__']}.jpg")
image_stem = Path(turn_image_path).stem
if not os.path.exists(turn_image_path) or image_stem in error_image_ids:
break

if i>0:
step_prompt = copy.deepcopy(history_prompt)
if i%2==0:
step_prompt.append("###Human:")
else:
step_prompt.append("###Assistant:")


if turn_image_path is not None:
step_target = [f"{turn_text} {ALL_IMG_TOKENS_STR}", turn_image_path]
else:
step_target = [turn_text]

if turn_image_path is not None:
history_prompt = source + f" {turn_text} <Img><ImageHere></Img>\n"
history_images.append(turn_image_path)
else:
history_prompt = source + f" {turn_text}\n"

if (i%2==1 and i>2):
pattern = "###Human:(.*?)###Human:"
match = re.search(pattern, history_prompt, re.DOTALL)
match_text = match.group(0)
history_prompt = history_prompt.replace(match_text, "###Human:")
pattern2 = '<Img><ImageHere></Img>'
match_num = len(re.findall(pattern2, match_text))
if match_num>0:
history_images = history_images[match_num:]
all_sources.append(step_prompt)
all_targets.append(step_target)
self.task_names.append(f'{data_id}_{i}-mm')

if i%2==0:
turn_text = f"###Human: {turn_text}"
user1_counting+=1
if user1_counting==2:
remove_idx = len(history_prompt)
else:
turn_text = f"###Assistant: {turn_text}"

if user1_counting == 2 and i%2==1:
history_prompt = history_prompt[remove_idx:]
user1_counting = 1
remove_idx = 0
history_prompt.append(turn_text)
if turn_image_path is not None:
history_prompt.append(turn_image_path)

for source, target in zip(all_sources, all_targets):
new_source = [system_prompt]
input_images = []
for s in source:
if s.endswith('.jpg'):
input_images.append(s)
new_source.append('<Img><ImageHere></Img>')
else:
new_source.append(s)
self.sources.append(' '.join(new_source))
self.targets.append(f"{target[0]} ###")
self.input_image_path.append(input_images if len(input_images)>0 else [None])
self.output_image_path.append(target[1] if len(target)>1 else None)
self.caption.append(None)

with open(preprocessed_data_pkl, 'wb') as f:
torch.save({'sources': self.sources,
'targets': self.targets,
'input_image_path': self.input_image_path,
'output_image_path': self.output_image_path,
'caption': self.caption,
'task_names': self.task_names,
}, f)

self.valid_idx = list(range(len(self.sources)))
print('Load data done with {} samples!'.format(len(self.sources)))
Expand All @@ -387,7 +433,8 @@ def __getitem__(self, i):
try:
item = super().__getitem__(i)
break
except:
except Exception as e:
print(e)
i = random.choice(self.valid_idx)
return item

Expand Down

0 comments on commit 2121c74

Please sign in to comment.