-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
What does Local Decoder return? #30
Comments
I feel like this picture makes it clear: The model predicts bytes, not patches. If you look at the byte-level tokenizer, you'll see the same. And yes, that means your inputs will be huge, compared to the standard subword-token transformer. That is what the downsampling is for. |
Thank you for your response (this is my other account). That means BLT outputs just one byte for each iteration. If so, then generate a paragraph. It took a ton of iterations compared with BPE-using LLM, right? |
Yes, a single character can take anywhere from 1 to 4 bytes to produce. So, 1 to 4 forward passes, to predict a single character. Rather than sub-word tokenization, it's a bit like sub-character tokenization. However, some of that complexity is reclaimed, because the downsampling (patching) reduces the sequence length through intermediate (latent/global) transformer layers. Forward passes are generally faster, compared to a subword-based model. |
Thank you for a lot of useful information, but I doubt whether each forward pass is ~3 times faster than each forward pass of BPE-based LLM. |
I think each forward pass generates just one bit, making this BLT slow in practice. They use FlexAttention instead of vanilla Attention to speed up training and inference. Could you clarify this for me? @EntilZha |
Both the Local Encoder and Decoder are lightweight Transformers. They are responsible solely for the task of converting bytes to patches and patches to bytes. |
@Jeoyal Thank you for the information. But my point is that I do not believe that light weight Transformer will process >3 times faster than standard Transformer, while each token (sub-word, including many characters) will consist of many bytes (sub-character). ). The paper also does not mention the througthput comparison between BLT and Llama, so I think for example, with a text of 100 characters, BLT will have a slower generation speed than Llama. |
@QuanHoangDanh Sorry, I might have a slight misunderstanding earlier. However, in reality, the input to the transformer layer of the Local Encoder/Decoder is an entire sequence, not just a patch related inputs, right? |
@Jeoyal I think so. The local encoder will take the input as a byte sequence and output the hidden state of the byte sequence and patch embedding. The local decoder will take the hidden state from the local encoder and the patch embedding obtained after passing through the latent transformer to output the predicted byte sequence. In practice, during training, it can run in parallel (a characteristic of transformers), but since the output is at the byte level, I think it will need to run more times compared to a typical LLM. What do you think about this issue @Vectorrent ? |
@QuanHoangDanh What do you mean during training, it can run in parallel (a characteristic of transformers)? |
@Jeoyal Sorry I don't quite understand what you want to say. What I mean by running in parallel is that the transformer can be trained in parallel differently than the RNN model. Can you say it more clearly? |
Sorry I misunderstood. Thank you for your explanation. |
Yes, it will need to run more times. A standard transformer will output tokens; the BLT will output bytes. You will need anywhere from 1-4 bytes (each one requiring a single forward pass) to produce a single character. |
Hi all, thanks for all the interest and questions! Hopefully I can help with a few things, although @Vectorrent seems to have helped with a few questions already, thanks! To answer the original question, @Jeoyal is correct that the local encoder/decoder translate bytes to/from patches. The global/latent transformer is responsible for processing these patch representations. Re model efficiency @Spidartist , in our paper/experiments, we focused on FLOP-matched experiments. As mentioned, since our model operates on bytes, this requires more forward passes, but that is mainly of the lightweight local encoder/decoder, which don't cost that many FLOPs. We gain the FLOP advantage by calling the large latent/global transformer less frequently, since that is what dominates in terms of FLOP usage. A regular LM only has one type of forward pass, a big one; in BLT, we have the option to take a small forward pass or a large one. In terms of actual runtime instead of FLOPs, that will heavily depend on the implementation of the transformer, e.g., attention layers. I'd also note that the runtime speed has different issues between training and inference. During training, you can run all parts of the model as normal and cuda/gpu/torch will handle pipelining things in parallel. Our main focus while working on the paper was improving training speed (specifically, words per second) to be able to complete more experiments faster. We spent less time on optimizing inference and its one area we're looking to improve. As with normal LMs, you do need to generate the next byte then feed it back into the model. I think what will end up determining how fast inference ends up being is how the overhead of running more generation steps compares to the individual steps being on average smaller (certainly smaller for local models). There is also still a lot to be improved here, e.g., adding caches. @QuanHoangDanh, this will greatly depend on the patch size you're targetting. E.G., if you modify the threshold such that patches are ~20 characters long, I'd guess that this would have faster inference, provided the engineering in each implementation is similar. Hopefully that helps! |
@EntilZha Many thanks for the detailed answer, I understand what you mean in the sentence below, but let me explain again to see if I understand correctly.
If with normal LMs, a token is the smallest element, each token must pass through the LLM. But with BLT, assuming the average patch size is p, considering each byte as the smallest element, that byte will only pass through the Latent Transformer by patch, equivalent to (p-1) bytes (implicitly) before the last byte in the patch not passing through the Latent Transformer. Do I understand it correctly? |
I don't 100% follow, but I think you have the right idea. Suppose you have a patch For each patch, the global/latent model is called once. So if you have data of length One detail, if my memory serves correct, is that the global model receives both byte and patches. IE, if the text In a regular LM, you'd have |
I don't know if the Local Decoder will predict the next byte or the next patch (consisting of multiple bytes)?
The text was updated successfully, but these errors were encountered: