Skip to content

Commit

Permalink
Merge pull request gnuradio#2 from mormj/variant
Browse files Browse the repository at this point in the history
More Vector considerations and serialization
  • Loading branch information
jsallay authored Oct 31, 2022
2 parents 152a80f + e8cbf95 commit 8d8f45c
Show file tree
Hide file tree
Showing 7 changed files with 528 additions and 108 deletions.
121 changes: 110 additions & 11 deletions include/pmtv/base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,23 @@ class pmt {
explicit operator T() const {
return std::visit([](const auto& arg) -> T {
using U = std::decay_t<decltype(arg)>;
if constexpr (UniformVectorInsidePmt<U>) {
if constexpr(std::constructible_from<T, typename U::element_type>) {
return T(*arg);
}
else throw std::runtime_error("Invalid PMT Cast");
}
if constexpr(std::constructible_from<T, U>) return T(arg);
else throw std::runtime_error("Invalid PMT Cast");
}, _value.base()); }

_pmt_storage storage() const noexcept { return _value; }

operator _pmt_storage() const {
return storage();
}


protected:
_pmt_storage _value;
};
Expand All @@ -116,7 +127,8 @@ pmt::pmt(const T& other) {
// Vector of uniform arithmetic types
_value = std::make_shared<std::vector<typename T::value_type>>(other.begin(), other.end());
}
else if constexpr(associative_array<T>) {
//else if constexpr(associative_array<T>) {
else if constexpr(PmtMap<T>) {
// Map or hash table
_value = std::make_shared<std::map<std::string, _pmt_storage>>(other.begin(), other.end());
}
Expand Down Expand Up @@ -167,13 +179,24 @@ bool PmtEqual(const T& arg, const U& other) {
return arg == other;
}
}
else if constexpr(IsSharedPtr<T> && UniformVector<U>) {
else if constexpr(UniformVectorInsidePmt<T> && UniformVectorInsidePmt<U>) {
return std::visit([&arg, &other]() -> bool {
return PmtEqual(*arg, *other); }
);
}
else if constexpr(UniformVectorInsidePmt<T> && UniformVector<U>) {
return std::visit([&arg, &other]() -> bool {
return PmtEqual(*arg, other); }
);
}
else if constexpr(UniformVector<T> && UniformVectorInsidePmt<U>) {
return std::visit([&arg, &other]() -> bool {
return PmtEqual(arg, *other); }
);
}
else if constexpr(UniformVector<T> && UniformVector<U>) {
if constexpr(std::is_same_v<T, U>) {
// if constexpr(std::is_same_v<T, U>) {
if constexpr(std::is_same_v<typename T::value_type, typename U::value_type>) {
if (arg.size() == other.size()) {
return std::equal(arg.begin(), arg.end(), other.begin());
}
Expand All @@ -183,11 +206,12 @@ bool PmtEqual(const T& arg, const U& other) {
}
}
else {
return false;
// std::cerr << typeid(T).name() << " " << typeid(U).name() << std::endl;
return PmtEqual(arg, other);
}
}
else {
std::cerr << typeid(T).name() << " " << typeid(U).name() << std::endl;
// std::cerr << typeid(T).name() << " " << typeid(U).name() << std::endl;
return false;
}
// else if constexpr(std::is_convertible_v<T, U>) return arg == other;
Expand Down Expand Up @@ -243,10 +267,12 @@ size_t pmt::serialize(std::streambuf& sb) const {
else if (container == pmt_container_type::UNIFORM_VECTOR) {
std::visit([&length, &sb](auto&& arg) {
using T = std::decay_t<decltype(arg)>;
if constexpr(UniformVector<T>) {
auto id = element_type<T>();
if constexpr(UniformVectorInsidePmt<T>) {
auto id = element_type<typename T::element_type>();
length += sb.sputn(reinterpret_cast<const char*>(&id), 1);
length += sb.sputn(reinterpret_cast<const char*>(arg.data()), arg.size()*sizeof(arg[0]));
uint64_t sz = arg->size();
length += sb.sputn(reinterpret_cast<const char*>(&sz), sizeof(uint64_t));
length += sb.sputn(reinterpret_cast<const char*>(arg->data()), arg->size()*sizeof((*arg)[0]));
}
}, _value.base());

Expand All @@ -259,8 +285,8 @@ pmt pmt::deserialize(std::streambuf& sb)
{
uint16_t version;
pmt_container_type container;
sb.sgetn(reinterpret_cast<char*>(&version), sizeof(version));
sb.sgetn(reinterpret_cast<char*>(&container), sizeof(container));
sb.sgetn(reinterpret_cast<char*>(&version), 2);
sb.sgetn(reinterpret_cast<char*>(&container), 2);

pmt ret;
if (container == pmt_container_type::EMPTY) {
Expand Down Expand Up @@ -337,7 +363,76 @@ pmt pmt::deserialize(std::streambuf& sb)
}
}
else if (container == pmt_container_type::UNIFORM_VECTOR) {

pmt_element_type T_type;
sb.sgetn(reinterpret_cast<char*>(&T_type), 1);
uint64_t sz;
sb.sgetn(reinterpret_cast<char*>(&sz), sizeof(uint64_t));

switch(T_type) {
case pmt_element_type::UINT8: {
std::vector<uint8_t> val(sz);
sb.sgetn(reinterpret_cast<char*>(val.data()), sz*sizeof(val[0]));
ret = pmt(val);
} break;
case pmt_element_type::UINT16: {
std::vector<uint16_t> val(sz);
sb.sgetn(reinterpret_cast<char*>(val.data()), sz*sizeof(val[0]));
ret = pmt(val);
} break;
case pmt_element_type::UINT32: {
std::vector<uint32_t> val(sz);
sb.sgetn(reinterpret_cast<char*>(val.data()), sz*sizeof(val[0]));
ret = pmt(val);
} break;
case pmt_element_type::UINT64: {
std::vector<uint64_t> val(sz);
sb.sgetn(reinterpret_cast<char*>(val.data()), sz*sizeof(val[0]));
ret = pmt(val);
} break;
case pmt_element_type::INT8: {
std::vector<int8_t> val(sz);
sb.sgetn(reinterpret_cast<char*>(val.data()), sz*sizeof(val[0]));
ret = pmt(val);
} break;
case pmt_element_type::INT16: {
std::vector<int16_t> val(sz);
sb.sgetn(reinterpret_cast<char*>(val.data()), sz*sizeof(val[0]));
ret = pmt(val);
} break;
case pmt_element_type::INT32: {
std::vector<int32_t> val(sz);
sb.sgetn(reinterpret_cast<char*>(val.data()), sz*sizeof(val[0]));
ret = pmt(val);
} break;
case pmt_element_type::INT64: {
std::vector<int64_t> val(sz);
sb.sgetn(reinterpret_cast<char*>(val.data()), sz*sizeof(val[0]));
ret = pmt(val);
} break;
case pmt_element_type::FLOAT: {
std::vector<float> val(sz);
sb.sgetn(reinterpret_cast<char*>(val.data()), sz*sizeof(val[0]));
ret = pmt(val);
} break;
case pmt_element_type::DOUBLE: {
std::vector<double> val(sz);
sb.sgetn(reinterpret_cast<char*>(val.data()), sz*sizeof(val[0]));
ret = pmt(val);
} break;
case pmt_element_type::COMPLEX_FLOAT: {
std::vector<std::complex<float>> val(sz);
sb.sgetn(reinterpret_cast<char*>(val.data()), sz*sizeof(val[0]));
ret = pmt(val);
} break;
case pmt_element_type::COMPLEX_DOUBLE: {
std::vector<std::complex<double>> val(sz);
sb.sgetn(reinterpret_cast<char*>(val.data()), sz*sizeof(val[0]));
ret = pmt(val);
}
default:{

}
}
}

return ret;
Expand Down Expand Up @@ -420,4 +515,8 @@ std::ostream& operator<<(std::ostream& os, const P& value) {
}



// Explicit cast std::vector


} // namespace pmtv
141 changes: 141 additions & 0 deletions include/pmtv/map.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#pragma once

#include <map>
#include <pmtv/base.hpp>
#include <ranges>
#include <span>

namespace pmtv {

/**
* @brief map of keys of type string to pmts
*
*/
class map : public pmt {
public:
using key_type = std::string;
using mapped_type = _pmt_storage;
using value_type = std::pair<const key_type, mapped_type>;
using reference = value_type &;
using const_reference = const value_type &;
using map_type = std::map<key_type, mapped_type>;
using size_type = size_t;
using map_sptr = std::shared_ptr<map_type>;

using iterator = map_type::iterator;
using const_iterator = map_type::const_iterator;
// Construct empty map
map() : pmt(map_type{}) {}
// Copy from std map
map(const map_type &other) : pmt(other) {}

// Copy from std map
map(const std::map<std::string, pmt> &other) : pmt(map_type{}) {
for (auto &[k, v] : other) {
// FIXME - the [] operator seems to be returning
// the variant by value, not by reference
this->operator[](k) = v.storage();
// auto x = this->operator[](k);
// x = v.storage();
}
}

// // Copy from pmt
// template <class T, typename = IsPmt<T>>
// map(const T& other) {
// if (other.data_type() != data_type())
// throw ConversionError(other, "map");
// _map = other;
// }
// map(std::initializer_list<value_type> il) {
// _MakeEmptyMap();
// for (auto& [k, v]: il)
// this->operator[](k) = v;
// }
// //template <class T>
// //map(std::map<string
// ~map() {}

/**************************************************************************
* Iterators
**************************************************************************/
typename map_type::iterator begin() noexcept { return _get_map()->begin(); }
typename map_type::const_iterator begin() const noexcept {
return _get_map()->begin();
}
typename map_type::iterator end() noexcept { return _get_map()->end(); }
typename map_type::const_iterator end() const noexcept {
return _get_map()->end();
}

/**************************************************************************
* Element Access
**************************************************************************/
mapped_type &at(const key_type &key) { return _get_map()->at(key); }
const mapped_type &at(const key_type &key) const {
return _get_map()->at(key);
}
mapped_type &operator[](const key_type &key) {
return _get_map()->operator[](key);
}

// size_t size() const { return _get_map()->size(); }
// size_t count(const key_type& key) const { return
// _get_map()->count(key); }

// static constexpr Data data_type() { return
// DataTraits<type>::enum_value; } const pmt& get_pmt_buffer() const {
// return _map; }

// //! Equality Comparisons
// // Declared as class members so that we don't do implicit
// conversions. template <class U> bool operator==(const U& x) const;
// template <class U>
// bool operator!=(const U& x) const { return !(operator==(x));}
// void pre_serial_update() const {
// // It may look odd to declare this function as const when it
// modifies
// // count. But count is part of the internal interface, so to the
// // user, this is a const function.
// std::shared_ptr<base_buffer> scalar = _map._scalar;
// scalar->data_as<type>()->mutate_count(_get_map()->size());
// }

private:
map_sptr _get_map() { return std::get<map_sptr>(_value); }
const map_sptr _get_map() const { return std::get<map_sptr>(_value); }
};

// template <class T, class U>
// using IsNotVectorT = std::enable_if_t<!std::is_same_v<uniform_vector<T>, U>,
// bool>;

// // Reversed case. This allows for x == y and y == x
// template <class T, class U, typename = IsNotVectorT<T, U> >
// bool operator==(const U& y, const uniform_vector<T>& x) {
// return x.operator==(y);
// }

// template<>
// bool PmtEqual(const std::vector<T>& arg, const uniform_vector<T>& other) {

// }

template <class T> using IsMap = std::enable_if_t<std::is_same_v<map, T>, bool>;

// Need to have map operator here because it has pmts in it.
template <class T, IsMap<T> = true>
std::ostream &operator<<(std::ostream &os, const T &value) {
os << "{ ";
bool first = true;
for (const auto &[k, v] : value) {
if (!first)
os << ", ";
first = false;
os << k << ": " << pmt(v);
}
os << " }";
return os;
}

} // namespace pmtv
Loading

0 comments on commit 8d8f45c

Please sign in to comment.