PPOTrainer training raises 'tuple' object has no attribute 'logits' error #2640
-
I'm trying to implement PPO training similar to DeepSeek's R1 approach, but encountering an error related to model caching and logits extraction. I'm new to RL and would appreciate guidance on the correct approach. def load_models_and_ref():
model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT_PATH,
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True
)
ref_base = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True
)
ref_model = PeftModel.from_pretrained(ref_base, CHECKPOINT_PATH)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
ref_model,
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True
)
generation_config = GenerationConfig.from_pretrained(MODEL_PATH)
generation_config.pad_token_id = ref_base.config.eos_token_id
generation_config.return_dict = True
model.config.return_dict = True
ref_model.config.return_dict = True
model.generation_config = generation_config
ref_model.generation_config = generation_config
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token = tokenizer.eos_token
return model, ref_model, tokenizer
ppo_config = PPOConfig(
output_dir="./RL",
learning_rate=1e-5,
batch_size=BATCH_SIZE,
mini_batch_size=1,
gradient_accumulation_steps=8,
cliprange=0.2,
cliprange_value=0.2,
vf_coef=0.5,
optim="adamw_8bit",
)
ppo_trainer = PPOTrainer(
config=ppo_config,
model=model,
ref_model=ref_model,
reward_model=reward_model,
value_model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
)
ppo_trainer.train() Error I'm getting: [From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.
Traceback (most recent call last):
File "/home/user/app/app.py", line 248, in <module>
ppo_trainer.train()
File "/home/user/.pyenv/versions/3.10.16/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py", line 441, in train
ref_logits = ref_output.logits[:, context_length - 1 : -1]
AttributeError: 'tuple' object has no attribute 'logits' Questions:
Environment:
|
Beta Was this translation helpful? Give feedback.
Answered by
qgallouedec
Jan 23, 2025
Replies: 1 comment 1 reply
-
If you're trying to replicate R1, then you should instead use the GRPO trainer. |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
qgallouedec
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
If you're trying to replicate R1, then you should instead use the GRPO trainer.