Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

This loss seems to consume a lot of memory. #13

Open
piekey1994 opened this issue Apr 14, 2023 · 4 comments
Open

This loss seems to consume a lot of memory. #13

piekey1994 opened this issue Apr 14, 2023 · 4 comments

Comments

@piekey1994
Copy link

The idea of this paper is really great and much easier to understand than ppo.
However, if there are six candidate responses, then at least batch size should be equal to 6 when calculating loss once. If the model scale is large, it seems difficult for a GPU to support a forward operation. I think the tokens generated in the paper has been cut to 192, which is far lower than the 2048 configured in ordinary LLM training. Is this also the reason?
Is there any optimization strategy to solve this problem? For example, a step only calculates the rank loss of a pair of responses and the sft loss of the best response of the current pair group. I don't know if this is feasible

@GanjinZero
Copy link
Owner

When you are doing ordinary LLM training, you have a batch size that is the same as the max response count you can have. If you can train an LLM with a length 2048 with bsz=4, you can also train RRHF with a length 2048 with query=4.
I don't think our produced loss has larger memory consumption than vanilla pre-training.

For some ideas to minimize memory consumption, you can pre-select queries and only calculate the loss on them.

@piekey1994
Copy link
Author

piekey1994 commented Apr 15, 2023

When you are doing ordinary LLM training, you have a batch size that is the same as the max response count you can have. If you can train an LLM with a length 2048 with bsz=4, you can also train RRHF with a length 2048 with query=4. I don't think our produced loss has larger memory consumption than vanilla pre-training.

For some ideas to minimize memory consumption, you can pre-select queries and only calculate the loss on them.

I know what you mean, but for example, if I want to train a 60b llama model now, I may only use batch size=2 or 1 for training. If so, how can we train an RRHF model?
By consuming more gpu memory, I mean that when training ppo, I don't need to calculate the loss of multiple responses in one step.

@GanjinZero
Copy link
Owner

If you can only use bsz=2, you can still use RRHF to rank these two responses. If you can only have bsz=1, we must need to either truncate input or use something like LORA.

@GanjinZero
Copy link
Owner

There is a possible thing for saving GPU memory that we have not implemented is every response share the same query. Thus we do not need to recompute the query many times. If query is much longer than response, this will save many gpu memory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants