Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there a parallel between tile GPU/TPU kernes and Cubed chunks? #490

Open
alxmrs opened this issue Jun 25, 2024 · 7 comments
Open

Is there a parallel between tile GPU/TPU kernes and Cubed chunks? #490

alxmrs opened this issue Jun 25, 2024 · 7 comments

Comments

@alxmrs
Copy link
Contributor

alxmrs commented Jun 25, 2024

Tile based operations have been quite a success for creating optimal GPU kernels. The programming model, in my understanding, offers flexibility while taking advantage of cache hierarchies.

http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf

The triton language takes advantage of this model by providing a sort of MLIR/LLVM middleware for custom kernel acceleration of specific NN ops. Jax even now offers its own portable version of kennel control with time semantics via Pallas.

https://jax.readthedocs.io/en/latest/pallas/index.html

I can’t help but think that there are parallels between Cubed’s chunked blockwise op and these tile based techniques. What could an intersection look like?

  • Maybe, as is, business logic written in cubed would have affordances for GPU/TPU lowering
  • If not, how can we make that so?
  • More diabolical still, could Cubed do this for users automatically when accelerated arrays are used (Jax integration #304)? How similar are tiles to chucks, anyway? The array-aware abstractions of Cubed, to me, seem to offer enough information to make optimizations in compute. Where this is limited, I suspect modifications to Spec could make the difference.
@alxmrs
Copy link
Contributor Author

alxmrs commented Jun 25, 2024

I believe that Cubed chunks are “macro tiles” within the tile hierarchy.

IMG_5639

@tomwhite
Copy link
Member

Very interesting - thanks for the pointers Alex!

I believe that Cubed chunks are “macro tiles” within the tile hierarchy.

To be clear, do you mean that a Cubed chunk would be composed of multiple Triton tiles?

@alxmrs
Copy link
Contributor Author

alxmrs commented Jun 28, 2024

I was taking a bit of poetic license. :)

I think that Cubed chunks, which live in RAM in userspace, could be considered part of the overall memory hierarchy for accelerated computation in the triton model. I do think that they provide natural affordances for efficient kernel construction that can be automated by Jax via Pallas or Triton.

@TomNicholas
Copy link
Member

@rbavery looked at this idea with me and also seemed to think it had merit. I think he was interested in the possibility of using cubed with pytorch (which also has an array-like class)

@rbavery
Copy link
Contributor

rbavery commented Jul 15, 2024

Maybe, as is, business logic written in cubed would have affordances for GPU/TPU lowering

Yes it'd be cool if Cubed could understand Pytorch Tensors, what device they are on, and therefore where to run operations (GPU, CPU, TPU). I'm not very familiar with the internals of how Pytorch already handles lowering operations to different accelerators but I think they are getting more interoperable with Triton. TorchInductor (used by torch.compile) compiles Pytorch code to Triton kernels.

The docs on Pytorch internals are somewhat scattered, I've found this podcast to be most helpful at explaining https://pytorch-dev-podcast.simplecast.com/episodes

@alxmrs
Copy link
Contributor Author

alxmrs commented Jul 18, 2024

Ryan, I've been really enjoying the Pytorch dev podcast, thanks for sharing it. I'm a few episodes in and have some initial thoughts:

  • Given that Cuda (and GPU programming) is asynchronous, I wonder if there will be a performance advantage to using Zarr v3 / async Python.
  • Cubed needs to consider the tradeoff between performance (GPU synchronoziation) and debuggability (introducing syncs to know what caused what error).

@alxmrs
Copy link
Contributor Author

alxmrs commented Jul 24, 2024

I think this AOT Complication route is a good first step to explore lowering with Jax: https://jax.readthedocs.io/en/latest/aot.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants