-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[stdlib] Add DType.get_dtype()
#3810
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: martinvuyk <[email protected]>
The `DType` if matched, otherwise `DType.invalid`. | ||
""" | ||
|
||
if _type_is_eq[T, SIMD[DType.bool, size]](): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know compiler internals, but I figured adding a @parameter
would be redundant after the first execution (?). Should I add it anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add it anyways as I'm not sure.
Re the implementation itself, is there a way we can leverage the element_type
alias from SIMD to simplify this? Conceptually, we want to just retrieve T.element_type
for SIMD types and return DType.invalid
otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To do that we'd somehow need to rebind the AnyType
to a SIMD
, which I don't think the compiler will allow when the type is not SIMD
. I think we'd need a version of _type_is_eq
that works for unbound parameters like what @rd4com showed on Discord: _type_is_eq[SIMD[_, _]]()
and then we'd need an "unbound rebind" where we'd do dtype = rebind[SIMD[_, _]](value).element_type
, but I think this would need some very heavy involvement on the compiler side.
I had another idea that we could also add another method .get_simd() -> (DType, Int)
that finds the vector characteristics by brute force
@parameter
for i in range(len(dtypes)):
@parameter
for j in range(len(sizes)):
alias size = sizes.get[j, Int]()
@parameter
if sizeof[T]() // size != 1:
continue
alias dt = dtypes.get[i, DType]()
@parameter
if _type_is_eq[T, SIMD[dt, size]](value)
return dt, size
return DType.invalid, 1
This seems like it would make compile times suffer, but it would actually only need alias sizes = (1, 2, 4, 8)
which would then work for only one size sizeof[T]() // size == 1
, and then it just unrolls the comparison for every dtype. It's still expensive if not used properly, but I think we could execute it only once and save the results as aliases inside the struct like:
alias _res = DType.get_simd[T]()
alias _D = Self._res[0]
alias _size = Self._res[1]
alias _SIMD = SIMD[Self._D, Self._size]
alias _SpanT = Span[Self._SIMD, _]
We can leave this for another PR if you'd like. I think the current .get_dtype()
implementation is good for most scalar use cases or those where one knows the size of the SIMD
vectors. WDYT?
Signed-off-by: martinvuyk <[email protected]>
Signed-off-by: martinvuyk <[email protected]>
Signed-off-by: martinvuyk <[email protected]>
Add
DType.get_dtype()
.This will enable a lot of vector based optimizations for more generic function signatures.
This is a split-off from PR #3577
Examples: