-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce memory consumption when training with PPO #2571
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -164,7 +164,7 @@ def forward( | |||
kwargs (`dict`, `optional`): | ||||
Additional keyword arguments, that are passed to the wrapped model. | ||||
""" | ||||
kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples | ||||
kwargs["output_hidden_states"] = False # this had already been set in the LORA / PEFT examples | ||||
kwargs["past_key_values"] = past_key_values | ||||
|
||||
if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": | ||||
|
@@ -176,7 +176,8 @@ def forward( | |||
**kwargs, | ||||
) | ||||
|
||||
last_hidden_state = base_model_output.hidden_states[-1] | ||||
# last_hidden_state = base_model_output.hidden_states[-1] | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
last_hidden_state = base_model_output.last_hidden_state | ||||
lm_logits = base_model_output.logits | ||||
loss = base_model_output.loss | ||||
|
||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -417,7 +417,7 @@ def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset], | |||
dataset, | ||||
batch_size=self.config.batch_size, | ||||
collate_fn=data_collator, | ||||
shuffle=True, | ||||
shuffle=False, | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it needed? |
||||
drop_last=True, | ||||
) | ||||
return dataloader | ||||
|
@@ -485,6 +485,7 @@ def generate( | |||
if generate_ref_response: | ||||
ref_model = self.model if self.is_peft_model else self.ref_model | ||||
if isinstance(query_tensor, List): | ||||
self.model.eval() | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you need it in eval model here? |
||||
response = self._generate_batched( | ||||
self.model, | ||||
query_tensor, | ||||
|
@@ -494,6 +495,7 @@ def generate( | |||
**generation_kwargs, | ||||
) | ||||
if generate_ref_response: | ||||
self.ref_model.eval() | ||||
ref_response = self._generate_batched( | ||||
ref_model, | ||||
query_tensor, | ||||
|
@@ -733,9 +735,8 @@ def step( | |||
) | ||||
|
||||
model_inputs_names = list(model_inputs.keys()) | ||||
|
||||
full_kl_penalty = self.config.kl_penalty == "full" | ||||
|
||||
# torch.cuda.memory._record_memory_history() | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
with torch.no_grad(): | ||||
all_logprobs, logits_or_none, values, masks = self.batched_forward_pass( | ||||
self.model, | ||||
|
@@ -744,6 +745,7 @@ def step( | |||
model_inputs, | ||||
response_masks=response_masks, | ||||
return_logits=full_kl_penalty, | ||||
use_cache=False | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you disable the cache? |
||||
) | ||||
with self.optional_peft_ctx(): | ||||
ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass( | ||||
|
@@ -752,8 +754,9 @@ def step( | |||
responses, | ||||
model_inputs, | ||||
return_logits=full_kl_penalty, | ||||
use_cache=False | ||||
) | ||||
|
||||
# torch.cuda.memory._dump_snapshot(f"ppo_batched_forward_pass-{time.time()}.pickle") | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
timing["time/ppo/forward_pass"] = time.time() - t | ||||
|
||||
with torch.no_grad(): | ||||
|
@@ -813,13 +816,14 @@ def step( | |||
mini_batch_dict[k] = batch_dict[k][mini_batch_inds] | ||||
with self.accelerator.accumulate(self.model): | ||||
model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names} | ||||
|
||||
# torch.cuda.memory._record_memory_history() | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
logprobs, logits, vpreds, _ = self.batched_forward_pass( | ||||
self.model, | ||||
mini_batch_dict["queries"], | ||||
mini_batch_dict["responses"], | ||||
model_inputs, | ||||
return_logits=True, | ||||
is_training=True, | ||||
) | ||||
train_stats = self.train_minibatch( | ||||
mini_batch_dict["logprobs"], | ||||
|
@@ -832,7 +836,8 @@ def step( | |||
mini_batch_dict["returns"], | ||||
) | ||||
all_stats.append(train_stats) | ||||
|
||||
# torch.cuda.memory._dump_snapshot("ppo_train_minibatch_bs16_mbs4.pickle") | ||||
# exit(0) | ||||
# typically, early stopping is done at the epoch level | ||||
if self.config.early_stopping: | ||||
policykl = train_stats["policy/policykl"] | ||||
|
@@ -976,6 +981,8 @@ def batched_forward_pass( | |||
model_inputs: dict, | ||||
return_logits: bool = False, | ||||
response_masks: Optional[torch.Tensor] = None, | ||||
is_training: bool = False, | ||||
use_cache: bool = True | ||||
): | ||||
""" | ||||
Calculate model outputs in multiple batches. | ||||
|
@@ -1001,11 +1008,14 @@ def batched_forward_pass( | |||
all_logits = [] | ||||
all_masks = [] | ||||
all_values = [] | ||||
|
||||
model.eval() | ||||
|
||||
if not is_training: | ||||
model.eval() | ||||
else: | ||||
model.train() | ||||
for i in range(math.ceil(bs / fbs)): | ||||
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} | ||||
input_kwargs["use_cache"] = use_cache | ||||
print("use_cache:", use_cache) | ||||
query_batch = queries[i * fbs : (i + 1) * fbs] | ||||
response_batch = responses[i * fbs : (i + 1) * fbs] | ||||
if response_masks is not None: | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it the default value? maybe we can simply remove it