You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for your hard work. I tried to conduct batch inference but encountered some errors. My code looks like:
prompts = tokenizer(test_dataset, return_tensors='pt', padding=True, truncation=True)
gen_tokens = model.generate(
**prompts,
do_sample=False,
max_new_tokens=30,
)
gen_text = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
The error message is about "reporting a bug to pytorch". I think the problems roots in "hidden_states.to(torch.float32)". I say in your evaluation code, there is only "inference_on_one". Can you provide more guidance?
Thank you for your time and consideration.
The text was updated successfully, but these errors were encountered:
Sorry for missing some critical information. I am using QLORA. Here is my configuration.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
I have 1460 test samples. Without batch inference, it can take up to 46 minutes on 3090.
Thanks for your hard work. I tried to conduct batch inference but encountered some errors. My code looks like:
prompts = tokenizer(test_dataset, return_tensors='pt', padding=True, truncation=True)
gen_tokens = model.generate(
**prompts,
do_sample=False,
max_new_tokens=30,
)
gen_text = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
The error message is about "reporting a bug to pytorch". I think the problems roots in "hidden_states.to(torch.float32)". I say in your evaluation code, there is only "inference_on_one". Can you provide more guidance?
Thank you for your time and consideration.
The text was updated successfully, but these errors were encountered: