Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make fp8 compatible with tensor parallelism #65

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

lw
Copy link
Contributor

@lw lw commented Dec 31, 2024

No description provided.

[ghstack-poisoned]
@lw
Copy link
Contributor Author

lw commented Dec 31, 2024

Stack from ghstack (oldest at bottom):

lw added a commit that referenced this pull request Dec 31, 2024
ghstack-source-id: db07e928f48cb886a86e017755ec4372c0f7ec3e
ghstack-comment-id: 2566319697
Pull Request resolved: #65
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 31, 2024
return 1


def mul_tiled(a, *bs):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that you need to apply such a function in the test of pytorch/pytorch#143760 to "manually" do tiled multiplication/division to compute scaled results.
Here "if b is m x n" only appears when it's DTensor sub-row-wise scaling, in which case the local tensor of b would always have m x 1 shape. So is it correct that:

  1. on L38 with local_map we can always assume no tiled multiplication is needed; and
  2. on L46 if you're willing to also use local_map, tile multiplication can be avoided too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking a look!

I am learning to use DTensors and I thought it was more idiomatic to express the calculation on the "global" distributed tensor, rather than on the local shard. In order to do so, however, we need to know how many shards there are and reshape accordingly, which arguably isn't that pretty either.

I believe the "fundamental" reason for it is that we're stacking the different components in the wrong order. Here we first replace the matmuls with our custom function, and then we propagate DTensors through it (which means our function needs to know how to handle DTensors). However, I believe the ideal solution would be to first propagate DTensors through some regular matmuls, then take the resulting graph and swap the local matmuls with our function. The issue is that I don't really know how to achieve that, and our code was already written this way before we started supporting DTensors.

(There's also another open question which is how to integrate this with async-TP)

As for local_map, this is currently an unfortunate implementation detail. Ideally the scaling is supposed to be done by the _scaled_mm operator internally, which is what it does! However, because the row-wise scaled-mm is slow (when using slow accum), we use the tensor-wise (un)scaled-mm and do the scaling ourselves. If we were able to make the row-wise scaled-mm faster we could avoid local_map altogether.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the analysis! Makes a lot of sense to me!

@lw lw marked this pull request as ready for review January 17, 2025 17:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants