You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried to perform next token prediction task using the pretrained model hyenadna-small-32k-seqlen-hf, and I found the result not so solid. Here' the code I tried:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig
from transformers import TrainingArguments, Trainer, logging
from configuration_hyena import HyenaConfig
import torch
# instantiate pretrained model
checkpoint = 'hyenadna-small-32k-seqlen-hf'
max_length = 500
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True, config=config)
seq = 'AGCTACATTGGCC'
tok_seq = tokenizer(seq)['input_ids']
print(tok_seq)
tok_seq = torch.LongTensor(tok_seq).unsqueeze(0)
print(tokenizer.batch_decode(tok_seq))
out = model(tok_seq)
tokenizer.batch_decode(out['logits'][:, :, :].argmax(-1))
In my understanding, I've set this model to perform next token prediction, therefore if I input a sequence 'AGCTACATTGGCC', the model should return something like 'AGCTACATTGGCC+new_predict_token' (i.e. keep the most of previous bases the same), but the sequence I get differs from what I input a lot. I wonder if there's anything wrong in my understanding or coding.
The text was updated successfully, but these errors were encountered:
I tried to perform next token prediction task using the pretrained model
hyenadna-small-32k-seqlen-hf
, and I found the result not so solid. Here' the code I tried:and I get:
In my understanding, I've set this model to perform next token prediction, therefore if I input a sequence
'AGCTACATTGGCC'
, the model should return something like'AGCTACATTGGCC+new_predict_token'
(i.e. keep the most of previous bases the same), but the sequence I get differs from what I input a lot. I wonder if there's anything wrong in my understanding or coding.The text was updated successfully, but these errors were encountered: