-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to train discriminators ? #3
Comments
Hi, Thanks for reaching out! I did that to ensure the trained classifier has the same vocabulary as GPT-2. This way, the classifier can process the token sequence produced by the GPT-2 model. Regarding the implementation, the following code might give you a hint: tokenizer = AutoTokenizer.from_pretrained('gpt2-large', use_fast=not args.use_slow_tokenizer) # the classifier uses gpt2 tokenizer
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name_or_path, # Roberta
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
ignore_mismatched_sizes=args.ignore_mismatched_sizes,
) # initialize the Roberta model
gpt_model = GPT2ForSequenceClassification.from_pretrained(
'gpt2-large'
)
model.roberta.embeddings.word_embeddings = gpt_model.transformer.wte # replace the roberta embedding with gpt2 embedding
del gpt_model Hope this helps! |
Feel free to email me if you have any further questions |
Thanks for providing the code which is very helpful! Did you write a custom trainer? Because when I ran your code and used from transformers import GPT2Tokenizer, AutoTokenizer, GPT2ForSequenceClassification, AutoModelForSequenceClassification, RobertaForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset
import torch
config = {
"attention_probs_dropout_prob": 0.1,
"bos_token_id": 0,
"classifier_dropout": None,
"eos_token_id": 2,
"finetuning_task": "yelp_polarity",
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {"0": "1", "1": "2"},
"initializer_range": 0.02,
"intermediate_size": 3072,
"label2id": {"1": 0, "2": 1},
"layer_norm_eps": 1e-05,
"max_position_embeddings": 514,
"model_type": "roberta",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"position_embedding_type": "absolute",
"problem_type": "single_label_classification",
"torch_dtype": "float32",
"vocab_size": 50257 # GPT-2 large vocabulary size
}
dataset = load_dataset('yelp_polarity',split='train[10:20]')
tokenizer = AutoTokenizer.from_pretrained('gpt2-large', use_fast=True) # the classifier uses gpt2 tokenizer
model = AutoModelForSequenceClassification.from_pretrained(
'roberta-base', # Roberta
config=config,
ignore_mismatched_sizes=True,
) # initialize the Roberta model
gpt_model = GPT2ForSequenceClassification.from_pretrained(
'gpt2-large'
)
model.roberta.embeddings.word_embeddings = gpt_model.transformer.wte # replace the roberta embedding with gpt2 embedding
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
return tokenizer(examples['text'], padding='max_length', max_length = 512, truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
training_args = TrainingArguments(
output_dir='./results',
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets,
eval_dataset=tokenized_datasets
)
trainer.train() |
I added a transformation layer to bridge the dimension gap when training the classifier. Its weights can be merged into the embedding layer when saving the model. Sorry I missed this detail earlier. You can also try randomly initializing the embedding layer with size as (gpt2_vocab_len, roberta_hidden_dim) and update it during fine-tuning, which should also work. |
Hi,
Thanks for providing the code. :)
I have a question regarding training the classifiers. What do you mean by replacing GPT2-large embeddings with roberta-base? I'm not sure if I totally understand it...
The text was updated successfully, but these errors were encountered: