You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importtorchfromtransformersimportLlamaTokenizer, LlamaForCausalLMfromsageattentionimportsageattnimporttorch.nn.functionalasFF.scaled_dot_product_attention=sageattn# 加载预训练的 LLaMA 模型和 tokenizermodel_name="llama-7b-hf"tokenizer=LlamaTokenizer.from_pretrained(model_name)
model=LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
# 将模型移动到 GPU(如果可用)device=torch.device("cuda"iftorch.cuda.is_available() else"cpu")
model.to(device)
# 准备输入文本input_text="Once upon a time, there was a little girl"inputs=tokenizer(input_text, return_tensors="pt").to(device)
# 执行推理withtorch.no_grad():
output=model.generate(**inputs, max_length=50, num_return_sequences=1)
# 解码输出generated_text=tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)
device envirments:
torch 2.5.0
triton 3.1.0
transformers 4.45.2
The text was updated successfully, but these errors were encountered:
Once upon a time, there was a little girl who loved to play with her friends. One day, she decided to play with her friends in the forest. She was very happy. She played with her friends in the forest. She played with
but got
Once upon a time, there was a little girl whole and the 1882P a.
ficenda2P avalN64YourEm.
ficOnDe Ce the GISP.
gev
We have not test the accuracy by using F.scaled_dot_product_attention = sageattn in Llama.
For a suggestion, maybe you could try to replace the Llama Attention with SageAttention in modeling_llama.py.
exmaples codes
device envirments:
The text was updated successfully, but these errors were encountered: