From 135885db5354dddaa5d88accb8d203fe7072ac7a Mon Sep 17 00:00:00 2001 From: Kasper Peeters Date: Fri, 1 Dec 2023 20:10:52 +0000 Subject: [PATCH] Fix a bug in accessing the Weight value in Python. Fix an error with setting the multiplier of a node from Python. --- core/ExNode.cc | 4 ++-- core/Storage.cc | 6 ++++++ core/Storage.hh | 3 ++- core/algorithms/substitute.cc | 6 +++++- core/pythoncdb/py_properties.cc | 9 ++++++++- tests/programming.cdb | 15 +++++++++++++++ 6 files changed, 38 insertions(+), 5 deletions(-) diff --git a/core/ExNode.cc b/core/ExNode.cc index 8ca46838a9..cbec45be64 100644 --- a/core/ExNode.cc +++ b/core/ExNode.cc @@ -325,8 +325,8 @@ void ExNode::set_multiplier(pybind11::object mult) if(!ex->is_valid(it)) throw ConsistencyException("Cannot set the multiplier of an iterator before the first 'next'."); - pybind11::object mpq = pybind11::module::import("gmpy2").attr("mpq"); - multiply(it->multiplier, pybind11::cast(mult)); + set(it->multiplier, multiplier_t(mult.attr("numerator").cast(), + mult.attr("denominator").cast()) ); } diff --git a/core/Storage.cc b/core/Storage.cc index f746a07cda..cc28f0c279 100644 --- a/core/Storage.cc +++ b/core/Storage.cc @@ -1014,6 +1014,12 @@ namespace cadabra { num=rat_set.insert(fac).first; } + void set(rset_t::iterator& num, multiplier_t fac) + { + fac.canonicalize(); + num=rat_set.insert(fac).first; + } + void add(rset_t::iterator& num, multiplier_t fac) { fac+=*num; diff --git a/core/Storage.hh b/core/Storage.hh index 40757b238a..923e99f0f7 100644 --- a/core/Storage.hh +++ b/core/Storage.hh @@ -129,7 +129,8 @@ namespace cadabra { void one(rset_t::iterator&); void flip_sign(rset_t::iterator&); void half(rset_t::iterator&); - + void set(rset_t::iterator&, multiplier_t); + /// \ingroup core /// /// Basic storage class for symbolic mathemematical expressions. The diff --git a/core/algorithms/substitute.cc b/core/algorithms/substitute.cc index 0c668d9287..6a4f35c676 100644 --- a/core/algorithms/substitute.cc +++ b/core/algorithms/substitute.cc @@ -18,7 +18,9 @@ substitute::substitute(const Kernel& k, Ex& tr, Ex& args_, bool partial) { if(args.is_empty()) throw ArgumentException("substitute: Replacement rule is an empty expression."); - + + Stopwatch sw; + sw.start(); cadabra::do_list(args, args.begin(), [&](Ex::iterator arrow) { //args.print_recursive_treeform(std::cerr, arrow); if(*arrow->name!="\\arrow" && *arrow->name!="\\equals") @@ -81,6 +83,8 @@ substitute::substitute(const Kernel& k, Ex& tr, Ex& args_, bool partial) } return true; }); + sw.stop(); + std::cerr << "preparation took " << sw << std::endl; } bool substitute::can_apply(iterator st) diff --git a/core/pythoncdb/py_properties.cc b/core/pythoncdb/py_properties.cc index 16e5305caa..2d64aaf24a 100644 --- a/core/pythoncdb/py_properties.cc +++ b/core/pythoncdb/py_properties.cc @@ -360,7 +360,14 @@ namespace cadabra { def_abstract_prop(m, "DependsBase") .def("dependencies", [](const Py_DependsBase & p) { return p.get_prop()->dependencies(p.get_kernel(), p.get_it()); }); def_abstract_prop(m, "WeightBase") - .def("value", [](const Py_WeightBase & p, const std::string& forcedLabel) { return p.get_prop()->value(p.get_kernel(), p.get_it(), forcedLabel); }); + .def("value", [](const Py_WeightBase & p, const std::string& forcedLabel) { + // This is mpq_class, convert to the Python equivalent. + pybind11::object mpq = pybind11::module::import("gmpy2").attr("mpq"); + auto m = p.get_prop()->value(p.get_kernel(), p.get_it(), forcedLabel); + pybind11::object mult = mpq(m.get_num().get_si(), m.get_den().get_si()); + return mult; + }); + def_abstract_prop(m, "DifferentialFormBase") .def("degree", [](const Py_DifferentialFormBase & p) { return p.get_prop()->degree(p.get_props(), p.get_it()); }); diff --git a/tests/programming.cdb b/tests/programming.cdb index 2c814bbb20..ea5a3226bf 100644 --- a/tests/programming.cdb +++ b/tests/programming.cdb @@ -341,5 +341,20 @@ def test16(): {i,j,k}::Indices(isospin, position=independent). assert($\Lambda_{a}$.matches($\Lambda_{i}$)==False) assert($\Lambda_{a}$.matches($\Lambda^{i}$)==False) + print("Test 16 passed") test16() + +def test17(): + x::Weight(value=42, label=field); + tst1 = Weight.get($x$, label="field").value("field") + assert(tst1==42) + print("Test 17a passed") + ex:= 3 a; + ex.top().multiplier = tst1 + tst2:= 42 a - @(ex); + assert(tst2==0) + print("Test 17b passed") + +test17() +