Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make model type backwards compatible #1212

Merged
merged 13 commits into from
Feb 1, 2025
54 changes: 37 additions & 17 deletions examples/python/model-chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,39 @@ def main(args):
search_options['batch_size'] = 1

if args.verbose: print(search_options)

# Get model type
model_type = None
if hasattr(model, "type"):
model_type = model.type
else:
import json, os

with open(os.path.join(args.model_path, "genai_config.json"), "r") as f:
genai_config = json.load(f)
model_type = genai_config["model"]["type"]

# Set chat template
default_chat_template = ""
if args.chat_template:
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
raise ValueError("Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
else:
if model.type.startswith("phi2") or model.type.startswith("phi3"):
elif args.chat_template == default_chat_template:
if model_type.startswith("phi2") or model_type.startswith("phi3"):
args.chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
elif model.type.startswith("phi4"):
elif model_type.startswith("phi4"):
args.chat_template = '<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>'
elif model.type.startswith("llama3"):
elif model_type.startswith("llama3"):
args.chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
elif model.type.startswith("llama2"):
elif model_type.startswith("llama2"):
args.chat_template = '<s>{input}'
elif model_type.startswith("qwen2"):
args.chat_template = '<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n'
else:
raise ValueError(f"Chat Template for model type {model.type} is not known. Please provide chat template using --chat_template")
raise ValueError(f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template")
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved

if args.verbose:
print("Model type is:", model.type)
print("Model type is:", model_type)
print("Chat Template is:", args.chat_template)

params = og.GeneratorParams(model)
Expand All @@ -55,16 +70,21 @@ def main(args):
if args.verbose: print("Generator created")

# Set system prompt
if model.type.startswith('phi2') or model.type.startswith('phi3'):
system_prompt = f"<|system|>\n{args.system_prompt}<|end|>"
elif model.type.startswith('phi4'):
system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>"
elif model.type.startswith("llama3"):
system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>"
elif model.type.startswith("llama2"):
system_prompt = f"<s>[INST] <<SYS>>\n{args.system_prompt}\n<</SYS>>"
else:
system_prompt = args.system_prompt
default_system_prompt = "You are a helpful assistant."
if args.system_prompt == default_system_prompt:
if model_type.startswith('phi2') or model_type.startswith('phi3'):
system_prompt = f"<|system|>\n{args.system_prompt}<|end|>"
elif model_type.startswith('phi4'):
system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>"
elif model_type.startswith("llama3"):
system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>"
elif model_type.startswith("llama2"):
system_prompt = f"<s>[INST] <<SYS>>\n{args.system_prompt}\n<</SYS>>"
elif model_type.startswith("qwen2"):
qwen_system_prompt = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
system_prompt = f"<|im_start|>system\n{qwen_system_prompt}<|im_end|>\n"
else:
system_prompt = args.system_prompt

system_tokens = tokenizer.encode(system_prompt)
generator.append_tokens(system_tokens)
Expand Down
52 changes: 36 additions & 16 deletions examples/python/model-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,57 @@ def main(args):
search_options['batch_size'] = 1

if args.verbose: print(search_options)

# Get model type
model_type = None
if hasattr(model, "type"):
model_type = model.type
else:
import json, os

with open(os.path.join(args.model_path, "genai_config.json"), "r") as f:
genai_config = json.load(f)
model_type = genai_config["model"]["type"]

# Set chat template
default_chat_template = ""
if args.chat_template:
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
raise ValueError("Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
else:
if model.type.startswith("phi2") or model.type.startswith("phi3"):
elif args.chat_template == default_chat_template:
if model_type.startswith("phi2") or model_type.startswith("phi3"):
args.chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
elif model.type.startswith("phi4"):
elif model_type.startswith("phi4"):
args.chat_template = '<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>'
elif model.type.startswith("llama3"):
elif model_type.startswith("llama3"):
args.chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
elif model.type.startswith("llama2"):
elif model_type.startswith("llama2"):
args.chat_template = '<s>{input}'
elif model_type.startswith("qwen2"):
args.chat_template = '<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n'
else:
raise ValueError(f"Chat Template for model type {model.type} is not known. Please provide chat template using --chat_template")
raise ValueError(f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template")

params = og.GeneratorParams(model)
params.set_search_options(**search_options)
generator = og.Generator(model, params)

# Set system prompt
if model.type.startswith('phi2') or model.type.startswith('phi3'):
system_prompt = f"<|system|>\n{args.system_prompt}<|end|>"
elif model.type.startswith('phi4'):
system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>"
elif model.type.startswith("llama3"):
system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>"
elif model.type.startswith("llama2"):
system_prompt = f"<s>[INST] <<SYS>>\n{args.system_prompt}\n<</SYS>>"
else:
system_prompt = args.system_prompt
default_system_prompt = "You are a helpful assistant."
if args.system_prompt == default_system_prompt:
if model_type.startswith('phi2') or model_type.startswith('phi3'):
system_prompt = f"<|system|>\n{args.system_prompt}<|end|>"
elif model_type.startswith('phi4'):
system_prompt = f"<|im_start|>system<|im_sep|>\n{args.system_prompt}<|im_end|>"
elif model_type.startswith("llama3"):
system_prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{args.system_prompt}<|eot_id|>"
elif model_type.startswith("llama2"):
system_prompt = f"<s>[INST] <<SYS>>\n{args.system_prompt}\n<</SYS>>"
elif model_type.startswith("qwen2"):
qwen_system_prompt = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
system_prompt = f"<|im_start|>system\n{qwen_system_prompt}<|im_end|>\n"
else:
system_prompt = args.system_prompt

system_tokens = tokenizer.encode(system_prompt)
generator.append_tokens(system_tokens)
Expand Down
Loading