-
-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for a mask during autoregressive generation with Key-Value Caching #292
Comments
@Oufattole in what scenario would you need masking when doing autoregressive decoding? |
I'm trying to do sliding window inference, but the lengths of the initial prompts are different in my transformer, so I think I should mask out the padding as that's what we do during autoregressive pretraining. I'm applying transformers to medical trajectories as a part of this open source project providing ML tooling for modeling patient time-series data (where you tokenize a patient's irregularly sampled time series observations, such as medications, diagnoses, procedures, etc.). I'm interested in generating future trajectories and evaluating them. Here is the relevant code I am currently using for generating trajectories. I currently am just not caching key value pairs, so that I can apply masks, but that is prohibitively slow. |
@Oufattole yes I see, so you are off the beaten path sliding windows isn't supported here yet |
@Oufattole you can do away with masking by slicing the cached key values before passing it back in |
Ahhh I see thank you, I'll try that! With medical data, unlike in NLP and CV, many patient trajectories are very small and you don't need a long sequence length at all. For example, with my dataset 80% of patients are below the 512 max sequence length, but a small subset of patients are punching over 30k (this is after extreme reductions in the vocabulary -- i.e. which time-series variables we model, prior to which some of these patients hit over 300k). I naively am trying to use sliding windows, but if there is a better approach you recommend for handling such extreme sequence length variations, I would be happy to try it. |
Wait, actually, I think you do support masking the left padded tokens with the seq_start_pos arg here @lucidrains . |
@Oufattole so that hyperparameter was actually built for variable prompt lengths iirc. i'll have to take a closer look to really know if it can be repurposed for what you are doing during sliding window, you'll have to slice the cached key values as you decode out of the window length |
@Oufattole what specialty is this and what exactly are you trending in the EMR that hits 300k in length? |
Yes, I think you already do this kv-cache slicing during generation here when restricting to the max_seq_length (i.e. in the sliding window setting). Am I correct about this? I'll send you an email in regard to the broader EHR modeling question, which I realize may be out of scope for this github issue. |
@Oufattole it has been a while, let me review it tomorrow morning and see if it can be made to work for your issue |
Why isn't a mask supported when key-value caching is enabled here?
The text was updated successfully, but these errors were encountered: