-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is there a parallel between tile GPU/TPU kernes and Cubed chunks? #490
Comments
Very interesting - thanks for the pointers Alex!
To be clear, do you mean that a Cubed chunk would be composed of multiple Triton tiles? |
I was taking a bit of poetic license. :) I think that Cubed chunks, which live in RAM in userspace, could be considered part of the overall memory hierarchy for accelerated computation in the triton model. I do think that they provide natural affordances for efficient kernel construction that can be automated by Jax via Pallas or Triton. |
@rbavery looked at this idea with me and also seemed to think it had merit. I think he was interested in the possibility of using cubed with pytorch (which also has an array-like class) |
Yes it'd be cool if Cubed could understand Pytorch Tensors, what device they are on, and therefore where to run operations (GPU, CPU, TPU). I'm not very familiar with the internals of how Pytorch already handles lowering operations to different accelerators but I think they are getting more interoperable with Triton. TorchInductor (used by torch.compile) compiles Pytorch code to Triton kernels. The docs on Pytorch internals are somewhat scattered, I've found this podcast to be most helpful at explaining https://pytorch-dev-podcast.simplecast.com/episodes |
Ryan, I've been really enjoying the Pytorch dev podcast, thanks for sharing it. I'm a few episodes in and have some initial thoughts:
|
I think this AOT Complication route is a good first step to explore lowering with Jax: https://jax.readthedocs.io/en/latest/aot.html |
Tile based operations have been quite a success for creating optimal GPU kernels. The programming model, in my understanding, offers flexibility while taking advantage of cache hierarchies.
http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf
The triton language takes advantage of this model by providing a sort of MLIR/LLVM middleware for custom kernel acceleration of specific NN ops. Jax even now offers its own portable version of kennel control with time semantics via Pallas.
https://jax.readthedocs.io/en/latest/pallas/index.html
I can’t help but think that there are parallels between Cubed’s chunked blockwise op and these tile based techniques. What could an intersection look like?
The text was updated successfully, but these errors were encountered: