diff --git a/configs/minigpt4/README.md b/configs/minigpt4/README.md
index 01e53954639..23666fc9f95 100644
--- a/configs/minigpt4/README.md
+++ b/configs/minigpt4/README.md
@@ -34,9 +34,10 @@ For Vicuna model, please refer to [MiniGPT-4 page](https://github.com/Vision-CAI
### Pretrained models
-| Model | Params (M) | Flops (G) | Config | Download |
-| :------------------------------ | :--------: | :-------: | :--------------------------------------: | :------------------------------------------------------------------------------------------------------------: |
-| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth) |
+| Model | Params (M) | Flops (G) | Config | Download |
+| :------------------------------ | :--------: | :-------: | :----------------------------------------: | :----------------------------------------------------------------------------------------------------------: |
+| `minigpt-4_baichuan-7b_caption` | 8094.77 | N/A | [config](minigpt-4_baichuan-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth) |
+| `minigpt-4_vicuna-7b_caption`\* | 8121.32 | N/A | [config](minigpt-4_vicuna-7b_caption.py) | [model](https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth) |
*Models with * are converted from the [official repo](https://github.com/Vision-CAIR/MiniGPT-4/tree/main). The config files of these models are only for inference. We haven't reproduce the training results.*
diff --git a/configs/minigpt4/metafile.yml b/configs/minigpt4/metafile.yml
index a7879d986f2..f70cc9ba604 100644
--- a/configs/minigpt4/metafile.yml
+++ b/configs/minigpt4/metafile.yml
@@ -19,8 +19,19 @@ Models:
- Task: Image Caption
Dataset: COCO
Metrics: null
- Weights: https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_linear-projection_20230615-714b5f52.pth
+ Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_vicuna7b_20230615-714b5f52.pth
Config: configs/minigpt4/minigpt-4_vicuna-7b_caption.py
Converted From:
Weights: https://github.com/Vision-CAIR/MiniGPT-4/tree/main
Code: https://github.com/Vision-CAIR/MiniGPT-4/tree/main
+ - Name: minigpt-4_baichuan-7b_caption
+ Metadata:
+ FLOPs: null
+ Parameters: 8094769024
+ In Collection: MiniGPT4
+ Results:
+ - Task: Image Caption
+ Dataset: COCO
+ Metrics: null
+ Weights: https://download.openmmlab.com/mmclassification/v1/minigpt4/minigpt-4_linear_baichuan7b_20231011-5dca7ed6.pth
+ Config: configs/minigpt4/minigpt-4_baichuan-7b_caption.py
diff --git a/configs/minigpt4/minigpt-4_baichuan-7b_caption.py b/configs/minigpt4/minigpt-4_baichuan-7b_caption.py
new file mode 100644
index 00000000000..7e610a099c8
--- /dev/null
+++ b/configs/minigpt4/minigpt-4_baichuan-7b_caption.py
@@ -0,0 +1,190 @@
+_base_ = [
+ '../_base_/default_runtime.py',
+]
+
+data_preprocessor = dict(
+ type='MultiModalDataPreprocessor',
+ mean=[122.770938, 116.7460125, 104.09373615],
+ std=[68.5005327, 66.6321579, 70.32316305],
+ to_rgb=True,
+)
+
+# dataset settings
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='Resize',
+ scale=(224, 224),
+ interpolation='bicubic',
+ backend='pillow'),
+ dict(type='RandomFlip', prob=0.5, direction='horizontal'),
+ dict(
+ type='CleanCaption',
+ keys='chat_content',
+ remove_chars='',
+ lowercase=False),
+ dict(
+ type='PackInputs',
+ algorithm_keys=['chat_content', 'lang'],
+ meta_keys=['image_id']),
+]
+
+train_dataloader = dict(
+ batch_size=2,
+ num_workers=4,
+ dataset=dict(
+ type='MiniGPT4Dataset',
+ data_root='YOUR_DATA_DIRECTORY',
+ ann_file='YOUR_DATA_FILE',
+ pipeline=train_pipeline),
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ collate_fn=dict(type='default_collate'),
+ drop_last=False,
+)
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='Resize',
+ scale=(224, 224),
+ interpolation='bicubic',
+ backend='pillow'),
+ dict(type='PackInputs', meta_keys=['image_id']),
+]
+
+test_evaluator = dict(
+ type='COCOCaption',
+ ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
+)
+
+test_dataloader = dict(
+ batch_size=1,
+ dataset=dict(
+ type='COCOCaption',
+ data_root='data/coco',
+ ann_file='annotations/coco_karpathy_val.json',
+ pipeline=test_pipeline))
+
+# model settings
+model = dict(
+ type='MiniGPT4',
+ vision_encoder=dict(
+ type='BEiTViT',
+ # eva-g without the final layer
+ arch=dict(
+ embed_dims=1408,
+ num_layers=39,
+ num_heads=16,
+ feedforward_channels=6144,
+ ),
+ img_size=224,
+ patch_size=14,
+ layer_scale_init_value=0.0,
+ frozen_stages=39,
+ use_abs_pos_emb=True,
+ use_rel_pos_bias=False,
+ final_norm=False,
+ use_shared_rel_pos_bias=False,
+ out_type='raw',
+ pretrained= # noqa
+ 'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_eva-g-p14_20230615-e908c021.pth' # noqa
+ ),
+ q_former_model=dict(
+ type='Qformer',
+ model_style='bert-base-uncased',
+ vision_model_width=1408,
+ add_cross_attention=True,
+ cross_attention_freq=2,
+ num_query_token=32,
+ pretrained= # noqa
+ 'https://download.openmmlab.com/mmpretrain/v1.0/minigpt4/minigpt-4_qformer_20230615-1dfa889c.pth' # noqa
+ ),
+ lang_encoder=dict(
+ type='AutoModelForCausalLM',
+ name_or_path='baichuan-inc/baichuan-7B',
+ trust_remote_code=True),
+ tokenizer=dict(
+ type='AutoTokenizer',
+ name_or_path='baichuan-inc/baichuan-7B',
+ trust_remote_code=True),
+ task='caption',
+ prompt_template=dict([('en', '###Ask: {} ###Answer: '),
+ ('zh', '###问:{} ###答:')]),
+ raw_prompts=dict([
+ ('en', [(' '
+ 'Describe this image in detail.'),
+ (' '
+ 'Take a look at this image and describe what you notice.'),
+ (' '
+ 'Please provide a detailed description of the picture.'),
+ (' '
+ 'Could you describe the contents of this image for me?')]),
+ ('zh', [(' '
+ '详细描述这张图片。'), (' '
+ '浏览这张图片并描述你注意到什么。'),
+ (' '
+ '请对这张图片进行详细的描述。'),
+ (' '
+ '你能为我描述这张图片的内容吗?')])
+ ]),
+ max_txt_len=160,
+ end_sym='###')
+
+strategy = dict(
+ type='DeepSpeedStrategy',
+ fp16=dict(
+ enabled=True,
+ auto_cast=False,
+ fp16_master_weights_and_grads=False,
+ loss_scale=0,
+ loss_scale_window=1000,
+ hysteresis=1,
+ min_loss_scale=1,
+ initial_scale_power=16,
+ ),
+ inputs_to_half=[0],
+ zero_optimization=dict(
+ stage=2,
+ allgather_partitions=True,
+ allgather_bucket_size=2e8,
+ reduce_scatter=True,
+ reduce_bucket_size='auto',
+ overlap_comm=True,
+ contiguous_gradients=True,
+ ),
+)
+
+# schedule settings
+optim_wrapper = dict(
+ type='DeepSpeedOptimWrapper',
+ optimizer=dict(type='AdamW', lr=1e-3, weight_decay=0.05))
+
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1e-3 / 500,
+ by_epoch=False,
+ begin=0,
+ end=500,
+ ),
+ dict(
+ type='CosineAnnealingLR',
+ eta_min=2e-4,
+ by_epoch=False,
+ begin=500,
+ ),
+]
+
+train_cfg = dict(by_epoch=True, max_epochs=6)
+test_cfg = dict()
+
+runner_type = 'FlexibleRunner'
+
+default_hooks = dict(
+ checkpoint=dict(
+ type='CheckpointHook',
+ interval=1,
+ by_epoch=True,
+ save_last=True,
+ max_keep_ckpts=1,
+ ))
diff --git a/configs/minigpt4/minigpt-4_vicuna-7b_caption.py b/configs/minigpt4/minigpt-4_vicuna-7b_caption.py
index 704760af4e9..f468e2d8fac 100644
--- a/configs/minigpt4/minigpt-4_vicuna-7b_caption.py
+++ b/configs/minigpt4/minigpt-4_vicuna-7b_caption.py
@@ -55,13 +55,25 @@
type='AutoModelForCausalLM', name_or_path='YOUR_PATH_TO_VICUNA'),
tokenizer=dict(type='LlamaTokenizer', name_or_path='YOUR_PATH_TO_VICUNA'),
task='caption',
- prompt_template='###Human: {} ###Assistant: ',
- raw_prompts=[
- ' Describe this image in detail.',
- ' Take a look at this image and describe what you notice.', # noqa
- ' Please provide a detailed description of the picture.', # noqa
- ' Could you describe the contents of this image for me?', # noqa
- ],
+ prompt_template=dict([('en', '###Ask: {} ###Answer: '),
+ ('zh', '###问:{} ###答:')]),
+ raw_prompts=dict([
+ ('en', [(' '
+ 'Describe this image in detail.'),
+ (' '
+ 'Take a look at this image and describe what you notice.'),
+ (' '
+ 'Please provide a detailed description of the picture.'),
+ (' '
+ 'Could you describe the contents of this image for me?')]),
+ ('zh', [(' '
+ '详细描述这张图片。'), (' '
+ '浏览这张图片并描述你注意到什么。'),
+ (' '
+ '请对这张图片进行详细的描述。'),
+ (' '
+ '你能为我描述这张图片的内容吗?')])
+ ]),
max_txt_len=160,
end_sym='###')
diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py
index 29753d7070b..e621e157714 100644
--- a/mmpretrain/datasets/__init__.py
+++ b/mmpretrain/datasets/__init__.py
@@ -43,6 +43,7 @@
from .gqa_dataset import GQA
from .iconqa import IconQA
from .infographic_vqa import InfographicVQA
+ from .minigpt4_dataset import MiniGPT4Dataset
from .nocaps import NoCaps
from .ocr_vqa import OCRVQA
from .refcoco import RefCOCO
@@ -56,5 +57,6 @@
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
- 'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA'
+ 'VSR', 'VizWiz', 'OCRVQA', 'InfographicVQA', 'IconQA',
+ 'MiniGPT4Dataset'
])
diff --git a/mmpretrain/datasets/minigpt4_dataset.py b/mmpretrain/datasets/minigpt4_dataset.py
new file mode 100644
index 00000000000..e14e5c354e2
--- /dev/null
+++ b/mmpretrain/datasets/minigpt4_dataset.py
@@ -0,0 +1,79 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+import mmengine
+from mmengine.dataset import BaseDataset
+from mmengine.fileio import get_file_backend
+
+from mmpretrain.registry import DATASETS
+
+
+@DATASETS.register_module()
+class MiniGPT4Dataset(BaseDataset):
+ """Dataset for training MiniGPT4.
+
+ MiniGPT4 dataset directory:
+
+ minigpt4_dataset
+ ├── image
+ │ ├── id0.jpg
+ │ │── id1.jpg
+ │ │── id2.jpg
+ │ └── ...
+ └── conversation_data.json
+
+ The structure of conversation_data.json:
+
+ [
+ // English data
+ {
+ "id": str(id0),
+ "conversation": "###Ask: [Ask content]
+ ###Answer: [Answer content]"
+ },
+
+ // Chinese data
+ {
+ "id": str(id1),
+ "conversation": "###问: [Ask content]
+ ###答:[Answer content]"
+ },
+
+ ...
+ ]
+
+ Args:
+ data_root (str): The root directory for ``ann_file`` and ``image``.
+ ann_file (str): Conversation file path.
+ **kwargs: Other keyword arguments in :class:`BaseDataset`.
+ """
+
+ def load_data_list(self) -> List[dict]:
+ file_backend = get_file_backend(self.data_root)
+ conversation_path = file_backend.join_path(self.data_root,
+ self.ann_file)
+ conversation = mmengine.load(conversation_path)
+ img_ids = {}
+ n = 0
+ for conv in conversation:
+ img_id = conv['id']
+ if img_id not in img_ids.keys():
+ img_ids[img_id] = n
+ n += 1
+
+ img_root = file_backend.join_path(self.data_root, 'image')
+ data_list = []
+ for conv in conversation:
+ img_file = '{}.jpg'.format(conv['id'])
+ chat_content = conv['conversation']
+ lang = 'en' if chat_content.startswith('###Ask: ') else 'zh'
+ data_info = {
+ 'image_id': img_ids[conv['id']],
+ 'img_path': file_backend.join_path(img_root, img_file),
+ 'chat_content': chat_content,
+ 'lang': lang,
+ }
+
+ data_list.append(data_info)
+
+ return data_list
diff --git a/mmpretrain/models/multimodal/minigpt4/minigpt4.py b/mmpretrain/models/multimodal/minigpt4/minigpt4.py
index eccbb27ef14..d25d0b6be36 100644
--- a/mmpretrain/models/multimodal/minigpt4/minigpt4.py
+++ b/mmpretrain/models/multimodal/minigpt4/minigpt4.py
@@ -31,12 +31,12 @@ class MiniGPT4(BaseModel):
True.
num_query_token (int): Number of query tokens of Qformer. Defaults to
32.
- prompt_template (str): Prompt template of the model. Defaults to
- '###Human: {} ###Assistant: '.
- raw_prompts (list): Prompts for training. Defaults to None.
+ prompt_template (dict): Multi-language prompt template of the model. Defaults to dict([ ('en', '###Ask: {} ###Answer: '),
+ ('zh', '###问:{} ###答:')])
+ raw_prompts (dict): Prompts for training. Defaults to dict().
max_txt_len (int): Max token length while doing tokenization. Defaults
to 32.
- end_sym (str): Ended symbol of the sequence. Defaults to '\\n'.
+ end_sym (str): Ended symbol of the sequence. Defaults to '###'.
generation_cfg (dict): The config of text generation. Defaults to
dict().
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
@@ -54,10 +54,12 @@ def __init__(self,
freeze_vit: bool = True,
freeze_q_former: bool = True,
num_query_token: int = 32,
- prompt_template: str = '###Human: {} ###Assistant: ',
- raw_prompts: Optional[list] = None,
+ prompt_template: dict = dict([('en',
+ '###Ask: {} ###Answer: '),
+ ('zh', '###问:{} ###答:')]),
+ raw_prompts: dict = dict(),
max_txt_len: int = 32,
- end_sym: str = '\n',
+ end_sym: str = '###',
generation_cfg: dict = dict(),
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None):
@@ -135,16 +137,23 @@ def __init__(self,
self.end_token_id = self.llama_tokenizer.encode(end_sym)[-1]
# set prompts
- if raw_prompts is not None:
- filted_prompts = [
- raw_prompt for raw_prompt in raw_prompts
+ self.en_prompt_list, self.zh_prompt_list = [], []
+ if raw_prompts.get('en') is not None:
+ en_filted_prompts = [
+ raw_prompt for raw_prompt in raw_prompts['en']
if '' in raw_prompt
]
- self.prompt_list = [
- prompt_template.format(p) for p in filted_prompts
+ self.en_prompt_list = [
+ prompt_template['en'].format(p) for p in en_filted_prompts
+ ]
+ if raw_prompts.get('zh') is not None:
+ zh_filted_prompts = [
+ raw_prompt for raw_prompt in raw_prompts['zh']
+ if '' in raw_prompt
+ ]
+ self.zh_prompt_list = [
+ prompt_template['zh'].format(p) for p in zh_filted_prompts
]
- else:
- self.prompt_list = []
# update generation configs
self.generation_cfg = dict(
@@ -153,7 +162,7 @@ def __init__(self,
do_sample=True,
min_length=1,
top_p=0.9,
- repetition_penalty=1.0,
+ repetition_penalty=1.1,
length_penalty=1.0,
temperature=1.0)
self.generation_cfg.update(**generation_cfg)
@@ -161,6 +170,10 @@ def __init__(self,
if hasattr(self, 'register_load_state_dict_post_hook'):
self.register_load_state_dict_post_hook(self._load_llama_proj_hook)
+ def half(self):
+ self.llama_model = self.llama_model.half()
+ return self
+
def encode_img(self,
images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""The function to encode the images."""
@@ -184,33 +197,39 @@ def encode_img(self,
return inputs_llama, atts_llama
def prompt_wrap(self, img_embeds: torch.Tensor, atts_img: torch.Tensor,
- prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
+ prompt: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
"""The function to wrap the image and prompt.
- Currently, the function only supports applying one prompt to all input
- images in the one batch.
+ Make sure that len(prompt) == img_embeds.shape[0].
Args:
img_embeds (torch.Tensor): The embedding of the input images.
atts_img (torch.Tensor): Attention map of the image embeddings.
- prompt (str): The prompt of the batch data.
+ prompt (List[str]): The prompt of the batch data.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The embedding and attention map.
"""
- if prompt:
- batch_size = img_embeds.shape[0]
- p_before, p_after = prompt.split('')
+ if len(prompt) > 0:
+ p_before_list, p_after_list = [], []
+ for pro in prompt:
+ p_before, p_after = pro.split('')
+ p_before_list.append(p_before)
+ p_after_list.append(p_after)
p_before_tokens = self.llama_tokenizer(
- p_before, return_tensors='pt',
+ p_before_list,
+ return_tensors='pt',
+ padding='longest',
add_special_tokens=False).to(img_embeds.device)
p_after_tokens = self.llama_tokenizer(
- p_after, return_tensors='pt',
+ p_after_list,
+ return_tensors='pt',
+ padding='longest',
add_special_tokens=False).to(img_embeds.device)
p_before_embeds = self.llama_model.model.embed_tokens(
- p_before_tokens.input_ids).expand(batch_size, -1, -1)
+ p_before_tokens.input_ids)
p_after_embeds = self.llama_model.model.embed_tokens(
- p_after_tokens.input_ids).expand(batch_size, -1, -1)
+ p_after_tokens.input_ids)
wrapped_img_embeds = torch.cat(
[p_before_embeds, img_embeds, p_after_embeds], dim=1)
wrapped_atts_img = atts_img[:, :1].expand(
@@ -234,17 +253,22 @@ def loss(self,
"""
img_embeds, atts_img = self.encode_img(images)
- if self.task == 'caption' and self.prompt_list:
- prompt = random.choice(self.prompt_list)
- img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
- prompt)
-
self.llama_tokenizer.padding_side = 'right'
- text = [t + self.end_sym for t in data_samples['text_input']]
+ prompts, texts = [], []
+ for t in data_samples:
+ chat_content = t.chat_content
+ split_mark = '###Answer: ' if t.lang == 'en' else '###答:'
+ prompt, text = chat_content.split(split_mark)
+ prompt += split_mark
+ text += self.end_sym
+ prompts.append(prompt)
+ texts.append(text)
+
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts)
to_regress_tokens = self.llama_tokenizer(
- text,
+ texts,
return_tensors='pt',
padding='longest',
truncation=True,
@@ -295,10 +319,12 @@ def predict(
with torch.no_grad():
img_embeds, atts_img = self.encode_img(images)
- if self.task == 'caption' and self.prompt_list:
- prompt = random.choice(self.prompt_list)
- img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img,
- prompt)
+ prompts = [
+ random.choice(self.zh_prompt_list) if hasattr(t, 'lang')
+ and t.lang == 'zh' else random.choice(self.en_prompt_list)
+ for t in data_samples
+ ]
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompts)
batch_size = img_embeds.shape[0]
bos = torch.ones(
@@ -336,7 +362,6 @@ def post_process(
for output, data_sample in zip(outputs, data_samples):
if self.task == 'caption':
output = output.split('###')[0]
- output = output.split('Assistant:')[-1].strip()
data_sample.pred_caption = output
else:
# raw output
diff --git a/projects/gradio_demo/conversation.py b/projects/gradio_demo/conversation.py
new file mode 100644
index 00000000000..3c5946900b0
--- /dev/null
+++ b/projects/gradio_demo/conversation.py
@@ -0,0 +1,137 @@
+# Modified from
+# https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py
+import dataclasses
+from typing import List
+
+import torch
+
+
+@dataclasses.dataclass
+class Conversation:
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ sep: str = '###'
+
+ def get_prompt(self):
+ ret = self.system + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ': ' + message + self.sep
+ else:
+ ret += role + ':'
+ return ret
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=[role for role in self.roles],
+ messages=[[y for y in x] for x in self.messages],
+ sep=self.sep,
+ )
+
+ def dict(self):
+ return {
+ 'system': self.system,
+ 'roles': self.roles,
+ 'messages': self.messages,
+ 'offset': self.offset,
+ 'sep': self.sep,
+ }
+
+
+EN_CONV_VISION = Conversation(
+ system='Give the following image. '
+ 'You will be able to see the image once I provide it to you. '
+ 'Please answer my questions in detail.',
+ roles=['Ask', 'Answer'],
+ messages=[],
+ sep='###',
+)
+
+ZH_CONV_VISION = Conversation(
+ system='给定一张图片,请仔细观察这张图片,并回答我的问题。',
+ roles=['问', '答'],
+ messages=[],
+ sep='###',
+)
+
+
+class Chat:
+
+ def __init__(self, inferencer, device, is_half=False):
+ self.device = device
+ self.inferencer = inferencer
+ self.model = inferencer.model
+ self.is_half = is_half
+ if is_half:
+ self.model = self.model.half()
+ self.model = self.model.to(device)
+ self.max_length = 2000
+
+ def upload_img(self, image, conv, img_list):
+ img = next(self.inferencer.preprocess([image]))
+ img = self.model.data_preprocessor(img, False)['images']
+ img = img.to(self.device)
+ image_emb, _ = self.model.encode_img(img)
+ img_list.append(image_emb)
+ conv.append_message(conv.roles[0], '')
+
+ def get_context_emb(self, conv, img_list):
+ prompt = conv.get_prompt()
+ prompt_segs = prompt.split('')
+ seg_tokens = [
+ self.model.llama_tokenizer(
+ seg, return_tensors='pt',
+ add_special_tokens=(i == 0)).to(self.device).input_ids
+ for i, seg in enumerate(prompt_segs)
+ ]
+ seg_embs = [
+ self.model.llama_model.model.embed_tokens(seg_token)
+ for seg_token in seg_tokens
+ ]
+ mixed_embs = [
+ emb for pair in zip(seg_embs[:-1], img_list) for emb in pair
+ ] + [seg_embs[-1]]
+ mixed_embs = torch.cat(mixed_embs, dim=1)
+ return mixed_embs
+
+ def ask(self, text, conv):
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[
+ 0] and conv.messages[-1][1][-6:] == '':
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
+ else:
+ conv.append_message(conv.roles[0], text)
+
+ def answer(self, conv, img_list, generation_cfg):
+ conv.append_message(conv.roles[1], None)
+ embs = self.get_context_emb(conv, img_list)
+ cur_max_len = generation_cfg['max_new_tokens'] + embs.shape[1]
+ if cur_max_len > self.max_length:
+ print('Warning: The number of tokens in current conversation'
+ 'exceeds the max length. '
+ 'The model will not see the contexts outside the range.')
+ begin_idx = max(0, cur_max_len - self.max_length)
+ embs = embs[:, begin_idx:]
+ if self.is_half:
+ embs = embs.half()
+ outputs = self.model.llama_model.generate(
+ inputs_embeds=embs,
+ eos_token_id=self.model.end_token_id,
+ **generation_cfg)
+
+ output_token = outputs[0]
+ if output_token[0] == 0:
+ output_token = output_token[1:]
+ elif output_token[0] == 1:
+ output_token = output_token[1:]
+ output_text = self.model.llama_tokenizer.decode(
+ output_token,
+ add_special_tokens=False,
+ skip_special_tokens=True)
+ output_text = output_text.split('###')[0]
+ conv.messages[-1][1] = output_text
+ return output_text
diff --git a/projects/gradio_demo/minigpt4_demo.py b/projects/gradio_demo/minigpt4_demo.py
new file mode 100644
index 00000000000..e4d61426fa7
--- /dev/null
+++ b/projects/gradio_demo/minigpt4_demo.py
@@ -0,0 +1,144 @@
+import argparse
+
+import gradio as gr
+import numpy as np
+import torch
+from conversation import EN_CONV_VISION, ZH_CONV_VISION, Chat
+
+from mmpretrain import ImageCaptionInferencer
+
+parser = argparse.ArgumentParser(description='MiniGPT4 demo')
+parser.add_argument(
+ 'cfg', type=str, help='config file for minigpt4 (absolute path)')
+parser.add_argument(
+ 'ckpt', type=str, help='pretrained file for minigpt4 (absolute path)')
+args = parser.parse_args()
+
+if torch.cuda.is_available():
+ devices = [
+ torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
+ ]
+elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
+ devices = [torch.device('mps')]
+else:
+ devices = [torch.device('cpu')]
+
+
+def get_free_device():
+ if hasattr(torch.cuda, 'mem_get_info'):
+ free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices]
+ select = max(zip(free, range(len(free))))[1]
+ else:
+ import random
+ select = random.randint(0, len(devices) - 1)
+ return devices[select]
+
+
+device = get_free_device()
+inferencer = ImageCaptionInferencer(model=args.cfg, pretrained=args.ckpt)
+model = inferencer.model
+chat = Chat(inferencer, device=device, is_half=(device.type != 'cpu'))
+
+
+def reset(chat_state, img_list):
+ if chat_state is not None:
+ chat_state.messages = []
+ if img_list is not None:
+ img_list = []
+ return (None, gr.update(value=None, interactive=True),
+ gr.update(
+ value=None,
+ placeholder='Please upload your image first',
+ interactive=False),
+ gr.update(value='Upload & Start Chat',
+ interactive=True), chat_state, img_list,
+ gr.update(value='Restart', interactive=False),
+ gr.update(value='English', interactive=True))
+
+
+def upload_img(gr_img, language, chat_state):
+ if gr_img is None:
+ return (None,
+ gr.update(
+ placeholder='Please upload your image first',
+ interactive=False),
+ gr.update(value='Upload & Start Chat',
+ interactive=True), chat_state, None,
+ gr.update(value='Restart', interactive=False),
+ gr.update(value='English', interactive=True))
+
+ if (language == 'English'):
+ chat_state = EN_CONV_VISION.copy()
+ else:
+ chat_state = ZH_CONV_VISION.copy()
+ img_list = []
+ gr_img_array = np.asarray(gr_img)
+ chat.upload_img(gr_img_array, chat_state, img_list)
+ return (gr.update(interactive=False),
+ gr.update(placeholder='Type and press Enter', interactive=True),
+ gr.update(value='Start Chatting',
+ interactive=False), chat_state, img_list,
+ gr.update(value='Restart',
+ interactive=True), gr.update(interactive=False))
+
+
+def ask(user_message, chatbot, chat_state):
+ if (len(user_message) == 0):
+ return gr.update(
+ value=None,
+ placeholder='Input should not be empty!',
+ interactive=True), chatbot, chat_state
+ chat.ask(user_message, chat_state)
+ chatbot = chatbot + [[user_message, None]]
+ return '', chatbot, chat_state
+
+
+def answer(chatbot, chat_state, img_list):
+ llm_message = chat.answer(
+ conv=chat_state,
+ img_list=img_list,
+ generation_cfg=model.generation_cfg)
+ chatbot[-1][1] = llm_message
+ return chatbot, chat_state, img_list
+
+
+if __name__ == '__main__':
+ title = 'MMPretrain MiniGPT-4 Inference Demo'
+ with gr.Blocks(analytics_enabled=False, title=title) as demo:
+ gr.Markdown(f'# {title}')
+ with gr.Row():
+ with gr.Column():
+ image = gr.Image(type='pil')
+ language = gr.Dropdown(['English', 'Chinese'],
+ label='Language',
+ info='Select chatbot\'s language',
+ value='English',
+ interactive=True)
+ upload_button = gr.Button(
+ value='Upload & Start Chat', interactive=True)
+ clear = gr.Button(value='Restart', interactive=False)
+
+ with gr.Column():
+ chat_state = gr.State()
+ img_list = gr.State()
+ chatbot = gr.Chatbot(
+ label='MiniGPT-4', min_width=320, height=600)
+ text_input = gr.Textbox(
+ label='User',
+ placeholder='Please upload your image first',
+ interactive=False)
+
+ upload_button.click(upload_img, [image, language, chat_state], [
+ image, text_input, upload_button, chat_state, img_list, clear,
+ language
+ ])
+ text_input.submit(ask, [text_input, chatbot, chat_state],
+ [text_input, chatbot, chat_state]).then(
+ answer, [chatbot, chat_state, img_list],
+ [chatbot, chat_state, img_list])
+ clear.click(reset, [chat_state, img_list], [
+ chatbot, image, text_input, upload_button, chat_state, img_list,
+ clear, language
+ ])
+
+ demo.launch(share=True)