-
Notifications
You must be signed in to change notification settings - Fork 162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
多轮对话模板 #404
Comments
我进一步封装了官方代码成为 使用方式
class Qwen2VL:
def __init__(self, model_path = None, max_new_tokens = 1024, min_pixels = 256*28*28, max_pixels = 1280*28*28):
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype="auto",
device_map="auto",
)
self.processor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels)
self.gen_config = {
"max_new_tokens": max_new_tokens,
}
def parse_input(self, query=None, imgs=None):
if imgs is None:
messages = [{"role": "user", "content": query}]
return messages
if isinstance(imgs, str):
imgs = [imgs]
content = []
for img in imgs:
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": query})
messages = [{"role": "user", "content": content}]
return messages
def chat(self, query = None, imgs = None, history = None):
if history is None:
history = []
user_query = self.parse_input(query, imgs)
history.extend(user_query)
text = self.processor.apply_chat_template(history, tokenize=False, add_generation_prompt=True, add_vision_id=True)
image_inputs, video_inputs = process_vision_info(history)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
generated_ids = self.model.generate(**inputs, **self.gen_config)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
response = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
history.append({"role": "assistant", "content": response})
del inputs, generated_ids, generated_ids_trimmed
torch.cuda.empty_cache()
gc.collect()
return response, history
chat_model = Qwen2VL(model_path="local path/repo id")
# First turn
history = None
response, history = chat_model.chat(query="hello", history=history)
print(response, history)
# Second turn
# For image type, (imgae_url, local_image_path, base64)
# For image count, ([image], [image1, image2], ...)
response, history = chat_model.chat(query="please describe the image", imgs=["image_url"], history=history)
print(response, history) 执行结果 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
求个多轮对话模板
The text was updated successfully, but these errors were encountered: