Skip to content

Commit

Permalink
Moved sparse operator() into tensor_impl_t (#841)
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick authored Jan 30, 2025
1 parent 5446cbc commit 1777afc
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 84 deletions.
82 changes: 3 additions & 79 deletions include/matx/core/sparse_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,85 +108,9 @@ class sparse_tensor_t
__MATX_INLINE__ ~sparse_tensor_t() = default;

// Size getters.
index_t Nse() const { return values_.size() / sizeof(VAL); }
index_t crdSize(int l) const { return coordinates_[l].size() / sizeof(CRD); }
index_t posSize(int l) const { return positions_[l].size() / sizeof(POS); }

// Locates position of an element at given indices, or returns -1 when not
// found.
template <int L = 0>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t
GetPos(index_t *lvlsz, index_t *lvl, index_t pos) const {
if constexpr (L < LVL) {
using ftype = std::tuple_element_t<L, typename TF::LVLSPECS>;
if constexpr (ftype::lvltype == LvlType::Dense) {
// Dense level: pos * size + i.
// TODO: see below, use a constexpr GetLvlSize(L) instead?
const index_t dpos = pos * lvlsz[L] + lvl[L];
if constexpr (L + 1 < LVL) {
return GetPos<L + 1>(lvlsz, lvl, dpos);
} else {
return dpos;
}
} else if constexpr (ftype::lvltype == LvlType::Singleton) {
// Singleton level: pos if crd[pos] == i and next levels match.
if (this->CRDData(L)[pos] == lvl[L]) {
if constexpr (L + 1 < LVL) {
return GetPos<L + 1>(lvlsz, lvl, pos);
} else {
return pos;
}
}
} else if constexpr (ftype::lvltype == LvlType::Compressed ||
ftype::lvltype == LvlType::CompressedNonUnique) {
// Compressed level: scan for match on i and test next levels.
const CRD *c = this->CRDData(L);
const POS *p = this->POSData(L);
for (index_t pp = p[pos], hi = p[pos + 1]; pp < hi; pp++) {
if (c[pp] == lvl[L]) {
if constexpr (L + 1 < LVL) {
const index_t cpos = GetPos<L + 1>(lvlsz, lvl, pp);
if constexpr (ftype::lvltype == LvlType::Compressed) {
return cpos; // always end scan (unique)
} else if (cpos != -1) {
return cpos; // only end scan on success (non-unique)
}
} else {
return pp;
}
}
}
}
}
return -1; // not found
}

// Element getter (viz. "lhs = Acoo(0,0);"). Note that due to the compact
// nature of sparse data structures, these storage formats do not provide
// cheap random access to their elements. Instead, the implementation will
// search for a stored element at the given position (which involves a scan
// at each compressed level). The implicit value zero is returned when the
// element cannot be found. So, although functional for testing, clients
// should avoid using getters inside performance critial regions, since
// the implementation is far worse than O(1).
template <typename... Is>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ VAL
operator()(Is... indices) const noexcept {
static_assert(
sizeof...(Is) == DIM,
"Number of indices of operator() must match rank of sparse tensor");
cuda::std::array<index_t, DIM> dim{indices...};
cuda::std::array<index_t, LVL> lvl;
cuda::std::array<index_t, LVL> lvlsz;
TF::dim2lvl(dim.data(), lvl.data(), /*asSize=*/false);
// TODO: only compute once and provide a constexpr LvlSize(l) instead?
TF::dim2lvl(this->Shape().data(), lvlsz.data(), /*asSize=*/true);
const index_t pos = GetPos(lvlsz.data(), lvl.data(), 0);
if (pos != -1) {
return this->Data()[pos];
}
return static_cast<VAL>(0); // implicit zero
}
index_t Nse() const { return static_cast<index_t>(values_.size() / sizeof(VAL)); }
index_t crdSize(int l) const { return static_cast<index_t>(coordinates_[l].size() / sizeof(CRD)); }
index_t posSize(int l) const { return static_cast<index_t>(positions_[l].size() / sizeof(POS)); }

private:
// Primary storage of sparse tensor (explicitly stored element values).
Expand Down
102 changes: 97 additions & 5 deletions include/matx/core/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "matx/core/type_utils.h"
#include "matx/core/tensor_utils.h"
#include "matx/operators/set.h"
#include "matx/core/sparse_tensor_format.h"
//#include "matx_exec_kernel.h"
#include "iterator.h"
#include "matx/core/make_tensor.h"
Expand All @@ -58,6 +59,7 @@ struct DenseTensorData {
template <typename T, typename CRD, typename POS, typename TF>
struct SparseTensorData {
using sparse_data = bool;
using value_type = T;
using crd_type = CRD;
using pos_type = POS;
using Format = TF;
Expand Down Expand Up @@ -894,6 +896,86 @@ MATX_IGNORE_WARNING_POP_GCC
return data_.ldata_ + GetValC<0, Is...>(cuda::std::make_tuple(indices...));
}

// Locates position of an element at given indices, or returns -1 when not
// found.
template <int L = 0>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t
GetPos(index_t *lvlsz, index_t *lvl, index_t pos) const {
static constexpr int LVL = TensorData::Format::LVL;
if constexpr (L < LVL) {
using ftype = std::tuple_element_t<L, typename TensorData::Format::LVLSPECS>;
if constexpr (ftype::lvltype == ::matx::experimental::LvlType::Dense) {
// Dense level: pos * size + i.
// TODO: see below, use a constexpr GetLvlSize(L) instead?
const index_t dpos = pos * lvlsz[L] + lvl[L];
if constexpr (L + 1 < LVL) {
return GetPos<L + 1>(lvlsz, lvl, dpos);
} else {
return dpos;
}
} else if constexpr (ftype::lvltype == ::matx::experimental::LvlType::Singleton) {
// Singleton level: pos if crd[pos] == i and next levels match.
if (CRDData(L)[pos] == lvl[L]) {
if constexpr (L + 1 < LVL) {
return GetPos<L + 1>(lvlsz, lvl, pos);
} else {
return pos;
}
}
} else if constexpr (ftype::lvltype == ::matx::experimental::LvlType::Compressed ||
ftype::lvltype == ::matx::experimental::LvlType::CompressedNonUnique) {
// Compressed level: scan for match on i and test next levels.
const typename TensorData::crd_type *c = CRDData(L);
const typename TensorData::pos_type *p = POSData(L);
for (index_t pp = p[pos], hi = p[pos + 1]; pp < hi; pp++) {
if (c[pp] == lvl[L]) {
if constexpr (L + 1 < LVL) {
const index_t cpos = GetPos<L + 1>(lvlsz, lvl, pp);
if constexpr (ftype::lvltype == ::matx::experimental::LvlType::Compressed) {
return cpos; // always end scan (unique)
} else if (cpos != -1) {
return cpos; // only end scan on success (non-unique)
}
} else {
return pp;
}
}
}
}
}
return -1; // not found
}

// Element getter (viz. "lhs = Acoo(0,0);"). Note that due to the compact
// nature of sparse data structures, these storage formats do not provide
// cheap random access to their elements. Instead, the implementation will
// search for a stored element at the given position (which involves a scan
// at each compressed level). The implicit value zero is returned when the
// element cannot be found. So, although functional for testing, clients
// should avoid using getters inside performance critial regions, since
// the implementation is far worse than O(1).
template <typename... Is>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T GetSparseValue(Is... indices) const noexcept {
static constexpr int DIM = TensorData::Format::DIM;
static constexpr int LVL = TensorData::Format::LVL;

static_assert(sizeof...(Is) == DIM,
"Number of indices of operator() must match rank of sparse tensor");
cuda::std::array<index_t, DIM> dim{indices...};
cuda::std::array<index_t, LVL> lvl;
cuda::std::array<index_t, LVL> lvlsz;
TensorData::Format::dim2lvl(dim.data(), lvl.data(), /*asSize=*/false);
// TODO: only compute once and provide a constexpr LvlSize(l) instead?
TensorData::Format::dim2lvl(Shape().data(), lvlsz.data(), /*asSize=*/true);
const index_t pos = GetPos(lvlsz.data(), lvl.data(), 0);
if (pos != -1) {
const typename TensorData::value_type tmp = Data()[pos];
return tmp;
}

return static_cast<typename TensorData::value_type>(0); // implicit zero
}

/**
* Check if a tensor is linear in memory for all elements in the view
*
Expand Down Expand Up @@ -942,10 +1024,15 @@ MATX_IGNORE_WARNING_POP_GCC
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) const noexcept
{
static_assert(sizeof...(Is) == M, "Number of indices of operator() must match rank of tensor");
if constexpr (!is_sparse_data_v<TensorData>) {
#ifndef NDEBUG
assert(data_.ldata_ != nullptr);
assert(data_.ldata_ != nullptr);
#endif
return *(data_.ldata_ + GetValC<0, Is...>(cuda::std::make_tuple(indices...)));
return *(data_.ldata_ + GetValC<0, Is...>(cuda::std::make_tuple(indices...)));
}
else { // Sparse tensor getter
return GetSparseValue(indices...);
}
}

/**
Expand All @@ -961,11 +1048,16 @@ MATX_IGNORE_WARNING_POP_GCC
std::enable_if_t<std::conjunction_v<std::is_integral<Is>...>, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) noexcept
{
static_assert(sizeof...(Is) == M, "Number of indices of operator() must match rank of tensor");
if constexpr (!is_sparse_data_v<TensorData>) {
static_assert(sizeof...(Is) == M, "Number of indices of operator() must match rank of tensor");
#ifndef NDEBUG
assert(data_.ldata_ != nullptr);
assert(data_.ldata_ != nullptr);
#endif
return *(data_.ldata_ + GetVal<0, Is...>(cuda::std::make_tuple(indices...)));
return *(data_.ldata_ + GetVal<0, Is...>(cuda::std::make_tuple(indices...)));
}
else {
return GetSparseValue(indices...);
}
}

/**
Expand Down

0 comments on commit 1777afc

Please sign in to comment.