Skip to content

Commit

Permalink
Implemented from type details function.
Browse files Browse the repository at this point in the history
  • Loading branch information
inakleinbottle committed Aug 7, 2023
1 parent 6e220b3 commit 35c9f63
Showing 1 changed file with 94 additions and 37 deletions.
131 changes: 94 additions & 37 deletions scalars/src/scalar_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,59 @@ const ScalarType* ScalarType::for_id(const string& id)
{
return ScalarType::of<double>();
}
const ScalarType* ScalarType::from_type_details(const BasicScalarInfo& details,
const ScalarDeviceInfo& device)
const ScalarType* ScalarType::from_type_details(
const BasicScalarInfo& details, const ScalarDeviceInfo& device
)
{
return ScalarType::of<double>();
RPY_CHECK(device.device_type == ScalarDeviceType::CPU);

switch (details.code) {
case ScalarTypeCode::Int: RPY_FALLTHROUGH;
case ScalarTypeCode::UInt: return nullptr;
case ScalarTypeCode::Float:
switch (details.bits) {
case 16: return ScalarType::of<half>();
case 32: return ScalarType::of<float>();
case 64: return ScalarType::of<double>();
default:
RPY_THROW(
std::invalid_argument,
"supported float bits are 16, 32, and 64 only"
);
}
case ScalarTypeCode::BFloat:
switch (details.bits) {
case 16: return ScalarType::of<bfloat16>();
default:
RPY_THROW(
std::invalid_argument,
"supported bfloat bits are 16 only"
);
}
case ScalarTypeCode::OpaqueHandle: RPY_FALLTHROUGH;
case ScalarTypeCode::Complex: RPY_FALLTHROUGH;
case ScalarTypeCode::Bool:
RPY_THROW(
std::invalid_argument,
"opaque handles, complex, and bool arguments are not "
"currently supported"
);
}
RPY_UNREACHABLE_RETURN(nullptr);
}
Scalar ScalarType::from(long long int numerator,
long long int denominator) const
Scalar
ScalarType::from(long long int numerator, long long int denominator) const
{
if (denominator == 0) { RPY_THROW(std::invalid_argument, "division by zero"); }
if (denominator == 0) {
RPY_THROW(std::invalid_argument, "division by zero");
}
auto result = allocate(1);
assign(result, numerator, denominator);
return {result, flags::OwnedPointer};
}
void ScalarType::convert_fill(ScalarPointer out, ScalarPointer in, dimn_t count,
const string& id) const
void ScalarType::convert_fill(
ScalarPointer out, ScalarPointer in, dimn_t count, const string& id
) const
{
if (!id.empty()) {
try {
Expand Down Expand Up @@ -222,10 +260,10 @@ static const pair<string, ScalarTypeInfo> reserved[]

static std::mutex s_scalar_type_cache_lock;
static std::unordered_map<string, const ScalarType*> s_scalar_type_cache{
{string("f32"), ScalarType::of<float>()},
{string("f64"), ScalarType::of<double>()},
{string("f16"), ScalarType::of<half>()},
{string("bf16"), ScalarType::of<bfloat16>()},
{ string("f32"), ScalarType::of<float>()},
{ string("f64"), ScalarType::of<double>()},
{ string("f16"), ScalarType::of<half>()},
{ string("bf16"), ScalarType::of<bfloat16>()},
{string("rational"), ScalarType::of<rational_scalar_type>()}
};

Expand All @@ -236,15 +274,20 @@ void rpy::scalars::register_type(const ScalarType* type)
const auto& identifier = type->id();
for (const auto& i : reserved) {
if (identifier == i.first) {
RPY_THROW(std::runtime_error,"cannot register identifier " + identifier
+ ", it is reserved");
RPY_THROW(
std::runtime_error,
"cannot register identifier " + identifier
+ ", it is reserved"
);
}
}

auto& entry = s_scalar_type_cache[identifier];
if (entry != nullptr) {
RPY_THROW(std::runtime_error,"type with id " + identifier
+ " is already registered");
RPY_THROW(
std::runtime_error,
"type with id " + identifier + " is already registered"
);
}

entry = type;
Expand All @@ -262,8 +305,8 @@ const ScalarType* rpy::scalars::get_type(const string& id)
RPY_THROW(std::runtime_error, "no type with id " + id);
}

const ScalarType* scalars::get_type(const string& id,
const ScalarDeviceInfo& device)
const ScalarType*
scalars::get_type(const string& id, const ScalarDeviceInfo& device)
{
// TODO: Needs implementation
return nullptr;
Expand Down Expand Up @@ -323,8 +366,9 @@ ROUGHPY_MAKE_TYPE_ID_OF(rational_poly_scalar, "RationalPoly")
#undef ROUGHPY_MAKE_TYPE_ID_OF

#define MAKE_CONVERSION_FUNCTION(SRC, DST, SRC_T, DST_T) \
static void SRC##_to_##DST(ScalarPointer dst, ScalarPointer src, \
dimn_t count) \
static void SRC##_to_##DST( \
ScalarPointer dst, ScalarPointer src, dimn_t count \
) \
{ \
const auto* src_p = src.raw_cast<const SRC_T>(); \
auto* dst_p = dst.raw_cast<DST_T>(); \
Expand Down Expand Up @@ -365,33 +409,39 @@ static std::unordered_map<string, conversion_function> s_conversion_cache{

#undef ADD_DEF_CONV

static inline string type_ids_to_key(const string& src_type,
const string& dst_type)
static inline string
type_ids_to_key(const string& src_type, const string& dst_type)
{
return src_type + "->" + dst_type;
}

const conversion_function& rpy::scalars::get_conversion(const string& src_id,
const string& dst_id)
const conversion_function&
rpy::scalars::get_conversion(const string& src_id, const string& dst_id)
{
std::lock_guard<std::mutex> access(s_conversion_lock);

auto found = s_conversion_cache.find(type_ids_to_key(src_id, dst_id));
if (found != s_conversion_cache.end()) { return found->second; }

RPY_THROW(std::runtime_error,"no conversion function from " + src_id + " to "
+ dst_id);
RPY_THROW(
std::runtime_error,
"no conversion function from " + src_id + " to " + dst_id
);
}
void rpy::scalars::register_conversion(const string& src_id,
const string& dst_id,
conversion_function converter)
void rpy::scalars::register_conversion(
const string& src_id, const string& dst_id,
conversion_function converter
)
{
std::lock_guard<std::mutex> access(s_conversion_lock);

auto& found = s_conversion_cache[type_ids_to_key(src_id, dst_id)];
if (found != nullptr) {
RPY_THROW(std::runtime_error,"conversion from " + src_id + " to " + dst_id
+ " already registered");
RPY_THROW(
std::runtime_error,
"conversion from " + src_id + " to " + dst_id
+ " already registered"
);
} else {
found = std::move(converter);
}
Expand All @@ -400,13 +450,17 @@ void rpy::scalars::register_conversion(const string& src_id,
std::unique_ptr<RandomGenerator>
ScalarType::get_rng(const string& bit_generator, Slice<uint64_t> seed) const
{
RPY_THROW(std::runtime_error,
"no random number generators are defined for this scalar type");
RPY_THROW(
std::runtime_error,
"no random number generators are defined for this scalar type"
);
}

std::unique_ptr<BlasInterface> ScalarType::get_blas() const
{
RPY_THROW(std::runtime_error, "No blas/lapack available for this scalar type");
RPY_THROW(
std::runtime_error, "No blas/lapack available for this scalar type"
);
}

const std::string& rpy::scalars::id_from_basic_info(const BasicScalarInfo& info)
Expand All @@ -422,21 +476,24 @@ const std::string& rpy::scalars::id_from_basic_info(const BasicScalarInfo& info)
case 16: return reserved[I16IDX].second.id;
case 32: return reserved[I32IDX].second.id;
case 64: return reserved[I64IDX].second.id;
default: RPY_THROW(std::runtime_error, "unsupported integer type");
default:
RPY_THROW(std::runtime_error, "unsupported integer type");
}
case ScalarTypeCode::UInt:
switch (info.bits) {
case 8: return reserved[U8IDX].second.id;
case 16: return reserved[U16IDX].second.id;
case 32: return reserved[U32IDX].second.id;
case 64: return reserved[U64IDX].second.id;
default: RPY_THROW(std::runtime_error, "unsupported integer type");
default:
RPY_THROW(std::runtime_error, "unsupported integer type");
}
case ScalarTypeCode::Float:
switch (info.bits) {
case 32: return ScalarType::of<float>()->id();
case 64: return ScalarType::of<double>()->id();
default: RPY_THROW(std::runtime_error, "unsupported float type");
default:
RPY_THROW(std::runtime_error, "unsupported float type");
}
default: RPY_THROW(std::runtime_error, "unsupported scalar type");
}
Expand Down

0 comments on commit 35c9f63

Please sign in to comment.