Skip to content

Commit

Permalink
refactor: fix tests for instance sharing refactor [no ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
program-- committed Jun 12, 2024
1 parent 214d83a commit f2c7a49
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 42 deletions.
31 changes: 28 additions & 3 deletions include/forcing/ForcingsEngineDataProvider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,25 @@ static constexpr auto default_time_format = "%Y-%m-%d %H:%M:%S";

namespace detail {

//! Storage for Forcings Engine-specific BMI instances.
struct ForcingsEngineStorage {
using key_type = std::string;
using bmi_type = models::bmi::Bmi_Py_Adapter;
using value_type = std::shared_ptr<bmi_type>;

value_type get(const key_type& key);
void set(const key_type& key, value_type value);
void clear();
value_type get(const key_type& key)
{
auto pos = data_.find(key);
if (pos == data_.end()) {
return nullptr;
}

return pos->second;
}

void set(const key_type& key, value_type value) { data_[key] = value; }

void clear() { data_.clear(); }

private:
//! Instance map of underlying BMI models.
Expand Down Expand Up @@ -90,6 +101,11 @@ struct ForcingsEngineDataProvider
return (epoch - time_begin_) / time_step_;
}

std::shared_ptr<models::bmi::Bmi_Py_Adapter> model() noexcept
{
return bmi_;
}

/* Remaining virtual member functions from DataProvider must be implemented
by derived classes. */

Expand Down Expand Up @@ -151,6 +167,15 @@ struct ForcingsEngineDataProvider
time_current_index_++;
}

void next(double time) {
const auto start = bmi_->GetCurrentTime();
bmi_->UpdateUntil(time);
const auto end = bmi_->GetCurrentTime();
time_current_index_ += static_cast<int64_t>(
(end - start) / bmi_->GetTimeStep()
);
}

//! Forcings Engine instance
std::shared_ptr<models::bmi::Bmi_Py_Adapter> bmi_ = nullptr;

Expand Down
14 changes: 10 additions & 4 deletions include/forcing/ForcingsEngineGriddedDataProvider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ namespace data_access {
struct ForcingsEngineGriddedDataProvider
: public ForcingsEngineDataProvider<Cell, GridDataSelector>
{
using data_type = data_type;
using selection_type = selection_type;
using base_type = ForcingsEngineDataProvider<data_type, selection_type>;

~ForcingsEngineGriddedDataProvider() override = default;

Cell get_value(const GridDataSelector& selector, data_access::ReSampleMethod m) override;
data_type get_value(const selection_type& selector, data_access::ReSampleMethod m) override;

std::vector<Cell> get_values(const GridDataSelector& selector, data_access::ReSampleMethod m) override;
std::vector<data_type> get_values(const selection_type& selector, data_access::ReSampleMethod m) override;

static std::unique_ptr<ForcingsEngineDataProvider> make_gridded_instance(
const std::string& init,
Expand All @@ -27,10 +31,12 @@ struct ForcingsEngineGriddedDataProvider
}

private:
friend base_type;

ForcingsEngineGriddedDataProvider(
const std::string& init,
std::size_t time_begin_seconds,
std::size_t time_end_seconds
std::time_t time_begin_seconds,
std::time_t time_end_seconds
);

int var_grid_id_{};
Expand Down
14 changes: 10 additions & 4 deletions include/forcing/ForcingsEngineLumpedDataProvider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ namespace data_access {
struct ForcingsEngineLumpedDataProvider
: public ForcingsEngineDataProvider<double, CatchmentAggrDataSelector>
{
using data_type = data_type;
using selection_type = selection_type;
using base_type = ForcingsEngineDataProvider<data_type, selection_type>;

static constexpr auto bad_index = static_cast<std::size_t>(-1);

~ForcingsEngineLumpedDataProvider() override = default;

double get_value(const CatchmentAggrDataSelector& selector, data_access::ReSampleMethod m) override;
data_type get_value(const selection_type& selector, data_access::ReSampleMethod m) override;

std::vector<double> get_values(const CatchmentAggrDataSelector& selector, data_access::ReSampleMethod m) override;
std::vector<data_type> get_values(const selection_type& selector, data_access::ReSampleMethod m) override;

/**
* @brief Get the index in `CAT-ID` for a given divide in the instance cache.
Expand Down Expand Up @@ -45,10 +49,12 @@ struct ForcingsEngineLumpedDataProvider
}

private:
friend base_type;

ForcingsEngineLumpedDataProvider(
const std::string& init,
std::size_t time_begin_seconds,
std::size_t time_end_seconds
std::time_t time_begin_seconds,
std::time_t time_end_seconds
);

/**
Expand Down
1 change: 1 addition & 0 deletions src/forcing/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ if(NGEN_WITH_PYTHON)
PRIVATE
"${CMAKE_CURRENT_LIST_DIR}/ForcingsEngineDataProvider.cpp"
"${CMAKE_CURRENT_LIST_DIR}/ForcingsEngineLumpedDataProvider.cpp"
"${CMAKE_CURRENT_LIST_DIR}/ForcingsEngineGriddedDataProvider.cpp"
)
target_link_libraries(forcing PUBLIC pybind11::embed NGen::ngen_bmi)
endif()
Expand Down
34 changes: 19 additions & 15 deletions src/forcing/ForcingsEngineGriddedDataProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

namespace data_access {

using BaseProvider = ForcingsEngineDataProvider<Cell, GridDataSelector>;
using Provider = ForcingsEngineGriddedDataProvider;
using BaseProvider = Provider::base_type;

Provider::ForcingsEngineGriddedDataProvider(
const std::string& init,
std::size_t time_begin_seconds,
std::size_t time_end_seconds
std::time_t time_begin_seconds,
std::time_t time_end_seconds
)
: BaseProvider(init, time_begin_seconds, time_end_seconds)
{
Expand Down Expand Up @@ -46,6 +46,11 @@ Provider::ForcingsEngineGriddedDataProvider(
}

Cell Provider::get_value(const GridDataSelector& selector, data_access::ReSampleMethod m)
{
throw std::runtime_error{"ForcingsEngineGriddedDataProvider::get_value() is not implemented"};
}

std::vector<Cell> Provider::get_values(const GridDataSelector& selector, data_access::ReSampleMethod m)
{
if (m != ReSampleMethod::SUM && m != ReSampleMethod::MEAN) {
throw std::runtime_error{"Given ReSampleMethod " + std::to_string(m) + " not implemented."};
Expand All @@ -54,33 +59,32 @@ Cell Provider::get_value(const GridDataSelector& selector, data_access::ReSample
const auto start = clock_type::from_time_t(selector.initial_time());
const auto end = std::chrono::seconds{selector.duration()} + start;
const auto step = std::chrono::seconds{record_duration()};

auto cell = selector.cells()[0]; // FIXME: bad semantics,

std::vector<Cell> cells = { selector.cells().begin(), selector.cells().end() };
for (auto current = start; current < end; current += step) {
bmi_->UpdateUntil(current.time_since_epoch().count());

this->next(current.time_since_epoch().count());
boost::span<double> values{
static_cast<double*>(bmi_->GetValuePtr(selector.variable())),
var_grid_.rows * var_grid_.columns
};

cell.value += values[cell.x + cell.y * var_grid_.rows];
for (auto& cell : cells) {
cell.value += values[cell.x + cell.y * var_grid_.rows];
}
}

if (m == ReSampleMethod::MEAN) {
const auto time_step_seconds = step.count();
const auto time_duration = std::chrono::duration_cast<std::chrono::seconds>(end - start).count();
const auto num_time_steps = time_duration / time_step_seconds;
cell.value /= num_time_steps;
}

return cell;
}
for (auto& cell : cells) {
cell.value /= num_time_steps;
}
}

std::vector<Cell> Provider::get_values(const GridDataSelector& selector, data_access::ReSampleMethod m)
{

return cells;
}


Expand Down
4 changes: 2 additions & 2 deletions src/forcing/ForcingsEngineLumpedDataProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ using Provider = ForcingsEngineLumpedDataProvider;

Provider::ForcingsEngineLumpedDataProvider(
const std::string& init,
std::size_t time_begin_seconds,
std::size_t time_end_seconds
std::time_t time_begin_seconds,
std::time_t time_end_seconds
)
: BaseProvider(init, time_begin_seconds, time_end_seconds)
{
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ ngen_add_test(
test_forcings_engine
OBJECTS
forcing/ForcingsEngineLumpedDataProvider_Test.cpp
forcing/ForcingsEngineGriddedDataProvider_Test.cpp
LIBRARIES
NGen::forcing
REQUIRES
Expand Down
15 changes: 8 additions & 7 deletions test/forcing/ForcingsEngineGriddedDataProvider_Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ struct ForcingsEngineGriddedDataProviderTest

static const forcing_params default_params;
static std::shared_ptr<utils::ngenPy::InterpreterUtil> gil_;
static provider_type* provider_;
static std::unique_ptr<provider_type> provider_;
static mpi_info mpi_;
};

using TestFixture = ForcingsEngineGriddedDataProviderTest;
const forcing_params TestFixture::default_params = { "", "ForcingsEngine", "2024-01-17 01:00:00", "2024-01-17 06:00:00" };
std::shared_ptr<utils::ngenPy::InterpreterUtil> TestFixture::gil_ = nullptr;
TestFixture::provider_type* TestFixture::provider_ = nullptr;
std::unique_ptr<TestFixture::provider_type> TestFixture::provider_ = nullptr;
mpi_info TestFixture::mpi_ = {};

void TestFixture::SetUpTestSuite()
Expand All @@ -67,7 +68,7 @@ void TestFixture::SetUpTestSuite()

void TestFixture::TearDownTestSuite()
{
provider_->finalize_all();
data_access::detail::forcings_engine_instances.clear();
gil_.reset();

#if NGEN_WITH_MPI
Expand All @@ -82,11 +83,11 @@ void TestFixture::TearDownTestSuite()
*/
TEST_F(ForcingsEngineGriddedDataProviderTest, Storage)
{
auto* inst_a = data_access::ForcingsEngineGriddedDataProvider::instance(config_file, default_params.start_time, default_params.end_time);
ASSERT_EQ(inst_a, provider_);
auto inst_a = data_access::ForcingsEngineGriddedDataProvider::make_gridded_instance(config_file, default_params.start_time, default_params.end_time);
ASSERT_EQ(inst_a->model(), provider_->model());

auto* inst_b = data_access::ForcingsEngineGriddedDataProvider::instance(config_file, default_params.start_time, default_params.end_time);
ASSERT_EQ(inst_a, inst_b);
auto inst_b = data_access::ForcingsEngineGriddedDataProvider::make_gridded_instance(config_file, default_params.start_time, default_params.end_time);
ASSERT_EQ(inst_a->model(), inst_b->model());
}

TEST_F(ForcingsEngineGriddedDataProviderTest, VariableAccess)
Expand Down
16 changes: 9 additions & 7 deletions test/forcing/ForcingsEngineLumpedDataProvider_Test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <gtest/gtest.h>

#include "ForcingsEngineDataProvider.hpp"
#include "ForcingsEngineLumpedDataProvider.hpp"

#include "NGenConfig.h"
Expand Down Expand Up @@ -40,7 +41,7 @@ struct ForcingsEngineLumpedDataProviderTest
static const forcing_params default_params;

static std::shared_ptr<utils::ngenPy::InterpreterUtil> gil_;
static provider_type* provider_;
static std::unique_ptr<provider_type> provider_;
static mpi_info mpi_;
};

Expand All @@ -50,7 +51,7 @@ using TestFixture = ForcingsEngineLumpedDataProviderTest;
/* Static member initialization */
const forcing_params TestFixture::default_params = { "", "ForcingsEngine", "2024-01-17 01:00:00", "2024-01-17 06:00:00" };
std::shared_ptr<utils::ngenPy::InterpreterUtil> TestFixture::gil_ = nullptr;
TestFixture::provider_type* TestFixture::provider_ = nullptr;
std::unique_ptr<TestFixture::provider_type> TestFixture::provider_ = nullptr;
mpi_info TestFixture::mpi_ = {};

// Initialize MPI if available, get Python GIL, and initialize forcings engine.
Expand All @@ -75,7 +76,7 @@ void TestFixture::SetUpTestSuite()
// Destroy providers, GIL, and finalize MPI
void TestFixture::TearDownTestSuite()
{
provider_->finalize_all();
data_access::detail::forcings_engine_instances.clear();
gil_.reset();

#if NGEN_WITH_MPI
Expand All @@ -92,11 +93,12 @@ void TestFixture::TearDownTestSuite()
*/
TEST_F(ForcingsEngineLumpedDataProviderTest, Storage)
{
auto* inst_a = data_access::ForcingsEngineLumpedDataProvider::instance(config_file, default_params.start_time, default_params.end_time);
ASSERT_EQ(inst_a, provider_);
auto inst_a = data_access::ForcingsEngineLumpedDataProvider::make_lumped_instance(config_file, default_params.start_time, default_params.end_time);
ASSERT_EQ(inst_a->model(), provider_->model());


auto* inst_b = data_access::ForcingsEngineLumpedDataProvider::instance(config_file, default_params.start_time, default_params.end_time);
ASSERT_EQ(inst_a, inst_b);
auto inst_b = data_access::ForcingsEngineLumpedDataProvider::make_lumped_instance(config_file, default_params.start_time, default_params.end_time);
ASSERT_EQ(inst_a->model(), inst_b->model());
}

TEST_F(ForcingsEngineLumpedDataProviderTest, Timing)
Expand Down

0 comments on commit f2c7a49

Please sign in to comment.