Skip to content

Commit

Permalink
Get column value by column label in Framework Core ASoA (#13498)
Browse files Browse the repository at this point in the history
  • Loading branch information
mytkom authored Dec 11, 2024
1 parent c5cbdc4 commit 47d098d
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 1 deletion.
78 changes: 78 additions & 0 deletions Framework/Core/include/Framework/ASoA.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <array>
#include <cassert>
#include <fmt/format.h>
#include <concepts>
#include <cstring>
#include <gsl/span>
#include <limits>

Expand Down Expand Up @@ -2172,6 +2174,82 @@ std::tuple<typename Cs::type...> getRowData(arrow::Table* table, T rowIterator,
{
return std::make_tuple(getSingleRowData<T, Cs>(table, rowIterator, ci, ai, globalIndex)...);
}

template <typename R, typename T, typename C>
R getColumnValue(const T& rowIterator)
{
return static_cast<R>(static_cast<C>(rowIterator).get());
}

template <typename R, typename T>
using ColumnGetterFunction = R (*)(const T&);

template <typename T, typename R>
concept dynamic_with_common_getter = is_dynamic_column<T> &&
// lambda is callable without additional free args
framework::pack_size(typename T::bindings_t{}) == framework::pack_size(typename T::callable_t::args{}) &&
requires(T t) {
{ t.get() } -> std::convertible_to<R>;
};

template <typename T, typename R>
concept persistent_with_common_getter = is_persistent_v<T> && requires(T t) {
{ t.get() } -> std::convertible_to<R>;
};

template <typename R, typename T, persistent_with_common_getter<R> C>
ColumnGetterFunction<R, T> createGetterPtr(const std::string_view& targetColumnLabel)
{
return targetColumnLabel == C::columnLabel() ? &getColumnValue<R, T, C> : nullptr;
}

template <typename R, typename T, dynamic_with_common_getter<R> C>
ColumnGetterFunction<R, T> createGetterPtr(const std::string_view& targetColumnLabel)
{
std::string_view columnLabel(C::columnLabel());

// allows user to use consistent formatting (with prefix) of all column labels
// by default there isn't 'f' prefix for dynamic column labels
if (targetColumnLabel.starts_with("f") && targetColumnLabel.substr(1) == columnLabel) {
return &getColumnValue<R, T, C>;
}

// check also exact match if user is aware of prefix missing
if (targetColumnLabel == columnLabel) {
return &getColumnValue<R, T, C>;
}

return nullptr;
}

template <typename R, typename T, typename... Cs>
ColumnGetterFunction<R, T> getColumnGetterByLabel(o2::framework::pack<Cs...>, const std::string_view& targetColumnLabel)
{
ColumnGetterFunction<R, T> func;

(void)((func = createGetterPtr<R, T, Cs>(targetColumnLabel), func) || ...);

if (!func) {
throw framework::runtime_error_f("Getter for \"%s\" not found", targetColumnLabel);
}

return func;
}

template <typename T, typename R>
using with_common_getter_t = typename std::conditional<persistent_with_common_getter<T, R> || dynamic_with_common_getter<T, R>, std::true_type, std::false_type>::type;

template <typename R, typename T>
ColumnGetterFunction<R, typename T::iterator> getColumnGetterByLabel(const std::string_view& targetColumnLabel)
{
using TypesWithCommonGetter = o2::framework::selected_pack_multicondition<with_common_getter_t, framework::pack<R>, typename T::columns_t>;

if (targetColumnLabel.size() == 0) {
throw framework::runtime_error("columnLabel: must not be empty");
}

return getColumnGetterByLabel<R, typename T::iterator>(TypesWithCommonGetter{}, targetColumnLabel);
}
} // namespace row_helpers
} // namespace o2::soa

Expand Down
1 change: 0 additions & 1 deletion Framework/Core/include/Framework/BinningPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#ifndef FRAMEWORK_BINNINGPOLICY_H
#define FRAMEWORK_BINNINGPOLICY_H

#include "Framework/ASoA.h"
#include "Framework/HistogramSpec.h" // only for VARIABLE_WIDTH
#include "Framework/Pack.h"

Expand Down
60 changes: 60 additions & 0 deletions Framework/Core/test/benchmark_ASoA.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ DECLARE_SOA_COLUMN_FULL(X, x, float, "x");
DECLARE_SOA_COLUMN_FULL(Y, y, float, "y");
DECLARE_SOA_COLUMN_FULL(Z, z, float, "z");
DECLARE_SOA_DYNAMIC_COLUMN(Sum, sum, [](float x, float y) { return x + y; });
DECLARE_SOA_DYNAMIC_COLUMN(SumFreeArgs, sumFreeArgs, [](float x, float y, float freeArg) { return x + y + freeArg; });
} // namespace test

DECLARE_SOA_TABLE(TestTable, "AOD", "TESTTBL", test::X, test::Y, test::Z, test::Sum<test::X, test::Y>);
Expand Down Expand Up @@ -290,6 +291,36 @@ static void BM_ASoADynamicColumnPresent(benchmark::State& state)

BENCHMARK(BM_ASoADynamicColumnPresent)->Range(8, 8 << maxrange);

static void BM_ASoADynamicColumnPresentGetGetterByLabel(benchmark::State& state)
{
// Seed with a real random value, if available
std::default_random_engine e1(1234567891);
std::uniform_real_distribution<float> uniform_dist(0, 1);

TableBuilder builder;
auto rowWriter = builder.persist<float, float, float>({"x", "y", "z"});
for (auto i = 0; i < state.range(0); ++i) {
rowWriter(0, uniform_dist(e1), uniform_dist(e1), uniform_dist(e1));
}
auto table = builder.finalize();

using Test = o2::soa::InPlaceTable<"A/0"_h, test::X, test::Y, test::Z, test::Sum<test::X, test::Y>>;

for (auto _ : state) {
Test tests{table};
float sum = 0;
auto xGetter = o2::soa::row_helpers::getColumnGetterByLabel<float, Test>("x");
auto yGetter = o2::soa::row_helpers::getColumnGetterByLabel<float, Test>("y");
for (auto& test : tests) {
sum += xGetter(test) + yGetter(test);
}
benchmark::DoNotOptimize(sum);
}
state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(float) * 2);
}

BENCHMARK(BM_ASoADynamicColumnPresentGetGetterByLabel)->Range(8, 8 << maxrange);

static void BM_ASoADynamicColumnCall(benchmark::State& state)
{
// Seed with a real random value, if available
Expand Down Expand Up @@ -317,4 +348,33 @@ static void BM_ASoADynamicColumnCall(benchmark::State& state)
}
BENCHMARK(BM_ASoADynamicColumnCall)->Range(8, 8 << maxrange);

static void BM_ASoADynamicColumnCallGetGetterByLabel(benchmark::State& state)
{
// Seed with a real random value, if available
std::default_random_engine e1(1234567891);
std::uniform_real_distribution<float> uniform_dist(0, 1);

TableBuilder builder;
auto rowWriter = builder.persist<float, float, float>({"x", "y", "z"});
for (auto i = 0; i < state.range(0); ++i) {
rowWriter(0, uniform_dist(e1), uniform_dist(e1), uniform_dist(e1));
}
auto table = builder.finalize();

// SumFreeArgs presence checks if dynamic columns get() is handled correctly during compilation
using Test = o2::soa::InPlaceTable<"A/0"_h, test::X, test::Y, test::Sum<test::X, test::Y>, test::SumFreeArgs<test::X, test::Y>>;

Test tests{table};
for (auto _ : state) {
float sum = 0;
auto sumGetter = o2::soa::row_helpers::getColumnGetterByLabel<float, Test>("Sum");
for (auto& test : tests) {
sum += sumGetter(test);
}
benchmark::DoNotOptimize(sum);
}
state.SetBytesProcessed(state.iterations() * state.range(0) * sizeof(float) * 2);
}
BENCHMARK(BM_ASoADynamicColumnCallGetGetterByLabel)->Range(8, 8 << maxrange);

BENCHMARK_MAIN();

0 comments on commit 47d098d

Please sign in to comment.