Skip to content

Commit

Permalink
Merge branch 'develop' into rethinking-device-framework
Browse files Browse the repository at this point in the history
  • Loading branch information
inakleinbottle authored Aug 18, 2023
2 parents 26227ed + 5ecbdb1 commit 3dd9ced
Show file tree
Hide file tree
Showing 57 changed files with 2,005 additions and 1,252 deletions.
1 change: 1 addition & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Version 0.0.7:
MKL/BLAS+LAPACK.
- Added function to query ring characteristics of a ScalarType - currently
unused.
- Added KeyScalarStream for constructing streams from array-like data more easily.
- Implemented the `from_type_details` function for scalar types. This fixes a bug when constructing objects using the dlpack protocol.


Expand Down
6 changes: 6 additions & 0 deletions algebra/include/roughpy/algebra/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,15 @@ class RPY_EXPORT Context : public ContextBase
void
cbh_fallback(FreeTensor& collector, const std::vector<Lie>& lies) const;

void
cbh_fallback(FreeTensor& collector, Slice<const Lie*> lies) const;

public:
RPY_NO_DISCARD virtual Lie
cbh(const std::vector<Lie>& lies, VectorType vtype) const;
RPY_NO_DISCARD virtual Lie
cbh(Slice<const Lie*> lies, VectorType vtype) const;

RPY_NO_DISCARD virtual Lie
cbh(const Lie& left, const Lie& right, VectorType vtype) const;

Expand Down
59 changes: 48 additions & 11 deletions algebra/src/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ void Context::cbh_fallback(FreeTensor& collector, const std::vector<Lie>& lies)
}
}
}
void Context::cbh_fallback(FreeTensor& collector, Slice<const Lie*> lies) const
{
for (const auto* alie : lies) {
if (alie->dimension() != 0) {
collector.fmexp(this->lie_to_tensor(*alie));
}
}
}

Lie Context::cbh(const std::vector<Lie>& lies, VectorType vtype) const
{
Expand All @@ -173,6 +181,19 @@ Lie Context::cbh(const std::vector<Lie>& lies, VectorType vtype) const

return tensor_to_lie(collector.log());
}
Lie Context::cbh(Slice<const Lie*> lies, VectorType vtype) const {
if (lies.size() == 1) {
return convert(*lies[0], vtype);
}

FreeTensor collector = zero_free_tensor(vtype);
collector[0] = scalars::Scalar(1);

cbh_fallback(collector, lies);

return tensor_to_lie(collector.log());
}

Lie Context::cbh(const Lie& left, const Lie& right, VectorType vtype) const
{
FreeTensor tmp = lie_to_tensor(left).exp();
Expand Down Expand Up @@ -241,12 +262,16 @@ context_pointer rpy::algebra::get_context(
}

if (found.empty()) {
RPY_THROW(std::invalid_argument,"cannot find a context maker for the "
"width, depth, dtype, and preferences set");
RPY_THROW(
std::invalid_argument,
"cannot find a context maker for the "
"width, depth, dtype, and preferences set"
);
}

if (found.size() > 1) {
RPY_THROW(std::invalid_argument,
RPY_THROW(
std::invalid_argument,
"found multiple context maker candidates for specified width, "
"depth, dtype, and preferences set"
);
Expand Down Expand Up @@ -314,30 +339,42 @@ UnspecifiedAlgebraType Context::free_multiply(
const ConstRawUnspecifiedAlgebraType right
) const
{
RPY_THROW(std::runtime_error,"free tensor multiply is not implemented for "
"arbitrary types with this backend");
RPY_THROW(
std::runtime_error,
"free tensor multiply is not implemented for "
"arbitrary types with this backend"
);
}
UnspecifiedAlgebraType Context::shuffle_multiply(
ConstRawUnspecifiedAlgebraType left,
ConstRawUnspecifiedAlgebraType right
) const
{
RPY_THROW(std::runtime_error,"shuffle multiply is not implemented for "
"arbitrary types with this backend");
RPY_THROW(
std::runtime_error,
"shuffle multiply is not implemented for "
"arbitrary types with this backend"
);
}
UnspecifiedAlgebraType Context::half_shuffle_multiply(
ConstRawUnspecifiedAlgebraType left,
ConstRawUnspecifiedAlgebraType right
) const
{
RPY_THROW(std::runtime_error,"half shuffle multiply is not implemented for "
"arbitrary types with this backend");
RPY_THROW(
std::runtime_error,
"half shuffle multiply is not implemented for "
"arbitrary types with this backend"
);
}
UnspecifiedAlgebraType Context::adjoint_to_left_multiply_by(
ConstRawUnspecifiedAlgebraType multiplier,
ConstRawUnspecifiedAlgebraType argument
) const
{
RPY_THROW(std::runtime_error,"adjoint of left multiply is not implemented for "
"arbitrary types with this backend");
RPY_THROW(
std::runtime_error,
"adjoint of left multiply is not implemented for "
"arbitrary types with this backend"
);
}
4 changes: 4 additions & 0 deletions platform/include/roughpy/platform/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
CEREAL_SPECIALIZE_FOR_ALL_ARCHIVES(T, M)
# define RPY_SERIAL_SERIALIZE_BARE(V) archive(V)
# define RPY_SERIAL_REGISTER_CLASS(T) CEREAL_REGISTER_TYPE(T)
# define RPY_SERIAL_FORCE_DYNAMIC_INIT(LIB) CEREAL_FORCE_DYNAMIC_INIT(LIB)
# define RPY_SERIAL_DYNAMIC_INIT(LIB) CEREAL_REGISTER_DYNAMIC_INIT(LIB)

# define RPY_SERIAL_LOAD_AND_CONSTRUCT(T) \
namespace cereal { \
Expand All @@ -102,6 +104,8 @@
# define RPY_SERIAL_SPECIALIZE_TYPES(T)
# define RPY_SERIAL_REGISTER_CLASS(T)
# define RPY_SERIAL_SERIALIZE_BARE(V) (void) 0
# define RPY_SERIAL_FORCE_DYNAMIC_INIT(LIB)
# define RPY_SERIAL_DYNAMIC_INIT(LIB)
#endif

#define RPY_SERIAL_SERIALIZE_FN() \
Expand Down
10 changes: 8 additions & 2 deletions roughpy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ target_sources(RoughPy_PyModule PRIVATE
src/algebra/tensor_key_iterator.h
src/args/convert_timestamp.cpp
src/args/convert_timestamp.h
src/args/dlpack_helpers.cpp
src/args/dlpack_helpers.h
src/args/kwargs_to_path_metadata.h
src/args/kwargs_to_path_metadata.cpp
src/args/kwargs_to_vector_construction.cpp
Expand All @@ -81,6 +83,8 @@ target_sources(RoughPy_PyModule PRIVATE
src/args/numpy.h
src/args/parse_schema.cpp
src/args/parse_schema.h
src/args/strided_copy.cpp
src/args/strided_copy.h
src/intervals/date_time_interval.cpp
src/intervals/date_time_interval.h
src/intervals/dyadic.h
Expand All @@ -97,6 +101,8 @@ target_sources(RoughPy_PyModule PRIVATE
src/intervals/real_interval.h
src/intervals/segmentation.cpp
src/intervals/segmentation.h
src/scalars/parse_key_scalar_stream.cpp
src/scalars/parse_key_scalar_stream.h
src/scalars/r_py_polynomial.cpp
src/scalars/r_py_polynomial.h
src/scalars/scalar_type.h
Expand All @@ -117,8 +123,8 @@ target_sources(RoughPy_PyModule PRIVATE
src/streams/piecewise_abelian_stream.h
src/streams/r_py_tick_construction_helper.cpp
src/streams/r_py_tick_construction_helper.h
src/streams/py_schema_context.cpp
src/streams/py_schema_context.h
src/streams/py_parametrization.cpp
src/streams/py_parametrization.h
src/streams/schema.cpp
src/streams/schema.h
src/streams/stream.cpp
Expand Down
1 change: 0 additions & 1 deletion roughpy/src/algebra/free_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ static FreeTensor construct_free_tensor(py::object data, py::kwargs kwargs)
{std::move(buffer), helper.vtype}
);

if (options.cleanup) { options.cleanup(); }

RPY_DBG_ASSERT(result.coeff_type() != nullptr);

Expand Down
1 change: 0 additions & 1 deletion roughpy/src/algebra/lie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ static Lie construct_lie(py::object data, py::kwargs kwargs)

auto result = helper.ctx->construct_lie({std::move(buffer), helper.vtype});

if (options.cleanup) { options.cleanup(); }

return result;
}
Expand Down
1 change: 0 additions & 1 deletion roughpy/src/algebra/shuffle_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ static ShuffleTensor construct_shuffle(py::object data, py::kwargs kwargs)
{std::move(buffer), helper.vtype}
);

if (options.cleanup) { options.cleanup(); }

return result;
}
Expand Down
74 changes: 74 additions & 0 deletions roughpy/src/args/dlpack_helpers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//
// Created by sam on 11/08/23.
//

#include "dlpack_helpers.h"

using namespace rpy;
using namespace python;



#define DLTHROW(type, bits) \
RPY_THROW(std::runtime_error, \
std::to_string(bits) + " bit " #type " is not supported")

const string&
python::type_id_for_dl_info(const DLDataType& dtype, const DLDevice& device)
{
if (device.device_type != kDLCPU) {
RPY_THROW(
std::runtime_error,
"for the time being, constructing non-cpu "
"devices are not supported by RoughPy"
);
}

switch (dtype.code) {
case kDLFloat:
switch (dtype.bits) {
case 16: return scalars::type_id_of<scalars::half>();
case 32: return scalars::type_id_of<float>();
case 64: return scalars::type_id_of<double>();
default: DLTHROW(float, dtype.bits);
}
case kDLInt:
switch (dtype.bits) {
case 8: return scalars::type_id_of<char>();
case 16: return scalars::type_id_of<short>();
case 32: return scalars::type_id_of<int>();
case 64: return scalars::type_id_of<long long>();
default: DLTHROW(int, dtype.bits);
}
case kDLUInt:
switch (dtype.bits) {
case 8: return scalars::type_id_of<unsigned char>();
case 16: return scalars::type_id_of<unsigned short>();
case 32: return scalars::type_id_of<unsigned int>();
case 64: return scalars::type_id_of<unsigned long long>();
default: DLTHROW(uint, dtype.bits);
}
case kDLBfloat:
if (dtype.bits == 16) {
return scalars::type_id_of<scalars::bfloat16>();
} else {
DLTHROW(bfloat, dtype.bits);
}
case kDLComplex: DLTHROW(complex, dtype.bits);
case kDLOpaqueHandle: DLTHROW(opaquehandle, dtype.bits);
case kDLBool: DLTHROW(bool, dtype.bits);
}
RPY_UNREACHABLE_RETURN({});
}

#undef DLTHROW



const scalars::ScalarType*
python::scalar_type_for_dl_info(const DLDataType& dtype, const DLDevice& device)
{
return scalars::ScalarType::for_id(type_id_for_dl_info(dtype, device));
}


90 changes: 90 additions & 0 deletions roughpy/src/args/dlpack_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//
// Created by sam on 11/08/23.
//

#ifndef ROUGHPY_DLPACK_HELPERS_H
#define ROUGHPY_DLPACK_HELPERS_H

#include "roughpy_module.h"
#include "dlpack.h"

#include <roughpy/scalars/scalars_fwd.h>
#include <roughpy/scalars/scalar_type.h>

namespace rpy {
namespace python {

constexpr scalars::ScalarTypeCode
convert_from_dl_typecode(uint8_t code) noexcept {
return static_cast<scalars::ScalarTypeCode>(code);
}
constexpr uint8_t
convert_to_dl_typecode(scalars::ScalarTypeCode code) noexcept {
return static_cast<uint8_t>(code);
}

constexpr scalars::ScalarDeviceType
convert_from_dl_device_type(DLDeviceType type) noexcept {
return static_cast<scalars::ScalarDeviceType>(type);
}

constexpr DLDeviceType
convert_to_dl_device_type(scalars::ScalarDeviceType type) noexcept
{
return static_cast<DLDeviceType>(type);
}

constexpr scalars::BasicScalarInfo
convert_from_dl_datatype(const DLDataType& dtype) noexcept {
return {
convert_from_dl_typecode(dtype.code),
dtype.bits,
dtype.lanes
};
}

inline DLDataType
convert_to_dl_datatype(const scalars::BasicScalarInfo& info) noexcept
{
return {
convert_to_dl_typecode(info.code),
info.bits,
info.lanes
};
}

constexpr scalars::ScalarDeviceInfo
convert_from_dl_device_info(const DLDevice& device) noexcept
{
return {
convert_from_dl_device_type(device.device_type),
device.device_id
};
}

constexpr DLDevice
convert_to_dl_device_info(const scalars::ScalarDeviceInfo& device) noexcept {
return {
convert_to_dl_device_type(device.device_type),
device.device_id
};
}


const string&
type_id_for_dl_info(const DLDataType& dtype, const DLDevice& device);

const scalars::ScalarType*
scalar_type_for_dl_info(const DLDataType& dtype, const DLDevice& device);

inline const scalars::ScalarType* scalar_type_of_dl_info(const DLDataType& dtype, const DLDevice& device)
{
return scalars::ScalarType::from_type_details(
convert_from_dl_datatype(dtype),
convert_from_dl_device_info(device));
}

}
}

#endif// ROUGHPY_DLPACK_HELPERS_H
Loading

0 comments on commit 3dd9ced

Please sign in to comment.