Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

多轮对话模板 #404

Open
THU-Kingmin opened this issue Oct 15, 2024 · 1 comment
Open

多轮对话模板 #404

THU-Kingmin opened this issue Oct 15, 2024 · 1 comment

Comments

@THU-Kingmin
Copy link

求个多轮对话模板

@gxlover0625
Copy link

我进一步封装了官方代码成为Qwen2VL类,自己实现了chat函数。
用户只需要输入自然语言形式的query,以url、本地路径、base64格式的图片imgs(单图和多图都支持),前文对话history即可,非常简单易用。

使用方式

  • copy下面的代码
class Qwen2VL:
    def __init__(self, model_path = None, max_new_tokens = 1024, min_pixels = 256*28*28, max_pixels = 1280*28*28):
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            model_path,
            torch_dtype="auto",
            device_map="auto",
        )
        self.processor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels)
        self.gen_config = {
            "max_new_tokens": max_new_tokens,
        }
    
    def parse_input(self, query=None, imgs=None):
        if imgs is None:
            messages = [{"role": "user", "content": query}]
            return messages
        
        if isinstance(imgs, str):
            imgs = [imgs]
        content = []
        for img in imgs:
            content.append({"type": "image", "image": img})
        content.append({"type": "text", "text": query})
        messages = [{"role": "user", "content": content}]
        return messages

    def chat(self, query = None, imgs = None, history = None):
        if history is None:
            history = []
            
        user_query = self.parse_input(query, imgs)
        history.extend(user_query)

        text = self.processor.apply_chat_template(history, tokenize=False, add_generation_prompt=True, add_vision_id=True)
        image_inputs, video_inputs = process_vision_info(history)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )

        inputs = inputs.to("cuda")
        generated_ids = self.model.generate(**inputs, **self.gen_config)
        generated_ids_trimmed = [
            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        response = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]

        history.append({"role": "assistant", "content": response})

        del inputs, generated_ids, generated_ids_trimmed
        torch.cuda.empty_cache()
        gc.collect()
        return response, history
  • 使用chat的API
chat_model = Qwen2VL(model_path="local path/repo id")

# First turn
history = None
response, history = chat_model.chat(query="hello", history=history)
print(response, history)

# Second turn
# For image type, (imgae_url, local_image_path, base64)
# For image count, ([image], [image1, image2], ...)
response, history = chat_model.chat(query="please describe the image", imgs=["image_url"], history=history)
print(response, history)

执行结果

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants