Sachin Goyal, Ziwei Ji, Ankit Rawat, Aditya Menon, Sanjiv Kumar, Vaishnavh Nagarajan, 2024
Transformer-based language models produce tokens sequentially, with each token influenced by preceding hidden states. Introducing tokens, which delay token generation, allows the model to process additional hidden states. This paper demonstrates that incorporating these delays during both pretraining and finetuning shows improved performance on various datasets.
- Pause Token Mechanism: Introduces the concept of a token that delays next-token generation, allowing the model to manipulate more hidden vectors.
- Pause-Pretraining and Finetuning: Demonstrates the importance of incorporating the pause token during both pretraining and finetuning to realize performance benefits.
- Selecting Optimal Number of tokens: Studies in depth to determine the optimal number of tokens for different tasks, ensuring practical applicability of the approach.
(a) In standard inference (finetuning), the model’s output is generated immediately after the last prefix token. (b) In pause-inference (and pause-finetuning), output generation starts only after adding a specified number of <pause>
tokens.
To enhance input sequence length,
For a given pretraining sequence <pause>
tokens (say <pause>
token. Then, for the decoder-only language model
where
In downstream finetuning, we are given a prefix <pause>
token to <pause>
token is seen. We apply the standard next-token prediction loss on the target with the new prefix, thus minimizing
where <pause>
token, as is standard. We term this pause-finetuning.
Standard Pretraining and Standard Finetuning (StdPT_StdFT), Standard Pretraining and Pause-Finetuning (StdPT_PauseFT), Pause-Pretraining and Standard Finetuning (PausePT_StdFT) and Pause-Pretraining and Pause-Finetuning (PausePT_PauseFT)
We consider decoder-only models of sizes 1B and 130M for our main experiments. For ablations, we focus on the 1B model. Both standard and pause models are pretrained on the C4 English mixture (Raffel et al., 2020), using the causal next token prediction objective over <pause>
token at 10% of sequence length (<pause>
token embedding, adding
- There is no claim of pause tokens being helpful for all the downstream tasks
- The most pressing next step would be to find ways to make delays helpful directly on a standard pretrained model