-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'develop' into rethinking-device-framework
- Loading branch information
Showing
57 changed files
with
2,005 additions
and
1,252 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
// | ||
// Created by sam on 11/08/23. | ||
// | ||
|
||
#include "dlpack_helpers.h" | ||
|
||
using namespace rpy; | ||
using namespace python; | ||
|
||
|
||
|
||
#define DLTHROW(type, bits) \ | ||
RPY_THROW(std::runtime_error, \ | ||
std::to_string(bits) + " bit " #type " is not supported") | ||
|
||
const string& | ||
python::type_id_for_dl_info(const DLDataType& dtype, const DLDevice& device) | ||
{ | ||
if (device.device_type != kDLCPU) { | ||
RPY_THROW( | ||
std::runtime_error, | ||
"for the time being, constructing non-cpu " | ||
"devices are not supported by RoughPy" | ||
); | ||
} | ||
|
||
switch (dtype.code) { | ||
case kDLFloat: | ||
switch (dtype.bits) { | ||
case 16: return scalars::type_id_of<scalars::half>(); | ||
case 32: return scalars::type_id_of<float>(); | ||
case 64: return scalars::type_id_of<double>(); | ||
default: DLTHROW(float, dtype.bits); | ||
} | ||
case kDLInt: | ||
switch (dtype.bits) { | ||
case 8: return scalars::type_id_of<char>(); | ||
case 16: return scalars::type_id_of<short>(); | ||
case 32: return scalars::type_id_of<int>(); | ||
case 64: return scalars::type_id_of<long long>(); | ||
default: DLTHROW(int, dtype.bits); | ||
} | ||
case kDLUInt: | ||
switch (dtype.bits) { | ||
case 8: return scalars::type_id_of<unsigned char>(); | ||
case 16: return scalars::type_id_of<unsigned short>(); | ||
case 32: return scalars::type_id_of<unsigned int>(); | ||
case 64: return scalars::type_id_of<unsigned long long>(); | ||
default: DLTHROW(uint, dtype.bits); | ||
} | ||
case kDLBfloat: | ||
if (dtype.bits == 16) { | ||
return scalars::type_id_of<scalars::bfloat16>(); | ||
} else { | ||
DLTHROW(bfloat, dtype.bits); | ||
} | ||
case kDLComplex: DLTHROW(complex, dtype.bits); | ||
case kDLOpaqueHandle: DLTHROW(opaquehandle, dtype.bits); | ||
case kDLBool: DLTHROW(bool, dtype.bits); | ||
} | ||
RPY_UNREACHABLE_RETURN({}); | ||
} | ||
|
||
#undef DLTHROW | ||
|
||
|
||
|
||
const scalars::ScalarType* | ||
python::scalar_type_for_dl_info(const DLDataType& dtype, const DLDevice& device) | ||
{ | ||
return scalars::ScalarType::for_id(type_id_for_dl_info(dtype, device)); | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
// | ||
// Created by sam on 11/08/23. | ||
// | ||
|
||
#ifndef ROUGHPY_DLPACK_HELPERS_H | ||
#define ROUGHPY_DLPACK_HELPERS_H | ||
|
||
#include "roughpy_module.h" | ||
#include "dlpack.h" | ||
|
||
#include <roughpy/scalars/scalars_fwd.h> | ||
#include <roughpy/scalars/scalar_type.h> | ||
|
||
namespace rpy { | ||
namespace python { | ||
|
||
constexpr scalars::ScalarTypeCode | ||
convert_from_dl_typecode(uint8_t code) noexcept { | ||
return static_cast<scalars::ScalarTypeCode>(code); | ||
} | ||
constexpr uint8_t | ||
convert_to_dl_typecode(scalars::ScalarTypeCode code) noexcept { | ||
return static_cast<uint8_t>(code); | ||
} | ||
|
||
constexpr scalars::ScalarDeviceType | ||
convert_from_dl_device_type(DLDeviceType type) noexcept { | ||
return static_cast<scalars::ScalarDeviceType>(type); | ||
} | ||
|
||
constexpr DLDeviceType | ||
convert_to_dl_device_type(scalars::ScalarDeviceType type) noexcept | ||
{ | ||
return static_cast<DLDeviceType>(type); | ||
} | ||
|
||
constexpr scalars::BasicScalarInfo | ||
convert_from_dl_datatype(const DLDataType& dtype) noexcept { | ||
return { | ||
convert_from_dl_typecode(dtype.code), | ||
dtype.bits, | ||
dtype.lanes | ||
}; | ||
} | ||
|
||
inline DLDataType | ||
convert_to_dl_datatype(const scalars::BasicScalarInfo& info) noexcept | ||
{ | ||
return { | ||
convert_to_dl_typecode(info.code), | ||
info.bits, | ||
info.lanes | ||
}; | ||
} | ||
|
||
constexpr scalars::ScalarDeviceInfo | ||
convert_from_dl_device_info(const DLDevice& device) noexcept | ||
{ | ||
return { | ||
convert_from_dl_device_type(device.device_type), | ||
device.device_id | ||
}; | ||
} | ||
|
||
constexpr DLDevice | ||
convert_to_dl_device_info(const scalars::ScalarDeviceInfo& device) noexcept { | ||
return { | ||
convert_to_dl_device_type(device.device_type), | ||
device.device_id | ||
}; | ||
} | ||
|
||
|
||
const string& | ||
type_id_for_dl_info(const DLDataType& dtype, const DLDevice& device); | ||
|
||
const scalars::ScalarType* | ||
scalar_type_for_dl_info(const DLDataType& dtype, const DLDevice& device); | ||
|
||
inline const scalars::ScalarType* scalar_type_of_dl_info(const DLDataType& dtype, const DLDevice& device) | ||
{ | ||
return scalars::ScalarType::from_type_details( | ||
convert_from_dl_datatype(dtype), | ||
convert_from_dl_device_info(device)); | ||
} | ||
|
||
} | ||
} | ||
|
||
#endif// ROUGHPY_DLPACK_HELPERS_H |
Oops, something went wrong.