Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Add DType.get_dtype() #3810

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions stdlib/src/builtin/dtype.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ from collections import KeyElement
from collections.string import StringSlice
from hashlib._hasher import _HashableWithHasher, _Hasher
from sys import bitwidthof, os_is_windows, sizeof
from sys.intrinsics import _type_is_eq


alias _mIsSigned = UInt8(1)
alias _mIsInteger = UInt8(1 << 7)
Expand Down Expand Up @@ -699,6 +701,76 @@ struct DType(
else:
raise Error("only arithmetic types are supported")

# ===----------------------------------------------------------------------===#
# utils
# ===----------------------------------------------------------------------===#

@staticmethod
fn get_dtype[T: AnyType, size: Int = 1]() -> DType:
"""Get the `DType` if the given Type is a `SIMD[_, size]` of a `DType`.

Parameters:
T: AnyType.
size: The SIMD size to compare against.

Returns:
The `DType` if matched, otherwise `DType.invalid`.
"""

@parameter
if _type_is_eq[T, SIMD[DType.bool, size]]():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know compiler internals, but I figured adding a @parameter would be redundant after the first execution (?). Should I add it anyway?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add it anyways as I'm not sure.

Re the implementation itself, is there a way we can leverage the element_type alias from SIMD to simplify this? Conceptually, we want to just retrieve T.element_type for SIMD types and return DType.invalid otherwise.

Copy link
Contributor Author

@martinvuyk martinvuyk Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To do that we'd somehow need to rebind the AnyType to a SIMD, which I don't think the compiler will allow when the type is not SIMD. I think we'd need a version of _type_is_eq that works for unbound parameters like what @rd4com showed on Discord: _type_is_eq[SIMD[_, _]]() and then we'd need an "unbound rebind" where we'd do dtype = rebind[SIMD[_, _]](value).element_type, but I think this would need some very heavy involvement on the compiler side.

I had another idea that we could also add another method .get_simd() -> (DType, Int) that finds the vector characteristics by brute force

@parameter
for i in range(len(dtypes)):
    @parameter
    for j in range(len(sizes)):
        alias size = sizes.get[j, Int]()

        @parameter
        if sizeof[T]() // size != 1:
            continue

        alias dt = dtypes.get[i, DType]()

        @parameter
        if _type_is_eq[T, SIMD[dt, size]](value)
            return dt, size
return DType.invalid, 1

This seems like it would make compile times suffer, but it would actually only need alias sizes = (1, 2, 4, 8) which would then work for only one size sizeof[T]() // size == 1, and then it just unrolls the comparison for every dtype. It's still expensive if not used properly, but I think we could execute it only once and save the results as aliases inside the struct like:

alias _res = DType.get_simd[T]()
alias _D = Self._res[0]
alias _size = Self._res[1]
alias _SIMD = SIMD[Self._D, Self._size]
alias _SpanT = Span[Self._SIMD, _]

We can leave this for another PR if you'd like. I think the current .get_dtype() implementation is good for most scalar use cases or those where one knows the size of the SIMD vectors. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JoeLoser gentle ping. I think we won't be able to implement what you're suggesting without having reflection first. So could we move forward with this PR?

return DType.bool
elif _type_is_eq[T, SIMD[DType.int8, size]]():
return DType.int8
elif _type_is_eq[T, SIMD[DType.uint8, size]]():
return DType.uint8
elif _type_is_eq[T, SIMD[DType.int16, size]]():
return DType.int16
elif _type_is_eq[T, SIMD[DType.uint16, size]]():
return DType.uint16
elif _type_is_eq[T, SIMD[DType.int32, size]]():
return DType.int32
elif _type_is_eq[T, SIMD[DType.uint32, size]]():
return DType.uint32
elif _type_is_eq[T, SIMD[DType.int64, size]]():
return DType.int64
elif _type_is_eq[T, SIMD[DType.uint64, size]]():
return DType.uint64
elif _type_is_eq[T, SIMD[DType.index, size]]():
return DType.index
elif _type_is_eq[T, SIMD[DType.float8e5m2, size]]():
return DType.float8e5m2
elif _type_is_eq[T, SIMD[DType.float8e5m2fnuz, size]]():
return DType.float8e5m2fnuz
elif _type_is_eq[T, SIMD[DType.float8e4m3, size]]():
return DType.float8e4m3
elif _type_is_eq[T, SIMD[DType.float8e4m3fnuz, size]]():
return DType.float8e4m3fnuz
elif _type_is_eq[T, SIMD[DType.bfloat16, size]]():
return DType.bfloat16
elif _type_is_eq[T, SIMD[DType.float16, size]]():
return DType.float16
elif _type_is_eq[T, SIMD[DType.float32, size]]():
return DType.float32
elif _type_is_eq[T, SIMD[DType.tensor_float32, size]]():
return DType.tensor_float32
elif _type_is_eq[T, SIMD[DType.float64, size]]():
return DType.float64
else:
return DType.invalid

@staticmethod
fn is_scalar[T: AnyType]() -> Bool:
"""Whether the given Type is a Scalar of a DType.

Parameters:
T: AnyType.

Returns:
The result.
"""
return Self.get_dtype[T]() is not DType.invalid


# ===-------------------------------------------------------------------===#
# integral_type_of
Expand Down
24 changes: 24 additions & 0 deletions stdlib/test/builtin/test_dtype.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,34 @@ def test_from_str():
assert_equal(DType._from_str("DType.blahblah"), DType.invalid)


def test_get_dtype():
def _test[D: DType]():
assert_equal(DType.get_dtype[Scalar[D]](), D)

@parameter
for i in range(32):
assert_equal(DType.get_dtype[SIMD[D, i], i](), D)

_test[DType.int8]()
_test[DType.int16]()
_test[DType.int32]()
_test[DType.int64]()
_test[DType.uint8]()
_test[DType.uint16]()
_test[DType.uint32]()
_test[DType.uint64]()
_test[DType.float16]()
_test[DType.float32]()
_test[DType.float64]()
_test[DType.index]()
_test[DType.bool]()


def main():
test_equality()
test_stringable()
test_representable()
test_key_element()
test_sizeof()
test_from_str()
test_get_dtype()
3 changes: 2 additions & 1 deletion stdlib/test/python/my_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self, bar):

class AbstractPerson(ABC):
@abstractmethod
def method(self): ...
def method(self):
...


def my_function(name):
Expand Down