diff --git a/configs/minigpt4/README.md b/configs/minigpt4/README.md index 01e53954639..23666fc9f95 100644 --- a/configs/minigpt4/README.md +++ b/configs/minigpt4/README.md @@ -34,9 +34,10 @@ For Vicuna model, please refer to [MiniGPT-4 page](https://github.com/Vision-CAI ### Pretrained models -| Model | Params (M) | Flops (G) | Config | Download | -| :------------------------------ | :--------: | :-------: | :--------------------------------------: | :------------------------------------------------------------------------------------------------------------: | -| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth) | +| Model | Params (M) | Flops (G) | Config | Download | +| :------------------------------ | :--------: | :-------: | :----------------------------------------: | :----------------------------------------------------------------------------------------------------------: | +| `minigpt-4_baichuan-7b_caption` | 8094.77 | N/A | [config](minigpt-4_baichuan-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth) | +| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth) | *Models with * are converted from the [official repo](https://github.com/Vision-CAIR/MiniGPT-4/tree/main). The config files of these models are only for inference. We haven't reproduce the training results.* diff --git a/configs/minigpt4/metafile.yml b/configs/minigpt4/metafile.yml index a7879d986f2..f70cc9ba604 100644 --- a/configs/minigpt4/metafile.yml +++ b/configs/minigpt4/metafile.yml @@ -19,8 +19,19 @@ Models: - Task: Image Caption Dataset: COCO Metrics: null - Weights: https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth + Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth Config: configs/minigpt4/minigpt-4_vicuna-7b_caption.py Converted From: Weights: https://github.com/Vision-CAIR/MiniGPT-4/tree/main Code: https://github.com/Vision-CAIR/MiniGPT-4/tree/main + - Name: minigpt-4_baichuan-7b_caption + Metadata: + FLOPs: null + Parameters: 8094769024 + In Collection: MiniGPT4 + Results: + - Task: Image Caption + Dataset: COCO + Metrics: null + Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth + Config: configs/minigpt4/minigpt-4_baichuan-7b_caption.py diff --git a/configs/minigpt4/minigpt-4_baichuan-7b_caption.py b/configs/minigpt4/minigpt-4_baichuan-7b_caption.py new file mode 100644 index 00000000000..7e610a099c8 --- /dev/null +++ b/configs/minigpt4/minigpt-4_baichuan-7b_caption.py @@ -0,0 +1,190 @@ +_base_ = [ + '../_base_/default_runtime.py', +] + +data_preprocessor = dict( + type='MultiModalDataPreprocessor', + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, +) + +# dataset settings +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(224, 224), + interpolation='bicubic', + backend='pillow'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='CleanCaption', + keys='chat_content', + remove_chars='', + lowercase=False), + dict( + type='PackInputs', + algorithm_keys=['chat_content', 'lang'], + meta_keys=['image_id']), +] + +train_dataloader = dict( + batch_size=2, + num_workers=4, + dataset=dict( + type='MiniGPT4Dataset', + data_root='YOUR_DATA_DIRECTORY', + ann_file='YOUR_DATA_FILE', + pipeline=train_pipeline), + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + drop_last=False, +) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + scale=(224, 224), + interpolation='bicubic', + backend='pillow'), + dict(type='PackInputs', meta_keys=['image_id']), +] + +test_evaluator = dict( + type='COCOCaption', + ann_file='data/coco/annotations/coco_karpathy_val_gt.json', +) + +test_dataloader = dict( + batch_size=1, + dataset=dict( + type='COCOCaption', + data_root='data/coco', + ann_file='annotations/coco_karpathy_val.json', + pipeline=test_pipeline)) + +# model settings +model = dict( + type='MiniGPT4', + vision_encoder=dict( + type='BEiTViT', + # eva-g without the final layer + arch=dict( + embed_dims=1408, + num_layers=39, + num_heads=16, + feedforward_channels=6144, + ), + img_size=224, + patch_size=14, + layer_scale_init_value=0.0, + frozen_stages=39, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + final_norm=False, + use_shared_rel_pos_bias=False, + out_type='raw', + pretrained= # noqa + 'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_eva-g-p14_20230615-e908c021.pth' # noqa + ), + q_former_model=dict( + type='Qformer', + model_style='bert-base-uncased', + vision_model_width=1408, + add_cross_attention=True, + cross_attention_freq=2, + num_query_token=32, + pretrained= # noqa + 'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_qformer_20230615-1dfa889c.pth' # noqa + ), + lang_encoder=dict( + type='AutoModelForCausalLM', + name_or_path='baichuan-inc/baichuan-7B', + trust_remote_code=True), + tokenizer=dict( + type='AutoTokenizer', + name_or_path='baichuan-inc/baichuan-7B', + trust_remote_code=True), + task='caption', + prompt_template=dict([('en', '###Ask: {} ###Answer: '), + ('zh', '###问:{} ###答:')]), + raw_prompts=dict([ + ('en', [(' ' + 'Describe this image in detail.'), + (' ' + 'Take a look at this image and describe what you notice.'), + (' ' + 'Please provide a detailed description of the picture.'), + (' ' + 'Could you describe the contents of this image for me?')]), + ('zh', [(' ' + '详细描述这张图片。'), (' ' + '浏览这张图片并描述你注意到什么。'), + (' ' + '请对这张图片进行详细的描述。'), + (' ' + '你能为我描述这张图片的内容吗?')]) + ]), + max_txt_len=160, + end_sym='###') + +strategy = dict( + type='DeepSpeedStrategy', + fp16=dict( + enabled=True, + auto_cast=False, + fp16_master_weights_and_grads=False, + loss_scale=0, + loss_scale_window=1000, + hysteresis=1, + min_loss_scale=1, + initial_scale_power=16, + ), + inputs_to_half=[0], + zero_optimization=dict( + stage=2, + allgather_partitions=True, + allgather_bucket_size=2e8, + reduce_scatter=True, + reduce_bucket_size='auto', + overlap_comm=True, + contiguous_gradients=True, + ), +) + +# schedule settings +optim_wrapper = dict( + type='DeepSpeedOptimWrapper', + optimizer=dict(type='AdamW', lr=1e-3, weight_decay=0.05)) + +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-3 / 500, + by_epoch=False, + begin=0, + end=500, + ), + dict( + type='CosineAnnealingLR', + eta_min=2e-4, + by_epoch=False, + begin=500, + ), +] + +train_cfg = dict(by_epoch=True, max_epochs=6) +test_cfg = dict() + +runner_type = 'FlexibleRunner' + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + interval=1, + by_epoch=True, + save_last=True, + max_keep_ckpts=1, + )) diff --git a/configs/minigpt4/minigpt-4_vicuna-7b_caption.py b/configs/minigpt4/minigpt-4_vicuna-7b_caption.py index 704760af4e9..f468e2d8fac 100644 --- a/configs/minigpt4/minigpt-4_vicuna-7b_caption.py +++ b/configs/minigpt4/minigpt-4_vicuna-7b_caption.py @@ -55,13 +55,25 @@ type='AutoModelForCausalLM', name_or_path='YOUR_PATH_TO_VICUNA'), tokenizer=dict(type='LlamaTokenizer', name_or_path='YOUR_PATH_TO_VICUNA'), task='caption', - prompt_template='###Human: {} ###Assistant: ', - raw_prompts=[ - ' Describe this image in detail.', - ' Take a look at this image and describe what you notice.', # noqa - ' Please provide a detailed description of the picture.', # noqa - ' Could you describe the contents of this image for me?', # noqa - ], + prompt_template=dict([('en', '###Ask: {} ###Answer: '), + ('zh', '###问:{} ###答:')]), + raw_prompts=dict([ + ('en', [(' ' + 'Describe this image in detail.'), + (' ' + 'Take a look at this image and describe what you notice.'), + (' ' + 'Please provide a detailed description of the picture.'), + (' ' + 'Could you describe the contents of this image for me?')]), + ('zh', [(' ' + '详细描述这张图片。'), (' ' + '浏览这张图片并描述你注意到什么。'), + (' ' + '请对这张图片进行详细的描述。'), + (' ' + '你能为我描述这张图片的内容吗?')]) + ]), max_txt_len=160, end_sym='###') diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py index 29753d7070b..e621e157714 100644 --- a/mmpretrain/datasets/__init__.py +++ b/mmpretrain/datasets/__init__.py @@ -43,6 +43,7 @@ from .gqa_dataset import GQA from .iconqa import IconQA from .infographic_vqa import InfographicVQA + from .minigpt4_dataset import MiniGPT4Dataset from .nocaps import NoCaps from .ocr_vqa import OCRVQA from .refcoco import RefCOCO @@ -56,5 +57,6 @@ 'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption', 'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval', 'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', - 'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA' + 'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA', + 'MiniGPT4Dataset' ]) diff --git a/mmpretrain/datasets/minigpt4_dataset.py b/mmpretrain/datasets/minigpt4_dataset.py new file mode 100644 index 00000000000..e14e5c354e2 --- /dev/null +++ b/mmpretrain/datasets/minigpt4_dataset.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import mmengine +from mmengine.dataset import BaseDataset +from mmengine.fileio import get_file_backend + +from mmpretrain.registry import DATASETS + + +@DATASETS.register_module() +class MiniGPT4Dataset(BaseDataset): + """Dataset for training MiniGPT4. + + MiniGPT4 dataset directory: + + minigpt4_dataset + ├── image + │ ├── id0.jpg + │ │── id1.jpg + │ │── id2.jpg + │ └── ... + └── conversation_data.json + + The structure of conversation_data.json: + + [ + // English data + { + "id": str(id0), + "conversation": "###Ask: [Ask content] + ###Answer: [Answer content]" + }, + + // Chinese data + { + "id": str(id1), + "conversation": "###问: [Ask content] + ###答:[Answer content]" + }, + + ... + ] + + Args: + data_root (str): The root directory for ``ann_file`` and ``image``. + ann_file (str): Conversation file path. + **kwargs: Other keyword arguments in :class:`BaseDataset`. + """ + + def load_data_list(self) -> List[dict]: + file_backend = get_file_backend(self.data_root) + conversation_path = file_backend.join_path(self.data_root, + self.ann_file) + conversation = mmengine.load(conversation_path) + img_ids = {} + n = 0 + for conv in conversation: + img_id = conv['id'] + if img_id not in img_ids.keys(): + img_ids[img_id] = n + n += 1 + + img_root = file_backend.join_path(self.data_root, 'image') + data_list = [] + for conv in conversation: + img_file = '{}.jpg'.format(conv['id']) + chat_content = conv['conversation'] + lang = 'en' if chat_content.startswith('###Ask: ') else 'zh' + data_info = { + 'image_id': img_ids[conv['id']], + 'img_path': file_backend.join_path(img_root, img_file), + 'chat_content': chat_content, + 'lang': lang, + } + + data_list.append(data_info) + + return data_list diff --git a/mmpretrain/models/multimodal/minigpt4/minigpt4.py b/mmpretrain/models/multimodal/minigpt4/minigpt4.py index eccbb27ef14..d25d0b6be36 100644 --- a/mmpretrain/models/multimodal/minigpt4/minigpt4.py +++ b/mmpretrain/models/multimodal/minigpt4/minigpt4.py @@ -31,12 +31,12 @@ class MiniGPT4(BaseModel): True. num_query_token (int): Number of query tokens of Qformer. Defaults to 32. - prompt_template (str): Prompt template of the model. Defaults to - '###Human: {} ###Assistant: '. - raw_prompts (list): Prompts for training. Defaults to None. + prompt_template (dict): Multi-language prompt template of the model. Defaults to dict([ ('en', '###Ask: {} ###Answer: '), + ('zh', '###问:{} ###答:')]) + raw_prompts (dict): Prompts for training. Defaults to dict(). max_txt_len (int): Max token length while doing tokenization. Defaults to 32. - end_sym (str): Ended symbol of the sequence. Defaults to '\\n'. + end_sym (str): Ended symbol of the sequence. Defaults to '###'. generation_cfg (dict): The config of text generation. Defaults to dict(). data_preprocessor (:obj:`BaseDataPreprocessor`): Used for @@ -54,10 +54,12 @@ def __init__(self, freeze_vit: bool = True, freeze_q_former: bool = True, num_query_token: int = 32, - prompt_template: str = '###Human: {} ###Assistant: ', - raw_prompts: Optional[list] = None, + prompt_template: dict = dict([('en', + '###Ask: {} ###Answer: '), + ('zh', '###问:{} ###答:')]), + raw_prompts: dict = dict(), max_txt_len: int = 32, - end_sym: str = '\n', + end_sym: str = '###', generation_cfg: dict = dict(), data_preprocessor: Optional[dict] = None, init_cfg: Optional[dict] = None): @@ -135,16 +137,23 @@ def __init__(self, self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1] # set prompts - if raw_prompts is not None: - filted_prompts = [ - raw_prompt for raw_prompt in raw_prompts + self.en_prompt_list, self.zh_prompt_list = [], [] + if raw_prompts.get('en') is not None: + en_filted_prompts = [ + raw_prompt for raw_prompt in raw_prompts['en'] if '' in raw_prompt ] - self.prompt_list = [ - prompt_template.format(p) for p in filted_prompts + self.en_prompt_list = [ + prompt_template['en'].format(p) for p in en_filted_prompts + ] + if raw_prompts.get('zh') is not None: + zh_filted_prompts = [ + raw_prompt for raw_prompt in raw_prompts['zh'] + if '' in raw_prompt + ] + self.zh_prompt_list = [ + prompt_template['zh'].format(p) for p in zh_filted_prompts ] - else: - self.prompt_list = [] # update generation configs self.generation_cfg = dict( @@ -153,7 +162,7 @@ def __init__(self, do_sample=True, min_length=1, top_p=0.9, - repetition_penalty=1.0, + repetition_penalty=1.1, length_penalty=1.0, temperature=1.0) self.generation_cfg.update(**generation_cfg) @@ -161,6 +170,10 @@ def __init__(self, if hasattr(self, 'register_load_state_dict_post_hook'): self.register_load_state_dict_post_hook(self._load_llama_proj_hook) + def half(self): + self.llama_model = self.llama_model.half() + return self + def encode_img(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """The function to encode the images.""" @@ -184,33 +197,39 @@ def encode_img(self, return inputs_llama, atts_llama def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor, - prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: + prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: """The function to wrap the image and prompt. - Currently, the function only supports applying one prompt to all input - images in the one batch. + Make sure that len(prompt) == img_embeds.shape[0]. Args: img_embeds (torch.Tensor): The embedding of the input images. atts_img (torch.Tensor): Attention map of the image embeddings. - prompt (str): The prompt of the batch data. + prompt (List[str]): The prompt of the batch data. Returns: Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map. """ - if prompt: - batch_size = img_embeds.shape[0] - p_before, p_after = prompt.split('') + if len(prompt) > 0: + p_before_list, p_after_list = [], [] + for pro in prompt: + p_before, p_after = pro.split('') + p_before_list.append(p_before) + p_after_list.append(p_after) p_before_tokens = self.llama_tokenizer( - p_before, return_tensors='pt', + p_before_list, + return_tensors='pt', + padding='longest', add_special_tokens=False).to(img_embeds.device) p_after_tokens = self.llama_tokenizer( - p_after, return_tensors='pt', + p_after_list, + return_tensors='pt', + padding='longest', add_special_tokens=False).to(img_embeds.device) p_before_embeds = self.llama_model.model.embed_tokens( - p_before_tokens.input_ids).expand(batch_size, -1, -1) + p_before_tokens.input_ids) p_after_embeds = self.llama_model.model.embed_tokens( - p_after_tokens.input_ids).expand(batch_size, -1, -1) + p_after_tokens.input_ids) wrapped_img_embeds = torch.cat( [p_before_embeds, img_embeds, p_after_embeds], dim=1) wrapped_atts_img = atts_img[:, :1].expand( @@ -234,17 +253,22 @@ def loss(self, """ img_embeds, atts_img = self.encode_img(images) - if self.task == 'caption' and self.prompt_list: - prompt = random.choice(self.prompt_list) - img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, - prompt) - self.llama_tokenizer.padding_side = 'right' - text = [t + self.end_sym for t in data_samples['text_input']] + prompts, texts = [], [] + for t in data_samples: + chat_content = t.chat_content + split_mark = '###Answer: ' if t.lang == 'en' else '###答:' + prompt, text = chat_content.split(split_mark) + prompt += split_mark + text += self.end_sym + prompts.append(prompt) + texts.append(text) + + img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts) to_regress_tokens = self.llama_tokenizer( - text, + texts, return_tensors='pt', padding='longest', truncation=True, @@ -295,10 +319,12 @@ def predict( with torch.no_grad(): img_embeds, atts_img = self.encode_img(images) - if self.task == 'caption' and self.prompt_list: - prompt = random.choice(self.prompt_list) - img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, - prompt) + prompts = [ + random.choice(self.zh_prompt_list) if hasattr(t, 'lang') + and t.lang == 'zh' else random.choice(self.en_prompt_list) + for t in data_samples + ] + img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts) batch_size = img_embeds.shape[0] bos = torch.ones( @@ -336,7 +362,6 @@ def post_process( for output, data_sample in zip(outputs, data_samples): if self.task == 'caption': output = output.split('###')[0] - output = output.split('Assistant:')[-1].strip() data_sample.pred_caption = output else: # raw output diff --git a/projects/gradio_demo/conversation.py b/projects/gradio_demo/conversation.py new file mode 100644 index 00000000000..3c5946900b0 --- /dev/null +++ b/projects/gradio_demo/conversation.py @@ -0,0 +1,137 @@ +# Modified from +# https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py +import dataclasses +from typing import List + +import torch + + +@dataclasses.dataclass +class Conversation: + system: str + roles: List[str] + messages: List[List[str]] + sep: str = '###' + + def get_prompt(self): + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ':' + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def copy(self): + return Conversation( + system=self.system, + roles=[role for role in self.roles], + messages=[[y for y in x] for x in self.messages], + sep=self.sep, + ) + + def dict(self): + return { + 'system': self.system, + 'roles': self.roles, + 'messages': self.messages, + 'offset': self.offset, + 'sep': self.sep, + } + + +EN_CONV_VISION = Conversation( + system='Give the following image. ' + 'You will be able to see the image once I provide it to you. ' + 'Please answer my questions in detail.', + roles=['Ask', 'Answer'], + messages=[], + sep='###', +) + +ZH_CONV_VISION = Conversation( + system='给定一张图片,请仔细观察这张图片,并回答我的问题。', + roles=['问', '答'], + messages=[], + sep='###', +) + + +class Chat: + + def __init__(self, inferencer, device, is_half=False): + self.device = device + self.inferencer = inferencer + self.model = inferencer.model + self.is_half = is_half + if is_half: + self.model = self.model.half() + self.model = self.model.to(device) + self.max_length = 2000 + + def upload_img(self, image, conv, img_list): + img = next(self.inferencer.preprocess([image])) + img = self.model.data_preprocessor(img, False)['images'] + img = img.to(self.device) + image_emb, _ = self.model.encode_img(img) + img_list.append(image_emb) + conv.append_message(conv.roles[0], '') + + def get_context_emb(self, conv, img_list): + prompt = conv.get_prompt() + prompt_segs = prompt.split('') + seg_tokens = [ + self.model.llama_tokenizer( + seg, return_tensors='pt', + add_special_tokens=(i == 0)).to(self.device).input_ids + for i, seg in enumerate(prompt_segs) + ] + seg_embs = [ + self.model.llama_model.model.embed_tokens(seg_token) + for seg_token in seg_tokens + ] + mixed_embs = [ + emb for pair in zip(seg_embs[:-1], img_list) for emb in pair + ] + [seg_embs[-1]] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs + + def ask(self, text, conv): + if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[ + 0] and conv.messages[-1][1][-6:] == '': + conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) + else: + conv.append_message(conv.roles[0], text) + + def answer(self, conv, img_list, generation_cfg): + conv.append_message(conv.roles[1], None) + embs = self.get_context_emb(conv, img_list) + cur_max_len = generation_cfg['max_new_tokens'] + embs.shape[1] + if cur_max_len > self.max_length: + print('Warning: The number of tokens in current conversation' + 'exceeds the max length. ' + 'The model will not see the contexts outside the range.') + begin_idx = max(0, cur_max_len - self.max_length) + embs = embs[:, begin_idx:] + if self.is_half: + embs = embs.half() + outputs = self.model.llama_model.generate( + inputs_embeds=embs, + eos_token_id=self.model.end_token_id, + **generation_cfg) + + output_token = outputs[0] + if output_token[0] == 0: + output_token = output_token[1:] + elif output_token[0] == 1: + output_token = output_token[1:] + output_text = self.model.llama_tokenizer.decode( + output_token, + add_special_tokens=False, + skip_special_tokens=True) + output_text = output_text.split('###')[0] + conv.messages[-1][1] = output_text + return output_text diff --git a/projects/gradio_demo/minigpt4_demo.py b/projects/gradio_demo/minigpt4_demo.py new file mode 100644 index 00000000000..e4d61426fa7 --- /dev/null +++ b/projects/gradio_demo/minigpt4_demo.py @@ -0,0 +1,144 @@ +import argparse + +import gradio as gr +import numpy as np +import torch +from conversation import EN_CONV_VISION, ZH_CONV_VISION, Chat + +from mmpretrain import ImageCaptionInferencer + +parser = argparse.ArgumentParser(description='MiniGPT4 demo') +parser.add_argument( + 'cfg', type=str, help='config file for minigpt4 (absolute path)') +parser.add_argument( + 'ckpt', type=str, help='pretrained file for minigpt4 (absolute path)') +args = parser.parse_args() + +if torch.cuda.is_available(): + devices = [ + torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count()) + ] +elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + devices = [torch.device('mps')] +else: + devices = [torch.device('cpu')] + + +def get_free_device(): + if hasattr(torch.cuda, 'mem_get_info'): + free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices] + select = max(zip(free, range(len(free))))[1] + else: + import random + select = random.randint(0, len(devices) - 1) + return devices[select] + + +device = get_free_device() +inferencer = ImageCaptionInferencer(model=args.cfg, pretrained=args.ckpt) +model = inferencer.model +chat = Chat(inferencer, device=device, is_half=(device.type != 'cpu')) + + +def reset(chat_state, img_list): + if chat_state is not None: + chat_state.messages = [] + if img_list is not None: + img_list = [] + return (None, gr.update(value=None, interactive=True), + gr.update( + value=None, + placeholder='Please upload your image first', + interactive=False), + gr.update(value='Upload & Start Chat', + interactive=True), chat_state, img_list, + gr.update(value='Restart', interactive=False), + gr.update(value='English', interactive=True)) + + +def upload_img(gr_img, language, chat_state): + if gr_img is None: + return (None, + gr.update( + placeholder='Please upload your image first', + interactive=False), + gr.update(value='Upload & Start Chat', + interactive=True), chat_state, None, + gr.update(value='Restart', interactive=False), + gr.update(value='English', interactive=True)) + + if (language == 'English'): + chat_state = EN_CONV_VISION.copy() + else: + chat_state = ZH_CONV_VISION.copy() + img_list = [] + gr_img_array = np.asarray(gr_img) + chat.upload_img(gr_img_array, chat_state, img_list) + return (gr.update(interactive=False), + gr.update(placeholder='Type and press Enter', interactive=True), + gr.update(value='Start Chatting', + interactive=False), chat_state, img_list, + gr.update(value='Restart', + interactive=True), gr.update(interactive=False)) + + +def ask(user_message, chatbot, chat_state): + if (len(user_message) == 0): + return gr.update( + value=None, + placeholder='Input should not be empty!', + interactive=True), chatbot, chat_state + chat.ask(user_message, chat_state) + chatbot = chatbot + [[user_message, None]] + return '', chatbot, chat_state + + +def answer(chatbot, chat_state, img_list): + llm_message = chat.answer( + conv=chat_state, + img_list=img_list, + generation_cfg=model.generation_cfg) + chatbot[-1][1] = llm_message + return chatbot, chat_state, img_list + + +if __name__ == '__main__': + title = 'MMPretrain MiniGPT-4 Inference Demo' + with gr.Blocks(analytics_enabled=False, title=title) as demo: + gr.Markdown(f'# {title}') + with gr.Row(): + with gr.Column(): + image = gr.Image(type='pil') + language = gr.Dropdown(['English', 'Chinese'], + label='Language', + info='Select chatbot\'s language', + value='English', + interactive=True) + upload_button = gr.Button( + value='Upload & Start Chat', interactive=True) + clear = gr.Button(value='Restart', interactive=False) + + with gr.Column(): + chat_state = gr.State() + img_list = gr.State() + chatbot = gr.Chatbot( + label='MiniGPT-4', min_width=320, height=600) + text_input = gr.Textbox( + label='User', + placeholder='Please upload your image first', + interactive=False) + + upload_button.click(upload_img, [image, language, chat_state], [ + image, text_input, upload_button, chat_state, img_list, clear, + language + ]) + text_input.submit(ask, [text_input, chatbot, chat_state], + [text_input, chatbot, chat_state]).then( + answer, [chatbot, chat_state, img_list], + [chatbot, chat_state, img_list]) + clear.click(reset, [chat_state, img_list], [ + chatbot, image, text_input, upload_button, chat_state, img_list, + clear, language + ]) + + demo.launch(share=True)