From 8d44a7916f26dd382d21be31793648ac65a213b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roberto=20Di=20Remigio=20Eik=C3=A5s?= Date: Thu, 13 Jul 2023 09:45:31 +0200 Subject: [PATCH] __add__ and __iadd__ --- src/vampyr/trees/trees.h | 99 +++++++++++++++++++--------------------- 1 file changed, 46 insertions(+), 53 deletions(-) diff --git a/src/vampyr/trees/trees.h b/src/vampyr/trees/trees.h index ce8ed28a..5950455a 100644 --- a/src/vampyr/trees/trees.h +++ b/src/vampyr/trees/trees.h @@ -2,8 +2,8 @@ #include -#include #include +#include #include #include @@ -12,6 +12,18 @@ #include namespace vampyr { +template +auto impl__add__(mrcpp::FunctionTree *inp_a, mrcpp::FunctionTree *inp_b) -> std::unique_ptr> { + using namespace mrcpp; + auto out = std::make_unique>(inp_a->getMRA()); + FunctionTreeVector vec; + vec.push_back({1.0, inp_a}); + vec.push_back({1.0, inp_b}); + build_grid(*out, vec); + add(-1.0, *out, vec); + return out; +}; + template void trees(pybind11::module &m) { using namespace mrcpp; namespace py = pybind11; @@ -30,12 +42,11 @@ template void trees(pybind11::module &m) { py::return_value_policy::reference_internal) .def("rootScale", &MWTree::getRootScale) .def("depth", &MWTree::getDepth) - .def( - "setZero", - [](MWTree *out) { - out->setZero(); - return out; - }) + .def("setZero", + [](MWTree *out) { + out->setZero(); + return out; + }) .def("clear", &MWTree::clear) .def("setName", &MWTree::setName) .def("name", &MWTree::getName) @@ -61,37 +72,34 @@ template void trees(pybind11::module &m) { .def("integrate", &FunctionTree::integrate) .def("quadrature", [](FunctionTree *tree) { + if constexpr (D != 1) { throw std::runtime_error("quadrature only implemented for 1D"); } - if constexpr (D != 1) { - throw std::runtime_error("quadrature only implemented for 1D"); - } + // Current implementation only makes sense in 1D - // Current implementation only makes sense in 1D + std::vector vec_pts; + // Iterate over all end nodes + for (int i = 0; i < tree->getNEndNodes(); i++) { + MWNode &node = tree->getEndMWNode(i); - std::vector vec_pts; - // Iterate over all end nodes - for (int i = 0; i < tree->getNEndNodes(); i++) { - MWNode &node = tree->getEndMWNode(i); + Eigen::MatrixXd pts; + node.getPrimitiveQuadPts(pts); - Eigen::MatrixXd pts; - node.getPrimitiveQuadPts(pts); + // Flatten the MatrixXd and add the points from this node to the vector + vec_pts.insert(vec_pts.end(), pts.data(), pts.data() + pts.size()); + } - // Flatten the MatrixXd and add the points from this node to the vector - vec_pts.insert(vec_pts.end(), pts.data(), pts.data() + pts.size()); - } + // Now we need to create an Eigen vector from our std::vector + Eigen::VectorXd final_pts = + Eigen::Map(vec_pts.data(), vec_pts.size()); - // Now we need to create an Eigen vector from our std::vector - Eigen::VectorXd final_pts = Eigen::Map(vec_pts.data(), vec_pts.size()); - - // Now final_pts holds all the points from all nodes - return final_pts; + // Now final_pts holds all the points from all nodes + return final_pts; + }) + .def("normalize", + [](FunctionTree *out) { + out->normalize(); + return out; }) - .def( - "normalize", - [](FunctionTree *out) { - out->normalize(); - return out; - }) .def( "saveTree", [](FunctionTree &obj, const std::string &filename) { @@ -137,25 +145,10 @@ template void trees(pybind11::module &m) { return out; }, py::is_operator()) - .def( - "__add__", - [](FunctionTree *inp_a, FunctionTree *inp_b) { - auto out = std::make_unique>(inp_a->getMRA()); - FunctionTreeVector vec; - vec.push_back({1.0, inp_a}); - vec.push_back({1.0, inp_b}); - build_grid(*out, vec); - add(-1.0, *out, vec); - return out; - }, - py::is_operator()) + .def("__add__", &impl__add__, py::is_operator()) .def( "__iadd__", - [](FunctionTree *out, FunctionTree *inp) { - refine_grid(*out, *inp); - out->add(1.0, *inp); - return out; - }, + [](FunctionTree *out, FunctionTree *inp) { return impl__add__(out, inp); }, py::is_operator()) .def( "__sub__", @@ -292,11 +285,11 @@ template void trees(pybind11::module &m) { .def("hasParent", &MWNode::hasParent) .def("hasCoefs", &MWNode::hasCoefs) .def("quadrature", - [](MWNode &node) { - Eigen::MatrixXd pts; - node.getPrimitiveQuadPts(pts); - return pts; - }) + [](MWNode &node) { + Eigen::MatrixXd pts; + node.getPrimitiveQuadPts(pts); + return pts; + }) .def("center", &MWNode::getCenter) .def("upperBounds", &MWNode::getUpperBounds) .def("lowerBounds", &MWNode::getLowerBounds)