Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What does Local Decoder return? #30

Open
QuanHoangDanh opened this issue Jan 20, 2025 · 16 comments
Open

What does Local Decoder return? #30

QuanHoangDanh opened this issue Jan 20, 2025 · 16 comments

Comments

@QuanHoangDanh
Copy link

I don't know if the Local Decoder will predict the next byte or the next patch (consisting of multiple bytes)?

@Vectorrent
Copy link
Contributor

Vectorrent commented Jan 20, 2025

I feel like this picture makes it clear:

Image

The model predicts bytes, not patches. If you look at the byte-level tokenizer, you'll see the same.

And yes, that means your inputs will be huge, compared to the standard subword-token transformer. That is what the downsampling is for.

@Spidartist
Copy link

Thank you for your response (this is my other account). That means BLT outputs just one byte for each iteration. If so, then generate a paragraph. It took a ton of iterations compared with BPE-using LLM, right?

@Vectorrent
Copy link
Contributor

Thank you for your response (this is my other account). That means BLT outputs just one byte for each iteration. If so, then generate a paragraph. It took a ton of iterations compared with BPE-using LLM, right?

Yes, a single character can take anywhere from 1 to 4 bytes to produce. So, 1 to 4 forward passes, to predict a single character. Rather than sub-word tokenization, it's a bit like sub-character tokenization.

However, some of that complexity is reclaimed, because the downsampling (patching) reduces the sequence length through intermediate (latent/global) transformer layers. Forward passes are generally faster, compared to a subword-based model.

@Spidartist
Copy link

Thank you for your response (this is my other account). That means BLT outputs just one byte for each iteration. If so, then generate a paragraph. It took a ton of iterations compared with BPE-using LLM, right?

Yes, a single character can take anywhere from 1 to 4 bytes to produce. So, 1 to 4 forward passes, to predict a single character. Rather than sub-word tokenization, it's a bit like sub-character tokenization.

However, some of that complexity is reclaimed, because the downsampling (patching) reduces the sequence length through intermediate (latent/global) transformer layers. Forward passes are generally faster, compared to a subword-based model.

Thank you for a lot of useful information, but I doubt whether each forward pass is ~3 times faster than each forward pass of BPE-based LLM.

@QuanHoangDanh
Copy link
Author

I think each forward pass generates just one bit, making this BLT slow in practice. They use FlexAttention instead of vanilla Attention to speed up training and inference. Could you clarify this for me? @EntilZha

@Jeoyal
Copy link

Jeoyal commented Jan 21, 2025

I think each forward pass generates just one bit, making this BLT slow in practice. They use FlexAttention instead of vanilla Attention to speed up training and inference. Could you clarify this for me? @EntilZha

Both the Local Encoder and Decoder are lightweight Transformers. They are responsible solely for the task of converting bytes to patches and patches to bytes.
Image

@QuanHoangDanh
Copy link
Author

@Jeoyal Thank you for the information. But my point is that I do not believe that light weight Transformer will process >3 times faster than standard Transformer, while each token (sub-word, including many characters) will consist of many bytes (sub-character). ). The paper also does not mention the througthput comparison between BLT and Llama, so I think for example, with a text of 100 characters, BLT will have a slower generation speed than Llama.

@Jeoyal
Copy link

Jeoyal commented Jan 21, 2025

@Jeoyal Thank you for the information. But my point is that I do not believe that light weight Transformer will process >3 times faster than standard Transformer, while each token (sub-word, including many characters) will consist of many bytes (sub-character). ). The paper also does not mention the througthput comparison between BLT and Llama, so I think for example, with a text of 100 characters, BLT will have a slower generation speed than Llama.

@QuanHoangDanh Sorry, I might have a slight misunderstanding earlier. However, in reality, the input to the transformer layer of the Local Encoder/Decoder is an entire sequence, not just a patch related inputs, right?

@QuanHoangDanh
Copy link
Author

QuanHoangDanh commented Jan 21, 2025

@Jeoyal I think so. The local encoder will take the input as a byte sequence and output the hidden state of the byte sequence and patch embedding. The local decoder will take the hidden state from the local encoder and the patch embedding obtained after passing through the latent transformer to output the predicted byte sequence. In practice, during training, it can run in parallel (a characteristic of transformers), but since the output is at the byte level, I think it will need to run more times compared to a typical LLM. What do you think about this issue @Vectorrent ?

@Jeoyal
Copy link

Jeoyal commented Jan 21, 2025

@QuanHoangDanh What do you mean during training, it can run in parallel (a characteristic of transformers)?

@QuanHoangDanh
Copy link
Author

@Jeoyal Sorry I don't quite understand what you want to say. What I mean by running in parallel is that the transformer can be trained in parallel differently than the RNN model. Can you say it more clearly?

@Jeoyal
Copy link

Jeoyal commented Jan 21, 2025

@Jeoyal Sorry I don't quite understand what you want to say. What I mean by running in parallel is that the transformer can be trained in parallel differently than the RNN model. Can you say it more clearly?

Sorry I misunderstood. Thank you for your explanation.

@Vectorrent
Copy link
Contributor

Vectorrent commented Jan 21, 2025

@Jeoyal I think so. The local encoder will take the input as a byte sequence and output the hidden state of the byte sequence and patch embedding. The local decoder will take the hidden state from the local encoder and the patch embedding obtained after passing through the latent transformer to output the predicted byte sequence. In practice, during training, it can run in parallel (a characteristic of transformers), but since the output is at the byte level, I think it will need to run more times compared to a typical LLM. What do you think about this issue @Vectorrent ?

Yes, it will need to run more times. A standard transformer will output tokens; the BLT will output bytes. You will need anywhere from 1-4 bytes (each one requiring a single forward pass) to produce a single character.

@EntilZha
Copy link
Contributor

Hi all, thanks for all the interest and questions! Hopefully I can help with a few things, although @Vectorrent seems to have helped with a few questions already, thanks!

To answer the original question, @Jeoyal is correct that the local encoder/decoder translate bytes to/from patches. The global/latent transformer is responsible for processing these patch representations.

Re model efficiency @Spidartist , in our paper/experiments, we focused on FLOP-matched experiments. As mentioned, since our model operates on bytes, this requires more forward passes, but that is mainly of the lightweight local encoder/decoder, which don't cost that many FLOPs. We gain the FLOP advantage by calling the large latent/global transformer less frequently, since that is what dominates in terms of FLOP usage. A regular LM only has one type of forward pass, a big one; in BLT, we have the option to take a small forward pass or a large one.

In terms of actual runtime instead of FLOPs, that will heavily depend on the implementation of the transformer, e.g., attention layers. I'd also note that the runtime speed has different issues between training and inference. During training, you can run all parts of the model as normal and cuda/gpu/torch will handle pipelining things in parallel. Our main focus while working on the paper was improving training speed (specifically, words per second) to be able to complete more experiments faster.

We spent less time on optimizing inference and its one area we're looking to improve. As with normal LMs, you do need to generate the next byte then feed it back into the model. I think what will end up determining how fast inference ends up being is how the overhead of running more generation steps compares to the individual steps being on average smaller (certainly smaller for local models). There is also still a lot to be improved here, e.g., adding caches.

@QuanHoangDanh, this will greatly depend on the patch size you're targetting. E.G., if you modify the threshold such that patches are ~20 characters long, I'd guess that this would have faster inference, provided the engineering in each implementation is similar.

Hopefully that helps!

@QuanHoangDanh
Copy link
Author

QuanHoangDanh commented Jan 23, 2025

@EntilZha Many thanks for the detailed answer, I understand what you mean in the sentence below, but let me explain again to see if I understand correctly.

We gain the FLOP advantage by calling the large latent/global transformer less frequently, since that is what dominates in terms of FLOP usage. A regular LM only has one type of forward pass, a big one; in BLT, we have the option to take a small forward pass or a large one.

If with normal LMs, a token is the smallest element, each token must pass through the LLM. But with BLT, assuming the average patch size is p, considering each byte as the smallest element, that byte will only pass through the Latent Transformer by patch, equivalent to (p-1) bytes (implicitly) before the last byte in the patch not passing through the Latent Transformer. Do I understand it correctly?

@EntilZha
Copy link
Contributor

I don't 100% follow, but I think you have the right idea. Suppose you have a patch p of k bytes. The local model will run for each one, so for a given patch, the local model is called k times.

For each patch, the global/latent model is called once. So if you have data of length l in bytes where the average patch length is n bytes, then the average compute used across the data is l/n * FLOPS per global pass + l * FLOPS per local pass. Or in other words, the flop cost is number of patches in the train data * FLOP per global pass + number of bytes in train data * FLOPS per local pass.

One detail, if my memory serves correct, is that the global model receives both byte and patches. IE, if the text The world is one patch, the global model would receive both the per byte position representations (ie, hash embeddings) and the patch representation for that text.

In a regular LM, you'd have number of tokens in train data * FLOPS per LLM pass. You could calculate the average number of bytes per token in the train data to convert to the FLOP cost per byte. There are two things you could note here, one is that as patch size grows, BLT can see more bytes of training data for the same FLOP cost. The second is that its pretty hard to change the average token size for regular LMs (have to retrain BPE, which due to how it works, tends to prefer smaller tokens), whereas with BLT you can do that by changing the entropy threshold. Hope that helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants