Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory consumption when training with PPO #2571

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
"peft": ["peft>=0.4.0"],
"diffusers": ["diffusers>=0.18.0"],
"deepspeed": ["deepspeed>=0.9.5"],
"benchmark": ["wandb", "ghapi", "openrlbenchmark==0.2.1a5", "requests", "deepspeed"],
"benchmark": ["wandb", "ghapi", "openrlbenchmark==0.1.1b4", "requests", "deepspeed"],
"quantization": ["bitsandbytes<=0.41.1"],
}
EXTRAS["dev"] = []
Expand Down
5 changes: 3 additions & 2 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def forward(
kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the wrapped model.
"""
kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples
kwargs["output_hidden_states"] = False # this had already been set in the LORA / PEFT examples
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it the default value? maybe we can simply remove it

kwargs["past_key_values"] = past_key_values

if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
Expand All @@ -176,7 +176,8 @@ def forward(
**kwargs,
)

last_hidden_state = base_model_output.hidden_states[-1]
# last_hidden_state = base_model_output.hidden_states[-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# last_hidden_state = base_model_output.hidden_states[-1]

last_hidden_state = base_model_output.last_hidden_state
lm_logits = base_model_output.logits
loss = base_model_output.loss

Expand Down
28 changes: 19 additions & 9 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset],
dataset,
batch_size=self.config.batch_size,
collate_fn=data_collator,
shuffle=True,
shuffle=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it needed?

drop_last=True,
)
return dataloader
Expand Down Expand Up @@ -485,6 +485,7 @@ def generate(
if generate_ref_response:
ref_model = self.model if self.is_peft_model else self.ref_model
if isinstance(query_tensor, List):
self.model.eval()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need it in eval model here?

response = self._generate_batched(
self.model,
query_tensor,
Expand All @@ -494,6 +495,7 @@ def generate(
**generation_kwargs,
)
if generate_ref_response:
self.ref_model.eval()
ref_response = self._generate_batched(
ref_model,
query_tensor,
Expand Down Expand Up @@ -733,9 +735,8 @@ def step(
)

model_inputs_names = list(model_inputs.keys())

full_kl_penalty = self.config.kl_penalty == "full"

# torch.cuda.memory._record_memory_history()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# torch.cuda.memory._record_memory_history()

with torch.no_grad():
all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
self.model,
Expand All @@ -744,6 +745,7 @@ def step(
model_inputs,
response_masks=response_masks,
return_logits=full_kl_penalty,
use_cache=False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you disable the cache?

)
with self.optional_peft_ctx():
ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
Expand All @@ -752,8 +754,9 @@ def step(
responses,
model_inputs,
return_logits=full_kl_penalty,
use_cache=False
)

# torch.cuda.memory._dump_snapshot(f"ppo_batched_forward_pass-{time.time()}.pickle")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# torch.cuda.memory._dump_snapshot(f"ppo_batched_forward_pass-{time.time()}.pickle")

timing["time/ppo/forward_pass"] = time.time() - t

with torch.no_grad():
Expand Down Expand Up @@ -813,13 +816,14 @@ def step(
mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
with self.accelerator.accumulate(self.model):
model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}

# torch.cuda.memory._record_memory_history()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# torch.cuda.memory._record_memory_history()

logprobs, logits, vpreds, _ = self.batched_forward_pass(
self.model,
mini_batch_dict["queries"],
mini_batch_dict["responses"],
model_inputs,
return_logits=True,
is_training=True,
)
train_stats = self.train_minibatch(
mini_batch_dict["logprobs"],
Expand All @@ -832,7 +836,8 @@ def step(
mini_batch_dict["returns"],
)
all_stats.append(train_stats)

# torch.cuda.memory._dump_snapshot("ppo_train_minibatch_bs16_mbs4.pickle")
# exit(0)
# typically, early stopping is done at the epoch level
if self.config.early_stopping:
policykl = train_stats["policy/policykl"]
Expand Down Expand Up @@ -976,6 +981,8 @@ def batched_forward_pass(
model_inputs: dict,
return_logits: bool = False,
response_masks: Optional[torch.Tensor] = None,
is_training: bool = False,
use_cache: bool = True
):
"""
Calculate model outputs in multiple batches.
Expand All @@ -1001,11 +1008,14 @@ def batched_forward_pass(
all_logits = []
all_masks = []
all_values = []

model.eval()

if not is_training:
model.eval()
else:
model.train()
for i in range(math.ceil(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
input_kwargs["use_cache"] = use_cache
print("use_cache:", use_cache)
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
Expand Down