From f2c7a49113e83783eeba8c87015ab24766032196 Mon Sep 17 00:00:00 2001 From: program-- Date: Wed, 12 Jun 2024 14:37:17 -0700 Subject: [PATCH] refactor: fix tests for instance sharing refactor [no ci] --- .../forcing/ForcingsEngineDataProvider.hpp | 31 +++++++++++++++-- .../ForcingsEngineGriddedDataProvider.hpp | 14 +++++--- .../ForcingsEngineLumpedDataProvider.hpp | 14 +++++--- src/forcing/CMakeLists.txt | 1 + .../ForcingsEngineGriddedDataProvider.cpp | 34 +++++++++++-------- .../ForcingsEngineLumpedDataProvider.cpp | 4 +-- test/CMakeLists.txt | 1 + ...ForcingsEngineGriddedDataProvider_Test.cpp | 15 ++++---- .../ForcingsEngineLumpedDataProvider_Test.cpp | 16 +++++---- 9 files changed, 88 insertions(+), 42 deletions(-) diff --git a/include/forcing/ForcingsEngineDataProvider.hpp b/include/forcing/ForcingsEngineDataProvider.hpp index 60529002dc..aa1706af8a 100644 --- a/include/forcing/ForcingsEngineDataProvider.hpp +++ b/include/forcing/ForcingsEngineDataProvider.hpp @@ -16,14 +16,25 @@ static constexpr auto default_time_format = "%Y-%m-%d %H:%M:%S"; namespace detail { +//! Storage for Forcings Engine-specific BMI instances. struct ForcingsEngineStorage { using key_type = std::string; using bmi_type = models::bmi::Bmi_Py_Adapter; using value_type = std::shared_ptr; - value_type get(const key_type& key); - void set(const key_type& key, value_type value); - void clear(); + value_type get(const key_type& key) + { + auto pos = data_.find(key); + if (pos == data_.end()) { + return nullptr; + } + + return pos->second; + } + + void set(const key_type& key, value_type value) { data_[key] = value; } + + void clear() { data_.clear(); } private: //! Instance map of underlying BMI models. @@ -90,6 +101,11 @@ struct ForcingsEngineDataProvider return (epoch - time_begin_) / time_step_; } + std::shared_ptr model() noexcept + { + return bmi_; + } + /* Remaining virtual member functions from DataProvider must be implemented by derived classes. */ @@ -151,6 +167,15 @@ struct ForcingsEngineDataProvider time_current_index_++; } + void next(double time) { + const auto start = bmi_->GetCurrentTime(); + bmi_->UpdateUntil(time); + const auto end = bmi_->GetCurrentTime(); + time_current_index_ += static_cast( + (end - start) / bmi_->GetTimeStep() + ); + } + //! Forcings Engine instance std::shared_ptr bmi_ = nullptr; diff --git a/include/forcing/ForcingsEngineGriddedDataProvider.hpp b/include/forcing/ForcingsEngineGriddedDataProvider.hpp index e0b775ec12..2002b7f8bc 100644 --- a/include/forcing/ForcingsEngineGriddedDataProvider.hpp +++ b/include/forcing/ForcingsEngineGriddedDataProvider.hpp @@ -10,11 +10,15 @@ namespace data_access { struct ForcingsEngineGriddedDataProvider : public ForcingsEngineDataProvider { + using data_type = data_type; + using selection_type = selection_type; + using base_type = ForcingsEngineDataProvider; + ~ForcingsEngineGriddedDataProvider() override = default; - Cell get_value(const GridDataSelector& selector, data_access::ReSampleMethod m) override; + data_type get_value(const selection_type& selector, data_access::ReSampleMethod m) override; - std::vector get_values(const GridDataSelector& selector, data_access::ReSampleMethod m) override; + std::vector get_values(const selection_type& selector, data_access::ReSampleMethod m) override; static std::unique_ptr make_gridded_instance( const std::string& init, @@ -27,10 +31,12 @@ struct ForcingsEngineGriddedDataProvider } private: + friend base_type; + ForcingsEngineGriddedDataProvider( const std::string& init, - std::size_t time_begin_seconds, - std::size_t time_end_seconds + std::time_t time_begin_seconds, + std::time_t time_end_seconds ); int var_grid_id_{}; diff --git a/include/forcing/ForcingsEngineLumpedDataProvider.hpp b/include/forcing/ForcingsEngineLumpedDataProvider.hpp index 71940ec432..8c33051beb 100644 --- a/include/forcing/ForcingsEngineLumpedDataProvider.hpp +++ b/include/forcing/ForcingsEngineLumpedDataProvider.hpp @@ -10,13 +10,17 @@ namespace data_access { struct ForcingsEngineLumpedDataProvider : public ForcingsEngineDataProvider { + using data_type = data_type; + using selection_type = selection_type; + using base_type = ForcingsEngineDataProvider; + static constexpr auto bad_index = static_cast(-1); ~ForcingsEngineLumpedDataProvider() override = default; - double get_value(const CatchmentAggrDataSelector& selector, data_access::ReSampleMethod m) override; + data_type get_value(const selection_type& selector, data_access::ReSampleMethod m) override; - std::vector get_values(const CatchmentAggrDataSelector& selector, data_access::ReSampleMethod m) override; + std::vector get_values(const selection_type& selector, data_access::ReSampleMethod m) override; /** * @brief Get the index in `CAT-ID` for a given divide in the instance cache. @@ -45,10 +49,12 @@ struct ForcingsEngineLumpedDataProvider } private: + friend base_type; + ForcingsEngineLumpedDataProvider( const std::string& init, - std::size_t time_begin_seconds, - std::size_t time_end_seconds + std::time_t time_begin_seconds, + std::time_t time_end_seconds ); /** diff --git a/src/forcing/CMakeLists.txt b/src/forcing/CMakeLists.txt index d374032b2e..79b8475dd2 100644 --- a/src/forcing/CMakeLists.txt +++ b/src/forcing/CMakeLists.txt @@ -32,6 +32,7 @@ if(NGEN_WITH_PYTHON) PRIVATE "${CMAKE_CURRENT_LIST_DIR}/ForcingsEngineDataProvider.cpp" "${CMAKE_CURRENT_LIST_DIR}/ForcingsEngineLumpedDataProvider.cpp" + "${CMAKE_CURRENT_LIST_DIR}/ForcingsEngineGriddedDataProvider.cpp" ) target_link_libraries(forcing PUBLIC pybind11::embed NGen::ngen_bmi) endif() diff --git a/src/forcing/ForcingsEngineGriddedDataProvider.cpp b/src/forcing/ForcingsEngineGriddedDataProvider.cpp index dd9d2bcd82..e7bef70ba1 100644 --- a/src/forcing/ForcingsEngineGriddedDataProvider.cpp +++ b/src/forcing/ForcingsEngineGriddedDataProvider.cpp @@ -4,13 +4,13 @@ namespace data_access { -using BaseProvider = ForcingsEngineDataProvider; using Provider = ForcingsEngineGriddedDataProvider; +using BaseProvider = Provider::base_type; Provider::ForcingsEngineGriddedDataProvider( const std::string& init, - std::size_t time_begin_seconds, - std::size_t time_end_seconds + std::time_t time_begin_seconds, + std::time_t time_end_seconds ) : BaseProvider(init, time_begin_seconds, time_end_seconds) { @@ -46,6 +46,11 @@ Provider::ForcingsEngineGriddedDataProvider( } Cell Provider::get_value(const GridDataSelector& selector, data_access::ReSampleMethod m) +{ + throw std::runtime_error{"ForcingsEngineGriddedDataProvider::get_value() is not implemented"}; +} + +std::vector Provider::get_values(const GridDataSelector& selector, data_access::ReSampleMethod m) { if (m != ReSampleMethod::SUM && m != ReSampleMethod::MEAN) { throw std::runtime_error{"Given ReSampleMethod " + std::to_string(m) + " not implemented."}; @@ -54,33 +59,32 @@ Cell Provider::get_value(const GridDataSelector& selector, data_access::ReSample const auto start = clock_type::from_time_t(selector.initial_time()); const auto end = std::chrono::seconds{selector.duration()} + start; const auto step = std::chrono::seconds{record_duration()}; - - auto cell = selector.cells()[0]; // FIXME: bad semantics, + std::vector cells = { selector.cells().begin(), selector.cells().end() }; for (auto current = start; current < end; current += step) { - bmi_->UpdateUntil(current.time_since_epoch().count()); - + this->next(current.time_since_epoch().count()); + boost::span values{ static_cast(bmi_->GetValuePtr(selector.variable())), var_grid_.rows * var_grid_.columns }; - cell.value += values[cell.x + cell.y * var_grid_.rows]; + for (auto& cell : cells) { + cell.value += values[cell.x + cell.y * var_grid_.rows]; + } } if (m == ReSampleMethod::MEAN) { const auto time_step_seconds = step.count(); const auto time_duration = std::chrono::duration_cast(end - start).count(); const auto num_time_steps = time_duration / time_step_seconds; - cell.value /= num_time_steps; - } - return cell; -} + for (auto& cell : cells) { + cell.value /= num_time_steps; + } + } -std::vector Provider::get_values(const GridDataSelector& selector, data_access::ReSampleMethod m) -{ - + return cells; } diff --git a/src/forcing/ForcingsEngineLumpedDataProvider.cpp b/src/forcing/ForcingsEngineLumpedDataProvider.cpp index 9d3fd32b59..2be0966203 100644 --- a/src/forcing/ForcingsEngineLumpedDataProvider.cpp +++ b/src/forcing/ForcingsEngineLumpedDataProvider.cpp @@ -8,8 +8,8 @@ using Provider = ForcingsEngineLumpedDataProvider; Provider::ForcingsEngineLumpedDataProvider( const std::string& init, - std::size_t time_begin_seconds, - std::size_t time_end_seconds + std::time_t time_begin_seconds, + std::time_t time_end_seconds ) : BaseProvider(init, time_begin_seconds, time_end_seconds) { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 376aab7f15..632c5c16cc 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -439,6 +439,7 @@ ngen_add_test( test_forcings_engine OBJECTS forcing/ForcingsEngineLumpedDataProvider_Test.cpp + forcing/ForcingsEngineGriddedDataProvider_Test.cpp LIBRARIES NGen::forcing REQUIRES diff --git a/test/forcing/ForcingsEngineGriddedDataProvider_Test.cpp b/test/forcing/ForcingsEngineGriddedDataProvider_Test.cpp index fd87195b9e..2b6d39ab73 100644 --- a/test/forcing/ForcingsEngineGriddedDataProvider_Test.cpp +++ b/test/forcing/ForcingsEngineGriddedDataProvider_Test.cpp @@ -39,13 +39,14 @@ struct ForcingsEngineGriddedDataProviderTest static const forcing_params default_params; static std::shared_ptr gil_; - static provider_type* provider_; + static std::unique_ptr provider_; static mpi_info mpi_; }; using TestFixture = ForcingsEngineGriddedDataProviderTest; +const forcing_params TestFixture::default_params = { "", "ForcingsEngine", "2024-01-17 01:00:00", "2024-01-17 06:00:00" }; std::shared_ptr TestFixture::gil_ = nullptr; -TestFixture::provider_type* TestFixture::provider_ = nullptr; +std::unique_ptr TestFixture::provider_ = nullptr; mpi_info TestFixture::mpi_ = {}; void TestFixture::SetUpTestSuite() @@ -67,7 +68,7 @@ void TestFixture::SetUpTestSuite() void TestFixture::TearDownTestSuite() { - provider_->finalize_all(); + data_access::detail::forcings_engine_instances.clear(); gil_.reset(); #if NGEN_WITH_MPI @@ -82,11 +83,11 @@ void TestFixture::TearDownTestSuite() */ TEST_F(ForcingsEngineGriddedDataProviderTest, Storage) { - auto* inst_a = data_access::ForcingsEngineGriddedDataProvider::instance(config_file, default_params.start_time, default_params.end_time); - ASSERT_EQ(inst_a, provider_); + auto inst_a = data_access::ForcingsEngineGriddedDataProvider::make_gridded_instance(config_file, default_params.start_time, default_params.end_time); + ASSERT_EQ(inst_a->model(), provider_->model()); - auto* inst_b = data_access::ForcingsEngineGriddedDataProvider::instance(config_file, default_params.start_time, default_params.end_time); - ASSERT_EQ(inst_a, inst_b); + auto inst_b = data_access::ForcingsEngineGriddedDataProvider::make_gridded_instance(config_file, default_params.start_time, default_params.end_time); + ASSERT_EQ(inst_a->model(), inst_b->model()); } TEST_F(ForcingsEngineGriddedDataProviderTest, VariableAccess) diff --git a/test/forcing/ForcingsEngineLumpedDataProvider_Test.cpp b/test/forcing/ForcingsEngineLumpedDataProvider_Test.cpp index fb8e03859b..e355f0676a 100644 --- a/test/forcing/ForcingsEngineLumpedDataProvider_Test.cpp +++ b/test/forcing/ForcingsEngineLumpedDataProvider_Test.cpp @@ -1,5 +1,6 @@ #include +#include "ForcingsEngineDataProvider.hpp" #include "ForcingsEngineLumpedDataProvider.hpp" #include "NGenConfig.h" @@ -40,7 +41,7 @@ struct ForcingsEngineLumpedDataProviderTest static const forcing_params default_params; static std::shared_ptr gil_; - static provider_type* provider_; + static std::unique_ptr provider_; static mpi_info mpi_; }; @@ -50,7 +51,7 @@ using TestFixture = ForcingsEngineLumpedDataProviderTest; /* Static member initialization */ const forcing_params TestFixture::default_params = { "", "ForcingsEngine", "2024-01-17 01:00:00", "2024-01-17 06:00:00" }; std::shared_ptr TestFixture::gil_ = nullptr; -TestFixture::provider_type* TestFixture::provider_ = nullptr; +std::unique_ptr TestFixture::provider_ = nullptr; mpi_info TestFixture::mpi_ = {}; // Initialize MPI if available, get Python GIL, and initialize forcings engine. @@ -75,7 +76,7 @@ void TestFixture::SetUpTestSuite() // Destroy providers, GIL, and finalize MPI void TestFixture::TearDownTestSuite() { - provider_->finalize_all(); + data_access::detail::forcings_engine_instances.clear(); gil_.reset(); #if NGEN_WITH_MPI @@ -92,11 +93,12 @@ void TestFixture::TearDownTestSuite() */ TEST_F(ForcingsEngineLumpedDataProviderTest, Storage) { - auto* inst_a = data_access::ForcingsEngineLumpedDataProvider::instance(config_file, default_params.start_time, default_params.end_time); - ASSERT_EQ(inst_a, provider_); + auto inst_a = data_access::ForcingsEngineLumpedDataProvider::make_lumped_instance(config_file, default_params.start_time, default_params.end_time); + ASSERT_EQ(inst_a->model(), provider_->model()); + - auto* inst_b = data_access::ForcingsEngineLumpedDataProvider::instance(config_file, default_params.start_time, default_params.end_time); - ASSERT_EQ(inst_a, inst_b); + auto inst_b = data_access::ForcingsEngineLumpedDataProvider::make_lumped_instance(config_file, default_params.start_time, default_params.end_time); + ASSERT_EQ(inst_a->model(), inst_b->model()); } TEST_F(ForcingsEngineLumpedDataProviderTest, Timing)