Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of tt.load and tt.store for FP8 when converting block ptr to regular ptrs #2374

Closed
etiotto opened this issue Sep 27, 2024 · 2 comments · Fixed by #2502, #2514 or #2534
Closed

Comments

@etiotto
Copy link
Contributor

etiotto commented Sep 27, 2024

We would like to remove the RewriteTensorPointer pass which rewrites block pointers into regular pointers (except when it determines load/store operations on block ptrs can be converted to 2D block reads/writes). The idea is to avoid loosing semantic information too early and instead deal with block ptr that cannot be used to generate 2D block reads/stores while lowering that operation).

For this scheme to work, we first need to improve the lowering code for tt.load and tt.store operations that use a block ptr with an element type that is not (currently) supported by the 2D read instructions available on the target GPU (e.g. the element is FP8).

See #2359 (comment) for more context.

@etiotto
Copy link
Contributor Author

etiotto commented Oct 9, 2024

The first step is to improve axis analysis and add support for blocked pointers to it (#2451).

@etiotto
Copy link
Contributor Author

etiotto commented Oct 21, 2024

A reduce test derived for the tutorial 06 now performs better when we coalesce block pointers than if we rewrite them to non-blocked ptr and then coalesce them:

Reduced attn test:

image

Rewrite block-ptrs and coalesce:

create kernel:_attn_fwd
fused-attention-batch4-head32-d64-fwd-causal=True:
    N_CTX  Triton [FP8]
0  1024.0     126.31971

Avoid rewrite block-ptr and coalesce them directly:

create kernel:_attn_fwd
fused-attention-batch4-head32-d64-fwd-causal=True:
    N_CTX  Triton [FP8]
0  1024.0    135.596459

The performance of the tutorial 06 (unmodified) is still not up to par (axis info analysis is not able yet to detect contiguity on all blocked ptrs in the kernel and therefore some aren't coalesced).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment