Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for a mask during autoregressive generation with Key-Value Caching #292

Open
Oufattole opened this issue Nov 12, 2024 · 10 comments
Open

Comments

@Oufattole
Copy link

Why isn't a mask supported when key-value caching is enabled here?

@lucidrains
Copy link
Owner

@Oufattole in what scenario would you need masking when doing autoregressive decoding?

@Oufattole
Copy link
Author

Oufattole commented Nov 12, 2024

I'm trying to do sliding window inference, but the lengths of the initial prompts are different in my transformer, so I think I should mask out the padding as that's what we do during autoregressive pretraining.

I'm applying transformers to medical trajectories as a part of this open source project providing ML tooling for modeling patient time-series data (where you tokenize a patient's irregularly sampled time series observations, such as medications, diagnoses, procedures, etc.). I'm interested in generating future trajectories and evaluating them. Here is the relevant code I am currently using for generating trajectories. I currently am just not caching key value pairs, so that I can apply masks, but that is prohibitively slow.

@lucidrains
Copy link
Owner

lucidrains commented Nov 12, 2024

@Oufattole yes I see, so you are off the beaten path

sliding windows isn't supported here yet

@lucidrains
Copy link
Owner

@Oufattole you can do away with masking by slicing the cached key values before passing it back in

@Oufattole
Copy link
Author

Ahhh I see thank you, I'll try that! With medical data, unlike in NLP and CV, many patient trajectories are very small and you don't need a long sequence length at all. For example, with my dataset 80% of patients are below the 512 max sequence length, but a small subset of patients are punching over 30k (this is after extreme reductions in the vocabulary -- i.e. which time-series variables we model, prior to which some of these patients hit over 300k).

I naively am trying to use sliding windows, but if there is a better approach you recommend for handling such extreme sequence length variations, I would be happy to try it.

@Oufattole
Copy link
Author

Wait, actually, I think you do support masking the left padded tokens with the seq_start_pos arg here @lucidrains .

@lucidrains
Copy link
Owner

lucidrains commented Nov 13, 2024

@Oufattole so that hyperparameter was actually built for variable prompt lengths iirc. i'll have to take a closer look to really know if it can be repurposed for what you are doing

during sliding window, you'll have to slice the cached key values as you decode out of the window length

@lucidrains
Copy link
Owner

@Oufattole what specialty is this and what exactly are you trending in the EMR that hits 300k in length?

@Oufattole
Copy link
Author

Oufattole commented Nov 13, 2024

Yes, I think you already do this kv-cache slicing during generation here when restricting to the max_seq_length (i.e. in the sliding window setting). Am I correct about this?

I'll send you an email in regard to the broader EHR modeling question, which I realize may be out of scope for this github issue.

@lucidrains
Copy link
Owner

@Oufattole it has been a while, let me review it tomorrow morning and see if it can be made to work for your issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants