Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Add DType.get_dtype() #3810

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

martinvuyk
Copy link
Contributor

@martinvuyk martinvuyk commented Nov 25, 2024

Add DType.get_dtype().

This will enable a lot of vector based optimizations for more generic function signatures.

This is a split-off from PR #3577

Examples:

fn memset[type: Copyable](ptr: UnsafePointer[type], value: type, count: Int):
    alias dt = DType.get_dtype[type]()

    @parameter
    if dt is not DType.invalid:
        var p = ptr.bitcast[Scalar[dt]]()
        _memset_impl[dt](p, rebind[Scalar[dt]](value), count)
    else:
        for i in range(count):
            (ptr + i).init_pointee_copy(value)

Signed-off-by: martinvuyk <[email protected]>
@martinvuyk martinvuyk requested a review from a team as a code owner November 25, 2024 21:00
The `DType` if matched, otherwise `DType.invalid`.
"""

if _type_is_eq[T, SIMD[DType.bool, size]]():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know compiler internals, but I figured adding a @parameter would be redundant after the first execution (?). Should I add it anyway?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add it anyways as I'm not sure.

Re the implementation itself, is there a way we can leverage the element_type alias from SIMD to simplify this? Conceptually, we want to just retrieve T.element_type for SIMD types and return DType.invalid otherwise.

Copy link
Contributor Author

@martinvuyk martinvuyk Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To do that we'd somehow need to rebind the AnyType to a SIMD, which I don't think the compiler will allow when the type is not SIMD. I think we'd need a version of _type_is_eq that works for unbound parameters like what @rd4com showed on Discord: _type_is_eq[SIMD[_, _]]() and then we'd need an "unbound rebind" where we'd do dtype = rebind[SIMD[_, _]](value).element_type, but I think this would need some very heavy involvement on the compiler side.

I had another idea that we could also add another method .get_simd() -> (DType, Int) that finds the vector characteristics by brute force

@parameter
for i in range(len(dtypes)):
    @parameter
    for j in range(len(sizes)):
        alias size = sizes.get[j, Int]()

        @parameter
        if sizeof[T]() // size != 1:
            continue

        alias dt = dtypes.get[i, DType]()

        @parameter
        if _type_is_eq[T, SIMD[dt, size]](value)
            return dt, size
return DType.invalid, 1

This seems like it would make compile times suffer, but it would actually only need alias sizes = (1, 2, 4, 8) which would then work for only one size sizeof[T]() // size == 1, and then it just unrolls the comparison for every dtype. It's still expensive if not used properly, but I think we could execute it only once and save the results as aliases inside the struct like:

alias _res = DType.get_simd[T]()
alias _D = Self._res[0]
alias _size = Self._res[1]
alias _SIMD = SIMD[Self._D, Self._size]
alias _SpanT = Span[Self._SIMD, _]

We can leave this for another PR if you'd like. I think the current .get_dtype() implementation is good for most scalar use cases or those where one knows the size of the SIMD vectors. WDYT?

Signed-off-by: martinvuyk <[email protected]>
@skongum02 skongum02 deleted the branch modular:main January 29, 2025 18:58
@skongum02 skongum02 closed this Jan 29, 2025
@skongum02 skongum02 reopened this Jan 29, 2025
@skongum02 skongum02 changed the base branch from nightly to main January 29, 2025 20:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
waiting for response Needs action/response from contributor before a PR can proceed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants