-
Notifications
You must be signed in to change notification settings - Fork 209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Float8QuantizedTensor (AQT subclass) and replace to_affine_quantized_floatx with to_affine_quantized_float8 in quantization APIs #1599
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1599
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 2f15cc1 with merge base 32d9b0b (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… quantization APIs ghstack-source-id: 293124bd8577fa1a3168d55942efd74af28e0f61 ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: 059b6978da29d45ed55481b0c510231f2ad93303 ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: d09ded5f1d785c6ad85cc0a578049e5569265c4e ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: 33f1e89a69344ccc38b98e297f88450e204c41b1 ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: 3faa777647431779fa0213b631f57891b24af86d ghstack-comment-id: 2608105249 Pull Request resolved: #1599
thanks, we also want to split out a Float8 (and floatx) specific AQT implementations as well, I talked to @jainapurva before |
… quantization APIs ghstack-source-id: 43890bf2fd3b4d9cc251b4ea614de6ff8d93735b ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: 26c1a6b2f4bf0bb6086d85b8cf18195f9485db65 ghstack-comment-id: 2608105249 Pull Request resolved: #1599
Yep that makes sense, when I talked to her earlier she said she is planning to create these AQT subclasses, so I decided to do this part of the refactor. |
torchao/dtypes/__init__.py
Outdated
@@ -38,6 +34,7 @@ | |||
"to_affine_quantized_fpx", | |||
"to_affine_quantized_floatx", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove floatx, float8 should replace floatx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I left it in since it's still in use in other parts of the code base (autoquant, autoquant v2), and I wasn't sure if I should be touching those - is it ok to replace all instances across the whole codebase?
… quantization APIs ghstack-source-id: 08fb7c834a304079f27d93d27c64b449323d92b7 ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: cba5e1cd1ea9a91b3551218dbd0407fecc4c3ee4 ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: 160d2dfad0ef985346edf0c184412b6b07fa120f ghstack-comment-id: 2608105249 Pull Request resolved: #1599
Yes, we want all the instances replaced. Autoquant is using it for Float8. Hence would be better to rename it float8 |
1 similar comment
This comment has been minimized.
This comment has been minimized.
… quantization APIs ghstack-source-id: 15a37e94b2ff3cf3136f6553e5b50144eb05112c ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: 3028fc5f84252f60353df9144ce3fda62b26fe8c ghstack-comment-id: 2608105249 Pull Request resolved: #1599
Done! |
… quantization APIs ghstack-source-id: 98647b42b631117f7a05f425ed8957c3c22f48ed ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: 451b9f7c4ba252e367c08ebaced0efb54c24885f ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: f655d60cc7481b5c8db708318b5d6da720a7a0ea ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: 75b010de13f2a6627d542965d6a9fa6f60b86bbb ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: 89552840a0083b5048a40cd7b72ae68e62bc88ec ghstack-comment-id: 2608105249 Pull Request resolved: #1599
@@ -209,19 +215,64 @@ def __repr__(self): | |||
) | |||
|
|||
|
|||
class Float8QuantizedTensor(AffineQuantizedTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan of this, this introduces one more abstraction (Float8QuantizedTensor
), while keeping the complexity of AffineQuantizedTensor
. I think either staying with AQT or just writing a float8 tensor without using AQT would seem more attractive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting - cc @jainapurva @jerryzh168 thoughts on this?
For context AQT subclassing was part of a BE effort for the week, I'll share the doc with you internally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing AQT abstraction is easy, but the only reason I felt like keeping it was consistency in all dtypes. Though I do agree that it adds another level of abstraction
… quantization APIs ghstack-source-id: a331504337cc231cd64a16c06efd2bdf08f78159 ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: 2cbe6199d279bde6a8c890f160ce3fe25cc2faf3 ghstack-comment-id: 2608105249 Pull Request resolved: #1599
… quantization APIs ghstack-source-id: 61cc8c2acc548ff09454232cbf16235957e33b32 ghstack-comment-id: 2608105249 Pull Request resolved: #1599
Discussed offline, closing until internal discussions are finalized. |
Context
Currently, AQT has the method from_hp_to_floatx for float8 quantization, and from_hp_to_fpx for low precision floating point data types like fp6 (technically can support fp1-fp7).
from_hp_to_floatx
re-uses from_hp_to_intx, which in turn uses these generic quantization primitives.Overall, in the current state the float8 path is a bit confusing for developers, due to both the naming ("floatx") and the use of generic functions which include a bunch of params which are unrelated to float8 quantization.
Summary of changes
The goal of this PR stack is to refactor this to have a clean separation of concerns, and simpler internal API surfaces for code using in float8 quantization for inference.
Specifically:
Note: I will add float8 static quantization in a separate set of PRs.