Skip to content

Commit

Permalink
refactor: redesign instance sharing for forcings engine providers [no…
Browse files Browse the repository at this point in the history
… ci]

Instance map stores BMI adapters rather than pointers to the data
providers. Also modifies factory methods slightly to adjust how
new data providers are created, and how timing is handled for each
provider instance.
  • Loading branch information
program-- committed Jun 10, 2024
1 parent d776e7d commit 214d83a
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 111 deletions.
117 changes: 49 additions & 68 deletions include/forcing/ForcingsEngineDataProvider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,26 @@ static constexpr auto forcings_engine_python_class = "NWMv3_Forcing_Engine_BMI_
static constexpr auto forcings_engine_python_classpath = "NextGen_Forcings_Engine.NWMv3_Forcing_Engine_BMI_model";
static constexpr auto default_time_format = "%Y-%m-%d %H:%M:%S";

namespace detail {

struct ForcingsEngineStorage {
using key_type = std::string;
using bmi_type = models::bmi::Bmi_Py_Adapter;
using value_type = std::shared_ptr<bmi_type>;

value_type get(const key_type& key);
void set(const key_type& key, value_type value);
void clear();

private:
//! Instance map of underlying BMI models.
std::unordered_map<key_type, value_type> data_;
};

static ForcingsEngineStorage forcings_engine_instances{};

} // namespace detail

//! Parse time string from format.
//! Utility function for ForcingsEngineLumpedDataProvider constructor.
time_t parse_time(const std::string& time, const std::string& fmt);
Expand Down Expand Up @@ -70,56 +90,14 @@ struct ForcingsEngineDataProvider
return (epoch - time_begin_) / time_step_;
}

// Temporary (?) function to clear out instances of this type.
static void finalize_all() {
instances_.clear();
}

/* Remaining virtual member functions from DataProvider must be implemented
by derived classes. */

data_type get_value(const selection_type& selector, data_access::ReSampleMethod m) override = 0;

std::vector<data_type> get_values(const selection_type& selector, data_access::ReSampleMethod m) override = 0;


/* Friend functions */
static ForcingsEngineDataProvider* instance(
const std::string& init,
const std::string& time_begin,
const std::string& time_end,
const std::string& time_fmt = default_time_format
)
{
auto& inst = instances_.at(init);
if (inst != nullptr) {
assert(inst->time_begin_.time_since_epoch() == std::chrono::seconds{parse_time(time_begin, time_fmt)});
assert(inst->time_end_.time_since_epoch() == std::chrono::seconds{parse_time(time_end, time_fmt)});
}

return inst.get();
}

template<typename Derived>
static ForcingsEngineDataProvider* make_instance(
const std::string& init,
const std::string& time_begin,
const std::string& time_end,
const std::string& time_fmt = default_time_format
)
{
auto time_begin_epoch = static_cast<size_t>(parse_time(time_begin, time_fmt));
auto time_end_epoch = static_cast<size_t>(parse_time(time_end, time_fmt));

auto provider = std::unique_ptr<Derived>{
new Derived{init, time_begin_epoch, time_end_epoch}
};

return set_instance(init, std::move(provider));
}

protected:

// TODO: It may make more sense to have time_begin_seconds and time_end_seconds coalesced into
// a single argument: `clock_type::duration time_duration`, since the forcings engine
// manages time via a duration rather than time points. !! Need to double check
Expand All @@ -134,44 +112,47 @@ struct ForcingsEngineDataProvider

assert_forcings_engine_requirements();

bmi_ = std::make_unique<models::bmi::Bmi_Py_Adapter>(
"ForcingsEngine",
init,
forcings_engine_python_classpath,
/*allow_exceed_end=*/true,
/*has_fixed_time_step=*/true,
utils::getStdOut()
);
bmi_ = detail::forcings_engine_instances.get(init);
if (bmi_ == nullptr) {
bmi_ = std::make_shared<models::bmi::Bmi_Py_Adapter>(
"ForcingsEngine",
init,
forcings_engine_python_classpath,
/*allow_exceed_end=*/true,
/*has_fixed_time_step=*/true,
utils::getStdOut()
);

detail::forcings_engine_instances.set(init, bmi_);
}

time_step_ = std::chrono::seconds{static_cast<int64_t>(bmi_->GetTimeStep())};
time_current_index_ = std::chrono::seconds{static_cast<int64_t>(bmi_->GetCurrentTime())} / time_step_;
var_output_names_ = bmi_->GetOutputVarNames();
}

static ForcingsEngineDataProvider* set_instance(
template<typename Derived>
static std::unique_ptr<ForcingsEngineDataProvider> make_instance(
const std::string& init,
std::unique_ptr<ForcingsEngineDataProvider>&& instance
const std::string& time_begin,
const std::string& time_end,
const std::string& time_fmt = default_time_format
)
{
instances_[init] = std::move(instance);
return instances_[init].get();
};

//! Instance map
//! @note this map will exist for each of the
//! 3 instance types (lumped, gridded, mesh).
static std::unordered_map<
std::string,
std::unique_ptr<ForcingsEngineDataProvider>
> instances_;

// TODO: this, or just push the scope on time members up?
void increment_time()
{
auto time_begin_epoch = parse_time(time_begin, time_fmt);
auto time_end_epoch = parse_time(time_end, time_fmt);
return std::unique_ptr<Derived>{
new Derived{init, time_begin_epoch, time_end_epoch}
};
}

void next() {
bmi_->Update();
time_current_index_++;
}

//! Forcings Engine instance
std::unique_ptr<models::bmi::Bmi_Py_Adapter> bmi_ = nullptr;
std::shared_ptr<models::bmi::Bmi_Py_Adapter> bmi_ = nullptr;

//! Output variable names
std::vector<std::string> var_output_names_{};
Expand Down
2 changes: 1 addition & 1 deletion include/forcing/ForcingsEngineGriddedDataProvider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct ForcingsEngineGriddedDataProvider

std::vector<Cell> get_values(const GridDataSelector& selector, data_access::ReSampleMethod m) override;

static ForcingsEngineDataProvider* make_gridded_instance(
static std::unique_ptr<ForcingsEngineDataProvider> make_gridded_instance(
const std::string& init,
const std::string& time_start,
const std::string& time_end,
Expand Down
10 changes: 1 addition & 9 deletions include/forcing/ForcingsEngineLumpedDataProvider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct ForcingsEngineLumpedDataProvider
*/
std::size_t variable_index(const std::string& variable) noexcept;

static ForcingsEngineDataProvider* make_lumped_instance(
static std::unique_ptr<ForcingsEngineDataProvider> make_lumped_instance(
const std::string& init,
const std::string& time_start,
const std::string& time_end,
Expand All @@ -51,14 +51,6 @@ struct ForcingsEngineLumpedDataProvider
std::size_t time_end_seconds
);

/**
* @brief Update to next timestep.
*
* @return true
* @return false
*/
bool next();

/**
* @brief Get a forcing value from the instance
*
Expand Down
11 changes: 10 additions & 1 deletion include/forcing/GridDataSelector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ struct BoundingBox
: box_(std::move(box))
{}

template<typename Geometry>
BoundingBox(const Geometry& geom)
: box_(boost::geometry::return_envelope<box_t>(geom))
{}

double xmin() const noexcept
{
return box_.min_corner().get<0>();
Expand Down Expand Up @@ -58,6 +63,10 @@ struct BoundingBox
return poly;
}

const box_t& as_box() const noexcept {
return box_;
}

private:
box_t box_;
};
Expand Down Expand Up @@ -155,7 +164,7 @@ struct GridDataSelector {
const auto ydiff = static_cast<double>(grid.rows) / (ymax - ymin);
const auto xdiff = static_cast<double>(grid.columns) / (xmax - xmin);

const auto bbox = BoundingBox{ boost::geometry::return_envelope<box_t>(polygon) };
const BoundingBox bbox = { polygon };
for (double row = bbox.ymin(); row < bbox.ymax() - ydiff; row += ydiff) {
for (double col = bbox.xmin(); col < bbox.xmax() - xdiff; row += xdiff) {
const box_t cell_box = {
Expand Down
4 changes: 0 additions & 4 deletions src/forcing/ForcingsEngineGriddedDataProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ namespace data_access {
using BaseProvider = ForcingsEngineDataProvider<Cell, GridDataSelector>;
using Provider = ForcingsEngineGriddedDataProvider;

//! Gridded Forcings Engine instances storage
template<>
std::unordered_map<std::string, std::unique_ptr<BaseProvider>> BaseProvider::instances_{};

Provider::ForcingsEngineGriddedDataProvider(
const std::string& init,
std::size_t time_begin_seconds,
Expand Down
50 changes: 22 additions & 28 deletions src/forcing/ForcingsEngineLumpedDataProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@ namespace data_access {
using BaseProvider = ForcingsEngineDataProvider<double, CatchmentAggrDataSelector>;
using Provider = ForcingsEngineLumpedDataProvider;

//! Lumped Forcings Engine instances storage
template<>
std::unordered_map<std::string, std::unique_ptr<BaseProvider>> BaseProvider::instances_{};

Provider::ForcingsEngineLumpedDataProvider(
const std::string& init,
std::size_t time_begin_seconds,
std::size_t time_end_seconds
)
: BaseProvider(init, time_begin_seconds, time_end_seconds)
{
const auto current_time = bmi_->GetCurrentTime();
const auto time_step = bmi_->GetTimeStep();

// Check that CAT-ID is an available output name, otherwise we most likely aren't
// running the correct configuration of the forcings engine for this class.
Expand All @@ -29,34 +27,39 @@ Provider::ForcingsEngineLumpedDataProvider(
}
var_output_names_.erase(cat_id_pos);

// Initialize the value cache
const auto id_dim = static_cast<std::size_t>(bmi_->GetVarNbytes("CAT-ID") / bmi_->GetVarItemsize("CAT-ID"));
const auto var_dim = get_available_variable_names().size();

bmi_->Update();
this->increment_time();
if (current_time <= time_step) {
next();
}

// Copy CAT-ID values into instance vector
const auto cat_id = boost::span<const int>(
static_cast<const int*>(bmi_->GetValuePtr("CAT-ID")),
id_dim
);

// temporary map to ensure uniqueness
std::unordered_map<int, int> uniq;
uniq.reserve(id_dim);
var_divides_.reserve(id_dim);
for (int i = 0; i < id_dim; ++i) {
const auto id = cat_id[i];

uniq[id]++;
if (uniq[id] > 1) {
throw std::runtime_error{"Non-unique catchment ID found in lumped forcings engine domain: " + std::to_string(id)};
}

var_divides_[id] = i;
}

if (current_time <= time_step) {
// temporary map to ensure uniqueness
// note: if this instance is not starting from the
// beginning, then this uniqueness check was already
// performed.
std::unordered_map<int, int> uniq;
for (const auto id : cat_id) {
if (uniq[id] > 0) {
throw std::runtime_error{"Non-unique catchment ID found in lumped forcings engine domain: " + std::to_string(id)};
}
uniq[id]++;
}
}

var_cache_ = decltype(var_cache_){{ id_dim, var_dim }};

// Cache initial iteration
Expand Down Expand Up @@ -101,14 +104,6 @@ std::size_t Provider::variable_index(const std::string& variable) noexcept
return std::distance(vars.begin(), pos);
}

bool Provider::next()
{
bmi_->Update();
this->increment_time();
this->update_value_storage_();
return true;
}

void Provider::update_value_storage_()
{
const auto outputs = this->get_available_variable_names();
Expand Down Expand Up @@ -195,10 +190,8 @@ double Provider::get_value(
double acc = 0.0;
auto current_time = start;
while (current_time < end) {
if (!this->next()) {
break;
}

this->next();
this->update_value_storage_();
acc += this->at(divide_index, var_index);
current_time += step;
}
Expand Down Expand Up @@ -234,6 +227,7 @@ std::vector<double> Provider::get_values(
for (auto current_time = start; current_time < end; current_time += step) {
values.push_back(this->at(divide_index, var_index));
this->next();
this->update_value_storage_();
}

return values;
Expand Down
Loading

0 comments on commit 214d83a

Please sign in to comment.