Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance question regarding next token prediction task #77

Open
HeMuling opened this issue Sep 23, 2024 · 1 comment
Open

Performance question regarding next token prediction task #77

HeMuling opened this issue Sep 23, 2024 · 1 comment

Comments

@HeMuling
Copy link

I tried to perform next token prediction task using the pretrained model hyenadna-small-32k-seqlen-hf, and I found the result not so solid. Here' the code I tried:

from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig
from transformers import TrainingArguments, Trainer, logging
from configuration_hyena import HyenaConfig
import torch

# instantiate pretrained model
checkpoint = 'hyenadna-small-32k-seqlen-hf'
max_length = 500
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True, config=config)

seq = 'AGCTACATTGGCC'
tok_seq = tokenizer(seq)['input_ids']
print(tok_seq)
tok_seq = torch.LongTensor(tok_seq).unsqueeze(0)
print(tokenizer.batch_decode(tok_seq))
out = model(tok_seq)
tokenizer.batch_decode(out['logits'][:, :, :].argmax(-1))

and I get:

[7, 9, 8, 10, 7, 8, 7, 10, 10, 9, 9, 8, 8, 1]
['AGCTACATTGGCC[SEP]']

['AAATAAATTGTAAC']

In my understanding, I've set this model to perform next token prediction, therefore if I input a sequence 'AGCTACATTGGCC', the model should return something like 'AGCTACATTGGCC+new_predict_token' (i.e. keep the most of previous bases the same), but the sequence I get differs from what I input a lot. I wonder if there's anything wrong in my understanding or coding.

@CYorick
Copy link

CYorick commented Oct 26, 2024

This is my code:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the pretrained HyenaDNA model and tokenizer
checkpoint = 'LongSafari/hyenadna-large-1m-seqlen-hf'
cache_dir = '/blue/sai.zhang/chen.yongzhuo/LLM_evaluation/cache'

tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(
    checkpoint, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, cache_dir=cache_dir
)

model.eval()

# Define the initial DNA sequence
initial_sequence = 'ATCGATCGATCGATCGATCGATCGA'

# Tokenize the initial sequence
input_ids = tokenizer(initial_sequence, return_tensors='pt')['input_ids'].to(model.device)

# Perform next-token prediction (one-step)
with torch.no_grad():
    outputs = model(input_ids)
    next_token_logits = outputs.logits[:, -1, :]  # Get logits for the next token
    next_token_id = torch.argmax(next_token_logits, dim=-1)

# Decode the next token
next_token = tokenizer.decode(next_token_id)
print(f"Next token: {next_token}")

extended_sequence = initial_sequence + next_token
print(f"Extended Sequence: {extended_sequence}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants