From d568c68285aed7e5cb65fc64ac8db76c6641cbf9 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Mon, 7 Aug 2023 14:59:30 +0100 Subject: [PATCH 01/33] Rework schema and channels to make it more flexible than before --- .../include/roughpy/platform/serialization.h | 4 + roughpy/src/streams/py_schema_context.cpp | 2 +- roughpy/src/streams/py_schema_context.h | 2 +- roughpy/src/streams/schema.cpp | 13 +- roughpy/src/streams/tick_stream.cpp | 2 +- streams/CMakeLists.txt | 17 +- streams/include/roughpy/streams/channels.h | 328 +------------ .../streams/channels/categorical_channel.h | 50 ++ .../streams/channels/increment_channel.h | 35 ++ .../streams/channels/lead_laggable_channel.h | 47 ++ .../roughpy/streams/channels/lie_channel.h | 33 ++ .../roughpy/streams/channels/stream_channel.h | 120 +++++ .../roughpy/streams/channels/value_channel.h | 36 ++ .../{schema_context.h => parametrization.h} | 26 +- streams/include/roughpy/streams/schema.h | 119 ++--- .../streams/stream_construction_helper.h | 4 +- streams/src/channels.cpp | 458 ------------------ streams/src/channels/categorical_channel.cpp | 79 +++ streams/src/channels/increment_channel.cpp | 9 + .../src/channels/lead_laggable_channel.cpp | 53 ++ streams/src/channels/lie_channel.cpp | 14 + streams/src/channels/stream_channel.cpp | 67 +++ streams/src/channels/value_channel.cpp | 10 + .../sound_file_data_source.cpp | 2 +- ...schema_context.cpp => parametrization.cpp} | 8 +- streams/src/schema.cpp | 78 +-- streams/src/stream_construction_helper.cpp | 6 +- 27 files changed, 663 insertions(+), 959 deletions(-) create mode 100644 streams/include/roughpy/streams/channels/categorical_channel.h create mode 100644 streams/include/roughpy/streams/channels/increment_channel.h create mode 100644 streams/include/roughpy/streams/channels/lead_laggable_channel.h create mode 100644 streams/include/roughpy/streams/channels/lie_channel.h create mode 100644 streams/include/roughpy/streams/channels/stream_channel.h create mode 100644 streams/include/roughpy/streams/channels/value_channel.h rename streams/include/roughpy/streams/{schema_context.h => parametrization.h} (53%) delete mode 100644 streams/src/channels.cpp create mode 100644 streams/src/channels/categorical_channel.cpp create mode 100644 streams/src/channels/increment_channel.cpp create mode 100644 streams/src/channels/lead_laggable_channel.cpp create mode 100644 streams/src/channels/lie_channel.cpp create mode 100644 streams/src/channels/stream_channel.cpp create mode 100644 streams/src/channels/value_channel.cpp rename streams/src/{schema_context.cpp => parametrization.cpp} (91%) diff --git a/platform/include/roughpy/platform/serialization.h b/platform/include/roughpy/platform/serialization.h index 09a78e5a4..1756e6613 100644 --- a/platform/include/roughpy/platform/serialization.h +++ b/platform/include/roughpy/platform/serialization.h @@ -76,6 +76,8 @@ CEREAL_SPECIALIZE_FOR_ALL_ARCHIVES(T, M) # define RPY_SERIAL_SERIALIZE_BARE(V) archive(V) # define RPY_SERIAL_REGISTER_CLASS(T) CEREAL_REGISTER_TYPE(T) +# define RPY_SERIAL_FORCE_DYNAMIC_INIT(LIB) CEREAL_FORCE_DYNAMIC_INIT(LIB) +# define RPY_SERIAL_DYNAMIC_INIT(LIB) CEREAL_REGISTER_DYNAMIC_INIT(LIB) # define RPY_SERIAL_LOAD_AND_CONSTRUCT(T) \ namespace cereal { \ @@ -102,6 +104,8 @@ # define RPY_SERIAL_SPECIALIZE_TYPES(T) # define RPY_SERIAL_REGISTER_CLASS(T) # define RPY_SERIAL_SERIALIZE_BARE(V) (void) 0 +# define RPY_SERIAL_FORCE_DYNAMIC_INIT(LIB) +# define RPY_SERIAL_DYNAMIC_INIT(LIB) #endif #define RPY_SERIAL_SERIALIZE_FN() \ diff --git a/roughpy/src/streams/py_schema_context.cpp b/roughpy/src/streams/py_schema_context.cpp index 4dfc24833..69364e1a9 100644 --- a/roughpy/src/streams/py_schema_context.cpp +++ b/roughpy/src/streams/py_schema_context.cpp @@ -27,7 +27,7 @@ intervals::RealInterval python::PySchemaContext::convert_parameter_interval( return reparametrize({inf, sup, arg.type()}); } - return SchemaContext::convert_parameter_interval(arg); + return Parameterization::convert_parameter_interval(arg); } void python::PySchemaContext::set_reference_dt(py::object dt_reference) diff --git a/roughpy/src/streams/py_schema_context.h b/roughpy/src/streams/py_schema_context.h index adffc0df3..53225d95a 100644 --- a/roughpy/src/streams/py_schema_context.h +++ b/roughpy/src/streams/py_schema_context.h @@ -42,7 +42,7 @@ namespace rpy { namespace python { -class PySchemaContext : public streams::SchemaContext +class PySchemaContext : public streams::Parameterization { py::object m_dt_reference; optional m_dt_conversion{}; diff --git a/roughpy/src/streams/schema.cpp b/roughpy/src/streams/schema.cpp index d536693fe..e170f2863 100644 --- a/roughpy/src/streams/schema.cpp +++ b/roughpy/src/streams/schema.cpp @@ -130,19 +130,19 @@ void rpy::python::init_schema(py::module_& m) auto* plist = labels.ptr(); py::ssize_t i = 0; for (auto&& item : *schema) { - switch (item.second.type()) { + switch (item.second->type()) { case streams::ChannelType::Increment: PyList_SET_ITEM( plist, i++, PyUnicode_FromString(item.first.c_str()) ); break; case streams::ChannelType::Value: - if (item.second.is_lead_lag()) { + if (item.second->is_lead_lag()) { PyList_SET_ITEM( plist, i++, PyUnicode_FromString( (item.first - + item.second.label_suffix(0)) + + item.second->label_suffix(0)) .c_str() ) ); @@ -150,7 +150,7 @@ void rpy::python::init_schema(py::module_& m) plist, i++, PyUnicode_FromString( (item.first - + item.second.label_suffix(1)) + + item.second->label_suffix(1)) .c_str() ) ); @@ -162,13 +162,13 @@ void rpy::python::init_schema(py::module_& m) } break; case streams::ChannelType::Categorical: { - auto nvariants = item.second.num_variants(); + auto nvariants = item.second->num_variants(); for (dimn_t idx = 0; idx < nvariants; ++idx) { PyList_SET_ITEM( plist, i++, PyUnicode_FromString( (item.first - + item.second.label_suffix(idx)) + + item.second->label_suffix(idx)) .c_str() ) ); @@ -178,6 +178,7 @@ void rpy::python::init_schema(py::module_& m) case streams::ChannelType::Lie: break; } } + RPY_CHECK(i == schema->width()); return labels; }); } diff --git a/roughpy/src/streams/tick_stream.cpp b/roughpy/src/streams/tick_stream.cpp index 44bccb674..d0ba06bf8 100644 --- a/roughpy/src/streams/tick_stream.cpp +++ b/roughpy/src/streams/tick_stream.cpp @@ -169,7 +169,7 @@ static py::object construct(const py::object& data, const py::kwargs& kwargs) += python::py_to_scalar(pmd.scalar_type, tick.data); break; case streams::ChannelType::Value: - if (channel.is_lead_lag()) { + if (channel->is_lead_lag()) { auto lag_key = key + 1; const auto& prev_lead = previous_values[idx]; const auto& prev_lag = previous_values[idx + 1]; diff --git a/streams/CMakeLists.txt b/streams/CMakeLists.txt index 5fc2149d0..e07782345 100644 --- a/streams/CMakeLists.txt +++ b/streams/CMakeLists.txt @@ -25,10 +25,21 @@ add_roughpy_component(Streams src/external_data_sources/sound_file_data_source.cpp src/external_data_sources/sound_file_data_source.h src/schema.cpp - src/schema_context.cpp + src/parametrization.cpp src/stream_construction_helper.cpp - src/channels.cpp + src/channels/categorical_channel.cpp + src/channels/stream_channel.cpp + src/channels/increment_channel.cpp + src/channels/value_channel.cpp + src/channels/lie_channel.cpp + src/channels/lead_laggable_channel.cpp PUBLIC_HEADERS + include/roughpy/streams/channels/stream_channel.h + include/roughpy/streams/channels/increment_channel.h + include/roughpy/streams/channels/value_channel.h + include/roughpy/streams/channels/categorical_channel.h + include/roughpy/streams/channels/lie_channel.h + include/roughpy/streams/channels/lead_laggable_channel.h include/roughpy/streams/dyadic_caching_layer.h include/roughpy/streams/dynamically_constructed_stream.h include/roughpy/streams/lie_increment_stream.h @@ -39,7 +50,7 @@ add_roughpy_component(Streams include/roughpy/streams/external_data_stream.h include/roughpy/streams/piecewise_abelian_stream.h include/roughpy/streams/schema.h - include/roughpy/streams/schema_context.h + include/roughpy/streams/parametrization.h include/roughpy/streams/stream_construction_helper.h include/roughpy/streams/channels.h PRIVATE_DEPS diff --git a/streams/include/roughpy/streams/channels.h b/streams/include/roughpy/streams/channels.h index ffe1ab06e..51fb512dd 100644 --- a/streams/include/roughpy/streams/channels.h +++ b/streams/include/roughpy/streams/channels.h @@ -1,328 +1,12 @@ #ifndef ROUGHPY_STREAMS_CHANNELS_H_ #define ROUGHPY_STREAMS_CHANNELS_H_ -#include -#include -#include -#include -#include +#include "channels/stream_channel.h" +#include "channels/lead_laggable_channel.h" +#include "channels/increment_channel.h" +#include "channels/value_channel.h" +#include "channels/categorical_channel.h" +#include "channels/lie_channel.h" -#include -#include - -namespace rpy { -namespace streams { - -enum struct ChannelType : uint8_t -{ - Increment = 0, - Value = 1, - Categorical = 2, - Lie = 3, -}; - -enum struct StaticChannelType : uint8_t -{ - Value = 0, - Categorical = 1 -}; - -struct IncrementChannelInfo { -}; -struct ValueChannelInfo { - bool lead_lag = false; -}; - -struct CategoricalChannelInfo { - std::vector variants; -}; -struct LieChannelInfo { - deg_t width; - deg_t depth; - algebra::VectorType vtype; -}; - -/** - * @brief Abstract description of a single channel of data. - * - * A stream channel contains metadata about one particular channel of - * incoming raw data. The data can take one of four possible forms: - * increment data, value data, categorical data, or lie data. Each form - * of data has its own set of required and optional metadata, including - * the number and labels of a categorical variable. These objects are - * typically accessed via a schema, which maintains the collection of - * all channels associated with a stream. - */ -class RPY_EXPORT StreamChannel -{ - ChannelType m_type; - union - { - IncrementChannelInfo increment_info; - ValueChannelInfo value_info; - CategoricalChannelInfo categorical_info; - LieChannelInfo lie_info; - }; - - template - static void inplace_construct(void* address, T&& value) noexcept( - is_nothrow_constructible, - decltype(value)>::value) - { - ::new (address) remove_cv_ref_t(std::forward(value)); - } - -public: - RPY_NO_DISCARD - ChannelType type() const noexcept { return m_type; } - - RPY_NO_DISCARD dimn_t num_variants() const - { - switch (m_type) { - case ChannelType::Increment: return 1; - case ChannelType::Value: return value_info.lead_lag ? 2 : 1; - case ChannelType::Categorical: - return categorical_info.variants.size(); - case ChannelType::Lie: return lie_info.width; - } - RPY_UNREACHABLE(); - } - - RPY_NO_DISCARD - string label_suffix(dimn_t variant_no) const; - - StreamChannel(); - StreamChannel(const StreamChannel& arg); - StreamChannel(StreamChannel&& arg) noexcept; - - explicit StreamChannel(ChannelType type); - - explicit StreamChannel(IncrementChannelInfo&& info) - : m_type(ChannelType::Increment), increment_info(std::move(info)) - {} - - explicit StreamChannel(ValueChannelInfo&& info) - : m_type(ChannelType::Value), value_info(std::move(info)) - {} - - explicit StreamChannel(CategoricalChannelInfo&& info) - : m_type(ChannelType::Categorical), categorical_info(std::move(info)) - {} - - explicit StreamChannel(LieChannelInfo&& info) - : m_type(ChannelType::Lie), lie_info(std::move(info)) - {} - - ~StreamChannel(); - - StreamChannel& operator=(const StreamChannel& other); - StreamChannel& operator=(StreamChannel&& other) noexcept; - - StreamChannel& operator=(IncrementChannelInfo&& info) - { - this->~StreamChannel(); - m_type = ChannelType::Increment; - inplace_construct(&increment_info, std::move(info)); - return *this; - } - - StreamChannel& operator=(ValueChannelInfo&& info) - { - this->~StreamChannel(); - m_type = ChannelType::Value; - inplace_construct(&value_info, std::move(info)); - return *this; - } - - StreamChannel& operator=(CategoricalChannelInfo&& info) - { - this->~StreamChannel(); - m_type = ChannelType::Categorical; - inplace_construct(&categorical_info, std::move(info)); - return *this; - } - - RPY_NO_DISCARD - dimn_t variant_id_of_label(string_view label) const; - - void set_lie_info(deg_t width, deg_t depth, algebra::VectorType vtype); - void set_lead_lag(bool new_value) - { - RPY_CHECK(m_type == ChannelType::Value); - value_info.lead_lag = new_value; - } - RPY_NO_DISCARD - bool is_lead_lag() const - { - RPY_CHECK(m_type == ChannelType::Value); - return value_info.lead_lag; - } - - StreamChannel& add_variant(string variant_label); - - /** - * @brief Insert a new variant if it doesn't already exist - * @param variant_label variant label to insert. - * @return referene to this channel. - */ - StreamChannel& insert_variant(string variant_label); - - RPY_NO_DISCARD - std::vector get_variants() const; - - RPY_SERIAL_SAVE_FN(); - RPY_SERIAL_LOAD_FN(); -}; - -class RPY_EXPORT StaticChannel -{ - StaticChannelType m_type; - union - { - ValueChannelInfo value_info; - CategoricalChannelInfo categorical_info; - }; - - template - static void inplace_construct(void* address, T&& value) noexcept( - is_nothrow_constructible, - decltype(value)>::value) - { - ::new (address) remove_cv_ref_t(std::forward(value)); - } - -public: - StaticChannel(); - StaticChannel(const StaticChannel& other); - StaticChannel(StaticChannel&& other) noexcept; - - explicit StaticChannel(ValueChannelInfo&& info) - : m_type(StaticChannelType::Value), value_info(std::move(info)) - {} - - explicit StaticChannel(CategoricalChannelInfo&& info) - : m_type(StaticChannelType::Categorical), - categorical_info(std::move(info)) - {} - - ~StaticChannel(); - - StaticChannel& operator=(const StaticChannel& other); - StaticChannel& operator=(StaticChannel&& other) noexcept; - - RPY_NO_DISCARD - StaticChannelType type() const noexcept { return m_type; } - - RPY_NO_DISCARD - string label_suffix(dimn_t index) const; - - RPY_NO_DISCARD - dimn_t num_variants() const noexcept; - - RPY_NO_DISCARD - std::vector get_variants() const; - - RPY_NO_DISCARD - dimn_t variant_id_of_label(const string& label) const; - - StaticChannel& insert_variant(string new_variant); - - StaticChannel& add_variant(string new_variant); - - RPY_SERIAL_LOAD_FN(); - RPY_SERIAL_SAVE_FN(); -}; - -RPY_SERIAL_SERIALIZE_FN_EXT(IncrementChannelInfo) -{ - (void) value; - RPY_SERIAL_SERIALIZE_NVP("data", 0); -} - -RPY_SERIAL_SERIALIZE_FN_EXT(ValueChannelInfo) -{ - RPY_SERIAL_SERIALIZE_NVP("lead_lag", value.lead_lag); -} - -RPY_SERIAL_SERIALIZE_FN_EXT(CategoricalChannelInfo) -{ - RPY_SERIAL_SERIALIZE_NVP("variants", value.variants); -} - -RPY_SERIAL_SERIALIZE_FN_EXT(LieChannelInfo) -{ - RPY_SERIAL_SERIALIZE_NVP("width", value.width); - RPY_SERIAL_SERIALIZE_NVP("depth", value.depth); - RPY_SERIAL_SERIALIZE_NVP("vector_type", value.vtype); -} - -RPY_SERIAL_SAVE_FN_IMPL(StreamChannel) -{ - RPY_SERIAL_SERIALIZE_NVP("type", m_type); - switch (m_type) { - case ChannelType::Increment: - RPY_SERIAL_SERIALIZE_NVP("increment", increment_info); - break; - case ChannelType::Value: - RPY_SERIAL_SERIALIZE_NVP("value", value_info); - break; - case ChannelType::Categorical: - RPY_SERIAL_SERIALIZE_NVP("categorical", categorical_info); - break; - case ChannelType::Lie: RPY_SERIAL_SERIALIZE_NVP("lie", lie_info); break; - } -} - -RPY_SERIAL_LOAD_FN_IMPL(StreamChannel) -{ - RPY_SERIAL_SERIALIZE_NVP("type", m_type); - switch (m_type) { - case ChannelType::Increment: - new (&increment_info) IncrementChannelInfo; - RPY_SERIAL_SERIALIZE_NVP("increment", increment_info); - break; - case ChannelType::Value: - new (&value_info) ValueChannelInfo; - RPY_SERIAL_SERIALIZE_NVP("value", value_info); - break; - case ChannelType::Categorical: - new (&categorical_info) CategoricalChannelInfo; - RPY_SERIAL_SERIALIZE_NVP("categorical", categorical_info); - break; - case ChannelType::Lie: - new (&lie_info) LieChannelInfo; - RPY_SERIAL_SERIALIZE_NVP("lie", lie_info); - break; - } -} - -RPY_SERIAL_SAVE_FN_IMPL(StaticChannel) -{ - RPY_SERIAL_SERIALIZE_NVP("type", m_type); - switch (m_type) { - case StaticChannelType::Value: - RPY_SERIAL_SERIALIZE_NVP("value", value_info); - break; - case StaticChannelType::Categorical: - RPY_SERIAL_SERIALIZE_NVP("categorical", categorical_info); - break; - } -} - -RPY_SERIAL_LOAD_FN_IMPL(StaticChannel) -{ - RPY_SERIAL_SERIALIZE_NVP("type", m_type); - switch (m_type) { - case StaticChannelType::Value: - ::new (&value_info) ValueChannelInfo; - break; - case StaticChannelType::Categorical: - ::new (&categorical_info) CategoricalChannelInfo; - break; - } -} - -}// namespace streams -}// namespace rpy #endif// ROUGHPY_STREAMS_CHANNELS_H_ diff --git a/streams/include/roughpy/streams/channels/categorical_channel.h b/streams/include/roughpy/streams/channels/categorical_channel.h new file mode 100644 index 000000000..509c4a956 --- /dev/null +++ b/streams/include/roughpy/streams/channels/categorical_channel.h @@ -0,0 +1,50 @@ +// +// Created by sam on 07/08/23. +// + +#ifndef ROUGHPY_STREAMS_CATEGORICAL_CHANNEL_H +#define ROUGHPY_STREAMS_CATEGORICAL_CHANNEL_H + +#include "stream_channel.h" + +namespace rpy { +namespace streams { + +class RPY_EXPORT CategoricalChannel : public StreamChannel +{ + std::vector m_variants; + +public: + + CategoricalChannel() : StreamChannel(ChannelType::Categorical, nullptr) + {} + + dimn_t num_variants() const override; + string label_suffix(dimn_t variant_no) const override; + dimn_t variant_id_of_label(string_view label) const override; + const std::vector& get_variants() const override; + + StreamChannel& add_variant(string variant_label) override; + StreamChannel& insert_variant(string variant_label) override; + + RPY_SERIAL_SERIALIZE_FN(); +}; + + +RPY_SERIAL_SERIALIZE_FN_IMPL(CategoricalChannel) { + RPY_SERIAL_SERIALIZE_BASE(StreamChannel); + RPY_SERIAL_SERIALIZE_NVP("variants", m_variants); +} + + + +}// namespace streams +}// namespace rpy + +RPY_SERIAL_SPECIALIZE_TYPES(rpy::streams::CategoricalChannel, + rpy::serial::specialization::member_serialize) + + + +#endif// ROUGHPY_STREAMS_CATEGORICAL_CHANNEL_H + diff --git a/streams/include/roughpy/streams/channels/increment_channel.h b/streams/include/roughpy/streams/channels/increment_channel.h new file mode 100644 index 000000000..b8b822d45 --- /dev/null +++ b/streams/include/roughpy/streams/channels/increment_channel.h @@ -0,0 +1,35 @@ +// +// Created by sam on 07/08/23. +// + +#ifndef ROUGHPY_STREAMS_INCREMENT_CHANNEL_H +#define ROUGHPY_STREAMS_INCREMENT_CHANNEL_H + +#include "lead_laggable_channel.h" + +namespace rpy { namespace streams { + + + +class RPY_EXPORT IncrementChannel : public LeadLaggableChannel +{ +public: + + IncrementChannel() : LeadLaggableChannel(ChannelType::Increment, nullptr) + {} + + + RPY_SERIAL_SERIALIZE_FN(); +}; + +RPY_SERIAL_SERIALIZE_FN_IMPL(IncrementChannel) { + RPY_SERIAL_SERIALIZE_BASE(LeadLaggableChannel); +} + +}} + +RPY_SERIAL_SPECIALIZE_TYPES(rpy::streams::IncrementChannel, + rpy::serial::specialization::member_serialize) + + +#endif// ROUGHPY_STREAMS_INCREMENT_CHANNEL_H diff --git a/streams/include/roughpy/streams/channels/lead_laggable_channel.h b/streams/include/roughpy/streams/channels/lead_laggable_channel.h new file mode 100644 index 000000000..7bb4b8267 --- /dev/null +++ b/streams/include/roughpy/streams/channels/lead_laggable_channel.h @@ -0,0 +1,47 @@ +// +// Created by sam on 07/08/23. +// + +#ifndef ROUGHPY_STREAMS_LEAD_LAGGABLE_CHANNEL_H +#define ROUGHPY_STREAMS_LEAD_LAGGABLE_CHANNEL_H + +#include "stream_channel.h" + +namespace rpy { +namespace streams { + +class RPY_EXPORT LeadLaggableChannel : public StreamChannel +{ + bool m_use_leadlag = false; + +protected: + + using StreamChannel::StreamChannel; + +public: + dimn_t num_variants() const override; + string label_suffix(dimn_t variant_no) const override; + dimn_t variant_id_of_label(string_view label) const override; + const std::vector& get_variants() const override; + void set_lead_lag(bool new_value) override; + bool is_lead_lag() const override; + + RPY_SERIAL_SERIALIZE_FN(); +}; + +RPY_SERIAL_SERIALIZE_FN_IMPL(LeadLaggableChannel) +{ + RPY_SERIAL_SERIALIZE_BASE(StreamChannel); + RPY_SERIAL_SERIALIZE_NVP("use_leadlag", m_use_leadlag); +} + +}// namespace streams +}// namespace rpy + +RPY_SERIAL_SPECIALIZE_TYPES( + rpy::streams::LeadLaggableChannel, + rpy::serial::specialization::member_serialize +) + + +#endif// ROUGHPY_STREAMS_LEAD_LAGGABLE_CHANNEL_H diff --git a/streams/include/roughpy/streams/channels/lie_channel.h b/streams/include/roughpy/streams/channels/lie_channel.h new file mode 100644 index 000000000..e4866df6f --- /dev/null +++ b/streams/include/roughpy/streams/channels/lie_channel.h @@ -0,0 +1,33 @@ +// +// Created by sam on 07/08/23. +// + +#ifndef ROUGHPY_STREAMS_LIE_CHANNEL_H +#define ROUGHPY_STREAMS_LIE_CHANNEL_H + +#include "stream_channel.h" + +namespace rpy { +namespace streams { + +class RPY_EXPORT LieChannel : public StreamChannel +{ +public: + LieChannel() : StreamChannel(ChannelType::Lie, nullptr) {} + + RPY_SERIAL_SERIALIZE_FN(); +}; + +RPY_SERIAL_SERIALIZE_FN_IMPL(LieChannel) +{ + RPY_SERIAL_SERIALIZE_BASE(StreamChannel); +} + +}// namespace streams +}// namespace rpy + +RPY_SERIAL_SPECIALIZE_TYPES( + rpy::streams::LieChannel, rpy::serial::specialization::member_serialize +) + +#endif// ROUGHPY_STREAMS_LIE_CHANNEL_H diff --git a/streams/include/roughpy/streams/channels/stream_channel.h b/streams/include/roughpy/streams/channels/stream_channel.h new file mode 100644 index 000000000..4cd80b24d --- /dev/null +++ b/streams/include/roughpy/streams/channels/stream_channel.h @@ -0,0 +1,120 @@ +// +// Created by sam on 07/08/23. +// + +#ifndef ROUGHPY_STREAMS_STREAM_CHANNEL_H +#define ROUGHPY_STREAMS_STREAM_CHANNEL_H + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace rpy { +namespace streams { + +enum struct ChannelType : uint8_t +{ + Increment = 0, + Value = 1, + Categorical = 2, + Lie = 3, +}; + +/** + * @brief Abstract description of a single channel of data. + * + * A stream channel contains metadata about one particular channel of + * incoming raw data. The data can take one of four possible forms: + * increment data, value data, categorical data, or lie data. Each form + * of data has its own set of required and optional metadata, including + * the number and labels of a categorical variable. These objects are + * typically accessed via a schema, which maintains the collection of + * all channels associated with a stream. + */ +class RPY_EXPORT StreamChannel +{ + ChannelType m_type; + const scalars::ScalarType* p_scalar_type = nullptr; + +protected: + explicit StreamChannel( + ChannelType type, const scalars::ScalarType* channel_dtype + ) + : m_type(type), p_scalar_type(channel_dtype) + {} + +public: + + StreamChannel() : m_type(ChannelType::Increment), p_scalar_type(nullptr) + {} + + RPY_NO_DISCARD ChannelType type() const noexcept { return m_type; } + + RPY_NO_DISCARD virtual dimn_t num_variants() const; + + RPY_NO_DISCARD virtual string label_suffix(dimn_t variant_no) const; + + virtual ~StreamChannel(); + + RPY_NO_DISCARD virtual dimn_t variant_id_of_label(string_view label) const; + + virtual void + set_lie_info(deg_t width, deg_t depth, algebra::VectorType vtype); + + virtual void set_lead_lag(bool new_value); + + RPY_NO_DISCARD virtual bool is_lead_lag() const; + + virtual void convert_input( + scalars::ScalarPointer& dst, const scalars::ScalarPointer& src, + dimn_t count + ) const; + + template + void convert_input(scalars::ScalarPointer& dst, const T& single_data) const + { + convert_input(dst, {scalars::type_id_of(), &single_data}, 1); + } + + virtual StreamChannel& add_variant(string variant_label); + + /** + * @brief Insert a new variant if it doesn't already exist + * @param variant_label variant label to insert. + * @return referene to this channel. + */ + virtual StreamChannel& insert_variant(string variant_label); + + RPY_NO_DISCARD virtual const std::vector& get_variants() const; + + RPY_SERIAL_SAVE_FN(); + RPY_SERIAL_LOAD_FN(); +}; + +RPY_SERIAL_SAVE_FN_IMPL(StreamChannel) +{ + RPY_SERIAL_SERIALIZE_NVP("type", m_type); + RPY_SERIAL_SERIALIZE_NVP( + "dtype_id", (p_scalar_type == nullptr) ? "" : p_scalar_type->id() + ); +} +RPY_SERIAL_LOAD_FN_IMPL(StreamChannel) { + RPY_SERIAL_SERIALIZE_NVP("type", m_type); + string id; + RPY_SERIAL_SERIALIZE_NVP("dtype_id", id); + if (!id.empty()) { + p_scalar_type = scalars::get_type(id); + } +} + +}// namespace streams +}// namespace rpy + + + +#endif// ROUGHPY_STREAMS_STREAM_CHANNEL_H diff --git a/streams/include/roughpy/streams/channels/value_channel.h b/streams/include/roughpy/streams/channels/value_channel.h new file mode 100644 index 000000000..6fba0e42c --- /dev/null +++ b/streams/include/roughpy/streams/channels/value_channel.h @@ -0,0 +1,36 @@ +// +// Created by sam on 07/08/23. +// + +#ifndef ROUGHPY_STREAMS_VALUE_CHANNEL_H +#define ROUGHPY_STREAMS_VALUE_CHANNEL_H + +#include "lead_laggable_channel.h" + +namespace rpy { +namespace streams { + +class RPY_EXPORT ValueChannel : public LeadLaggableChannel +{ + +public: + ValueChannel() : LeadLaggableChannel(ChannelType::Value, nullptr) {} + + RPY_SERIAL_SERIALIZE_FN(); +}; + +RPY_SERIAL_SERIALIZE_FN_IMPL(ValueChannel) +{ + RPY_SERIAL_SERIALIZE_BASE(LeadLaggableChannel); +} + +}// namespace streams +}// namespace rpy + +RPY_SERIAL_SPECIALIZE_TYPES( + rpy::streams::ValueChannel, + rpy::serial::specialization::member_serialize +) + + +#endif// ROUGHPY_STREAMS_VALUE_CHANNEL_H diff --git a/streams/include/roughpy/streams/schema_context.h b/streams/include/roughpy/streams/parametrization.h similarity index 53% rename from streams/include/roughpy/streams/schema_context.h rename to streams/include/roughpy/streams/parametrization.h index c339cd69e..e1ad6aa51 100644 --- a/streams/include/roughpy/streams/schema_context.h +++ b/streams/include/roughpy/streams/parametrization.h @@ -1,5 +1,5 @@ -#ifndef ROUGHPY_STREAMS_SCHEMA_CONTEXT_H_ -#define ROUGHPY_STREAMS_SCHEMA_CONTEXT_H_ +#ifndef ROUGHPY_STREAMS_PARAMETRIZATION_H_ +#define ROUGHPY_STREAMS_PARAMETRIZATION_H_ #include #include @@ -10,28 +10,36 @@ namespace rpy { namespace streams { -class RPY_EXPORT SchemaContext +class RPY_EXPORT Parameterization { param_t m_param_offset = 0.0; param_t m_param_scaling = 1.0; + + bool b_add_to_channels = false; + idimn_t m_is_channel = -1; + public: + virtual ~Parameterization(); + + void add_as_channel() noexcept { b_add_to_channels = true; } - virtual ~SchemaContext(); + RPY_NO_DISCARD bool is_channel() const noexcept + { + return b_add_to_channels || (m_is_channel >= 0); + } - RPY_NO_DISCARD - intervals::RealInterval + RPY_NO_DISCARD intervals::RealInterval reparametrize(const intervals::RealInterval& arg) const { return {m_param_offset + m_param_scaling * arg.inf(), m_param_offset + m_param_scaling * arg.sup(), arg.type()}; } - RPY_NO_DISCARD - virtual intervals::RealInterval + RPY_NO_DISCARD virtual intervals::RealInterval convert_parameter_interval(const intervals::Interval& arg) const; }; }// namespace streams }// namespace rpy -#endif// ROUGHPY_STREAMS_SCHEMA_CONTEXT_H_ +#endif// ROUGHPY_STREAMS_PARAMETRIZATION_H_ diff --git a/streams/include/roughpy/streams/schema.h b/streams/include/roughpy/streams/schema.h index e644bc990..b8c9d5699 100644 --- a/streams/include/roughpy/streams/schema.h +++ b/streams/include/roughpy/streams/schema.h @@ -50,7 +50,7 @@ #include #include "channels.h" -#include "schema_context.h" +#include "parametrization.h" namespace rpy { namespace streams { @@ -73,22 +73,18 @@ namespace streams { * data, this is a 1-1 mapping. However, categorical values must be expanded * and other types of data might occupy multiple stream dimensions. */ -class RPY_EXPORT StreamSchema : private std::vector> +class RPY_EXPORT StreamSchema + : private std::vector>> { - using base_type = std::vector>; - using static_vec_type = std::vector>; - - static_vec_type m_static_channels; + using base_type = std::vector>>; bool m_is_final = false; - std::unique_ptr p_context; + std::unique_ptr p_context; public: using typename base_type::const_iterator; using typename base_type::iterator; - using static_iterator = typename static_vec_type::iterator; - using static_const_iterator = typename static_vec_type::const_iterator; using typename base_type::value_type; using lie_key = typename algebra::LieBasis::key_type; @@ -98,15 +94,13 @@ class RPY_EXPORT StreamSchema : private std::vector> using base_type::reserve; using base_type::size; - static bool compare_labels(string_view item_label, - string_view ref_label) noexcept; + static bool + compare_labels(string_view item_label, string_view ref_label) noexcept; private: - RPY_NO_DISCARD - dimn_t channel_it_to_width(const_iterator channel_it) const; + RPY_NO_DISCARD dimn_t channel_it_to_width(const_iterator channel_it) const; - RPY_NO_DISCARD - dimn_t width_to_iterator(const_iterator end) const; + RPY_NO_DISCARD dimn_t width_to_iterator(const_iterator end) const; /** * @brief Get the iterator to the channel corresponding to stream_dim @@ -126,18 +120,19 @@ class RPY_EXPORT StreamSchema : private std::vector> StreamSchema& operator=(const StreamSchema&) = delete; StreamSchema& operator=(StreamSchema&&) noexcept = default; - RPY_NO_DISCARD - const_iterator nth(dimn_t idx) const noexcept + RPY_NO_DISCARD const_iterator nth(dimn_t idx) const noexcept { RPY_DBG_ASSERT(idx < size()); return begin() + static_cast(idx); } - RPY_NO_DISCARD - SchemaContext* context() const noexcept { return p_context.get(); } + RPY_NO_DISCARD Parameterization* context() const noexcept + { + return p_context.get(); + } template - enable_if_t::value> + enable_if_t::value> init_context(Args&&... args) { RPY_DBG_ASSERT(!p_context); @@ -145,96 +140,54 @@ class RPY_EXPORT StreamSchema : private std::vector> } public: - RPY_NO_DISCARD - iterator find(const string& label); - - RPY_NO_DISCARD - const_iterator find(const string& label) const; - - RPY_NO_DISCARD - static_iterator begin_static() noexcept - { - return m_static_channels.begin(); - } + RPY_NO_DISCARD iterator find(const string& label); - RPY_NO_DISCARD - static_iterator end_static() noexcept { return m_static_channels.end(); } + RPY_NO_DISCARD const_iterator find(const string& label) const; - RPY_NO_DISCARD - static_const_iterator begin_static() const noexcept + RPY_NO_DISCARD bool contains(const string& label) const { - return m_static_channels.cbegin(); + return find(label) != end(); } - RPY_NO_DISCARD - static_const_iterator end_static() const noexcept - { - return m_static_channels.cend(); - } - - RPY_NO_DISCARD - static_iterator find_static(const string& label); - - RPY_NO_DISCARD - static_const_iterator find_static(const string& label) const; + RPY_NO_DISCARD dimn_t width() const; - RPY_NO_DISCARD - bool contains(const string& label) const { return find(label) != end(); } + RPY_NO_DISCARD dimn_t channel_to_stream_dim(dimn_t channel_no) const; - RPY_NO_DISCARD - dimn_t width() const; + RPY_NO_DISCARD dimn_t + channel_variant_to_stream_dim(dimn_t channel_no, dimn_t variant_no) const; - RPY_NO_DISCARD - dimn_t channel_to_stream_dim(dimn_t channel_no) const; - - RPY_NO_DISCARD - dimn_t channel_variant_to_stream_dim(dimn_t channel_no, - dimn_t variant_no) const; - - RPY_NO_DISCARD - pair stream_dim_to_channel(dimn_t stream_dim) const; + RPY_NO_DISCARD pair stream_dim_to_channel(dimn_t stream_dim + ) const; private: - RPY_NO_DISCARD - static string label_from_channel_it(const_iterator channel_it, - dimn_t variant_id); + RPY_NO_DISCARD static string + label_from_channel_it(const_iterator channel_it, dimn_t variant_id); public: - RPY_NO_DISCARD - string label_of_stream_dim(dimn_t stream_dim) const; + RPY_NO_DISCARD string label_of_stream_dim(dimn_t stream_dim) const; - RPY_NO_DISCARD - string_view label_of_channel_id(dimn_t channel_id) const; + RPY_NO_DISCARD string_view label_of_channel_id(dimn_t channel_id) const; - RPY_NO_DISCARD - string label_of_channel_variant(dimn_t channel_id, - dimn_t channel_variant) const; + RPY_NO_DISCARD string + label_of_channel_variant(dimn_t channel_id, dimn_t channel_variant) const; - RPY_NO_DISCARD - dimn_t label_to_stream_dim(const string& label) const; + RPY_NO_DISCARD dimn_t label_to_stream_dim(const string& label) const; - StreamChannel& insert(string label, StreamChannel&& channel_data); + StreamChannel& insert(string label, std::unique_ptr&& channel_data); - StreamChannel& insert(StreamChannel&& channel_data); StreamChannel& insert_increment(string label); StreamChannel& insert_value(string label); StreamChannel& insert_categorical(string label); StreamChannel& insert_lie(string label); - StaticChannel& insert_static_value(string label); - StaticChannel& insert_static_categorical(string label); - - RPY_NO_DISCARD - bool is_final() const noexcept { return m_is_final; } + RPY_NO_DISCARD bool is_final() const noexcept { return m_is_final; } void finalize() noexcept { m_is_final = true; } - RPY_NO_DISCARD - intervals::RealInterval + RPY_NO_DISCARD intervals::RealInterval adjust_interval(const intervals::Interval& arg) const; - RPY_NO_DISCARD - lie_key label_to_lie_key(const string& label); + RPY_NO_DISCARD lie_key label_to_lie_key(const string& label); RPY_SERIAL_SERIALIZE_FN(); }; diff --git a/streams/include/roughpy/streams/stream_construction_helper.h b/streams/include/roughpy/streams/stream_construction_helper.h index 1e3d0f028..22837030c 100644 --- a/streams/include/roughpy/streams/stream_construction_helper.h +++ b/streams/include/roughpy/streams/stream_construction_helper.h @@ -117,8 +117,8 @@ void StreamConstructionHelper::add_value(param_t timestamp, string_view label, const auto found = p_schema->find(string(label)); RPY_CHECK(found != p_schema->end()); - auto lead_idx = found->second.variant_id_of_label("lead"); - auto lag_idx = found->second.variant_id_of_label("lag"); + auto lead_idx = found->second->variant_id_of_label("lead"); + auto lag_idx = found->second->variant_id_of_label("lag"); scalars::Scalar current(std::forward(value)); diff --git a/streams/src/channels.cpp b/streams/src/channels.cpp deleted file mode 100644 index 22f747fe9..000000000 --- a/streams/src/channels.cpp +++ /dev/null @@ -1,458 +0,0 @@ -// Copyright (c) 2023 the RoughPy Developers. All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// -// 3. Neither the name of the copyright holder nor the names of its contributors -// may be used to endorse or promote products derived from this software without -// specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -// POSSIBILITY OF SUCH DAMAGE. - -// -// Created by user on 04/07/23. -// -#include - -#include - -using namespace rpy; -using namespace streams; - -StreamChannel::StreamChannel() - : m_type(ChannelType::Increment), increment_info() -{} -StreamChannel::StreamChannel(const StreamChannel& arg) : m_type(arg.m_type) -{ - switch (m_type) { - case ChannelType::Increment: - inplace_construct(&increment_info, arg.increment_info); - break; - case ChannelType::Value: - inplace_construct(&value_info, arg.value_info); - break; - case ChannelType::Categorical: - inplace_construct(&categorical_info, arg.categorical_info); - break; - case ChannelType::Lie: - inplace_construct(&lie_info, arg.lie_info); - break; - } -} -StreamChannel::StreamChannel(StreamChannel&& arg) noexcept : m_type(arg.m_type) -{ - switch (m_type) { - case ChannelType::Increment: - inplace_construct(&increment_info, arg.increment_info); - break; - case ChannelType::Value: - inplace_construct(&value_info, arg.value_info); - break; - case ChannelType::Categorical: - inplace_construct(&categorical_info, arg.categorical_info); - break; - case ChannelType::Lie: - inplace_construct(&lie_info, arg.lie_info); - break; - } -} - -StreamChannel::StreamChannel(ChannelType type) : m_type(type) -{ - switch (m_type) { - case ChannelType::Increment: - inplace_construct(&increment_info, IncrementChannelInfo()); - break; - case ChannelType::Value: - inplace_construct(&value_info, ValueChannelInfo()); - break; - case ChannelType::Categorical: - inplace_construct(&categorical_info, CategoricalChannelInfo()); - break; - case ChannelType::Lie: - inplace_construct(&lie_info, LieChannelInfo()); - break; - } -} - -StreamChannel& StreamChannel::operator=(const StreamChannel& other) -{ - if (&other != this) { - this->~StreamChannel(); - m_type = other.m_type; - switch (m_type) { - case ChannelType::Increment: - inplace_construct(&increment_info, other.increment_info); - break; - case ChannelType::Value: - inplace_construct(&value_info, other.value_info); - break; - case ChannelType::Categorical: - inplace_construct(&categorical_info, other.categorical_info); - break; - case ChannelType::Lie: - inplace_construct(&lie_info, other.lie_info); - break; - } - } - return *this; -} -StreamChannel& StreamChannel::operator=(StreamChannel&& other) noexcept -{ - if (&other != this) { - this->~StreamChannel(); - m_type = other.m_type; - switch (m_type) { - case ChannelType::Increment: - inplace_construct(&increment_info, - std::move(other.increment_info)); - break; - case ChannelType::Value: - inplace_construct(&value_info, std::move(other.value_info)); - break; - case ChannelType::Categorical: - inplace_construct(&categorical_info, - std::move(other.categorical_info)); - break; - case ChannelType::Lie: - inplace_construct(&lie_info, std::move(other.lie_info)); - break; - } - } - return *this; -} - - - -StreamChannel::~StreamChannel() -{ - switch (m_type) { - case ChannelType::Increment: - increment_info.~IncrementChannelInfo(); - break; - case ChannelType::Value: value_info.~ValueChannelInfo(); break; - case ChannelType::Categorical: - categorical_info.~CategoricalChannelInfo(); - break; - case ChannelType::Lie: lie_info.~LieChannelInfo(); break; - } -} - - - -string StreamChannel::label_suffix(dimn_t variant_no) const -{ - switch (m_type) { - case ChannelType::Increment: return ""; - case ChannelType::Value: - if (value_info.lead_lag) { - RPY_CHECK(variant_no < 2); - return (variant_no == 0) ? ":lead" : ":lag"; - } else { - return ""; - }; - case ChannelType::Categorical: - RPY_CHECK(variant_no < categorical_info.variants.size()); - return ":" + categorical_info.variants[variant_no]; - case ChannelType::Lie: - RPY_CHECK(variant_no < static_cast(lie_info.width)); - return ":" + std::to_string(variant_no + 1); - } - RPY_UNREACHABLE_RETURN({}); -} - -void StreamChannel::set_lie_info(deg_t width, deg_t depth, - algebra::VectorType vtype) -{ - RPY_CHECK(m_type == ChannelType::Lie); - lie_info.width = width; - lie_info.depth = depth; - lie_info.vtype = vtype; -} - -dimn_t StreamChannel::variant_id_of_label(string_view label) const -{ - switch (m_type) { - case ChannelType::Increment: return 0; - case ChannelType::Value: - if (value_info.lead_lag) { - if (label == "lead") { - return 0; - } else if (label == "lag") { - return 1; - } else { - RPY_THROW(std::runtime_error, - "unrecognised variant label for type value"); - } - } else { - return 0; - } - case ChannelType::Categorical: break; - case ChannelType::Lie: - deg_t i = std::stoi(string(label)); - RPY_CHECK(i < lie_info.width); - return i; - } - - auto it = std::find(categorical_info.variants.begin(), - categorical_info.variants.end(), label); - if (it == categorical_info.variants.end()) { - RPY_THROW(std::runtime_error, - "unrecognised variant label for type categorical"); - } - - return static_cast(it - categorical_info.variants.begin()); -} - -StreamChannel& StreamChannel::add_variant(string variant_label) -{ - RPY_CHECK(m_type == ChannelType::Categorical); - - if (variant_label.empty()) { - variant_label = std::to_string(categorical_info.variants.size()); - } - - auto found = std::find(categorical_info.variants.begin(), - categorical_info.variants.end(), variant_label); - if (found != categorical_info.variants.end()) { - RPY_THROW(std::runtime_error,"variant with label " + variant_label - + " already exists"); - } - - categorical_info.variants.push_back(std::move(variant_label)); - return *this; -} -StreamChannel& StreamChannel::insert_variant(string variant_label) -{ - RPY_CHECK(m_type == ChannelType::Categorical); - - if (variant_label.empty()) { - variant_label = std::to_string(categorical_info.variants.size()); - } - - auto var_begin = categorical_info.variants.begin(); - auto var_end = categorical_info.variants.end(); - - auto found = std::find(var_begin, var_end, variant_label); - if (found == var_end) { - categorical_info.variants.push_back(std::move(variant_label)); - } - - return *this; -} - -std::vector StreamChannel::get_variants() const -{ - std::vector variants; - switch (m_type) { - case ChannelType::Increment: break; - case ChannelType::Value: - if (value_info.lead_lag) { - variants.push_back("lead"); - variants.push_back("lag"); - } - break; - case ChannelType::Categorical: - variants = categorical_info.variants; - break; - case ChannelType::Lie: - variants.reserve(lie_info.width); - for (deg_t i = 0; i < lie_info.width; ++i) { - variants.push_back(std::to_string(i)); - } - break; - } - return variants; -} - -/////////////////////////////////////////////////////////////////////////////// -// Static channel -/////////////////////////////////////////////////////////////////////////////// - -StaticChannel::StaticChannel() : m_type(StaticChannelType::Value) -{ - inplace_construct(&value_info, ValueChannelInfo()); -} -StaticChannel::StaticChannel(const StaticChannel& other) : m_type(other.m_type) -{ - switch (m_type) { - case StaticChannelType::Value: - inplace_construct(&value_info, other.value_info); - break; - case StaticChannelType::Categorical: - inplace_construct(&categorical_info, other.categorical_info); - break; - } -} - -StaticChannel::StaticChannel(StaticChannel&& other) noexcept - : m_type(other.m_type) -{ - switch (m_type) { - case StaticChannelType::Value: - inplace_construct(&value_info, std::move(other.value_info)); - break; - case StaticChannelType::Categorical: - inplace_construct(&categorical_info, - std::move(other.categorical_info)); - break; - } -} -StaticChannel::~StaticChannel() -{ - switch (m_type) { - case StaticChannelType::Value: value_info.~ValueChannelInfo(); break; - case StaticChannelType::Categorical: - categorical_info.~CategoricalChannelInfo(); - break; - } -} - -StaticChannel& StaticChannel::operator=(const StaticChannel& other) -{ - if (&other != this) { - this->~StaticChannel(); - m_type = other.m_type; - switch (m_type) { - case StaticChannelType::Value: - inplace_construct(&value_info, other.value_info); - break; - case StaticChannelType::Categorical: - inplace_construct(&categorical_info, other.categorical_info); - break; - } - } - return *this; -} -StaticChannel& StaticChannel::operator=(StaticChannel&& other) noexcept -{ - if (&other != this) { - this->~StaticChannel(); - m_type = other.m_type; - switch (m_type) { - case StaticChannelType::Value: - inplace_construct(&value_info, std::move(other.value_info)); - break; - case StaticChannelType::Categorical: - inplace_construct(&categorical_info, - std::move(other.categorical_info)); - break; - } - } - return *this; -} -string StaticChannel::label_suffix(dimn_t index) const -{ - switch (m_type) { - case StaticChannelType::Value: return {}; - case StaticChannelType::Categorical: { - return categorical_info.variants[index]; - } - } - RPY_UNREACHABLE_RETURN({}); -} -dimn_t StaticChannel::num_variants() const noexcept -{ - switch (m_type) { - case StaticChannelType::Value: return 1; - case StaticChannelType::Categorical: - return categorical_info.variants.size(); - } - RPY_UNREACHABLE_RETURN(0); -} -std::vector StaticChannel::get_variants() const -{ - switch (m_type) { - case StaticChannelType::Value: return {}; - case StaticChannelType::Categorical: return categorical_info.variants; - } - RPY_UNREACHABLE_RETURN({}); -} -dimn_t StaticChannel::variant_id_of_label(const string& label) const -{ - switch (m_type) { - case StaticChannelType::Value: return 0; - case StaticChannelType::Categorical: { - const auto begin = categorical_info.variants.begin(); - const auto end = categorical_info.variants.end(); - const auto found = std::find(begin, end, label); - if (found == end) { - RPY_THROW(std::runtime_error,"label " + label - + " not a valid " - "variant of this " - "channel"); - } - return static_cast(found - begin); - } - } - RPY_UNREACHABLE_RETURN(0); -} -StaticChannel& StaticChannel::insert_variant(string new_variant) -{ - RPY_CHECK(m_type == StaticChannelType::Categorical); - const auto begin = categorical_info.variants.begin(); - const auto end = categorical_info.variants.end(); - const auto found = std::find(begin, end, new_variant); - if (found == end) { - categorical_info.variants.push_back(std::move(new_variant)); - } - return *this; -} -StaticChannel& StaticChannel::add_variant(string new_variant) -{ - RPY_CHECK(m_type == StaticChannelType::Categorical); - const auto begin = categorical_info.variants.begin(); - const auto end = categorical_info.variants.end(); - const auto found = std::find(begin, end, new_variant); - - RPY_CHECK(found == end); - categorical_info.variants.push_back(std::move(new_variant)); - return *this; -} - -#define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::StreamChannel -#define RPY_SERIAL_DO_SPLIT - -#include - -#define RPY_SERIAL_EXTERNAL rpy::streams -#define RPY_SERIAL_IMPL_CLASSNAME IncrementChannelInfo - -#include - -#define RPY_SERIAL_EXTERNAL rpy::streams -#define RPY_SERIAL_IMPL_CLASSNAME ValueChannelInfo - -#include - -#define RPY_SERIAL_EXTERNAL rpy::streams -#define RPY_SERIAL_IMPL_CLASSNAME CategoricalChannelInfo - -#include - -#define RPY_SERIAL_EXTERNAL rpy::streams -#define RPY_SERIAL_IMPL_CLASSNAME LieChannelInfo - -#include - -#define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::StaticChannel -#define RPY_SERIAL_DO_SPLIT - -#include diff --git a/streams/src/channels/categorical_channel.cpp b/streams/src/channels/categorical_channel.cpp new file mode 100644 index 000000000..d0c1c2bf0 --- /dev/null +++ b/streams/src/channels/categorical_channel.cpp @@ -0,0 +1,79 @@ +// +// Created by sam on 07/08/23. +// + +#include + +#include + +using namespace rpy; +using namespace rpy::streams; + + +dimn_t CategoricalChannel::num_variants() const +{ + return m_variants.size(); +} +string CategoricalChannel::label_suffix(dimn_t variant_no) const +{ + RPY_CHECK(variant_no < m_variants.size()); + return ":" + m_variants[variant_no]; +} +dimn_t CategoricalChannel::variant_id_of_label(string_view label) const +{ + auto it = std::find(m_variants.begin(), m_variants.end(), label); + if (it == m_variants.end()) { + RPY_THROW(std::runtime_error, + "unrecognised variant label for type categorical"); + } + + return static_cast(it - m_variants.begin()); +} + +const std::vector& CategoricalChannel::get_variants() const +{ + return m_variants; +} +StreamChannel& CategoricalChannel::add_variant(string variant_label) +{ + string label; + if (variant_label.empty()) { + label = std::to_string(m_variants.size()); + } else { + label = variant_label; + } + + auto var_begin = m_variants.begin(); + auto var_end = m_variants.end(); + auto found = std::find(var_begin, var_end, label); + if (found != var_end) { + RPY_THROW(std::runtime_error, + "variant with label " + label + " already exists"); + } + m_variants.push_back(std::move(label)); + + return *this; + +} +StreamChannel& CategoricalChannel::insert_variant(string variant_label) +{ + string label; + if (variant_label.empty()) { + label = std::to_string(m_variants.size()); + } else { + label = variant_label; + } + + auto var_begin = m_variants.begin(); + auto var_end = m_variants.end(); + auto found = std::find(var_begin, var_end, label); + if (found == var_end) { + m_variants.push_back(std::move(label)); + } + + return *this; +} + +RPY_SERIAL_REGISTER_CLASS(rpy::streams::CategoricalChannel) +#define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::CategoricalChannel +#include \ No newline at end of file diff --git a/streams/src/channels/increment_channel.cpp b/streams/src/channels/increment_channel.cpp new file mode 100644 index 000000000..46c6ddb73 --- /dev/null +++ b/streams/src/channels/increment_channel.cpp @@ -0,0 +1,9 @@ +// +// Created by sam on 07/08/23. +// +#include + +RPY_SERIAL_REGISTER_CLASS(rpy::streams::IncrementChannel); + +#define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::IncrementChannel +#include \ No newline at end of file diff --git a/streams/src/channels/lead_laggable_channel.cpp b/streams/src/channels/lead_laggable_channel.cpp new file mode 100644 index 000000000..5d0d9967c --- /dev/null +++ b/streams/src/channels/lead_laggable_channel.cpp @@ -0,0 +1,53 @@ +// +// Created by sam on 07/08/23. +// + +#include + +using namespace rpy; +using namespace rpy::streams; + +dimn_t LeadLaggableChannel::num_variants() const +{ + return (m_use_leadlag) ? 2 : 1; +} +string LeadLaggableChannel::label_suffix(dimn_t variant_no) const +{ + if (m_use_leadlag) { + if (variant_no == 0) { + return ":lead"; + } else if (variant_no == 1){ + return ":lag"; + } + RPY_THROW(std::invalid_argument, "variant is not valid for a lead-lag channel"); + } + return StreamChannel::label_suffix(variant_no); +} +dimn_t LeadLaggableChannel::variant_id_of_label(string_view label) const +{ + if (m_use_leadlag) { + if (label == "lead") { + return 0; + } else if (label == "lag") { + return 1; + } + } + return StreamChannel::variant_id_of_label(label); +} +const std::vector& LeadLaggableChannel::get_variants() const +{ + static const std::vector leadlag { "lead", "lag" }; + return (m_use_leadlag) ? leadlag : StreamChannel::get_variants(); +} +void LeadLaggableChannel::set_lead_lag(bool new_value) +{ + m_use_leadlag = new_value; +} +bool LeadLaggableChannel::is_lead_lag() const +{ + return m_use_leadlag; +} + +RPY_SERIAL_REGISTER_CLASS(LeadLaggableChannel) +#define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::LeadLaggableChannel +#include \ No newline at end of file diff --git a/streams/src/channels/lie_channel.cpp b/streams/src/channels/lie_channel.cpp new file mode 100644 index 000000000..372a6af36 --- /dev/null +++ b/streams/src/channels/lie_channel.cpp @@ -0,0 +1,14 @@ +// +// Created by sam on 07/08/23. +// + +#include + + + + + + +RPY_SERIAL_REGISTER_CLASS(rpy::streams::LieChannel) +#define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::LieChannel +#include \ No newline at end of file diff --git a/streams/src/channels/stream_channel.cpp b/streams/src/channels/stream_channel.cpp new file mode 100644 index 000000000..da5373306 --- /dev/null +++ b/streams/src/channels/stream_channel.cpp @@ -0,0 +1,67 @@ +// +// Created by sam on 07/08/23. +// + + +#include + + +using namespace rpy; +using namespace rpy::streams; + +StreamChannel::~StreamChannel() {} + +dimn_t StreamChannel::num_variants() const +{ + return 1; +} + +string StreamChannel::label_suffix(rpy::dimn_t variant_no) const +{ + return ""; +} + +dimn_t StreamChannel::variant_id_of_label(string_view label) const { return 0; } +void StreamChannel::set_lie_info( + deg_t width, deg_t depth, algebra::VectorType vtype +) +{} +StreamChannel& StreamChannel::add_variant(string variant_label) +{ + return *this; +} +StreamChannel& StreamChannel::insert_variant(string variant_label) +{ + return *this; +} +const std::vector& StreamChannel::get_variants() const +{ + static const std::vector no_variants; + return no_variants; +} + +void StreamChannel::set_lead_lag(bool new_value) {} +bool StreamChannel::is_lead_lag() const { return false; } +void StreamChannel::convert_input( + scalars::ScalarPointer& dst, const scalars::ScalarPointer& src, + dimn_t count +) const +{ + if (count == 0) { return; } + RPY_CHECK(!src.is_null()); + RPY_CHECK(dst.type() != nullptr); + + if (dst.is_null()) { + dst = dst.type()->allocate(count); + } + + dst.type()->convert_copy(dst, src, count); +} + + + + +#define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::StreamChannel +#define RPY_SERIAL_DO_SPLIT + +#include \ No newline at end of file diff --git a/streams/src/channels/value_channel.cpp b/streams/src/channels/value_channel.cpp new file mode 100644 index 000000000..63e086d26 --- /dev/null +++ b/streams/src/channels/value_channel.cpp @@ -0,0 +1,10 @@ +// +// Created by sam on 07/08/23. +// + +#include + +RPY_SERIAL_REGISTER_CLASS(rpy::streams::ValueChannel) +#define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::ValueChannel + +#include \ No newline at end of file diff --git a/streams/src/external_data_sources/sound_file_data_source.cpp b/streams/src/external_data_sources/sound_file_data_source.cpp index 7cf5eb5c1..e4853809f 100644 --- a/streams/src/external_data_sources/sound_file_data_source.cpp +++ b/streams/src/external_data_sources/sound_file_data_source.cpp @@ -147,7 +147,7 @@ dimn_t SoundFileDataSource::query_impl( for (const auto& [_, chan] : schema) { auto out_idx = schema.channel_to_stream_dim(i); - if (chan.type() == ChannelType::Value) { + if (chan->type() == ChannelType::Value) { working[out_idx] = current[i] - previous[i]; } else { working[out_idx] = current[i]; diff --git a/streams/src/schema_context.cpp b/streams/src/parametrization.cpp similarity index 91% rename from streams/src/schema_context.cpp rename to streams/src/parametrization.cpp index 13fab1ff8..b749db63f 100644 --- a/streams/src/schema_context.cpp +++ b/streams/src/parametrization.cpp @@ -29,17 +29,15 @@ // Created by user on 04/07/23. // -#include - - +#include using namespace rpy; using namespace streams; -SchemaContext::~SchemaContext() = default; +Parameterization::~Parameterization() = default; intervals::RealInterval -SchemaContext::convert_parameter_interval(const intervals::Interval& arg) const +Parameterization::convert_parameter_interval(const intervals::Interval& arg) const { return {m_param_offset + m_param_scaling * arg.inf(), m_param_offset + m_param_scaling * arg.sup(), arg.type()}; diff --git a/streams/src/schema.cpp b/streams/src/schema.cpp index 3fbbb8f5a..165411926 100644 --- a/streams/src/schema.cpp +++ b/streams/src/schema.cpp @@ -71,7 +71,7 @@ bool StreamSchema::compare_labels( dimn_t StreamSchema::channel_it_to_width(const_iterator channel_it) const { - return channel_it->second.num_variants(); + return channel_it->second->num_variants(); } dimn_t StreamSchema::width_to_iterator(const_iterator end) const @@ -119,30 +119,6 @@ typename StreamSchema::iterator StreamSchema::find(const string& label) return it_end; } -typename StreamSchema::static_iterator -StreamSchema::find_static(const string& label) -{ - RPY_CHECK(!m_is_final); - auto it_current = m_static_channels.begin(); - const auto it_end = m_static_channels.end(); - - for (; it_current != it_end; ++it_current) { - if (compare_labels(it_current->first, label)) { return it_current; } - } - return it_end; -} -typename StreamSchema::static_const_iterator -StreamSchema::find_static(const string& label) const -{ - auto it_current = m_static_channels.cbegin(); - const auto it_end = m_static_channels.cend(); - - for (; it_current != it_end; ++it_current) { - if (compare_labels(it_current->first, label)) { return it_current; } - } - return it_end; -} - dimn_t StreamSchema::width() const { return width_to_iterator(end()); } dimn_t StreamSchema::channel_to_stream_dim(dimn_t channel_no) const @@ -158,7 +134,7 @@ dimn_t StreamSchema::channel_variant_to_stream_dim( auto it = nth(channel_no); auto so_far = width_to_iterator(it); - RPY_CHECK(variant_no < it->second.num_variants()); + RPY_CHECK(variant_no < it->second->num_variants()); return so_far + variant_no; } @@ -175,7 +151,7 @@ string StreamSchema::label_from_channel_it( const_iterator channel_it, dimn_t variant_id ) { - return channel_it->first + channel_it->second.label_suffix(variant_id); + return channel_it->first + channel_it->second->label_suffix(variant_id); } string StreamSchema::label_of_stream_dim(dimn_t stream_dim) const @@ -222,11 +198,13 @@ dimn_t StreamSchema::label_to_stream_dim(const string& label) const const string_view variant_label( &*variant_begin, static_cast(label.end() - variant_begin) ); - result += channel->second.variant_id_of_label(variant_label); + result += channel->second->variant_id_of_label(variant_label); return result; } -StreamChannel& StreamSchema::insert(string label, StreamChannel&& channel_data) +StreamChannel& StreamSchema::insert( + string label, std::unique_ptr&& channel_data +) { RPY_CHECK(!m_is_final); if (label.empty()) { label = std::to_string(size()); } @@ -235,10 +213,10 @@ StreamChannel& StreamSchema::insert(string label, StreamChannel&& channel_data) // Silly, but handle it gracefully. auto pos = find(label); - if (pos != end()) { return pos->second; } + if (pos != end()) { return *pos->second; } - return base_type::insert(pos, {std::move(label), std::move(channel_data)}) - ->second; + return *base_type::insert(pos, {std::move(label), std::move(channel_data)}) + ->second; } intervals::RealInterval @@ -248,52 +226,24 @@ StreamSchema::adjust_interval(const intervals::Interval& arg) const return intervals::RealInterval(arg); } -StreamChannel& StreamSchema::insert(StreamChannel&& channel_data) -{ - return insert(std::to_string(width()), std::move(channel_data)); -} StreamChannel& StreamSchema::insert_increment(string label) { - return insert(std::move(label), StreamChannel(IncrementChannelInfo())); + return insert(std::move(label), std::make_unique()); } StreamChannel& StreamSchema::insert_value(string label) { - return insert(std::move(label), StreamChannel(ValueChannelInfo())); + return insert(std::move(label), std::make_unique()); } StreamChannel& StreamSchema::insert_categorical(string label) { - return insert(std::move(label), StreamChannel(CategoricalChannelInfo())); + return insert(std::move(label), std::make_unique()); } StreamChannel& StreamSchema::insert_lie(string label) { - return insert(std::move(label), StreamChannel(LieChannelInfo())); -} - -StaticChannel& StreamSchema::insert_static_value(string label) -{ - auto found = find_static(label); - auto end = end_static(); - - if (found != end) { return found->second; } - - return m_static_channels - .insert(end, {std::move(label), StaticChannel(ValueChannelInfo())}) - ->second; + return insert(std::move(label), std::make_unique()); } -StaticChannel& StreamSchema::insert_static_categorical(string label) -{ - auto found = find_static(label); - auto end = end_static(); - - if (found != end) { return found->second; } - - return m_static_channels - .insert(end, - {std::move(label), StaticChannel(CategoricalChannelInfo())}) - ->second; -} typename StreamSchema::lie_key StreamSchema::label_to_lie_key(const string& label) { diff --git a/streams/src/stream_construction_helper.cpp b/streams/src/stream_construction_helper.cpp index d472dbd21..47196a99d 100644 --- a/streams/src/stream_construction_helper.cpp +++ b/streams/src/stream_construction_helper.cpp @@ -66,7 +66,7 @@ void StreamConstructionHelper::add_categorical(param_t timestamp, { const auto found = p_schema->find(string(channel)); RPY_CHECK(found != p_schema->end()); - RPY_CHECK(variant < found->second.num_variants()); + RPY_CHECK(variant < found->second->num_variants()); auto key = static_cast(found - p_schema->begin()) + static_cast(variant) + 1; next_entry(timestamp)[key] += p_ctx->ctype()->one(); @@ -78,7 +78,7 @@ void StreamConstructionHelper::add_categorical(param_t timestamp, RPY_CHECK(channel < p_schema->size()); const auto channel_item = p_schema->nth(channel); - const auto variants = channel_item->second.get_variants(); + const auto& variants = channel_item->second->get_variants(); const auto found = std::find(variants.begin(), variants.end(), variant); RPY_CHECK(found != variants.end()); @@ -107,6 +107,6 @@ optional StreamConstructionHelper::type_of(string_view label) const { const auto& schema = *p_schema; auto found = schema.find(string(label)); - if (found != schema.end()) { return found->second.type(); } + if (found != schema.end()) { return found->second->type(); } return {}; } From cceba99562b894dd91cbc5d0fb0e350762238cd2 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Mon, 7 Aug 2023 15:51:06 +0100 Subject: [PATCH 02/33] rename parametrization var and include in width calculation. --- roughpy/CMakeLists.txt | 4 ++-- ..._schema_context.cpp => py_parametrization.cpp} | 2 +- .../{py_schema_context.h => py_parametrization.h} | 0 .../src/streams/r_py_tick_construction_helper.cpp | 4 ++-- roughpy/src/streams/schema.cpp | 2 +- streams/include/roughpy/streams/parametrization.h | 5 +++++ streams/include/roughpy/streams/schema.h | 15 ++++++++------- streams/src/schema.cpp | 10 ++++++++-- 8 files changed, 27 insertions(+), 15 deletions(-) rename roughpy/src/streams/{py_schema_context.cpp => py_parametrization.cpp} (97%) rename roughpy/src/streams/{py_schema_context.h => py_parametrization.h} (100%) diff --git a/roughpy/CMakeLists.txt b/roughpy/CMakeLists.txt index 67136a96f..80e047f3f 100644 --- a/roughpy/CMakeLists.txt +++ b/roughpy/CMakeLists.txt @@ -117,8 +117,8 @@ target_sources(RoughPy_PyModule PRIVATE src/streams/piecewise_abelian_stream.h src/streams/r_py_tick_construction_helper.cpp src/streams/r_py_tick_construction_helper.h - src/streams/py_schema_context.cpp - src/streams/py_schema_context.h + src/streams/py_parametrization.cpp + src/streams/py_parametrization.h src/streams/schema.cpp src/streams/schema.h src/streams/stream.cpp diff --git a/roughpy/src/streams/py_schema_context.cpp b/roughpy/src/streams/py_parametrization.cpp similarity index 97% rename from roughpy/src/streams/py_schema_context.cpp rename to roughpy/src/streams/py_parametrization.cpp index 69364e1a9..da5890af0 100644 --- a/roughpy/src/streams/py_schema_context.cpp +++ b/roughpy/src/streams/py_parametrization.cpp @@ -2,7 +2,7 @@ // Created by user on 04/07/23. // -#include "py_schema_context.h" +#include "py_parametrization.h" #include "intervals/date_time_interval.h" diff --git a/roughpy/src/streams/py_schema_context.h b/roughpy/src/streams/py_parametrization.h similarity index 100% rename from roughpy/src/streams/py_schema_context.h rename to roughpy/src/streams/py_parametrization.h diff --git a/roughpy/src/streams/r_py_tick_construction_helper.cpp b/roughpy/src/streams/r_py_tick_construction_helper.cpp index 51696c7ef..b1dad4314 100644 --- a/roughpy/src/streams/r_py_tick_construction_helper.cpp +++ b/roughpy/src/streams/r_py_tick_construction_helper.cpp @@ -9,7 +9,7 @@ #include "args/convert_timestamp.h" #include "args/parse_schema.h" -#include "py_schema_context.h" +#include "py_parametrization.h" using namespace rpy; using namespace rpy::streams; @@ -33,7 +33,7 @@ python::RPyTickConstructionHelper::RPyTickConstructionHelper( m_reference_time(py::none()), m_time_conversion_options{PyDateTimeResolution::Seconds} { - if (!p_schema->is_final() && p_schema->context() == nullptr) { + if (!p_schema->is_final() && p_schema->parametrization() == nullptr) { p_schema->init_context(m_time_conversion_options); } RPY_CHECK(!schema_only || !p_schema->is_final()); diff --git a/roughpy/src/streams/schema.cpp b/roughpy/src/streams/schema.cpp index e170f2863..3c39c3b87 100644 --- a/roughpy/src/streams/schema.cpp +++ b/roughpy/src/streams/schema.cpp @@ -60,7 +60,7 @@ inline void init_channel_item(py::module_& m) .value("LieChannel", ChannelType::Lie) .export_values(); - py::class_ cls(m, "StreamChannel"); + //py::class_> cls(m, "StreamChannel"); } std::shared_ptr diff --git a/streams/include/roughpy/streams/parametrization.h b/streams/include/roughpy/streams/parametrization.h index e1ad6aa51..9edbc6ff7 100644 --- a/streams/include/roughpy/streams/parametrization.h +++ b/streams/include/roughpy/streams/parametrization.h @@ -28,6 +28,11 @@ class RPY_EXPORT Parameterization return b_add_to_channels || (m_is_channel >= 0); } + RPY_NO_DISCARD bool needs_adding() const noexcept + { + return b_add_to_channels && m_is_channel < 0; + } + RPY_NO_DISCARD intervals::RealInterval reparametrize(const intervals::RealInterval& arg) const { diff --git a/streams/include/roughpy/streams/schema.h b/streams/include/roughpy/streams/schema.h index b8c9d5699..cea798cf8 100644 --- a/streams/include/roughpy/streams/schema.h +++ b/streams/include/roughpy/streams/schema.h @@ -74,13 +74,13 @@ namespace streams { * and other types of data might occupy multiple stream dimensions. */ class RPY_EXPORT StreamSchema - : private std::vector>> + : private std::vector>> { - using base_type = std::vector>>; + using base_type = std::vector>>; bool m_is_final = false; - std::unique_ptr p_context; + std::unique_ptr p_parameterization; public: using typename base_type::const_iterator; @@ -126,17 +126,18 @@ class RPY_EXPORT StreamSchema return begin() + static_cast(idx); } - RPY_NO_DISCARD Parameterization* context() const noexcept + RPY_NO_DISCARD Parameterization* parametrization() const noexcept { - return p_context.get(); + return p_parameterization.get(); } template enable_if_t::value> init_context(Args&&... args) { - RPY_DBG_ASSERT(!p_context); - p_context = std::make_unique(std::forward(args)...); + RPY_DBG_ASSERT(!p_parameterization); + p_parameterization + = std::make_unique(std::forward(args)...); } public: diff --git a/streams/src/schema.cpp b/streams/src/schema.cpp index 165411926..000a20c91 100644 --- a/streams/src/schema.cpp +++ b/streams/src/schema.cpp @@ -119,7 +119,13 @@ typename StreamSchema::iterator StreamSchema::find(const string& label) return it_end; } -dimn_t StreamSchema::width() const { return width_to_iterator(end()); } +dimn_t StreamSchema::width() const { + auto channels_width = width_to_iterator(end()); + if (p_parameterization && p_parameterization->needs_adding()) { + channels_width += 1; + } + return channels_width; +} dimn_t StreamSchema::channel_to_stream_dim(dimn_t channel_no) const { @@ -222,7 +228,7 @@ StreamChannel& StreamSchema::insert( intervals::RealInterval StreamSchema::adjust_interval(const intervals::Interval& arg) const { - if (p_context) { return p_context->convert_parameter_interval(arg); } + if (p_parameterization) { return p_parameterization->convert_parameter_interval(arg); } return intervals::RealInterval(arg); } From a1ac3f4605803934a0d7f5e8911b1f40b9b4e08c Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Mon, 7 Aug 2023 16:14:49 +0100 Subject: [PATCH 03/33] Added method to get time channel key. --- streams/include/roughpy/streams/schema.h | 2 ++ streams/src/schema.cpp | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/streams/include/roughpy/streams/schema.h b/streams/include/roughpy/streams/schema.h index cea798cf8..a16663563 100644 --- a/streams/include/roughpy/streams/schema.h +++ b/streams/include/roughpy/streams/schema.h @@ -190,6 +190,8 @@ class RPY_EXPORT StreamSchema RPY_NO_DISCARD lie_key label_to_lie_key(const string& label); + RPY_NO_DISCARD lie_key time_channel_to_lie_key() const; + RPY_SERIAL_SERIALIZE_FN(); }; diff --git a/streams/src/schema.cpp b/streams/src/schema.cpp index 000a20c91..5a85de645 100644 --- a/streams/src/schema.cpp +++ b/streams/src/schema.cpp @@ -256,6 +256,13 @@ StreamSchema::label_to_lie_key(const string& label) auto idx = label_to_stream_dim(label); return static_cast(idx) + 1; } +typename StreamSchema::lie_key StreamSchema::time_channel_to_lie_key() const +{ + RPY_CHECK(p_parameterization); + RPY_CHECK(p_parameterization->needs_adding()); + + return static_cast(width_to_iterator(end())) + 1; +} #define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::StreamSchema From 9f286a2712feea93958f3d0f4114c9901a8751e9 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Mon, 7 Aug 2023 16:51:52 +0100 Subject: [PATCH 04/33] Width without param for building values. --- streams/include/roughpy/streams/schema.h | 1 + streams/src/schema.cpp | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/streams/include/roughpy/streams/schema.h b/streams/include/roughpy/streams/schema.h index a16663563..a4196665a 100644 --- a/streams/include/roughpy/streams/schema.h +++ b/streams/include/roughpy/streams/schema.h @@ -151,6 +151,7 @@ class RPY_EXPORT StreamSchema } RPY_NO_DISCARD dimn_t width() const; + RPY_NO_DISCARD dimn_t width_without_param() const; RPY_NO_DISCARD dimn_t channel_to_stream_dim(dimn_t channel_no) const; diff --git a/streams/src/schema.cpp b/streams/src/schema.cpp index 5a85de645..650b8792f 100644 --- a/streams/src/schema.cpp +++ b/streams/src/schema.cpp @@ -120,12 +120,17 @@ typename StreamSchema::iterator StreamSchema::find(const string& label) } dimn_t StreamSchema::width() const { - auto channels_width = width_to_iterator(end()); + auto channels_width = width_without_param(); if (p_parameterization && p_parameterization->needs_adding()) { channels_width += 1; } return channels_width; } +dimn_t StreamSchema::width_without_param() const { + return width_to_iterator(end()); +} + + dimn_t StreamSchema::channel_to_stream_dim(dimn_t channel_no) const { From b959809da1756e2ea6643b06f99111017fcda03d Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Tue, 8 Aug 2023 12:26:03 +0100 Subject: [PATCH 05/33] Schema now creates parametrization by default --- streams/include/roughpy/streams/schema.h | 2 +- streams/src/schema.cpp | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/streams/include/roughpy/streams/schema.h b/streams/include/roughpy/streams/schema.h index a4196665a..b58d86b1c 100644 --- a/streams/include/roughpy/streams/schema.h +++ b/streams/include/roughpy/streams/schema.h @@ -111,7 +111,7 @@ class RPY_EXPORT StreamSchema const_iterator stream_dim_to_channel_it(dimn_t& stream_dim) const; public: - StreamSchema() = default; + StreamSchema(); StreamSchema(const StreamSchema&) = delete; StreamSchema(StreamSchema&&) noexcept = default; diff --git a/streams/src/schema.cpp b/streams/src/schema.cpp index 650b8792f..28d36422b 100644 --- a/streams/src/schema.cpp +++ b/streams/src/schema.cpp @@ -38,6 +38,9 @@ using namespace rpy; using namespace rpy::streams; +StreamSchema::StreamSchema() : p_parameterization(new Parameterization) +{} + StreamSchema::StreamSchema(dimn_t width) { reserve(width); From 0ed532356a3ecaba523a51c7a1058b2a69b0bb62 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Tue, 8 Aug 2023 12:28:59 +0100 Subject: [PATCH 06/33] Remove checks for null parametrization --- streams/src/schema.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/streams/src/schema.cpp b/streams/src/schema.cpp index 28d36422b..cac5c2851 100644 --- a/streams/src/schema.cpp +++ b/streams/src/schema.cpp @@ -124,7 +124,7 @@ typename StreamSchema::iterator StreamSchema::find(const string& label) dimn_t StreamSchema::width() const { auto channels_width = width_without_param(); - if (p_parameterization && p_parameterization->needs_adding()) { + if (p_parameterization->needs_adding()) { channels_width += 1; } return channels_width; @@ -236,8 +236,7 @@ StreamChannel& StreamSchema::insert( intervals::RealInterval StreamSchema::adjust_interval(const intervals::Interval& arg) const { - if (p_parameterization) { return p_parameterization->convert_parameter_interval(arg); } - return intervals::RealInterval(arg); + return p_parameterization->convert_parameter_interval(arg); } From 5f3d69c1cf292180f63a70a314ff916f9317443f Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Tue, 8 Aug 2023 12:30:46 +0100 Subject: [PATCH 07/33] Always construct the schema. --- roughpy/src/args/kwargs_to_path_metadata.cpp | 40 ++++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/roughpy/src/args/kwargs_to_path_metadata.cpp b/roughpy/src/args/kwargs_to_path_metadata.cpp index a6d1fabfc..b5c58b716 100644 --- a/roughpy/src/args/kwargs_to_path_metadata.cpp +++ b/roughpy/src/args/kwargs_to_path_metadata.cpp @@ -137,26 +137,6 @@ python::kwargs_to_metadata(const pybind11::kwargs& kwargs) md.width = md.ctx->width(); md.scalar_type = md.ctx->ctype(); - if (!md.schema) { - md.schema = std::make_shared(); - for (deg_t i = 0; i < md.width; ++i) { - switch (ch_type) { - case streams::ChannelType::Increment: - md.schema->insert_increment(""); - break; - case streams::ChannelType::Value: - md.schema->insert_value(""); - break; - case streams::ChannelType::Categorical: - md.schema->insert_categorical(""); - break; - case streams::ChannelType::Lie: - md.schema->insert_lie(""); - break; - } - } - } - auto schema_width = static_cast(md.schema->width()); if (schema_width != md.width) { md.width = schema_width; @@ -188,6 +168,26 @@ python::kwargs_to_metadata(const pybind11::kwargs& kwargs) } } + if (!md.schema) { + md.schema = std::make_shared(); + for (deg_t i = 0; i < md.width; ++i) { + switch (ch_type) { + case streams::ChannelType::Increment: + md.schema->insert_increment(""); + break; + case streams::ChannelType::Value: + md.schema->insert_value(""); + break; + case streams::ChannelType::Categorical: + md.schema->insert_categorical(""); + break; + case streams::ChannelType::Lie: + md.schema->insert_lie(""); + break; + } + } + } + if (kwargs.contains("vtype")) { md.vector_type = kwargs["vtype"].cast(); } From 3cafd4c30ab2ccfc102a0c6ada3ffd04427d919b Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Tue, 8 Aug 2023 12:46:27 +0100 Subject: [PATCH 08/33] Missed a constructor --- streams/src/schema.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/streams/src/schema.cpp b/streams/src/schema.cpp index cac5c2851..0b03c4b8e 100644 --- a/streams/src/schema.cpp +++ b/streams/src/schema.cpp @@ -42,6 +42,7 @@ StreamSchema::StreamSchema() : p_parameterization(new Parameterization) {} StreamSchema::StreamSchema(dimn_t width) + : p_parameterization(new Parameterization) { reserve(width); for (dimn_t i = 0; i < width; ++i) { insert_increment(std::to_string(i)); } From 188d53f24e24d02a74671a0826134710968d24b1 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Tue, 8 Aug 2023 13:33:58 +0100 Subject: [PATCH 09/33] Don't construct schema here, it causes problems downstream. --- roughpy/src/args/kwargs_to_path_metadata.cpp | 40 ++++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/roughpy/src/args/kwargs_to_path_metadata.cpp b/roughpy/src/args/kwargs_to_path_metadata.cpp index b5c58b716..a6d1fabfc 100644 --- a/roughpy/src/args/kwargs_to_path_metadata.cpp +++ b/roughpy/src/args/kwargs_to_path_metadata.cpp @@ -137,6 +137,26 @@ python::kwargs_to_metadata(const pybind11::kwargs& kwargs) md.width = md.ctx->width(); md.scalar_type = md.ctx->ctype(); + if (!md.schema) { + md.schema = std::make_shared(); + for (deg_t i = 0; i < md.width; ++i) { + switch (ch_type) { + case streams::ChannelType::Increment: + md.schema->insert_increment(""); + break; + case streams::ChannelType::Value: + md.schema->insert_value(""); + break; + case streams::ChannelType::Categorical: + md.schema->insert_categorical(""); + break; + case streams::ChannelType::Lie: + md.schema->insert_lie(""); + break; + } + } + } + auto schema_width = static_cast(md.schema->width()); if (schema_width != md.width) { md.width = schema_width; @@ -168,26 +188,6 @@ python::kwargs_to_metadata(const pybind11::kwargs& kwargs) } } - if (!md.schema) { - md.schema = std::make_shared(); - for (deg_t i = 0; i < md.width; ++i) { - switch (ch_type) { - case streams::ChannelType::Increment: - md.schema->insert_increment(""); - break; - case streams::ChannelType::Value: - md.schema->insert_value(""); - break; - case streams::ChannelType::Categorical: - md.schema->insert_categorical(""); - break; - case streams::ChannelType::Lie: - md.schema->insert_lie(""); - break; - } - } - } - if (kwargs.contains("vtype")) { md.vector_type = kwargs["vtype"].cast(); } From 877614e6eceac112284bd11667f193db2f6c9bab Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Tue, 8 Aug 2023 13:36:25 +0100 Subject: [PATCH 10/33] Added cbh for slice of Lie pointers --- algebra/include/roughpy/algebra/context.h | 6 +++ algebra/src/context.cpp | 59 ++++++++++++++++++----- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/algebra/include/roughpy/algebra/context.h b/algebra/include/roughpy/algebra/context.h index 31f35769f..b7d430892 100644 --- a/algebra/include/roughpy/algebra/context.h +++ b/algebra/include/roughpy/algebra/context.h @@ -176,9 +176,15 @@ class RPY_EXPORT Context : public ContextBase void cbh_fallback(FreeTensor& collector, const std::vector& lies) const; + void + cbh_fallback(FreeTensor& collector, Slice lies) const; + public: RPY_NO_DISCARD virtual Lie cbh(const std::vector& lies, VectorType vtype) const; + RPY_NO_DISCARD virtual Lie + cbh(Slice lies, VectorType vtype) const; + RPY_NO_DISCARD virtual Lie cbh(const Lie& left, const Lie& right, VectorType vtype) const; diff --git a/algebra/src/context.cpp b/algebra/src/context.cpp index 8858eca4f..5934e0029 100644 --- a/algebra/src/context.cpp +++ b/algebra/src/context.cpp @@ -161,6 +161,14 @@ void Context::cbh_fallback(FreeTensor& collector, const std::vector& lies) } } } +void Context::cbh_fallback(FreeTensor& collector, Slice lies) const +{ + for (const auto* alie : lies) { + if (alie->dimension() != 0) { + collector.fmexp(this->lie_to_tensor(*alie)); + } + } +} Lie Context::cbh(const std::vector& lies, VectorType vtype) const { @@ -173,6 +181,19 @@ Lie Context::cbh(const std::vector& lies, VectorType vtype) const return tensor_to_lie(collector.log()); } +Lie Context::cbh(Slice lies, VectorType vtype) const { + if (lies.size() == 1) { + return convert(*lies[0], vtype); + } + + FreeTensor collector = zero_free_tensor(vtype); + collector[0] = scalars::Scalar(1); + + cbh_fallback(collector, lies); + + return tensor_to_lie(collector.log()); +} + Lie Context::cbh(const Lie& left, const Lie& right, VectorType vtype) const { FreeTensor tmp = lie_to_tensor(left).exp(); @@ -241,12 +262,16 @@ context_pointer rpy::algebra::get_context( } if (found.empty()) { - RPY_THROW(std::invalid_argument,"cannot find a context maker for the " - "width, depth, dtype, and preferences set"); + RPY_THROW( + std::invalid_argument, + "cannot find a context maker for the " + "width, depth, dtype, and preferences set" + ); } if (found.size() > 1) { - RPY_THROW(std::invalid_argument, + RPY_THROW( + std::invalid_argument, "found multiple context maker candidates for specified width, " "depth, dtype, and preferences set" ); @@ -314,30 +339,42 @@ UnspecifiedAlgebraType Context::free_multiply( const ConstRawUnspecifiedAlgebraType right ) const { - RPY_THROW(std::runtime_error,"free tensor multiply is not implemented for " - "arbitrary types with this backend"); + RPY_THROW( + std::runtime_error, + "free tensor multiply is not implemented for " + "arbitrary types with this backend" + ); } UnspecifiedAlgebraType Context::shuffle_multiply( ConstRawUnspecifiedAlgebraType left, ConstRawUnspecifiedAlgebraType right ) const { - RPY_THROW(std::runtime_error,"shuffle multiply is not implemented for " - "arbitrary types with this backend"); + RPY_THROW( + std::runtime_error, + "shuffle multiply is not implemented for " + "arbitrary types with this backend" + ); } UnspecifiedAlgebraType Context::half_shuffle_multiply( ConstRawUnspecifiedAlgebraType left, ConstRawUnspecifiedAlgebraType right ) const { - RPY_THROW(std::runtime_error,"half shuffle multiply is not implemented for " - "arbitrary types with this backend"); + RPY_THROW( + std::runtime_error, + "half shuffle multiply is not implemented for " + "arbitrary types with this backend" + ); } UnspecifiedAlgebraType Context::adjoint_to_left_multiply_by( ConstRawUnspecifiedAlgebraType multiplier, ConstRawUnspecifiedAlgebraType argument ) const { - RPY_THROW(std::runtime_error,"adjoint of left multiply is not implemented for " - "arbitrary types with this backend"); + RPY_THROW( + std::runtime_error, + "adjoint of left multiply is not implemented for " + "arbitrary types with this backend" + ); } From d91d17a837439d4c5a293adcab6aead15a08179f Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Tue, 8 Aug 2023 13:36:46 +0100 Subject: [PATCH 11/33] Overhaul of LieIncrementStream --- roughpy/src/streams/lie_increment_stream.cpp | 16 +- .../roughpy/streams/lie_increment_stream.h | 43 +++-- streams/src/lie_increment_stream.cpp | 147 ++++++++++-------- 3 files changed, 115 insertions(+), 91 deletions(-) diff --git a/roughpy/src/streams/lie_increment_stream.cpp b/roughpy/src/streams/lie_increment_stream.cpp index 9966fabfb..c50e32a61 100644 --- a/roughpy/src/streams/lie_increment_stream.cpp +++ b/roughpy/src/streams/lie_increment_stream.cpp @@ -139,11 +139,13 @@ static py::object lie_increment_stream_from_increments( if (icol < 0 || icol >= increment_size) { RPY_THROW(py::value_error, "index out of bounds"); } + RPY_CHECK(icol < buffer.size()); indices.reserve(num_increments); - for (idimn_t i = 0; i < num_increments; ++i) { + indices.push_back(buffer[icol].to_scalar_t()); + for (idimn_t i = 1; i < num_increments; ++i) { indices.push_back(static_cast( - buffer[i * increment_size + icol].to_scalar_t() + indices.back() + buffer[i * increment_size + icol].to_scalar_t() )); } } else if (py::isinstance(indices_arg)) { @@ -171,11 +173,17 @@ static py::object lie_increment_stream_from_increments( "number of indices"); } + if (!md.schema) { + md.schema = std::make_shared(md.width); + } + + auto result = streams::Stream(streams::LieIncrementStream( - std::move(buffer).copy_or_move(), indices, + buffer, indices, {md.width, effective_support, md.ctx, md.scalar_type, md.vector_type ? *md.vector_type : algebra::VectorType::Dense, - md.resolution} + md.resolution}, + md.schema )); if (md.support) { result.restrict_to(*md.support); } diff --git a/streams/include/roughpy/streams/lie_increment_stream.h b/streams/include/roughpy/streams/lie_increment_stream.h index c04c94ff9..700cef2dd 100644 --- a/streams/include/roughpy/streams/lie_increment_stream.h +++ b/streams/include/roughpy/streams/lie_increment_stream.h @@ -43,29 +43,29 @@ namespace streams { class RPY_EXPORT LieIncrementStream : public DyadicCachingLayer { - scalars::KeyScalarArray m_buffer; - boost::container::flat_map m_mapping; - using base_t = DyadicCachingLayer; public: - LieIncrementStream(scalars::KeyScalarArray&& buffer, - boost::container::flat_map&& mapping, - StreamMetadata&& md) - : DyadicCachingLayer(std::move(md)), m_buffer(std::move(buffer)), - m_mapping(std::move(mapping)) - {} + using Lie = algebra::Lie; + +private: + boost::container::flat_map m_data; + +public: + using DyadicCachingLayer::DyadicCachingLayer; - LieIncrementStream(scalars::KeyScalarArray&& buffer, Slice indices, - StreamMetadata md); + LieIncrementStream( + const scalars::KeyScalarArray& buffer, Slice indices, + StreamMetadata md, std::shared_ptr schema + ); - RPY_NO_DISCARD - bool empty(const intervals::Interval& interval) const noexcept override; + RPY_NO_DISCARD bool empty(const intervals::Interval& interval + ) const noexcept override; protected: - RPY_NO_DISCARD - algebra::Lie log_signature_impl(const intervals::Interval& interval, - const algebra::Context& ctx) const override; + RPY_NO_DISCARD algebra::Lie log_signature_impl( + const intervals::Interval& interval, const algebra::Context& ctx + ) const override; public: RPY_SERIAL_SERIALIZE_FN(); @@ -75,8 +75,7 @@ RPY_SERIAL_SERIALIZE_FN_IMPL(LieIncrementStream) { StreamMetadata md = metadata(); RPY_SERIAL_SERIALIZE_NVP("metadata", md); - RPY_SERIAL_SERIALIZE_NVP("buffer", m_buffer); - RPY_SERIAL_SERIALIZE_NVP("mapping", m_mapping); + RPY_SERIAL_SERIALIZE_NVP("data", m_data); } }// namespace streams @@ -91,12 +90,10 @@ RPY_SERIAL_LOAD_AND_CONSTRUCT(rpy::streams::LieIncrementStream) StreamMetadata md; RPY_SERIAL_SERIALIZE_NVP("metadata", md); - scalars::KeyScalarArray buffer; - RPY_SERIAL_SERIALIZE_VAL(buffer); - boost::container::flat_map mapping; - RPY_SERIAL_SERIALIZE_VAL(mapping); + boost::container::flat_map data; + RPY_SERIAL_SERIALIZE_VAL(data); - construct(std::move(buffer), std::move(mapping), std::move(md)); + construct(std::move(md)); } #endif diff --git a/streams/src/lie_increment_stream.cpp b/streams/src/lie_increment_stream.cpp index 4b65021f3..0563a0918 100644 --- a/streams/src/lie_increment_stream.cpp +++ b/streams/src/lie_increment_stream.cpp @@ -35,94 +35,113 @@ using namespace rpy; using namespace rpy::streams; -LieIncrementStream::LieIncrementStream(scalars::KeyScalarArray&& buffer, - Slice indices, - StreamMetadata metadata) - : base_t(std::move(metadata)), m_buffer(std::move(buffer)) +LieIncrementStream::LieIncrementStream( + const scalars::KeyScalarArray& buffer, Slice indices, + StreamMetadata metadata, std::shared_ptr schema +) + : base_t(std::move(metadata), std::move(schema)) { + using scalars::Scalar; + const auto& md = this->metadata(); - for (dimn_t i = 0; i < indices.size(); ++i) { - m_mapping[indices[i]] = i * md.width; + const auto& ctx = *md.default_context; + + const auto& sch = this->schema(); + const auto* param = sch.parametrization(); + const bool param_needs_adding = param != nullptr && param->needs_adding(); + const key_type param_slot + = (param_needs_adding) ? sch.time_channel_to_lie_key() : 0; + + m_data.reserve(indices.size()); + + if (!buffer.is_null()) { + if (buffer.has_keys()) { + /* + * The data is sparse, so we need to do some careful checking to make + * sure we pick out individual increments. + * TODO: Need key-scalar stream to implement this properly. + */ + RPY_THROW( + std::runtime_error, + "creating a Lie increment stream with sparse data is not " + "currently supported" + ); + + } else { + /* + * The data is dense. The only tricky part for this case is dealing with + * adding the "time" channel if it is given. + * + * Until we construct the relevant support mechanisms, we assume that + * the provided increments have degree 1. + * TODO: Add support for key-scalar streams. + */ + + const auto width = sch.width_without_param(); + + const char* dptr = buffer.raw_cast(); + const auto stride = buffer.type()->itemsize() * width; + param_t previous_param = 0.0; + for (auto index : indices) { + + algebra::VectorConstructionData cdata{ + scalars::KeyScalarArray(buffer.type(), dptr, width), + md.cached_vector_type}; + + auto [it, inserted] + = m_data.try_emplace(index, ctx.construct_lie(cdata)); + + if (inserted && param_needs_adding) { + /* + * We've inserted a new element, so we should now add the param + * value if it is needed. + */ + it->second[param_slot] = Scalar(index - previous_param); + } + previous_param = index; + dptr += stride; + } + } } - - // std::cerr << m_mapping.begin()->first << ' ' << - // (--m_mapping.end())->first << '\n'; } -algebra::Lie -LieIncrementStream::log_signature_impl(const intervals::Interval& interval, - const algebra::Context& ctx) const +algebra::Lie LieIncrementStream::log_signature_impl( + const intervals::Interval& interval, const algebra::Context& ctx +) const { const auto& md = metadata(); - // if (empty(interval)) { - // return ctx.zero_lie(md.cached_vector_type); - // } - - rpy::algebra::SignatureData data{scalars::ScalarStream(ctx.ctype()), - {}, - md.cached_vector_type}; - - if (m_mapping.size() == 1) { - data.data_stream.set_elts_per_row(m_buffer.size()); - } else if (m_mapping.size() > 1) { - auto row1 = (++m_mapping.begin())->second; - auto row0 = m_mapping.begin()->second; - data.data_stream.set_elts_per_row(row1 - row0); - } auto begin = (interval.type() == intervals::IntervalType::Opencl) - ? m_mapping.upper_bound(interval.inf()) - : m_mapping.lower_bound(interval.inf()); + ? m_data.upper_bound(interval.inf()) + : m_data.lower_bound(interval.inf()); auto end = (interval.type() == intervals::IntervalType::Opencl) - ? m_mapping.upper_bound(interval.sup()) - : m_mapping.lower_bound(interval.sup()); + ? m_data.upper_bound(interval.sup()) + : m_data.lower_bound(interval.sup()); if (begin == end) { return ctx.zero_lie(md.cached_vector_type); } - data.data_stream.reserve_size(end - begin); + std::vector lies; + lies.reserve(static_cast(end - begin)); - for (auto it1 = begin, it = it1++; it1 != end; ++it, ++it1) { - data.data_stream.push_back( - {m_buffer[it->second].to_pointer(), it1->second - it->second}); + for (auto it = begin; it != end; ++it) { + lies.push_back(&it->second); } - // Case it = it1 - 1 and it1 == end - --end; - data.data_stream.push_back({m_buffer[end->second].to_pointer(), - m_buffer.size() - end->second}); - - if (m_buffer.keys() != nullptr) { - data.key_stream.reserve(end - begin); - ++end; - for (auto it = begin; it != end; ++it) { - data.key_stream.push_back(m_buffer.keys() + it->second); - } - } - - RPY_CHECK(ctx.width() == md.width); - // assert(ctx.depth() == md.depth); - return ctx.log_signature(data); + return ctx.cbh(lies, md.cached_vector_type); } -bool LieIncrementStream::empty( - const intervals::Interval& interval) const noexcept +bool LieIncrementStream::empty(const intervals::Interval& interval +) const noexcept { - // std::cerr << "Checking " << interval; - // for (auto& item : m_mapping) { - // if (item.first >= interval.inf() && item.first < interval.sup()) { - // std::cerr << ' ' << item.first << ", " << item.second << ';'; - // } - // } auto begin = (interval.type() == intervals::IntervalType::Opencl) - ? m_mapping.upper_bound(interval.inf()) - : m_mapping.lower_bound(interval.inf()); + ? m_data.upper_bound(interval.inf()) + : m_data.lower_bound(interval.inf()); auto end = (interval.type() == intervals::IntervalType::Opencl) - ? m_mapping.upper_bound(interval.sup()) - : m_mapping.lower_bound(interval.sup()); + ? m_data.upper_bound(interval.sup()) + : m_data.lower_bound(interval.sup()); - // std::cerr << ' ' << (begin == end) << '\n'; return begin == end; } From 57315f2ff949c6d2c489b3d4b09d478942eb77af Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Tue, 8 Aug 2023 13:37:09 +0100 Subject: [PATCH 12/33] ScalarStream missing from sources --- scalars/CMakeLists.txt | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scalars/CMakeLists.txt b/scalars/CMakeLists.txt index c8b203632..e456e7d6e 100644 --- a/scalars/CMakeLists.txt +++ b/scalars/CMakeLists.txt @@ -15,7 +15,7 @@ set(RPY_USE_MKL ${MKL_FOUND}) configure_file(scalar_blas_defs.h.in scalar_blas_defs.h @ONLY) add_roughpy_component(Scalars - SOURCES + SOURCES src/linear_algebra/blas.h src/linear_algebra/blas_complex_double.cpp src/linear_algebra/blas_complex_float.cpp @@ -71,13 +71,14 @@ add_roughpy_component(Scalars #src/scalar_implementations/complex_float/complex_float_type.cpp #src/scalar_implementations/complex_float/complex_float_type.h ${CMAKE_CURRENT_BINARY_DIR}/scalar_blas_defs.h - PUBLIC_HEADERS + PUBLIC_HEADERS include/roughpy/scalars/scalars_fwd.h include/roughpy/scalars/scalar_pointer.h include/roughpy/scalars/scalar.h include/roughpy/scalars/scalar_type.h include/roughpy/scalars/scalar_interface.h include/roughpy/scalars/scalar_array.h + include/roughpy/scalars/scalar_stream.h include/roughpy/scalars/owned_scalar_array.h include/roughpy/scalars/key_scalar_array.h include/roughpy/scalars/scalar_traits.h @@ -86,15 +87,15 @@ add_roughpy_component(Scalars include/roughpy/scalars/scalar_blas.h include/roughpy/scalars/scalar_type_helper.h include/roughpy/scalars.h - PUBLIC_DEPS + PUBLIC_DEPS Boost::boost Eigen3::Eigen Libalgebra_lite::Libalgebra_lite BLAS::BLAS LAPACK::LAPACK - PRIVATE_DEPS + PRIVATE_DEPS PCGRandom::pcg_random - NEEDS + NEEDS RoughPy::Core RoughPy::Platform ) From 27c3f82d730df3a0eafd80f13844289cb9ed1d2f Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Tue, 8 Aug 2023 15:00:54 +0100 Subject: [PATCH 13/33] Added key scalar stream --- scalars/CMakeLists.txt | 2 + .../roughpy/scalars/key_scalar_stream.h | 47 +++++++++++++++ .../include/roughpy/scalars/scalar_stream.h | 5 +- scalars/src/key_scalar_stream.cpp | 58 +++++++++++++++++++ 4 files changed, 110 insertions(+), 2 deletions(-) create mode 100644 scalars/include/roughpy/scalars/key_scalar_stream.h create mode 100644 scalars/src/key_scalar_stream.cpp diff --git a/scalars/CMakeLists.txt b/scalars/CMakeLists.txt index e456e7d6e..1ba0c95b0 100644 --- a/scalars/CMakeLists.txt +++ b/scalars/CMakeLists.txt @@ -46,6 +46,7 @@ add_roughpy_component(Scalars src/random_impl.cpp src/random_impl.h src/scalar_blas_impl.h + src/key_scalar_stream.cpp src/scalar_implementations/half/half_type.h src/scalar_implementations/half/half_type.cpp src/scalar_implementations/float/float_type.cpp @@ -86,6 +87,7 @@ add_roughpy_component(Scalars include/roughpy/scalars/scalar_matrix.h include/roughpy/scalars/scalar_blas.h include/roughpy/scalars/scalar_type_helper.h + include/roughpy/scalars/key_scalar_stream.h include/roughpy/scalars.h PUBLIC_DEPS Boost::boost diff --git a/scalars/include/roughpy/scalars/key_scalar_stream.h b/scalars/include/roughpy/scalars/key_scalar_stream.h new file mode 100644 index 000000000..9f2bb6292 --- /dev/null +++ b/scalars/include/roughpy/scalars/key_scalar_stream.h @@ -0,0 +1,47 @@ +// +// Created by sam on 08/08/23. +// + +#ifndef ROUGHPY_SCALARS_KEY_SCALAR_STREAM_H +#define ROUGHPY_SCALARS_KEY_SCALAR_STREAM_H + +#include "scalars_fwd.h" +#include "scalar_array.h" +#include "key_scalar_array.h" +#include "scalar_stream.h" + +#include +#include + +namespace rpy { namespace scalars { + + +class RPY_EXPORT KeyScalarStream : public ScalarStream { + std::vector m_key_stream; + + +public: + + KeyScalarStream(); + KeyScalarStream(const KeyScalarStream&); + KeyScalarStream(KeyScalarStream&&) noexcept; + + KeyScalarStream& operator=(const KeyScalarStream&); + KeyScalarStream& operator=(KeyScalarStream&&) noexcept; + + RPY_NO_DISCARD + KeyScalarArray operator[](dimn_t row) const noexcept; + + void reserve_size(dimn_t num_rows); + + void push_back(const ScalarPointer& scalar_ptr, const key_type* key_ptr=nullptr); + void push_back(const ScalarArray& scalar_data, const key_type* key_ptr=nullptr); + +}; + + + +}} + + +#endif// ROUGHPY_SCALARS_KEY_SCALAR_STREAM_H diff --git a/scalars/include/roughpy/scalars/scalar_stream.h b/scalars/include/roughpy/scalars/scalar_stream.h index 0e596f0d6..03730b142 100644 --- a/scalars/include/roughpy/scalars/scalar_stream.h +++ b/scalars/include/roughpy/scalars/scalar_stream.h @@ -38,9 +38,10 @@ namespace scalars { class RPY_EXPORT ScalarStream { +protected: std::vector m_stream; - // boost::container::small_vector m_elts_per_row; - std::vector m_elts_per_row; + boost::container::small_vector m_elts_per_row; +// std::vector m_elts_per_row; const ScalarType* p_type; public: diff --git a/scalars/src/key_scalar_stream.cpp b/scalars/src/key_scalar_stream.cpp new file mode 100644 index 000000000..a559ef072 --- /dev/null +++ b/scalars/src/key_scalar_stream.cpp @@ -0,0 +1,58 @@ +// +// Created by sam on 08/08/23. +// + +#include + +using namespace rpy; +using namespace rpy::scalars; + +KeyScalarStream::KeyScalarStream() = default; +KeyScalarStream::KeyScalarStream(const rpy::scalars::KeyScalarStream&) + = default; +KeyScalarStream::KeyScalarStream(rpy::scalars::KeyScalarStream&&) noexcept + = default; + +KeyScalarStream& KeyScalarStream::operator=(const KeyScalarStream&) = default; +KeyScalarStream& KeyScalarStream::operator=(KeyScalarStream&&) noexcept + = default; + + + +void KeyScalarStream::reserve_size(dimn_t num_rows) +{ + ScalarStream::reserve_size(num_rows); + m_key_stream.reserve(num_rows); +} +void KeyScalarStream::push_back( + const ScalarPointer& scalar_ptr, const key_type* key_ptr +) +{ + if (key_ptr != nullptr) { + if (m_key_stream.empty()) { + m_key_stream.resize(m_stream.size(), nullptr); + } + m_key_stream.push_back(key_ptr); + } + ScalarStream::push_back(scalar_ptr); +} +void KeyScalarStream::push_back( + const ScalarArray& scalar_data, const key_type* key_ptr +) +{ + if (key_ptr != nullptr) { + if (m_key_stream.empty()) { + m_key_stream.resize(m_stream.size(), nullptr); + } + m_key_stream.push_back(key_ptr); + } + ScalarStream::push_back(scalar_data); +} + +KeyScalarArray KeyScalarStream::operator[](dimn_t row) const noexcept +{ + return { + ScalarStream::operator[](row), + m_key_stream.empty() ? nullptr : m_key_stream[row] + }; +} From a6b10bbe8ee8d187988d652b91ae8cbf7e7ba5cc Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Tue, 8 Aug 2023 15:48:15 +0100 Subject: [PATCH 14/33] Fixed a couple of issues relating to serialization --- .../streams/channels/categorical_channel.h | 1 + .../streams/channels/increment_channel.h | 1 + .../streams/channels/lead_laggable_channel.h | 1 + .../roughpy/streams/channels/lie_channel.h | 1 + .../roughpy/streams/channels/value_channel.h | 1 + streams/src/test_lie_increment_stream.cpp | 27 ++++++++++++------- streams/src/test_schema.cpp | 12 ++++----- 7 files changed, 29 insertions(+), 15 deletions(-) diff --git a/streams/include/roughpy/streams/channels/categorical_channel.h b/streams/include/roughpy/streams/channels/categorical_channel.h index 509c4a956..5e64ee5a4 100644 --- a/streams/include/roughpy/streams/channels/categorical_channel.h +++ b/streams/include/roughpy/streams/channels/categorical_channel.h @@ -45,6 +45,7 @@ RPY_SERIAL_SPECIALIZE_TYPES(rpy::streams::CategoricalChannel, rpy::serial::specialization::member_serialize) +RPY_SERIAL_REGISTER_CLASS(rpy::streams::CategoricalChannel) #endif// ROUGHPY_STREAMS_CATEGORICAL_CHANNEL_H diff --git a/streams/include/roughpy/streams/channels/increment_channel.h b/streams/include/roughpy/streams/channels/increment_channel.h index b8b822d45..aa2dd8554 100644 --- a/streams/include/roughpy/streams/channels/increment_channel.h +++ b/streams/include/roughpy/streams/channels/increment_channel.h @@ -31,5 +31,6 @@ RPY_SERIAL_SERIALIZE_FN_IMPL(IncrementChannel) { RPY_SERIAL_SPECIALIZE_TYPES(rpy::streams::IncrementChannel, rpy::serial::specialization::member_serialize) +RPY_SERIAL_REGISTER_CLASS(rpy::streams::IncrementChannel) #endif// ROUGHPY_STREAMS_INCREMENT_CHANNEL_H diff --git a/streams/include/roughpy/streams/channels/lead_laggable_channel.h b/streams/include/roughpy/streams/channels/lead_laggable_channel.h index 7bb4b8267..e4660ab2e 100644 --- a/streams/include/roughpy/streams/channels/lead_laggable_channel.h +++ b/streams/include/roughpy/streams/channels/lead_laggable_channel.h @@ -43,5 +43,6 @@ RPY_SERIAL_SPECIALIZE_TYPES( rpy::serial::specialization::member_serialize ) +RPY_SERIAL_REGISTER_CLASS(rpy::streams::LeadLaggableChannel) #endif// ROUGHPY_STREAMS_LEAD_LAGGABLE_CHANNEL_H diff --git a/streams/include/roughpy/streams/channels/lie_channel.h b/streams/include/roughpy/streams/channels/lie_channel.h index e4866df6f..34b177a1c 100644 --- a/streams/include/roughpy/streams/channels/lie_channel.h +++ b/streams/include/roughpy/streams/channels/lie_channel.h @@ -30,4 +30,5 @@ RPY_SERIAL_SPECIALIZE_TYPES( rpy::streams::LieChannel, rpy::serial::specialization::member_serialize ) +RPY_SERIAL_REGISTER_CLASS(rpy::streams::LieChannel) #endif// ROUGHPY_STREAMS_LIE_CHANNEL_H diff --git a/streams/include/roughpy/streams/channels/value_channel.h b/streams/include/roughpy/streams/channels/value_channel.h index 6fba0e42c..ec8fc1629 100644 --- a/streams/include/roughpy/streams/channels/value_channel.h +++ b/streams/include/roughpy/streams/channels/value_channel.h @@ -32,5 +32,6 @@ RPY_SERIAL_SPECIALIZE_TYPES( rpy::serial::specialization::member_serialize ) +RPY_SERIAL_REGISTER_CLASS(rpy::streams::ValueChannel) #endif// ROUGHPY_STREAMS_VALUE_CHANNEL_H diff --git a/streams/src/test_lie_increment_stream.cpp b/streams/src/test_lie_increment_stream.cpp index 702b6cf96..b51629510 100644 --- a/streams/src/test_lie_increment_stream.cpp +++ b/streams/src/test_lie_increment_stream.cpp @@ -92,12 +92,17 @@ class LieIncrementStreamTests : public ::testing::Test static constexpr deg_t depth = 2; algebra::context_pointer ctx; - StreamMetadata md{width, intervals::RealInterval(0.0, 1.0), ctx, - ctx->ctype(), algebra::VectorType::Dense}; + StreamMetadata md{ + width, intervals::RealInterval(0.0, 1.0), ctx, ctx->ctype(), + algebra::VectorType::Dense}; LieIncrementStreamTests() - : gen(1.0), ctx(algebra::get_context(width, depth, gen.ctype, - {{"backend", "libalgebra_lite"}})) + : gen(1.0), ctx(algebra::get_context( + width, depth, gen.ctype, + { + {"backend", "libalgebra_lite"} + } + )) {} scalars::OwnedScalarArray random_data(dimn_t rows, dimn_t cols = width) @@ -131,7 +136,9 @@ TEST_F(LieIncrementStreamTests, TestLogSignatureSingleIncrement) algebra::VectorType::Dense}; auto idx = indices(1); const streams::LieIncrementStream path( - scalars::KeyScalarArray(std::move(data)), idx, md); + scalars::KeyScalarArray(std::move(data)), idx, md, + std::make_shared(ctx->width()) + ); auto ctx1 = ctx->get_alike(1); auto lsig = path.log_signature(intervals::RealInterval(0.0, 1.0), 1, *ctx1); @@ -146,15 +153,17 @@ TEST_F(LieIncrementStreamTests, TestLogSignatureTwoIncrementsDepth1) auto data = random_data(2); - algebra::VectorConstructionData edata{scalars::KeyScalarArray(ctx->ctype()), - algebra::VectorType::Dense}; + algebra::VectorConstructionData edata{ + scalars::KeyScalarArray(ctx->ctype()), algebra::VectorType::Dense}; edata.data.allocate_scalars(width); edata.data.type()->convert_copy(edata.data, data, width); for (int i = 0; i < width; ++i) { edata.data[i] += data[i + width]; } auto idx = indices(2); - const LieIncrementStream path(scalars::KeyScalarArray(std::move(data)), idx, - md); + const LieIncrementStream path( + scalars::KeyScalarArray(std::move(data)), idx, md, + std::make_shared(ctx->width()) + ); auto ctx1 = ctx->get_alike(1); diff --git a/streams/src/test_schema.cpp b/streams/src/test_schema.cpp index f3141deb7..39785ed68 100644 --- a/streams/src/test_schema.cpp +++ b/streams/src/test_schema.cpp @@ -80,7 +80,7 @@ TEST(Schema, TestLabelCompareEmptyRefString) TEST(Schema, TestStreamChannelIncrementSerialization) { - StreamChannel channel(ChannelType::Increment); + IncrementChannel channel; std::stringstream ss; { @@ -88,7 +88,7 @@ TEST(Schema, TestStreamChannelIncrementSerialization) oarch(channel); } - StreamChannel in_channel; + IncrementChannel in_channel; { archives::JSONInputArchive iarch(ss); iarch(in_channel); @@ -99,7 +99,7 @@ TEST(Schema, TestStreamChannelIncrementSerialization) TEST(Schema, TestStreamChannelValueSerialization) { - StreamChannel channel(ChannelType::Value); + ValueChannel channel; std::stringstream ss; { @@ -107,7 +107,7 @@ TEST(Schema, TestStreamChannelValueSerialization) oarch(channel); } - StreamChannel in_channel; + ValueChannel in_channel; { archives::JSONInputArchive iarch(ss); iarch(in_channel); @@ -118,7 +118,7 @@ TEST(Schema, TestStreamChannelValueSerialization) TEST(Schema, TestStreamChannelCategoricalSerialization) { - StreamChannel channel(ChannelType::Categorical); + CategoricalChannel channel; channel.add_variant("first").add_variant("second"); std::stringstream ss; @@ -127,7 +127,7 @@ TEST(Schema, TestStreamChannelCategoricalSerialization) oarch(channel); } - StreamChannel in_channel; + CategoricalChannel in_channel; { archives::JSONInputArchive iarch(ss); iarch(in_channel); From e88669e2c589a26b0ec303182b87f87cb2869552 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Tue, 8 Aug 2023 15:49:33 +0100 Subject: [PATCH 15/33] Fixed a couple of issues relating to serialization --- streams/src/channels/categorical_channel.cpp | 1 - streams/src/channels/increment_channel.cpp | 1 - streams/src/channels/lead_laggable_channel.cpp | 1 - streams/src/channels/lie_channel.cpp | 1 - streams/src/channels/value_channel.cpp | 1 - 5 files changed, 5 deletions(-) diff --git a/streams/src/channels/categorical_channel.cpp b/streams/src/channels/categorical_channel.cpp index d0c1c2bf0..6c40a11b3 100644 --- a/streams/src/channels/categorical_channel.cpp +++ b/streams/src/channels/categorical_channel.cpp @@ -74,6 +74,5 @@ StreamChannel& CategoricalChannel::insert_variant(string variant_label) return *this; } -RPY_SERIAL_REGISTER_CLASS(rpy::streams::CategoricalChannel) #define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::CategoricalChannel #include \ No newline at end of file diff --git a/streams/src/channels/increment_channel.cpp b/streams/src/channels/increment_channel.cpp index 46c6ddb73..c6942ba82 100644 --- a/streams/src/channels/increment_channel.cpp +++ b/streams/src/channels/increment_channel.cpp @@ -3,7 +3,6 @@ // #include -RPY_SERIAL_REGISTER_CLASS(rpy::streams::IncrementChannel); #define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::IncrementChannel #include \ No newline at end of file diff --git a/streams/src/channels/lead_laggable_channel.cpp b/streams/src/channels/lead_laggable_channel.cpp index 5d0d9967c..1034c0fbc 100644 --- a/streams/src/channels/lead_laggable_channel.cpp +++ b/streams/src/channels/lead_laggable_channel.cpp @@ -48,6 +48,5 @@ bool LeadLaggableChannel::is_lead_lag() const return m_use_leadlag; } -RPY_SERIAL_REGISTER_CLASS(LeadLaggableChannel) #define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::LeadLaggableChannel #include \ No newline at end of file diff --git a/streams/src/channels/lie_channel.cpp b/streams/src/channels/lie_channel.cpp index 372a6af36..141da065d 100644 --- a/streams/src/channels/lie_channel.cpp +++ b/streams/src/channels/lie_channel.cpp @@ -9,6 +9,5 @@ -RPY_SERIAL_REGISTER_CLASS(rpy::streams::LieChannel) #define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::LieChannel #include \ No newline at end of file diff --git a/streams/src/channels/value_channel.cpp b/streams/src/channels/value_channel.cpp index 63e086d26..ac1e51268 100644 --- a/streams/src/channels/value_channel.cpp +++ b/streams/src/channels/value_channel.cpp @@ -4,7 +4,6 @@ #include -RPY_SERIAL_REGISTER_CLASS(rpy::streams::ValueChannel) #define RPY_SERIAL_IMPL_CLASSNAME rpy::streams::ValueChannel #include \ No newline at end of file From 390989ececd7f65a43754a3cd20ec6fc1edf9540 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Wed, 9 Aug 2023 16:41:41 +0100 Subject: [PATCH 16/33] Lots of work on the parsing of key-scalar streams. --- roughpy/CMakeLists.txt | 4 + roughpy/src/args/strided_copy.cpp | 38 +++ roughpy/src/args/strided_copy.h | 29 ++ .../src/scalars/parse_key_scalar_stream.cpp | 255 ++++++++++++++++++ roughpy/src/scalars/parse_key_scalar_stream.h | 36 +++ roughpy/src/scalars/scalars.cpp | 116 +++----- roughpy/src/scalars/scalars.h | 59 +++- .../roughpy/scalars/key_scalar_stream.h | 1 + scalars/src/key_scalar_stream.cpp | 3 +- 9 files changed, 451 insertions(+), 90 deletions(-) create mode 100644 roughpy/src/args/strided_copy.cpp create mode 100644 roughpy/src/args/strided_copy.h create mode 100644 roughpy/src/scalars/parse_key_scalar_stream.cpp create mode 100644 roughpy/src/scalars/parse_key_scalar_stream.h diff --git a/roughpy/CMakeLists.txt b/roughpy/CMakeLists.txt index 80e047f3f..ed414dd3c 100644 --- a/roughpy/CMakeLists.txt +++ b/roughpy/CMakeLists.txt @@ -81,6 +81,8 @@ target_sources(RoughPy_PyModule PRIVATE src/args/numpy.h src/args/parse_schema.cpp src/args/parse_schema.h + src/args/strided_copy.cpp + src/args/strided_copy.h src/intervals/date_time_interval.cpp src/intervals/date_time_interval.h src/intervals/dyadic.h @@ -97,6 +99,8 @@ target_sources(RoughPy_PyModule PRIVATE src/intervals/real_interval.h src/intervals/segmentation.cpp src/intervals/segmentation.h + src/scalars/parse_key_scalar_stream.cpp + src/scalars/parse_key_scalar_stream.h src/scalars/r_py_polynomial.cpp src/scalars/r_py_polynomial.h src/scalars/scalar_type.h diff --git a/roughpy/src/args/strided_copy.cpp b/roughpy/src/args/strided_copy.cpp new file mode 100644 index 000000000..c271c356d --- /dev/null +++ b/roughpy/src/args/strided_copy.cpp @@ -0,0 +1,38 @@ +// +// Created by sam on 09/08/23. +// + +#include "strided_copy.h" +#include + +void rpy::python::stride_copy( + void* dst, const void* src, const py::ssize_t itemsize, + const py::ssize_t ndim, const py::ssize_t* shape_in, + const py::ssize_t* strides_in, + const py::ssize_t* strides_out +) noexcept +{ + RPY_DBG_ASSERT(ndim == 1 || ndim == 2); + + auto* dptr = static_cast(dst); + const auto* sptr = static_cast(src); + + if (ndim == 1) { + for (py::ssize_t i=0; i +#include + +namespace rpy { namespace python { + +void stride_copy(void* RPY_RESTRICT dst, + const void* RPY_RESTRICT src, + const py::ssize_t itemsize, + const py::ssize_t ndim, + const py::ssize_t* shape_in, + const py::ssize_t* strides_in, + const py::ssize_t* strides_out + ) noexcept; + + + +}} + + +#endif// ROUGHPY_STRIDED_COPY_H diff --git a/roughpy/src/scalars/parse_key_scalar_stream.cpp b/roughpy/src/scalars/parse_key_scalar_stream.cpp new file mode 100644 index 000000000..64fb3f78d --- /dev/null +++ b/roughpy/src/scalars/parse_key_scalar_stream.cpp @@ -0,0 +1,255 @@ +// +// Created by sam on 08/08/23. +// + +#include "parse_key_scalar_stream.h" +#include "args/numpy.h" +#include "args/strided_copy.h" +#include "dlpack.h" + +using namespace rpy; +using namespace rpy::python; + +using rpy::scalars::KeyScalarArray; +using rpy::scalars::KeyScalarStream; + +namespace { + +inline void buffer_to_stream( + ParsedKeyScalarStream& result, const py::buffer_info& buf_info, + PyToBufferOptions& options +); + +inline void dl_to_stream( + ParsedKeyScalarStream& result, const py::object& dl_object, + PyToBufferOptions& options +); + +}// namespace + +python::ParsedKeyScalarStream python::parse_key_scalar_stream( + const py::object& data, rpy::python::PyToBufferOptions& options +) +{ + + ParsedKeyScalarStream result; + + /* + * A key-data stream should not represent a single (key-)scalar value, + * so we only need to deal with the following types: + * 1) An array/buffer of values, + * 2) A key-array dict, // implement later + * 3) Any other kind of sequential data + */ + + if (py::hasattr(data, "__dlpack__")) { + dl_to_stream(result, data, options); + } else if (py::isinstance(data)) { + const auto buffer_data = py::reinterpret_borrow(data); + buffer_to_stream(result, buffer_data.request(), options); + } else if (py::isinstance(data)) { + RPY_THROW( + std::runtime_error, + "constructing from a dict of arrays/lists is not yet supported" + ); + } else if (py::isinstance(data)) { + // We always need to make a copy from a Python object + result.data_buffer = py_to_buffer(data, options); + + /* + * Now we need to use the options.shape information to construct the + * stream. How we should interpret the shape values are determined by + * whether keys are given: + * 1) If result.data_buffer contains keys, then there are shape.size() + * increments, where the ith increment contains shape[i] values. + * 2) If result.data_buffer does not contain keys, then there are two + * cases to handle: + * a) if shape.size() != 2, then there are shape.size() increments, + * where the size of the ith increment is shape[i]; + * b) if shape.size() == 2 then + * - if shape[0]*shape[1] == data_buffer.size() then there are + * shape[0] increments with size shape[1] + * - otherwise there are 2 increments of sizes shape[0] and + * shape[1]. + * + * Oof, that's a lot of cases. + */ + + const auto buf_size = result.data_buffer.size(); + if (result.data_buffer.has_keys()) { + result.data_stream.reserve_size(options.shape.size()); + + scalars::ScalarPointer sptr(result.data_buffer); + const key_type* kptr = result.data_buffer.keys(); + dimn_t check = 0; + for (auto incr_size : options.shape) { + result.data_stream.push_back( + {sptr, static_cast(incr_size)}, kptr + ); + sptr += incr_size; + kptr += incr_size; + check += incr_size; + } + + RPY_CHECK(check == buf_size); + } else if (options.shape.size() != 2 || options.shape[0] * options.shape[1] != buf_size) { + result.data_stream.reserve_size(options.shape.size()); + scalars::ScalarPointer sptr(result.data_buffer); + dimn_t check = 0; + + for (auto incr_size : options.shape) { + result.data_stream.push_back( + {sptr, static_cast(incr_size)} + ); + sptr += incr_size; + check += incr_size; + } + + RPY_CHECK(check == buf_size); + } else { + RPY_DBG_ASSERT(options.shape[0] * options.shape[1] == buf_size); + + const auto num_increments = static_cast(options.shape[0]); + const auto incr_size = static_cast(options.shape[1]); + result.data_stream.set_elts_per_row(incr_size); + + scalars::ScalarPointer sptr(result.data_buffer); + for (dimn_t i = 0; i < num_increments; ++i) { + result.data_stream.push_back(sptr); + sptr += incr_size; + } + } + + } else { + RPY_THROW( + std::invalid_argument, + "could not parse argument to a valid scalar array type" + ); + } + + return result; +} + +void buffer_to_stream( + ParsedKeyScalarStream& result, const py::buffer_info& buf_info, + PyToBufferOptions& options +) +{ + RPY_CHECK(buf_info.ndim <= 2 && buf_info.ndim > 0); + + /* + * Generally we will want to borrow the data if it is at all possible. There + * are some caveats though. First, we can only borrow the data if it is + * contiguous and C-layout. Second, we can only borrow if the data is a + * simple type and this type matches the requested data type (if set). + */ + + auto type_id = py_buffer_to_type_id(buf_info); + + if (options.type == nullptr) { + options.type = scalars::ScalarType::for_id(type_id); + } + + // Imperfect check for whether the chosen data type is the same. + bool borrow = options.type->id() == type_id; + + // Check if the array is C-contiguous + auto acc_stride = buf_info.itemsize; + for (auto dim = buf_info.ndim; dim > 0;) { + const auto& this_stride = buf_info.strides[--dim]; + RPY_CHECK(this_stride > 0); + borrow &= this_stride == acc_stride; + acc_stride *= buf_info.shape[dim]; + } + + if (borrow) { + if (buf_info.ndim == 1) { + result.data_stream.set_elts_per_row(buf_info.shape[0]); + result.data_stream.reserve_size(1); + result.data_stream.push_back({type_id, buf_info.ptr}); + } else { + const auto num_increments = static_cast(buf_info.shape[0]); + result.data_stream.set_elts_per_row(buf_info.shape[1]); + result.data_stream.reserve_size(num_increments); + + const auto* ptr = static_cast(buf_info.ptr); + const auto stride = buf_info.strides[0]; + for (dimn_t i = 0; i < num_increments; ++i) { + result.data_stream.push_back({options.type, ptr}); + ptr += stride; + } + } + } else { + std::vector tmp(buf_info.size * buf_info.itemsize); + py::ssize_t tmp_strides[2]{}; + dimn_t tmp_shape[2]{}; + tmp_strides[buf_info.ndim - 1] = buf_info.itemsize; + bool transposed + = buf_info.ndim == 2 && buf_info.shape[0] < buf_info.shape[1]; + if (buf_info.ndim == 2) { + if (transposed) { + tmp_strides[0] = buf_info.shape[1]; + tmp_shape[0] = buf_info.shape[1]; + tmp_shape[1] = buf_info.shape[0]; + } else { + tmp_strides[0] = buf_info.shape[0]; + tmp_shape[0] = buf_info.shape[0]; + tmp_shape[1] = buf_info.shape[1]; + } + } + + stride_copy( + tmp.data(), buf_info.ptr, buf_info.itemsize, buf_info.ndim, + buf_info.shape.data(), buf_info.strides.data(), tmp_strides + ); + + // Now that we're C-contiguous, convert_copy into the result. + result.data_buffer = KeyScalarArray(options.type); + result.data_buffer.allocate_scalars(buf_info.size); + options.type->convert_copy( + result.data_buffer, {type_id, tmp.data()}, buf_info.size + ); + + if (buf_info.ndim == 1) { + result.data_stream.reserve_size(1); + result.data_stream.set_elts_per_row(buf_info.size); + result.data_stream.push_back({options.type, tmp.data()}); + } else { + // shape[0] increments of size shape[1] + RPY_DBG_ASSERT( + buf_info.shape[0] * buf_info.shape[1] == buf_info.size + ); + result.data_stream.reserve_size(tmp_shape[0]); + result.data_stream.set_elts_per_row(tmp_shape[1]); + + scalars::ScalarPointer sptr(result.data_buffer); + for (dimn_t i=0; i(); + RPY_CHECK(dltensor != nullptr); + auto& tensor = dltensor->dl_tensor; + + RPY_CHECK(tensor.device.device_type == kDLCPU); + RPY_CHECK(tensor.ndim == 1 || tensor.ndim == 2); + + const auto* type + = python::dlpack_dtype_to_scalar_type(tensor.dtype, tensor.device); +} diff --git a/roughpy/src/scalars/parse_key_scalar_stream.h b/roughpy/src/scalars/parse_key_scalar_stream.h new file mode 100644 index 000000000..2e1580aa1 --- /dev/null +++ b/roughpy/src/scalars/parse_key_scalar_stream.h @@ -0,0 +1,36 @@ +// +// Created by sam on 08/08/23. +// + +#ifndef ROUGHPY_PARSE_KEY_SCALAR_STREAM_H +#define ROUGHPY_PARSE_KEY_SCALAR_STREAM_H + +#include "roughpy_module.h" + +#include +#include + +#include "scalar_type.h" +#include "scalars.h" + +namespace rpy { +namespace python { + +struct ParsedKeyScalarStream { + /// Parsed key-scalar stream + scalars::KeyScalarStream data_stream; + + /// Buffer holding key/scalar data if a copy had to be made + scalars::KeyScalarArray data_buffer; + +}; + + +RPY_NO_DISCARD +ParsedKeyScalarStream +parse_key_scalar_stream(const py::object& data, PyToBufferOptions& options); + +} +}// namespace rpy + +#endif// ROUGHPY_PARSE_KEY_SCALAR_STREAM_H diff --git a/roughpy/src/scalars/scalars.cpp b/roughpy/src/scalars/scalars.cpp index 11e5401a3..f5d9b920f 100644 --- a/roughpy/src/scalars/scalars.cpp +++ b/roughpy/src/scalars/scalars.cpp @@ -42,6 +42,7 @@ #include "scalar_type.h" using namespace rpy; +using namespace rpy::python; using namespace pybind11::literals; static const char* SCALAR_DOC = R"edoc( @@ -136,8 +137,8 @@ void python::init_scalars(pybind11::module_& m) type = scalars::ScalarTypeCode::NAME; \ break -static const scalars::ScalarType* -dlpack_dtype_to_scalar_type(DLDataType dtype, DLDevice device) +const scalars::ScalarType* +python::dlpack_dtype_to_scalar_type(DLDataType dtype, DLDevice device) { using scalars::ScalarDeviceType; @@ -266,44 +267,7 @@ static bool try_fill_buffer_dlpack( return true; } -struct arg_size_info { - idimn_t num_values; - idimn_t num_keys; -}; -enum class ground_data_type -{ - UnSet, - Scalars, - KeyValuePairs -}; - -static inline bool is_scalar(py::handle arg) -{ - return (py::isinstance(arg) || py::isinstance(arg) - || RPyPolynomial_Check(arg.ptr())); -} - -static inline bool -is_key(py::handle arg, python::AlternativeKeyType* alternative) -{ - if (alternative != nullptr) { - return py::isinstance(arg) - || py::isinstance(arg, alternative->py_key_type); - } - if (py::isinstance(arg)) { return true; } - return false; -} - -static inline bool -is_kv_pair(py::handle arg, python::AlternativeKeyType* alternative) -{ - if (py::isinstance(arg)) { - auto tpl = py::reinterpret_borrow(arg); - if (tpl.size() == 2) { return is_key(tpl[0], alternative); } - } - return false; -} static void check_and_set_dtype(python::PyToBufferOptions& options, py::handle arg) @@ -318,24 +282,24 @@ check_and_set_dtype(python::PyToBufferOptions& options, py::handle arg) } static bool check_ground_type( - py::handle object, ground_data_type& ground_type, + py::handle object, GroundDataType& ground_type, python::PyToBufferOptions& options ) { py::handle scalar; - if (::is_scalar(object)) { - if (ground_type == ground_data_type::UnSet) { - ground_type = ground_data_type::Scalars; - } else if (ground_type != ground_data_type::Scalars) { + if (python::is_scalar(object)) { + if (ground_type == GroundDataType::UnSet) { + ground_type = GroundDataType::Scalars; + } else if (ground_type != GroundDataType::Scalars) { RPY_THROW( py::value_error, "inconsistent scalar/key-scalar-pair data" ); } scalar = object; } else if (is_kv_pair(object, options.alternative_key)) { - if (ground_type == ground_data_type::UnSet) { - ground_type = ground_data_type::KeyValuePairs; - } else if (ground_type != ground_data_type::KeyValuePairs) { + if (ground_type == GroundDataType::UnSet) { + ground_type = GroundDataType::KeyValuePairs; + } else if (ground_type != GroundDataType::KeyValuePairs) { RPY_THROW( py::value_error, "inconsistent scalar/key-scalar-pair data" ); @@ -355,7 +319,7 @@ static bool check_ground_type( static void compute_size_and_type_recurse( python::PyToBufferOptions& options, std::vector& leaves, - const py::handle& object, ground_data_type& ground_type, dimn_t depth + const py::handle& object, GroundDataType& ground_type, dimn_t depth ) { @@ -376,7 +340,7 @@ static void compute_size_and_type_recurse( // We've not visited this depth before, // add our length to the list options.shape.push_back(length); - } else if (ground_type == ground_data_type::Scalars) { + } else if (ground_type == GroundDataType::Scalars) { // We have visited this depth before, // check our length is consistent with the others if (length != options.shape[depth]) { @@ -424,9 +388,9 @@ static void compute_size_and_type_recurse( ); } switch (ground_type) { - case ground_data_type::UnSet: - ground_type = ground_data_type::KeyValuePairs; - case ground_data_type::KeyValuePairs: break; + case GroundDataType::UnSet: + ground_type = GroundDataType::KeyValuePairs; + case GroundDataType::KeyValuePairs: break; default: RPY_THROW(py::type_error, "mismatched types in array argument"); } @@ -443,19 +407,19 @@ static void compute_size_and_type_recurse( } } -static arg_size_info compute_size_and_type( +ArgSizeInfo python::compute_size_and_type( python::PyToBufferOptions& options, std::vector& leaves, py::handle arg ) { - arg_size_info info = {0, 0}; + ArgSizeInfo info = {0, 0}; RPY_CHECK(py::isinstance(arg)); - ground_data_type ground_type = ground_data_type::UnSet; + GroundDataType ground_type = GroundDataType::UnSet; compute_size_and_type_recurse(options, leaves, arg, ground_type, 0); - if (ground_type == ground_data_type::KeyValuePairs) { + if (ground_type == GroundDataType::KeyValuePairs) { options.shape.clear(); for (const auto& obj : leaves) { @@ -470,7 +434,7 @@ static arg_size_info compute_size_and_type( for (auto& shape_i : options.shape) { info.num_values *= shape_i; } } - if (info.num_values == 0 || ground_type == ground_data_type::UnSet) { + if (info.num_values == 0 || ground_type == GroundDataType::UnSet) { options.shape.clear(); leaves.clear(); } @@ -536,10 +500,7 @@ scalars::KeyScalarArray python::py_to_buffer( update_dtype_and_allocate(result, options, 1, 0); assign_py_object_to_scalar(result, object); - return result; - } - - if (is_kv_pair(object, options.alternative_key)) { + } else if (is_kv_pair(object, options.alternative_key)) { /* * Now for tuples of length 2, which we expect to be a kv-pair */ @@ -550,16 +511,11 @@ scalars::KeyScalarArray python::py_to_buffer( update_dtype_and_allocate(result, options, 1, 1); handle_sequence_tuple(result, result.keys(), object, options); - return result; - } - - if (py::hasattr(object, "__dlpack__")) { + } else if (py::hasattr(object, "__dlpack__")) { // If we used the dlpack interface, then the result is // already constructed. - if (try_fill_buffer_dlpack(result, options, object)) { return result; } - } - - if (py::isinstance(object)) { + try_fill_buffer_dlpack(result, options, object); + } else if (py::isinstance(object)) { // Fall back to the buffer protocol auto info = py::reinterpret_borrow(object).request(); auto type_id = py_buffer_to_type_id(info); @@ -580,10 +536,7 @@ scalars::KeyScalarArray python::py_to_buffer( options.shape.assign(info.shape.begin(), info.shape.end()); } - return result; - } - - if (py::isinstance(object)) { + } else if (py::isinstance(object)) { auto dict_arg = py::reinterpret_borrow(object); options.shape.push_back(static_cast(dict_arg.size())); @@ -600,10 +553,7 @@ scalars::KeyScalarArray python::py_to_buffer( handle_dict(ptr, key_ptr, options, dict_arg); } - return result; - } - - if (py::isinstance(object)) { + } else if (py::isinstance(object)) { std::vector leaves; auto size_info = compute_size_and_type(options, leaves, object); @@ -638,14 +588,14 @@ scalars::KeyScalarArray python::py_to_buffer( } } } - - return result; + } else { + RPY_THROW( + std::invalid_argument, + "could not parse argument to a valid scalar array type" + ); } - RPY_THROW( - std::invalid_argument, - "could not parse argument to a valid scalar type" - ); + return result; } scalars::Scalar diff --git a/roughpy/src/scalars/scalars.h b/roughpy/src/scalars/scalars.h index b000770ac..900a77373 100644 --- a/roughpy/src/scalars/scalars.h +++ b/roughpy/src/scalars/scalars.h @@ -37,6 +37,10 @@ #include #include +#include "r_py_polynomial.h" + +#include "dlpack.h" + namespace rpy { namespace python { @@ -45,6 +49,18 @@ struct RPY_NO_EXPORT AlternativeKeyType { std::function converter; }; +struct ArgSizeInfo { + idimn_t num_values; + idimn_t num_keys; +}; + +enum class GroundDataType +{ + UnSet, + Scalars, + KeyValuePairs +}; + struct RPY_NO_EXPORT PyToBufferOptions { /// Scalar type to use. If null, will be set to the resulting type const scalars::ScalarType* type = nullptr; @@ -62,9 +78,6 @@ struct RPY_NO_EXPORT PyToBufferOptions { /// All Python types will (try) to be converted to double. bool no_check_imported = false; - /// cleanup function to be called when we're finished with the data - std::function cleanup = nullptr; - /// Alternative acceptable key_type/conversion pair AlternativeKeyType* alternative_key = nullptr; }; @@ -76,12 +89,41 @@ inline py::type get_py_rational() ); } -inline py::type get_py_decimal() { +inline py::type get_py_decimal() +{ return py::reinterpret_borrow( - py::module_::import("decimal").attr("Decimal") - ); + py::module_::import("decimal").attr("Decimal") + ); } +inline bool is_scalar(py::handle arg) +{ + return (py::isinstance(arg) || py::isinstance(arg) + || RPyPolynomial_Check(arg.ptr())); +} + +inline bool is_key(py::handle arg, python::AlternativeKeyType* alternative) +{ + if (alternative != nullptr) { + return py::isinstance(arg) + || py::isinstance(arg, alternative->py_key_type); + } + if (py::isinstance(arg)) { return true; } + return false; +} + +inline bool is_kv_pair(py::handle arg, python::AlternativeKeyType* alternative) +{ + if (py::isinstance(arg)) { + auto tpl = py::reinterpret_borrow(arg); + if (tpl.size() == 2) { return is_key(tpl[0], alternative); } + } + return false; +} + +const scalars::ScalarType* +dlpack_dtype_to_scalar_type(DLDataType dtype, DLDevice device); + scalars::KeyScalarArray py_to_buffer(const py::handle& arg, PyToBufferOptions& options); @@ -91,6 +133,11 @@ void assign_py_object_to_scalar(scalars::ScalarPointer ptr, py::handle object); scalars::Scalar py_to_scalar(const scalars::ScalarType* type, py::handle object); +ArgSizeInfo compute_size_and_type( + python::PyToBufferOptions& options, std::vector& leaves, + py::handle arg +); + void init_scalars(py::module_& m); }// namespace python diff --git a/scalars/include/roughpy/scalars/key_scalar_stream.h b/scalars/include/roughpy/scalars/key_scalar_stream.h index 9f2bb6292..e793f8712 100644 --- a/scalars/include/roughpy/scalars/key_scalar_stream.h +++ b/scalars/include/roughpy/scalars/key_scalar_stream.h @@ -23,6 +23,7 @@ class RPY_EXPORT KeyScalarStream : public ScalarStream { public: KeyScalarStream(); + KeyScalarStream(const ScalarType* type); KeyScalarStream(const KeyScalarStream&); KeyScalarStream(KeyScalarStream&&) noexcept; diff --git a/scalars/src/key_scalar_stream.cpp b/scalars/src/key_scalar_stream.cpp index a559ef072..594ceb181 100644 --- a/scalars/src/key_scalar_stream.cpp +++ b/scalars/src/key_scalar_stream.cpp @@ -17,7 +17,8 @@ KeyScalarStream& KeyScalarStream::operator=(const KeyScalarStream&) = default; KeyScalarStream& KeyScalarStream::operator=(KeyScalarStream&&) noexcept = default; - +KeyScalarStream::KeyScalarStream(const ScalarType* type) : ScalarStream(type) +{} void KeyScalarStream::reserve_size(dimn_t num_rows) { From 1129bf2095ebc6170995db9e4dc829590f12030f Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Wed, 9 Aug 2023 17:13:48 +0100 Subject: [PATCH 17/33] tiny bit more work. --- .../src/scalars/parse_key_scalar_stream.cpp | 101 +++++++++++++++++- 1 file changed, 100 insertions(+), 1 deletion(-) diff --git a/roughpy/src/scalars/parse_key_scalar_stream.cpp b/roughpy/src/scalars/parse_key_scalar_stream.cpp index 64fb3f78d..8e3873973 100644 --- a/roughpy/src/scalars/parse_key_scalar_stream.cpp +++ b/roughpy/src/scalars/parse_key_scalar_stream.cpp @@ -223,7 +223,7 @@ void buffer_to_stream( result.data_stream.set_elts_per_row(tmp_shape[1]); scalars::ScalarPointer sptr(result.data_buffer); - for (dimn_t i=0; i(); + borrow = false; + } + } else { + borrow &= options.type == type; + } + + // Check if the array is C-contiguous + const auto itemsize = tensor.dtype.bits / 8; + auto size = 1; + for (int64_t i=0; i(tensor.shape[0]); + result.data_stream.set_elts_per_row(tensor.shape[1]); + result.data_stream.reserve_size(num_increments); + + const auto* ptr = static_cast(tensor.data); + const auto stride = tensor.shape[0] * itemsize; + for (dimn_t i = 0; i < num_increments; ++i) { + result.data_stream.push_back({options.type, ptr}); + ptr += stride; + } + } + } else { + std::vector tmp(size * itemsize); + py::ssize_t out_strides[2]{}; + py::ssize_t in_strides[2] {}; + dimn_t tmp_shape[2]{}; + out_strides[tensor.ndim - 1] = itemsize; + bool transposed = tensor.ndim == 2 && tensor.shape[0] < tensor.shape[1]; + if (tensor.ndim == 2) { + if (transposed) { + out_strides[0] = tensor.shape[1]; + tmp_shape[0] = tensor.shape[1]; + tmp_shape[1] = tensor.shape[0]; + } else { + out_strides[0] = tensor.shape[0]; + tmp_shape[0] = tensor.shape[0]; + tmp_shape[1] = tensor.shape[1]; + } + if (tensor.strides != nullptr) { + in_strides[0] = tensor.strides[0]; + in_strides[1] = tensor.strides[1]; + } else { + in_strides[0] = tensor.shape[0]; + in_strides[1] = tensor.shape[1]; + } + + } + + // TODO: We can't use tensor. objects because they are the wrong type. + stride_copy( + tmp.data(), tensor.data, itemsize, tensor.ndim, + tensor.shape, tensor.strides, out_strides + ); + + // Now that we're C-contiguous, convert_copy into the result. + result.data_buffer = KeyScalarArray(options.type); + result.data_buffer.allocate_scalars(size); + options.type->convert_copy( + result.data_buffer, {type, tmp.data()}, size // TODO: This isn't right + ); + + if (tensor.ndim == 1) { + result.data_stream.reserve_size(1); + result.data_stream.set_elts_per_row(size); + result.data_stream.push_back({options.type, tmp.data()}); + } else { + // shape[0] increments of size shape[1] + RPY_DBG_ASSERT( + tensor.shape[0] * tensor.shape[1] == size + ); + result.data_stream.reserve_size(tmp_shape[0]); + result.data_stream.set_elts_per_row(tmp_shape[1]); + + scalars::ScalarPointer sptr(result.data_buffer); + for (dimn_t i = 0; i < tmp_shape[0]; ++i) { + result.data_stream.push_back(sptr); + sptr += tmp_shape[1]; + } + } + } } From 64a46538bbdc6cc327602d96487c3239ff023277 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Thu, 10 Aug 2023 10:55:35 +0100 Subject: [PATCH 18/33] Finish off dl-to-stream --- .../src/scalars/parse_key_scalar_stream.cpp | 100 ++++++++++++++---- roughpy/src/scalars/parse_key_scalar_stream.h | 2 + 2 files changed, 80 insertions(+), 22 deletions(-) diff --git a/roughpy/src/scalars/parse_key_scalar_stream.cpp b/roughpy/src/scalars/parse_key_scalar_stream.cpp index 8e3873973..3e99daea0 100644 --- a/roughpy/src/scalars/parse_key_scalar_stream.cpp +++ b/roughpy/src/scalars/parse_key_scalar_stream.cpp @@ -27,6 +27,65 @@ inline void dl_to_stream( }// namespace +#define DLTHROW(type, bits) \ + RPY_THROW(std::runtime_error, std::to_string(bits) + " bit " #type " is not supported") + +string python::dl_to_type_id(const DLDataType& dtype, const DLDevice& RPY_UNUSED_VAR device) +{ + RPY_CHECK(dtype.lanes == 1); + switch (dtype.code) { + case kDLFloat: + switch (dtype.bits) { + case 16: + return scalars::type_id_of(); + case 32: + return scalars::type_id_of(); + case 64: + return scalars::type_id_of(); + default: + DLTHROW(float, dtype.bits); + } + case kDLInt: + switch (dtype.bits) { + case 8: + return scalars::type_id_of(); + case 16: + return scalars::type_id_of(); + case 32: + return scalars::type_id_of(); + case 64: + return scalars::type_id_of(); + default: + DLTHROW(int, dtype.bits); + } + case kDLUInt: + switch (dtype.bits) { + case 8: return scalars::type_id_of(); + case 16: return scalars::type_id_of(); + case 32: return scalars::type_id_of(); + case 64: return scalars::type_id_of(); + default: DLTHROW(uint, dtype.bits); + } + case kDLBfloat: + if (dtype.bits == 16) { + return scalars::type_id_of(); + } else { + DLTHROW(bfloat, dtype.bits); + } + case kDLComplex: + DLTHROW(complex, dtype.bits); + case kDLOpaqueHandle: + DLTHROW(opaquehandle, dtype.bits); + case kDLBool: + DLTHROW(bool, dtype.bits); + } + +} + +#undef DLTHROW + + + python::ParsedKeyScalarStream python::parse_key_scalar_stream( const py::object& data, rpy::python::PyToBufferOptions& options ) @@ -253,22 +312,17 @@ void dl_to_stream( const auto* type = python::dlpack_dtype_to_scalar_type(tensor.dtype, tensor.device); + const auto type_id = dl_to_type_id(tensor.dtype, tensor.device); - bool borrow = true; if (options.type == nullptr) { - if (type != nullptr) { - options.type = type; - } else { - options.type = scalars::ScalarType::of(); - borrow = false; - } - } else { - borrow &= options.type == type; + options.type = scalars::ScalarType::for_id(type_id); } + bool borrow = options.type->id() == type_id; + // Check if the array is C-contiguous const auto itemsize = tensor.dtype.bits / 8; - auto size = 1; + int64_t size = 1; for (int64_t i=0; i tmp(size * itemsize); py::ssize_t out_strides[2]{}; py::ssize_t in_strides[2] {}; - dimn_t tmp_shape[2]{}; + py::ssize_t in_shape[2] {}; + dimn_t out_shape[2]{}; out_strides[tensor.ndim - 1] = itemsize; bool transposed = tensor.ndim == 2 && tensor.shape[0] < tensor.shape[1]; if (tensor.ndim == 2) { if (transposed) { out_strides[0] = tensor.shape[1]; - tmp_shape[0] = tensor.shape[1]; - tmp_shape[1] = tensor.shape[0]; + out_shape[0] = tensor.shape[1]; + out_shape[1] = tensor.shape[0]; } else { out_strides[0] = tensor.shape[0]; - tmp_shape[0] = tensor.shape[0]; - tmp_shape[1] = tensor.shape[1]; + out_shape[0] = tensor.shape[0]; + out_shape[1] = tensor.shape[1]; } if (tensor.strides != nullptr) { in_strides[0] = tensor.strides[0]; @@ -317,19 +372,20 @@ void dl_to_stream( in_strides[1] = tensor.shape[1]; } + in_shape[1] = tensor.shape[1]; } + in_shape[0] = tensor.shape[0]; - // TODO: We can't use tensor. objects because they are the wrong type. stride_copy( tmp.data(), tensor.data, itemsize, tensor.ndim, - tensor.shape, tensor.strides, out_strides + in_shape, in_strides, out_strides ); // Now that we're C-contiguous, convert_copy into the result. result.data_buffer = KeyScalarArray(options.type); result.data_buffer.allocate_scalars(size); options.type->convert_copy( - result.data_buffer, {type, tmp.data()}, size // TODO: This isn't right + result.data_buffer, {type_id, tmp.data()}, size ); if (tensor.ndim == 1) { @@ -341,13 +397,13 @@ void dl_to_stream( RPY_DBG_ASSERT( tensor.shape[0] * tensor.shape[1] == size ); - result.data_stream.reserve_size(tmp_shape[0]); - result.data_stream.set_elts_per_row(tmp_shape[1]); + result.data_stream.reserve_size(out_shape[0]); + result.data_stream.set_elts_per_row(out_shape[1]); scalars::ScalarPointer sptr(result.data_buffer); - for (dimn_t i = 0; i < tmp_shape[0]; ++i) { + for (dimn_t i = 0; i < out_shape[0]; ++i) { result.data_stream.push_back(sptr); - sptr += tmp_shape[1]; + sptr += out_shape[1]; } } } diff --git a/roughpy/src/scalars/parse_key_scalar_stream.h b/roughpy/src/scalars/parse_key_scalar_stream.h index 2e1580aa1..0183af030 100644 --- a/roughpy/src/scalars/parse_key_scalar_stream.h +++ b/roughpy/src/scalars/parse_key_scalar_stream.h @@ -25,6 +25,8 @@ struct ParsedKeyScalarStream { }; +string dl_to_type_id(const DLDataType& dtype, const DLDevice& device); + RPY_NO_DISCARD ParsedKeyScalarStream From 1fc04fa977bb10396828fc6fb8aedf46d2eff6ed Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Thu, 10 Aug 2023 11:53:17 +0100 Subject: [PATCH 19/33] Add method for getting max row size in stream --- scalars/include/roughpy/scalars/scalar_stream.h | 2 ++ scalars/src/scalar_stream.cpp | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/scalars/include/roughpy/scalars/scalar_stream.h b/scalars/include/roughpy/scalars/scalar_stream.h index 03730b142..c223dda18 100644 --- a/scalars/include/roughpy/scalars/scalar_stream.h +++ b/scalars/include/roughpy/scalars/scalar_stream.h @@ -63,6 +63,8 @@ class RPY_EXPORT ScalarStream RPY_NO_DISCARD dimn_t row_count() const noexcept { return m_stream.size(); } + RPY_NO_DISCARD dimn_t max_row_size() const noexcept; + RPY_NO_DISCARD ScalarArray operator[](dimn_t row) const noexcept; RPY_NO_DISCARD diff --git a/scalars/src/scalar_stream.cpp b/scalars/src/scalar_stream.cpp index 6cb21b65e..53640c4c8 100644 --- a/scalars/src/scalar_stream.cpp +++ b/scalars/src/scalar_stream.cpp @@ -32,11 +32,14 @@ #include +#include + #include #include #include #include + using namespace rpy; using namespace rpy::scalars; @@ -81,6 +84,15 @@ dimn_t ScalarStream::col_count(dimn_t i) const noexcept return m_elts_per_row[i]; } +dimn_t ScalarStream::max_row_size() const noexcept { + if (m_elts_per_row.empty()) { return 0; } + if (m_elts_per_row.size() == 1) { return m_elts_per_row[0]; } + + auto max_elt = std::max_element(m_elts_per_row.begin(), m_elts_per_row.end()); + + return *max_elt; +} + ScalarArray ScalarStream::operator[](dimn_t row) const noexcept { return {ScalarPointer(p_type, m_stream[row]), col_count(row)}; From 251eb776528c3cb6948634c80756f66ad3af4599 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Thu, 10 Aug 2023 12:36:11 +0100 Subject: [PATCH 20/33] reworked constructor for lie increment stream to use key-scalar-stream --- roughpy/src/streams/lie_increment_stream.cpp | 101 +++++++++++-------- 1 file changed, 59 insertions(+), 42 deletions(-) diff --git a/roughpy/src/streams/lie_increment_stream.cpp b/roughpy/src/streams/lie_increment_stream.cpp index c50e32a61..ee6f40d3c 100644 --- a/roughpy/src/streams/lie_increment_stream.cpp +++ b/roughpy/src/streams/lie_increment_stream.cpp @@ -37,6 +37,7 @@ #include #include "args/kwargs_to_path_metadata.h" +#include "scalars/parse_key_scalar_stream.h" #include "scalars/scalar_type.h" #include "scalars/scalars.h" #include "stream.h" @@ -84,21 +85,8 @@ static py::object lie_increment_stream_from_increments( options.max_nested = 2; options.allow_scalar = false; - auto buffer = python::py_to_buffer(data, options); - - idimn_t increment_size = 0; - idimn_t num_increments = 0; - - if (options.shape.empty()) { - increment_size = static_cast(buffer.size()); - num_increments = 1; - } else if (options.shape.size() == 1) { - increment_size = options.shape[0]; - num_increments = 1; - } else { - increment_size = options.shape[1]; - num_increments = options.shape[0]; - } + // auto buffer = python::py_to_buffer(data, options); + auto ks_stream = python::parse_key_scalar_stream(data, options); if (md.scalar_type == nullptr) { if (options.type != nullptr) { @@ -108,20 +96,25 @@ static py::object lie_increment_stream_from_increments( } } - RPY_CHECK( - buffer.size() - == static_cast(increment_size * num_increments) - ); RPY_CHECK(md.scalar_type != nullptr); if (!md.ctx) { + if (md.width == 0) { + md.width = static_cast(ks_stream.data_stream.max_row_size()); + } + if (md.width == 0 || md.depth == 0) { - RPY_THROW(py::value_error, + RPY_THROW( + py::value_error, "either ctx or both width and depth must be specified" ); } md.ctx = algebra::get_context(md.width, md.depth, md.scalar_type); } + dimn_t num_increments = ks_stream.data_stream.row_count(); + + RPY_CHECK(num_increments > 0); + auto effective_support = intervals::RealInterval::right_unbounded(0.0, md.interval_type); @@ -134,25 +127,51 @@ static py::object lie_increment_stream_from_increments( buffer_to_indices(indices, info); } else if (py::isinstance(indices_arg)) { // Interpret this as a column in the data; - auto icol = indices_arg.cast(); - if (icol < 0) { icol += increment_size; } - if (icol < 0 || icol >= increment_size) { - RPY_THROW(py::value_error, "index out of bounds"); - } - RPY_CHECK(icol < buffer.size()); + auto icol = indices_arg.cast(); indices.reserve(num_increments); - indices.push_back(buffer[icol].to_scalar_t()); - for (idimn_t i = 1; i < num_increments; ++i) { - indices.push_back(static_cast( - indices.back() + buffer[i * increment_size + icol].to_scalar_t() - )); + + auto add_index = [&indices](param_t val) { + if (indices.empty()) { + indices.push_back(val); + } else { + indices.push_back(indices.back() + val); + } + }; + + for (idimn_t i = 0; i < num_increments; ++i) { + auto row = ks_stream.data_stream[i]; + + if (row.has_keys()) { + + const auto* begin = row.keys(); + const auto* end = begin + row.size(); + + auto found = std::find(begin, end, icol + 1); + + if (found == end) { + RPY_THROW( + std::invalid_argument, + "cannot find index column in provided data" + ); + } + + const auto pos = static_cast(found - begin); + + add_index(row[pos].to_scalar_t()); + } else { + RPY_CHECK(icol < row.size()); + add_index(row[icol].to_scalar_t()); + } } } else if (py::isinstance(indices_arg)) { indices = indices_arg.cast>(); } else { - RPY_THROW(py::type_error,"unexpected type provided to 'indices' " - "argument"); + RPY_THROW( + py::type_error, + "unexpected type provided to 'indices' " + "argument" + ); } if (!indices.empty()) { @@ -165,21 +184,21 @@ static py::object lie_increment_stream_from_increments( if (indices.empty()) { indices.reserve(num_increments); - for (idimn_t i = 0; i < num_increments; ++i) { - indices.emplace_back(i); - } + for (dimn_t i = 0; i < num_increments; ++i) { indices.emplace_back(i); } } else if (static_cast(indices.size()) != num_increments) { - RPY_THROW(py::value_error,"mismatch between number of rows in data and " - "number of indices"); + RPY_THROW( + py::value_error, + "mismatch between number of rows in data and " + "number of indices" + ); } if (!md.schema) { md.schema = std::make_shared(md.width); } - auto result = streams::Stream(streams::LieIncrementStream( - buffer, indices, + ks_stream, indices, {md.width, effective_support, md.ctx, md.scalar_type, md.vector_type ? *md.vector_type : algebra::VectorType::Dense, md.resolution}, @@ -187,8 +206,6 @@ static py::object lie_increment_stream_from_increments( )); if (md.support) { result.restrict_to(*md.support); } - if (options.cleanup) { options.cleanup(); } - return py::reinterpret_steal( python::RPyStream_FromStream(std::move(result)) ); From 6543412f5a3e501ee77c41f062ea58256ebc8e9e Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Thu, 10 Aug 2023 14:44:57 +0100 Subject: [PATCH 21/33] Set SPReal dtype instead of using "float32" --- tests/scalars/dlpack/test_construct_from_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/scalars/dlpack/test_construct_from_jax.py b/tests/scalars/dlpack/test_construct_from_jax.py index 89ceb38d2..1e4ee4d4b 100644 --- a/tests/scalars/dlpack/test_construct_from_jax.py +++ b/tests/scalars/dlpack/test_construct_from_jax.py @@ -34,7 +34,7 @@ def test_increment_stream_from_jax_array(self): [-0.6642682, -0.12166289, 1.04914618, -1.01415539, -1.58841276, -2.54356289] ]) - stream = rp.LieIncrementStream.from_increments(np.array(array), width=6, depth=2, dtype="float32") + stream = rp.LieIncrementStream.from_increments(np.array(array), width=6, depth=2, dtype=rp.SPReal) lsig01 = stream.log_signature(rp.RealInterval(0, 1)) lsig12 = stream.log_signature(rp.RealInterval(1, 2)) From 1f4ce28be5dca731600ed8b0647bee2278927763 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Thu, 10 Aug 2023 15:05:19 +0100 Subject: [PATCH 22/33] Better implementation of ScalarType::for_id --- scalars/src/scalar_type.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/scalars/src/scalar_type.cpp b/scalars/src/scalar_type.cpp index 29faeb6d3..9a880ba66 100644 --- a/scalars/src/scalar_type.cpp +++ b/scalars/src/scalar_type.cpp @@ -56,7 +56,14 @@ const ScalarType* ScalarType::host_type() const noexcept { return this; } const ScalarType* ScalarType::for_id(const string& id) { - return ScalarType::of(); + try { + const auto* type = get_type(id); + if (type) { return type; } + return ScalarType::of(); + } catch (std::exception&) { + return ScalarType::of(); + } + // TODO: needs more thorough implementation } const ScalarType* ScalarType::from_type_details( const BasicScalarInfo& details, const ScalarDeviceInfo& device From 7fd0a02c9a76d75700624ba18fd6f56eb7562b81 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Fri, 11 Aug 2023 11:48:56 +0100 Subject: [PATCH 23/33] Added a couple of debug asserts --- scalars/src/scalar_stream.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scalars/src/scalar_stream.cpp b/scalars/src/scalar_stream.cpp index 53640c4c8..02f87d417 100644 --- a/scalars/src/scalar_stream.cpp +++ b/scalars/src/scalar_stream.cpp @@ -118,11 +118,13 @@ void ScalarStream::reserve_size(dimn_t num_rows) { m_stream.reserve(num_rows); } void ScalarStream::push_back(const ScalarPointer& data) { RPY_CHECK(m_elts_per_row.size() == 1 && m_elts_per_row[0] > 0); + RPY_DBG_ASSERT(!data.is_null()); m_stream.push_back(data.ptr()); } void ScalarStream::push_back(const ScalarArray& data) { if (m_elts_per_row.size() == 1) { + RPY_DBG_ASSERT(!data.is_null()); m_stream.push_back(data.ptr()); if (data.size() != m_elts_per_row[0]) { m_elts_per_row.reserve(m_stream.size() + 1); From afea15d2a2f1fe1d210c1ccca9f2e0bf9cb0e050 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Fri, 11 Aug 2023 11:49:17 +0100 Subject: [PATCH 24/33] KeyScalarStream constructor --- .../roughpy/streams/lie_increment_stream.h | 6 +++ streams/src/lie_increment_stream.cpp | 45 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/streams/include/roughpy/streams/lie_increment_stream.h b/streams/include/roughpy/streams/lie_increment_stream.h index 700cef2dd..c5d097233 100644 --- a/streams/include/roughpy/streams/lie_increment_stream.h +++ b/streams/include/roughpy/streams/lie_increment_stream.h @@ -37,6 +37,7 @@ #include #include #include +#include namespace rpy { namespace streams { @@ -59,6 +60,11 @@ class RPY_EXPORT LieIncrementStream : public DyadicCachingLayer StreamMetadata md, std::shared_ptr schema ); + explicit LieIncrementStream( + const scalars::KeyScalarStream& ks_stream, Slice indices, + StreamMetadata md, std::shared_ptr schema + ); + RPY_NO_DISCARD bool empty(const intervals::Interval& interval ) const noexcept override; diff --git a/streams/src/lie_increment_stream.cpp b/streams/src/lie_increment_stream.cpp index 0563a0918..a01b21e88 100644 --- a/streams/src/lie_increment_stream.cpp +++ b/streams/src/lie_increment_stream.cpp @@ -105,6 +105,51 @@ LieIncrementStream::LieIncrementStream( } } +LieIncrementStream::LieIncrementStream( + const scalars::KeyScalarStream& ks_stream, Slice indices, + StreamMetadata mdarg, std::shared_ptr schema_arg +) + : DyadicCachingLayer(std::move(mdarg), std::move(schema_arg)) +{ + using scalars::Scalar; + RPY_CHECK(indices.size() == ks_stream.row_count()); + + const auto& md = this->metadata(); + const auto& ctx = *md.default_context; + + const auto& sch = this->schema(); + const auto* param = sch.parametrization(); + const bool param_needs_adding = param != nullptr && param->needs_adding(); + const key_type param_slot + = (param_needs_adding) ? sch.time_channel_to_lie_key() : 0; + + m_data.reserve(indices.size()); + param_t previous_param = 0.0; + + for (dimn_t i=0; isecond[param_slot] = Scalar(index - previous_param); + } + previous_param = index; + + } +} + algebra::Lie LieIncrementStream::log_signature_impl( const intervals::Interval& interval, const algebra::Context& ctx ) const From dbd66139d5b1f69e6888d4419e569bf3c7c93f6a Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Fri, 11 Aug 2023 11:49:34 +0100 Subject: [PATCH 25/33] Fixed bug in strided copy. --- roughpy/src/args/strided_copy.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roughpy/src/args/strided_copy.cpp b/roughpy/src/args/strided_copy.cpp index c271c356d..1e79f9814 100644 --- a/roughpy/src/args/strided_copy.cpp +++ b/roughpy/src/args/strided_copy.cpp @@ -25,7 +25,7 @@ void rpy::python::stride_copy( } } else { for (py::ssize_t i=0; i Date: Fri, 11 Aug 2023 11:49:54 +0100 Subject: [PATCH 26/33] Fixed a few bugs --- .../src/scalars/parse_key_scalar_stream.cpp | 43 +++++++++++-------- roughpy/src/scalars/parse_key_scalar_stream.h | 11 +++-- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/roughpy/src/scalars/parse_key_scalar_stream.cpp b/roughpy/src/scalars/parse_key_scalar_stream.cpp index 3e99daea0..a0d72839c 100644 --- a/roughpy/src/scalars/parse_key_scalar_stream.cpp +++ b/roughpy/src/scalars/parse_key_scalar_stream.cpp @@ -13,19 +13,17 @@ using namespace rpy::python; using rpy::scalars::KeyScalarArray; using rpy::scalars::KeyScalarStream; -namespace { -inline void buffer_to_stream( +static inline void buffer_to_stream( ParsedKeyScalarStream& result, const py::buffer_info& buf_info, PyToBufferOptions& options ); -inline void dl_to_stream( +static inline void dl_to_stream( ParsedKeyScalarStream& result, const py::object& dl_object, PyToBufferOptions& options ); -}// namespace #define DLTHROW(type, bits) \ RPY_THROW(std::runtime_error, std::to_string(bits) + " bit " #type " is not supported") @@ -63,7 +61,7 @@ string python::dl_to_type_id(const DLDataType& dtype, const DLDevice& RPY_UNUSED case 8: return scalars::type_id_of(); case 16: return scalars::type_id_of(); case 32: return scalars::type_id_of(); - case 64: return scalars::type_id_of(); + case 64: return scalars::type_id_of(); default: DLTHROW(uint, dtype.bits); } case kDLBfloat: @@ -80,18 +78,18 @@ string python::dl_to_type_id(const DLDataType& dtype, const DLDevice& RPY_UNUSED DLTHROW(bool, dtype.bits); } + RPY_UNREACHABLE_RETURN({}); } #undef DLTHROW -python::ParsedKeyScalarStream python::parse_key_scalar_stream( +void python::parse_key_scalar_stream(ParsedKeyScalarStream& result, const py::object& data, rpy::python::PyToBufferOptions& options ) { - ParsedKeyScalarStream result; /* * A key-data stream should not represent a single (key-)scalar value, @@ -186,7 +184,6 @@ python::ParsedKeyScalarStream python::parse_key_scalar_stream( ); } - return result; } void buffer_to_stream( @@ -306,17 +303,25 @@ void dl_to_stream( RPY_CHECK(dltensor != nullptr); auto& tensor = dltensor->dl_tensor; + const auto type_id = dl_to_type_id(tensor.dtype, tensor.device); + + if (tensor.ndim == 0 || tensor.shape[0] == 0 || tensor.ndim > 1 && tensor.shape[1] == 0) { + if (options.type == nullptr) { + options.type = scalars::ScalarType::for_id(type_id); + } + return; + } + RPY_CHECK(tensor.device.device_type == kDLCPU); RPY_CHECK(tensor.ndim == 1 || tensor.ndim == 2); RPY_CHECK(tensor.dtype.lanes == 1); - const auto* type - = python::dlpack_dtype_to_scalar_type(tensor.dtype, tensor.device); - const auto type_id = dl_to_type_id(tensor.dtype, tensor.device); - if (options.type == nullptr) { options.type = scalars::ScalarType::for_id(type_id); + } + result.data_stream = KeyScalarStream(options.type); + result.data_buffer = KeyScalarArray(options.type); bool borrow = options.type->id() == type_id; @@ -340,9 +345,8 @@ void dl_to_stream( result.data_stream.reserve_size(num_increments); const auto* ptr = static_cast(tensor.data); - const auto stride = tensor.shape[0] * itemsize; + const auto stride = tensor.shape[1] * itemsize; for (dimn_t i = 0; i < num_increments; ++i) { - result.data_stream.push_back({options.type, ptr}); ptr += stride; } } @@ -354,13 +358,15 @@ void dl_to_stream( dimn_t out_shape[2]{}; out_strides[tensor.ndim - 1] = itemsize; bool transposed = tensor.ndim == 2 && tensor.shape[0] < tensor.shape[1]; + + in_shape[0] = tensor.shape[0]; if (tensor.ndim == 2) { if (transposed) { - out_strides[0] = tensor.shape[1]; + out_strides[0] = tensor.shape[1]*itemsize; out_shape[0] = tensor.shape[1]; out_shape[1] = tensor.shape[0]; } else { - out_strides[0] = tensor.shape[0]; + out_strides[0] = tensor.shape[0]*itemsize; out_shape[0] = tensor.shape[0]; out_shape[1] = tensor.shape[1]; } @@ -368,13 +374,12 @@ void dl_to_stream( in_strides[0] = tensor.strides[0]; in_strides[1] = tensor.strides[1]; } else { - in_strides[0] = tensor.shape[0]; - in_strides[1] = tensor.shape[1]; + in_strides[0] = tensor.shape[1]*itemsize; + in_strides[1] = itemsize; } in_shape[1] = tensor.shape[1]; } - in_shape[0] = tensor.shape[0]; stride_copy( tmp.data(), tensor.data, itemsize, tensor.ndim, diff --git a/roughpy/src/scalars/parse_key_scalar_stream.h b/roughpy/src/scalars/parse_key_scalar_stream.h index 0183af030..d9e413ab0 100644 --- a/roughpy/src/scalars/parse_key_scalar_stream.h +++ b/roughpy/src/scalars/parse_key_scalar_stream.h @@ -22,17 +22,16 @@ struct ParsedKeyScalarStream { /// Buffer holding key/scalar data if a copy had to be made scalars::KeyScalarArray data_buffer; - }; string dl_to_type_id(const DLDataType& dtype, const DLDevice& device); +void parse_key_scalar_stream( + ParsedKeyScalarStream& result, const py::object& data, + PyToBufferOptions& options +); -RPY_NO_DISCARD -ParsedKeyScalarStream -parse_key_scalar_stream(const py::object& data, PyToBufferOptions& options); - -} +}// namespace python }// namespace rpy #endif// ROUGHPY_PARSE_KEY_SCALAR_STREAM_H From db5056eef17051b8ff78d183cf7a053ecb488498 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Fri, 11 Aug 2023 11:50:37 +0100 Subject: [PATCH 27/33] Force construct inplace, rather than relying on nrvo --- roughpy/src/streams/lie_increment_stream.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/roughpy/src/streams/lie_increment_stream.cpp b/roughpy/src/streams/lie_increment_stream.cpp index ee6f40d3c..eac848bd8 100644 --- a/roughpy/src/streams/lie_increment_stream.cpp +++ b/roughpy/src/streams/lie_increment_stream.cpp @@ -86,7 +86,8 @@ static py::object lie_increment_stream_from_increments( options.allow_scalar = false; // auto buffer = python::py_to_buffer(data, options); - auto ks_stream = python::parse_key_scalar_stream(data, options); + python::ParsedKeyScalarStream ks_stream; + python::parse_key_scalar_stream(ks_stream, data, options); if (md.scalar_type == nullptr) { if (options.type != nullptr) { @@ -113,7 +114,6 @@ static py::object lie_increment_stream_from_increments( dimn_t num_increments = ks_stream.data_stream.row_count(); - RPY_CHECK(num_increments > 0); auto effective_support = intervals::RealInterval::right_unbounded(0.0, md.interval_type); @@ -198,7 +198,7 @@ static py::object lie_increment_stream_from_increments( } auto result = streams::Stream(streams::LieIncrementStream( - ks_stream, indices, + ks_stream.data_stream, indices, {md.width, effective_support, md.ctx, md.scalar_type, md.vector_type ? *md.vector_type : algebra::VectorType::Dense, md.resolution}, From d8b4ea6b775c896f07b99f98ced451069f5b6205 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Fri, 11 Aug 2023 11:51:53 +0100 Subject: [PATCH 28/33] Remove call to cleanup --- roughpy/src/algebra/free_tensor.cpp | 1 - roughpy/src/algebra/lie.cpp | 1 - roughpy/src/algebra/shuffle_tensor.cpp | 1 - 3 files changed, 3 deletions(-) diff --git a/roughpy/src/algebra/free_tensor.cpp b/roughpy/src/algebra/free_tensor.cpp index 96c3f802c..2fff38b78 100644 --- a/roughpy/src/algebra/free_tensor.cpp +++ b/roughpy/src/algebra/free_tensor.cpp @@ -138,7 +138,6 @@ static FreeTensor construct_free_tensor(py::object data, py::kwargs kwargs) {std::move(buffer), helper.vtype} ); - if (options.cleanup) { options.cleanup(); } RPY_DBG_ASSERT(result.coeff_type() != nullptr); diff --git a/roughpy/src/algebra/lie.cpp b/roughpy/src/algebra/lie.cpp index d402ea6ff..1a69a2924 100644 --- a/roughpy/src/algebra/lie.cpp +++ b/roughpy/src/algebra/lie.cpp @@ -83,7 +83,6 @@ static Lie construct_lie(py::object data, py::kwargs kwargs) auto result = helper.ctx->construct_lie({std::move(buffer), helper.vtype}); - if (options.cleanup) { options.cleanup(); } return result; } diff --git a/roughpy/src/algebra/shuffle_tensor.cpp b/roughpy/src/algebra/shuffle_tensor.cpp index 7e020d96f..e34c23218 100644 --- a/roughpy/src/algebra/shuffle_tensor.cpp +++ b/roughpy/src/algebra/shuffle_tensor.cpp @@ -103,7 +103,6 @@ static ShuffleTensor construct_shuffle(py::object data, py::kwargs kwargs) {std::move(buffer), helper.vtype} ); - if (options.cleanup) { options.cleanup(); } return result; } From 821228037b456b007fd1ac056bca51b7eabf7e93 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Fri, 11 Aug 2023 12:41:10 +0100 Subject: [PATCH 29/33] Update changelog --- CHANGELOG | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG b/CHANGELOG index 384d874ae..6fe3298fe 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -18,6 +18,7 @@ Version 0.0.7: MKL/BLAS+LAPACK. - Added function to query ring characteristics of a ScalarType - currently unused. + - Added KeyScalarStream for constructing streams from array-like data more easily. From bb14a044a1a6a90f7227cb8639f676a7fb4ad8dc Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Fri, 11 Aug 2023 15:40:15 +0100 Subject: [PATCH 30/33] Fixed a bad macro guard name --- scalars/scalar_blas_defs.h.in | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scalars/scalar_blas_defs.h.in b/scalars/scalar_blas_defs.h.in index d2a937c13..45963b6a4 100644 --- a/scalars/scalar_blas_defs.h.in +++ b/scalars/scalar_blas_defs.h.in @@ -4,7 +4,7 @@ #cmakedefine RPY_USE_MKL -#if @RPY_USE_MKL@ +#ifdef RPY_USE_MKL # include # include #else @@ -14,7 +14,7 @@ #include -#if @RPY_USE_MKL@ +#ifdef RPY_USE_MKL # define EIGEN_USE_MKL_ALL #endif From c878591d292ada4d4fd5fa56f5500bdd83700ab5 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Fri, 11 Aug 2023 16:14:19 +0100 Subject: [PATCH 31/33] Added new dlpack helpers --- roughpy/CMakeLists.txt | 2 + roughpy/src/args/dlpack_helpers.cpp | 74 +++++++++++++++ roughpy/src/args/dlpack_helpers.h | 90 +++++++++++++++++++ .../src/scalars/parse_key_scalar_stream.cpp | 75 ++-------------- roughpy/src/scalars/parse_key_scalar_stream.h | 1 - roughpy/src/scalars/scalars.cpp | 33 +------ roughpy/src/scalars/scalars.h | 3 - 7 files changed, 176 insertions(+), 102 deletions(-) create mode 100644 roughpy/src/args/dlpack_helpers.cpp create mode 100644 roughpy/src/args/dlpack_helpers.h diff --git a/roughpy/CMakeLists.txt b/roughpy/CMakeLists.txt index ed414dd3c..86639fcff 100644 --- a/roughpy/CMakeLists.txt +++ b/roughpy/CMakeLists.txt @@ -73,6 +73,8 @@ target_sources(RoughPy_PyModule PRIVATE src/algebra/tensor_key_iterator.h src/args/convert_timestamp.cpp src/args/convert_timestamp.h + src/args/dlpack_helpers.cpp + src/args/dlpack_helpers.h src/args/kwargs_to_path_metadata.h src/args/kwargs_to_path_metadata.cpp src/args/kwargs_to_vector_construction.cpp diff --git a/roughpy/src/args/dlpack_helpers.cpp b/roughpy/src/args/dlpack_helpers.cpp new file mode 100644 index 000000000..95dd30cb5 --- /dev/null +++ b/roughpy/src/args/dlpack_helpers.cpp @@ -0,0 +1,74 @@ +// +// Created by sam on 11/08/23. +// + +#include "dlpack_helpers.h" + +using namespace rpy; +using namespace python; + + + +#define DLTHROW(type, bits) \ + RPY_THROW(std::runtime_error, \ + std::to_string(bits) + " bit " #type " is not supported") + +const string& +python::type_id_for_dl_info(const DLDataType& dtype, const DLDevice& device) +{ + if (device.device_type != kDLCPU) { + RPY_THROW( + std::runtime_error, + "for the time being, constructing non-cpu " + "devices are not supported by RoughPy" + ); + } + + switch (dtype.code) { + case kDLFloat: + switch (dtype.bits) { + case 16: return scalars::type_id_of(); + case 32: return scalars::type_id_of(); + case 64: return scalars::type_id_of(); + default: DLTHROW(float, dtype.bits); + } + case kDLInt: + switch (dtype.bits) { + case 8: return scalars::type_id_of(); + case 16: return scalars::type_id_of(); + case 32: return scalars::type_id_of(); + case 64: return scalars::type_id_of(); + default: DLTHROW(int, dtype.bits); + } + case kDLUInt: + switch (dtype.bits) { + case 8: return scalars::type_id_of(); + case 16: return scalars::type_id_of(); + case 32: return scalars::type_id_of(); + case 64: return scalars::type_id_of(); + default: DLTHROW(uint, dtype.bits); + } + case kDLBfloat: + if (dtype.bits == 16) { + return scalars::type_id_of(); + } else { + DLTHROW(bfloat, dtype.bits); + } + case kDLComplex: DLTHROW(complex, dtype.bits); + case kDLOpaqueHandle: DLTHROW(opaquehandle, dtype.bits); + case kDLBool: DLTHROW(bool, dtype.bits); + } + RPY_UNREACHABLE_RETURN({}); +} + +#undef DLTHROW + + + +const scalars::ScalarType* +python::scalar_type_for_dl_info(const DLDataType& dtype, const DLDevice& device) +{ + return scalars::ScalarType::for_id(type_id_for_dl_info(dtype, device)); +} + + diff --git a/roughpy/src/args/dlpack_helpers.h b/roughpy/src/args/dlpack_helpers.h new file mode 100644 index 000000000..c0ca2b888 --- /dev/null +++ b/roughpy/src/args/dlpack_helpers.h @@ -0,0 +1,90 @@ +// +// Created by sam on 11/08/23. +// + +#ifndef ROUGHPY_DLPACK_HELPERS_H +#define ROUGHPY_DLPACK_HELPERS_H + +#include "roughpy_module.h" +#include "dlpack.h" + +#include +#include + +namespace rpy { +namespace python { + +constexpr scalars::ScalarTypeCode +convert_from_dl_typecode(uint8_t code) noexcept { + return static_cast(code); +} +constexpr uint8_t +convert_to_dl_typecode(scalars::ScalarTypeCode code) noexcept { + return static_cast(code); +} + +constexpr scalars::ScalarDeviceType +convert_from_dl_device_type(DLDeviceType type) noexcept { + return static_cast(type); +} + +constexpr DLDeviceType +convert_to_dl_device_type(scalars::ScalarDeviceType type) noexcept +{ + return static_cast(type); +} + +constexpr scalars::BasicScalarInfo +convert_from_dl_datatype(const DLDataType& dtype) noexcept { + return { + convert_from_dl_typecode(dtype.code), + dtype.bits, + dtype.lanes + }; +} + +inline DLDataType +convert_to_dl_datatype(const scalars::BasicScalarInfo& info) noexcept +{ + return { + convert_to_dl_typecode(info.code), + info.bits, + info.lanes + }; +} + +constexpr scalars::ScalarDeviceInfo +convert_from_dl_device_info(const DLDevice& device) noexcept +{ + return { + convert_from_dl_device_type(device.device_type), + device.device_id + }; +} + +constexpr DLDevice +convert_to_dl_device_info(const scalars::ScalarDeviceInfo& device) noexcept { + return { + convert_to_dl_device_type(device.device_type), + device.device_id + }; +} + + +const string& +type_id_for_dl_info(const DLDataType& dtype, const DLDevice& device); + +const scalars::ScalarType* +scalar_type_for_dl_info(const DLDataType& dtype, const DLDevice& device); + +inline const scalars::ScalarType* scalar_type_of_dl_info(const DLDataType& dtype, const DLDevice& device) +{ + return scalars::ScalarType::from_type_details( + convert_from_dl_datatype(dtype), + convert_from_dl_device_info(device)); +} + +} +} + +#endif// ROUGHPY_DLPACK_HELPERS_H diff --git a/roughpy/src/scalars/parse_key_scalar_stream.cpp b/roughpy/src/scalars/parse_key_scalar_stream.cpp index a0d72839c..74cfddb1b 100644 --- a/roughpy/src/scalars/parse_key_scalar_stream.cpp +++ b/roughpy/src/scalars/parse_key_scalar_stream.cpp @@ -5,7 +5,7 @@ #include "parse_key_scalar_stream.h" #include "args/numpy.h" #include "args/strided_copy.h" -#include "dlpack.h" +#include "args/dlpack_helpers.h" using namespace rpy; using namespace rpy::python; @@ -25,66 +25,6 @@ static inline void dl_to_stream( ); -#define DLTHROW(type, bits) \ - RPY_THROW(std::runtime_error, std::to_string(bits) + " bit " #type " is not supported") - -string python::dl_to_type_id(const DLDataType& dtype, const DLDevice& RPY_UNUSED_VAR device) -{ - RPY_CHECK(dtype.lanes == 1); - switch (dtype.code) { - case kDLFloat: - switch (dtype.bits) { - case 16: - return scalars::type_id_of(); - case 32: - return scalars::type_id_of(); - case 64: - return scalars::type_id_of(); - default: - DLTHROW(float, dtype.bits); - } - case kDLInt: - switch (dtype.bits) { - case 8: - return scalars::type_id_of(); - case 16: - return scalars::type_id_of(); - case 32: - return scalars::type_id_of(); - case 64: - return scalars::type_id_of(); - default: - DLTHROW(int, dtype.bits); - } - case kDLUInt: - switch (dtype.bits) { - case 8: return scalars::type_id_of(); - case 16: return scalars::type_id_of(); - case 32: return scalars::type_id_of(); - case 64: return scalars::type_id_of(); - default: DLTHROW(uint, dtype.bits); - } - case kDLBfloat: - if (dtype.bits == 16) { - return scalars::type_id_of(); - } else { - DLTHROW(bfloat, dtype.bits); - } - case kDLComplex: - DLTHROW(complex, dtype.bits); - case kDLOpaqueHandle: - DLTHROW(opaquehandle, dtype.bits); - case kDLBool: - DLTHROW(bool, dtype.bits); - } - - RPY_UNREACHABLE_RETURN({}); -} - -#undef DLTHROW - - - void python::parse_key_scalar_stream(ParsedKeyScalarStream& result, const py::object& data, rpy::python::PyToBufferOptions& options ) @@ -303,12 +243,12 @@ void dl_to_stream( RPY_CHECK(dltensor != nullptr); auto& tensor = dltensor->dl_tensor; - const auto type_id = dl_to_type_id(tensor.dtype, tensor.device); + const auto type_id = python::type_id_for_dl_info(tensor.dtype, tensor.device); + if (options.type == nullptr) { + options.type = scalars::ScalarType::for_id(type_id); + } if (tensor.ndim == 0 || tensor.shape[0] == 0 || tensor.ndim > 1 && tensor.shape[1] == 0) { - if (options.type == nullptr) { - options.type = scalars::ScalarType::for_id(type_id); - } return; } @@ -316,10 +256,6 @@ void dl_to_stream( RPY_CHECK(tensor.ndim == 1 || tensor.ndim == 2); RPY_CHECK(tensor.dtype.lanes == 1); - if (options.type == nullptr) { - options.type = scalars::ScalarType::for_id(type_id); - - } result.data_stream = KeyScalarStream(options.type); result.data_buffer = KeyScalarArray(options.type); @@ -347,6 +283,7 @@ void dl_to_stream( const auto* ptr = static_cast(tensor.data); const auto stride = tensor.shape[1] * itemsize; for (dimn_t i = 0; i < num_increments; ++i) { + result.data_stream.push_back({options.type, ptr}); ptr += stride; } } diff --git a/roughpy/src/scalars/parse_key_scalar_stream.h b/roughpy/src/scalars/parse_key_scalar_stream.h index d9e413ab0..0c13b0e48 100644 --- a/roughpy/src/scalars/parse_key_scalar_stream.h +++ b/roughpy/src/scalars/parse_key_scalar_stream.h @@ -24,7 +24,6 @@ struct ParsedKeyScalarStream { scalars::KeyScalarArray data_buffer; }; -string dl_to_type_id(const DLDataType& dtype, const DLDevice& device); void parse_key_scalar_stream( ParsedKeyScalarStream& result, const py::object& data, diff --git a/roughpy/src/scalars/scalars.cpp b/roughpy/src/scalars/scalars.cpp index f5d9b920f..73d96ace4 100644 --- a/roughpy/src/scalars/scalars.cpp +++ b/roughpy/src/scalars/scalars.cpp @@ -41,6 +41,8 @@ #include "r_py_polynomial.h" #include "scalar_type.h" +#include "args/dlpack_helpers.h" + using namespace rpy; using namespace rpy::python; using namespace pybind11::literals; @@ -132,35 +134,7 @@ void python::init_scalars(pybind11::module_& m) * objects. */ -#define DOCASE(NAME) \ - case static_cast(scalars::ScalarTypeCode::NAME): \ - type = scalars::ScalarTypeCode::NAME; \ - break -const scalars::ScalarType* -python::dlpack_dtype_to_scalar_type(DLDataType dtype, DLDevice device) -{ - using scalars::ScalarDeviceType; - - scalars::ScalarTypeCode type; - switch (dtype.code) { - DOCASE(Float); - DOCASE(Int); - DOCASE(UInt); - DOCASE(OpaqueHandle); - DOCASE(BFloat); - DOCASE(Complex); - DOCASE(Bool); - } - - return scalars::ScalarType::from_type_details( - {type, dtype.bits, dtype.lanes}, - {static_cast(device.device_type), - device.device_id} - ); -} - -#undef DOCASE static inline void dl_copy_strided( std::int32_t ndim, std::int64_t* shape, std::int64_t* strides, @@ -228,8 +202,9 @@ static bool try_fill_buffer_dlpack( auto* strides = dltensor.strides; // This function throws if no matching dtype is found + const auto* tensor_stype - = dlpack_dtype_to_scalar_type(dltensor.dtype, dltensor.device); + = python::scalar_type_of_dl_info(dltensor.dtype, dltensor.device); if (options.type == nullptr) { options.type = tensor_stype; buffer = scalars::KeyScalarArray(options.type); diff --git a/roughpy/src/scalars/scalars.h b/roughpy/src/scalars/scalars.h index 900a77373..fac1c2423 100644 --- a/roughpy/src/scalars/scalars.h +++ b/roughpy/src/scalars/scalars.h @@ -121,9 +121,6 @@ inline bool is_kv_pair(py::handle arg, python::AlternativeKeyType* alternative) return false; } -const scalars::ScalarType* -dlpack_dtype_to_scalar_type(DLDataType dtype, DLDevice device); - scalars::KeyScalarArray py_to_buffer(const py::handle& arg, PyToBufferOptions& options); From 9d9e4fd0c398c22dd7a255b319f8da136389e0a0 Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Fri, 11 Aug 2023 16:47:22 +0100 Subject: [PATCH 32/33] Removed unnecessary check, now done in get-id function --- roughpy/src/scalars/parse_key_scalar_stream.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/roughpy/src/scalars/parse_key_scalar_stream.cpp b/roughpy/src/scalars/parse_key_scalar_stream.cpp index 74cfddb1b..dac29df36 100644 --- a/roughpy/src/scalars/parse_key_scalar_stream.cpp +++ b/roughpy/src/scalars/parse_key_scalar_stream.cpp @@ -252,7 +252,6 @@ void dl_to_stream( return; } - RPY_CHECK(tensor.device.device_type == kDLCPU); RPY_CHECK(tensor.ndim == 1 || tensor.ndim == 2); RPY_CHECK(tensor.dtype.lanes == 1); From 1cf91456333a55558f58baf7c1a4870546dab41c Mon Sep 17 00:00:00 2001 From: Sam Morley Date: Thu, 17 Aug 2023 12:28:50 +0100 Subject: [PATCH 33/33] Added trampoline classes for stream channels --- roughpy/src/streams/schema.cpp | 127 ++++++++++++++++++++++++++++++++- 1 file changed, 126 insertions(+), 1 deletion(-) diff --git a/roughpy/src/streams/schema.cpp b/roughpy/src/streams/schema.cpp index 3c39c3b87..650eb9c7e 100644 --- a/roughpy/src/streams/schema.cpp +++ b/roughpy/src/streams/schema.cpp @@ -41,6 +41,130 @@ using namespace rpy::python; using namespace rpy::streams; using namespace pybind11::literals; +namespace rpy { +namespace python { +class RPY_LOCAL RPyStreamChannel : public StreamChannel +{ +public: + using StreamChannel::StreamChannel; + + dimn_t num_variants() const override + { + PYBIND11_OVERRIDE(dimn_t, StreamChannel, num_variants); + } + string label_suffix(dimn_t variant_no) const override + { + PYBIND11_OVERRIDE(string, StreamChannel, label_suffix, variant_no); + } + dimn_t variant_id_of_label(string_view label) const override + { + PYBIND11_OVERRIDE(dimn_t, StreamChannel, variant_id_of_label, label); + } + void + set_lie_info(deg_t width, deg_t depth, algebra::VectorType vtype) override + { + if (type() == ChannelType::Lie) { + PYBIND11_OVERRIDE( + void, StreamChannel, set_lie_info, width, depth, vtype + ); + } else { + RPY_THROW( + std::runtime_error, + "set_lie_info should only be used for Lie-type channels" + ); + } + } + + StreamChannel& add_variant(string variant_label) override + { + if (type() == ChannelType::Categorical) { + PYBIND11_OVERRIDE( + StreamChannel&, StreamChannel, add_variant, variant_label + ); + } else { + RPY_THROW( + std::runtime_error, + "only categorical channels can have variants" + ); + } + } + StreamChannel& insert_variant(string variant_label) override + { + if (type() == ChannelType::Categorical) { + PYBIND11_OVERRIDE( + StreamChannel&, StreamChannel, insert_variant, variant_label + ); + } else { + RPY_THROW( + std::runtime_error, + "only categorical channels can have variants" + ); + } + } + const std::vector& get_variants() const override + { + if (type() == ChannelType::Categorical) { + PYBIND11_OVERRIDE(const std::vector&, StreamChannel, get_variants); + } else { + RPY_THROW( + std::runtime_error, + "only categorical channels can have variants" + ); + } + } +}; + + +class RPY_LOCAL RPyLeadLaggableChannel : public LeadLaggableChannel +{ +public: + + using LeadLaggableChannel::LeadLaggableChannel; + + dimn_t num_variants() const override + { + PYBIND11_OVERRIDE(dimn_t, LeadLaggableChannel, num_variants); + } + string label_suffix(dimn_t variant_no) const override + { + PYBIND11_OVERRIDE(string, LeadLaggableChannel, label_suffix, variant_no); + } + dimn_t variant_id_of_label(string_view label) const override + { + PYBIND11_OVERRIDE(dimn_t, LeadLaggableChannel, variant_id_of_label, label); + } + const std::vector& get_variants() const override + { + PYBIND11_OVERRIDE(const std::vector&, LeadLaggableChannel, get_variants); + } + void set_lead_lag(bool new_value) override + { + PYBIND11_OVERRIDE(void, LeadLaggableChannel, set_lead_lag, new_value); + } + bool is_lead_lag() const override + { + PYBIND11_OVERRIDE(bool, LeadLaggableChannel, is_lead_lag); + } + void + set_lie_info(deg_t width, deg_t depth, algebra::VectorType vtype) override + { + RPY_THROW(std::runtime_error, "set_lie_info is only available for Lie-type channels"); + } + StreamChannel& add_variant(string variant_label) override + { + RPY_THROW(std::runtime_error, "variants are only available for categorical channels"); + } + StreamChannel& insert_variant(string variant_label) override + { + RPY_THROW(std::runtime_error, "variants are only available for categorical channels"); + } +}; + + + +}// namespace python +}// namespace rpy + namespace { // class PyChannelItem { @@ -60,7 +184,8 @@ inline void init_channel_item(py::module_& m) .value("LieChannel", ChannelType::Lie) .export_values(); - //py::class_> cls(m, "StreamChannel"); + // py::class_> cls(m, + // "StreamChannel"); } std::shared_ptr