Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* Fixes rdkit#7675

* update expected psql results
these changed because of the version bump to the pickles
  • Loading branch information
greglandrum authored Aug 2, 2024
1 parent 86f197c commit 1790de8
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 64 deletions.
93 changes: 65 additions & 28 deletions Code/GraphMol/ChemReactions/catch_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ TEST_CASE("Github #1632", "[Reaction][PDB][bug]") {
std::unique_ptr<RWMol> mol(SequenceToMol("K", sanitize, flavor));
REQUIRE(mol);
REQUIRE(mol->getAtomWithIdx(0)->getMonomerInfo());
auto res = static_cast<AtomPDBResidueInfo*>(
auto res = static_cast<AtomPDBResidueInfo *>(
mol->getAtomWithIdx(0)->getMonomerInfo());
CHECK(res->getResidueNumber() == 1);
std::unique_ptr<ChemicalReaction> rxn(RxnSmartsToChemicalReaction(
Expand All @@ -53,15 +53,15 @@ TEST_CASE("Github #1632", "[Reaction][PDB][bug]") {
auto p = prods[0][0];
CHECK(p->getNumAtoms() == mol->getNumAtoms() + 1);
REQUIRE(p->getAtomWithIdx(0)->getMonomerInfo());
auto pres = static_cast<AtomPDBResidueInfo*>(
auto pres = static_cast<AtomPDBResidueInfo *>(
p->getAtomWithIdx(0)->getMonomerInfo());
CHECK(pres->getResidueNumber() == 1);
REQUIRE(!p->getAtomWithIdx(4)->getMonomerInfo());
}
}

static void clearAtomMappingProps(ROMol& mol) {
for (auto&& a : mol.atoms()) {
static void clearAtomMappingProps(ROMol &mol) {
for (auto &&a : mol.atoms()) {
a->clear();
}
}
Expand Down Expand Up @@ -186,7 +186,7 @@ TEST_CASE("negative charge queries. Part of testing changes for github #2604",
// we don't have a way to directly create NegativeFormalCharge queries, so
// make one by hand
REQUIRE(rxn->getProducts()[0]->getAtomWithIdx(0)->hasQuery());
static_cast<QueryAtom*>(rxn->getProducts()[0]->getAtomWithIdx(0))
static_cast<QueryAtom *>(rxn->getProducts()[0]->getAtomWithIdx(0))
->expandQuery(makeAtomNegativeFormalChargeQuery(1));
unsigned nWarnings = 0;
unsigned nErrors = 0;
Expand All @@ -202,7 +202,7 @@ TEST_CASE("negative charge queries. Part of testing changes for github #2604",
// we don't have a way to directly create NegativeFormalCharge queries, so
// make one by hand
REQUIRE(rxn->getProducts()[0]->getAtomWithIdx(0)->hasQuery());
static_cast<QueryAtom*>(rxn->getProducts()[0]->getAtomWithIdx(0))
static_cast<QueryAtom *>(rxn->getProducts()[0]->getAtomWithIdx(0))
->expandQuery(makeAtomNegativeFormalChargeQuery(
-1)); // a bit kludgy, but we need to check
unsigned nWarnings = 0;
Expand All @@ -219,7 +219,7 @@ TEST_CASE("negative charge queries. Part of testing changes for github #2604",
// we don't have a way to directly create NegativeFormalCharge queries, so
// make one by hand
REQUIRE(rxn->getProducts()[0]->getAtomWithIdx(0)->hasQuery());
static_cast<QueryAtom*>(rxn->getProducts()[0]->getAtomWithIdx(0))
static_cast<QueryAtom *>(rxn->getProducts()[0]->getAtomWithIdx(0))
->expandQuery(makeAtomNegativeFormalChargeQuery(2));
unsigned nWarnings = 0;
unsigned nErrors = 0;
Expand Down Expand Up @@ -295,7 +295,7 @@ TEST_CASE("reaction data in PNGs 1", "[Reaction][PNG]") {
metadata = PNGStringToMetadata(pngData);
auto iter =
std::find_if(metadata.begin(), metadata.end(),
[](const std::pair<std::string, std::string>& val) {
[](const std::pair<std::string, std::string> &val) {
return val.first == PNGData::rxnSmartsTag;
});
REQUIRE(iter != metadata.end());
Expand Down Expand Up @@ -403,7 +403,7 @@ TEST_CASE("Github #2891", "[Reaction][chirality][bug]") {
{"[C:4][C@:2]([F:1])([Cl])[Br:3]>>[C:4][C@:2]([F:1])[S:3]", 1},
{"[C:4][C@@:2]([F:1])([Cl])[Br:3]>>[C:4][C@:2]([F:1])[S:3]", 2},
};
for (const auto& pr : tests) {
for (const auto &pr : tests) {
std::unique_ptr<ChemicalReaction> rxn(
RxnSmartsToChemicalReaction(pr.first));
REQUIRE(rxn);
Expand Down Expand Up @@ -1287,14 +1287,14 @@ TEST_CASE("CDXML Parser") {
CHECK(rxns.size() == 1);
unsigned int i = 0;
int count = 0;
for (auto& mol : rxns[0]->getReactants()) {
for (auto &mol : rxns[0]->getReactants()) {
CHECK(mol->getProp<unsigned int>("CDX_SCHEME_ID") == 397);
CHECK(mol->getProp<unsigned int>("CDX_STEP_ID") == 398);
CHECK(mol->getProp<unsigned int>("CDX_REAGENT_ID") == i++);
CHECK(MolToSmiles(*mol) == expected[count++]);
}
i = 0;
for (auto& mol : rxns[0]->getProducts()) {
for (auto &mol : rxns[0]->getProducts()) {
CHECK(mol->getProp<unsigned int>("CDX_SCHEME_ID") == 397);
CHECK(mol->getProp<unsigned int>("CDX_STEP_ID") == 398);
CHECK(mol->getProp<unsigned int>("CDX_PRODUCT_ID") == i++);
Expand All @@ -1308,21 +1308,23 @@ TEST_CASE("CDXML Parser") {
}

SECTION("Github #7528 CDXML Grouped Agents in Reactions") {
// The failing case had fragments grouped with labels, ensure the grouped cersion and the ungrouped
// versions have the same results
// The failing case had fragments grouped with labels, ensure the grouped
// cersion and the ungrouped versions have the same results
auto fname = cdxmlbase + "github7467-grouped-fragments.cdxml";
auto rxns = CDXMLFileToChemicalReactions(fname);
CHECK(rxns.size() == 1);
fname = cdxmlbase + "github7467-ungrouped-fragments.cdxml";
auto rxns2 = CDXMLFileToChemicalReactions(fname);

CHECK(ChemicalReactionToRxnSmarts(*rxns[0]) == ChemicalReactionToRxnSmarts(*rxns2[0]));
CHECK(ChemicalReactionToRxnSmarts(*rxns[0]) ==
ChemicalReactionToRxnSmarts(*rxns2[0]));

// Check to see if our understanding of grouped reagents in reactions is correct
// Check to see if our understanding of grouped reagents in reactions is
// correct
fname = cdxmlbase + "reaction-with-grouped-templates.cdxml";
auto rxns3 = CDXMLFileToChemicalReactions(fname);
CHECK(rxns3.size() == 1);
std::string rxnb = R"RXN($RXN
std::string rxnb = R"RXN($RXN
Mrv2004 062120241319
Expand Down Expand Up @@ -1375,13 +1377,15 @@ M END
std::unique_ptr<ChemicalReaction> rxn_mb{RxnBlockToChemicalReaction(rxnb)};
// CDXMLToReaction is sanitized by default, this might be a mistake...
unsigned int failed;
RxnOps::sanitizeRxn(
*rxn_mb, failed,
RxnOps::SANITIZE_ADJUST_REACTANTS | RxnOps::SANITIZE_ADJUST_PRODUCTS,
RxnOps::MatchOnlyAtRgroupsAdjustParams());
RxnOps::sanitizeRxn(
*rxn_mb, failed,
RxnOps::SANITIZE_ADJUST_REACTANTS | RxnOps::SANITIZE_ADJUST_PRODUCTS,
RxnOps::MatchOnlyAtRgroupsAdjustParams());

CHECK(rxns3[0]->getNumReactantTemplates() == rxn_mb->getNumReactantTemplates());
CHECK(ChemicalReactionToRxnSmarts(*rxns3[0]) == ChemicalReactionToRxnSmarts(*rxn_mb));
CHECK(rxns3[0]->getNumReactantTemplates() ==
rxn_mb->getNumReactantTemplates());
CHECK(ChemicalReactionToRxnSmarts(*rxns3[0]) ==
ChemicalReactionToRxnSmarts(*rxn_mb));
}
}

Expand Down Expand Up @@ -1587,7 +1591,7 @@ TEST_CASE("Github #6211: substructmatchparams for chemical reactions") {
{"CC[C@H](N)O", "CC[C@@H](N)O"},
{"CC[C@@H](N)O", "CC[C@H](N)O"},
{"CCC(N)O", "CCC(N)O"}};
for (const auto& [inSmi, outSmi] : data) {
for (const auto &[inSmi, outSmi] : data) {
INFO(inSmi);
MOL_SPTR_VECT reacts = {ROMOL_SPTR(SmilesToMol(inSmi))};
REQUIRE(reacts[0]);
Expand All @@ -1612,7 +1616,7 @@ TEST_CASE("Github #6211: substructmatchparams for chemical reactions") {
{"CC[C@@H](N)O", ""},
{"CCC(N)O", ""}};
rxn->getSubstructParams().useChirality = true;
for (const auto& [inSmi, outSmi] : data) {
for (const auto &[inSmi, outSmi] : data) {
INFO(inSmi);
MOL_SPTR_VECT reacts = {ROMOL_SPTR(SmilesToMol(inSmi))};
REQUIRE(reacts[0]);
Expand All @@ -1631,7 +1635,7 @@ TEST_CASE("Github #6211: substructmatchparams for chemical reactions") {
}
// make sure the parameters are copied
ChemicalReaction cpy(*rxn);
for (const auto& [inSmi, outSmi] : data) {
for (const auto &[inSmi, outSmi] : data) {
INFO(inSmi);
MOL_SPTR_VECT reacts = {ROMOL_SPTR(SmilesToMol(inSmi))};
REQUIRE(reacts[0]);
Expand Down Expand Up @@ -1771,7 +1775,7 @@ TEST_CASE("Github #7028: Spacing bug in compute2DCoordsForReaction") {
REQUIRE(rxn);
RDDepict::compute2DCoordsForReaction(*rxn);
std::vector<std::pair<double, double>> xbounds;
for (const auto& reactant : rxn->getReactants()) {
for (const auto &reactant : rxn->getReactants()) {
REQUIRE(reactant->getNumConformers() == 1);
std::pair<double, double> bounds = {1e8, -1e8};
auto conf = reactant->getConformer();
Expand All @@ -1782,7 +1786,7 @@ TEST_CASE("Github #7028: Spacing bug in compute2DCoordsForReaction") {
}
xbounds.push_back(bounds);
}
for (const auto& product : rxn->getProducts()) {
for (const auto &product : rxn->getProducts()) {
REQUIRE(product->getNumConformers() == 1);
std::pair<double, double> bounds = {1e8, -1e8};
auto conf = product->getConformer();
Expand Down Expand Up @@ -1945,4 +1949,37 @@ M END)RXN";
CHECK(rxn2.getReactants()[0]->getAtomWithIdx(0)->hasProp("molFileValue"));
}
MolPickler::setDefaultPickleProperties(pklOpts);
}
}

TEST_CASE("Github #7675: pickling fails with a HasProp query") {
SECTION("as reported") {
auto rxnb = R"RXN($RXN
Mrv17183 050301241900
1 1
$MOL
Mrv1718305032419002D
1 0 0 0 0 0 999 V2000
3.1458 -0.1208 0.0000 N 0 0 0 0 0 0 0 0 0 0 0 0
V 1 Amine.Cyclic
M END
$MOL
Mrv1718305032419002D
1 0 0 0 0 0 999 V2000
3.1458 -0.1208 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
M END)RXN";
auto rxn = v2::ReactionParser::ReactionFromRxnBlock(rxnb);
REQUIRE(rxn);

std::string pkl;
ReactionPickler::pickleReaction(*rxn, pkl, PicklerOps::AllProps);
ChemicalReaction rxn2;
ReactionPickler::reactionFromPickle(pkl, rxn2);
CHECK(rxn2.getReactants()[0]->getAtomWithIdx(0)->hasProp("molFileValue"));
}
}
23 changes: 22 additions & 1 deletion Code/GraphMol/MolPickler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ using std::uint32_t;
namespace RDKit {

const int32_t MolPickler::versionMajor = 16;
const int32_t MolPickler::versionMinor = 1;
const int32_t MolPickler::versionMinor = 2;
const int32_t MolPickler::versionPatch = 0;
const int32_t MolPickler::endianId = 0xDEADBEEF;

Expand Down Expand Up @@ -359,6 +359,10 @@ QueryDetails getQueryDetails(const Query<int, T const *, true> *query) {
->getVal(),
static_cast<const EqualityQuery<int, T const *, true> *>(query)
->getTol()));
} else if (typeid(*query) == typeid(HasPropQuery<T const *>)) {
return QueryDetails(std::make_tuple(
MolPickler::QUERY_PROPERTY,
static_cast<const HasPropQuery<T const *> *>(query)->getPropName()));
} else if (typeid(*query) == typeid(Query<int, T const *, true>)) {
return QueryDetails(MolPickler::QUERY_NULL);
} else if (typeid(*query) == typeid(RangeQuery<int, T const *, true>)) {
Expand Down Expand Up @@ -450,6 +454,13 @@ void pickleQuery(std::ostream &ss, const Query<int, T const *, true> *query) {
}

} break;
case 5: {
auto v =
boost::get<std::tuple<MolPickler::Tags, std::string>>(qdetails);
streamWrite(ss, std::get<0>(v));
const auto &pval = std::get<1>(v);
streamWrite(ss, MolPickler::QUERY_VALUE, pval);
} break;
default:
throw MolPicklerException(
"do not know how to pickle part of the query.");
Expand Down Expand Up @@ -588,6 +599,16 @@ Query<int, T const *, true> *buildBaseQuery(std::istream &ss, T const *owner,
case MolPickler::QUERY_NULL:
res = new Query<int, T const *, true>();
break;
case MolPickler::QUERY_PROPERTY: {
streamRead(ss, tag, version);
if (tag != MolPickler::QUERY_VALUE) {
throw MolPicklerException(
"Bad pickle format: QUERY_VALUE tag not found.");
}
std::string propName = "";
streamRead(ss, propName, version);
res = makeHasPropQuery<T>(propName);
} break;
default:
throw MolPicklerException("unknown query-type tag encountered");
}
Expand Down
6 changes: 5 additions & 1 deletion Code/GraphMol/MolPickler.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class RDKIT_GRAPHMOL_EXPORT MolPickler {
BEGINSYMMSSSR,
BEGINFASTFIND,
BEGINFINDOTHERORUNKNOWN,
QUERY_PROPERTY,
// add new entries above here
INVALID_TAG = 255
} Tags;
Expand Down Expand Up @@ -305,11 +306,14 @@ class RDKIT_GRAPHMOL_EXPORT MolPickler {
};

namespace PicklerOps {
// clang-format off
using QueryDetails = boost::variant<
MolPickler::Tags, std::tuple<MolPickler::Tags, int32_t>,
std::tuple<MolPickler::Tags, int32_t, int32_t>,
std::tuple<MolPickler::Tags, int32_t, int32_t, int32_t, char>,
std::tuple<MolPickler::Tags, std::set<int32_t>>>;
std::tuple<MolPickler::Tags, std::set<int32_t>>,
std::tuple<MolPickler::Tags, std::string>>;
// clang-format on
template <class T>
QueryDetails getQueryDetails(const Queries::Query<int, T const *, true> *query);

Expand Down
6 changes: 4 additions & 2 deletions Code/GraphMol/QueryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,8 @@ void finalizeQueryFromDescription(
} else if (descr == "AtomInNRings" || descr == "RecursiveStructure") {
// don't need to do anything here because the classes
// automatically have everything set
} else if (descr == "AtomAnd" || descr == "AtomOr" || descr == "AtomXor") {
} else if (descr == "AtomAnd" || descr == "AtomOr" || descr == "AtomXor" ||
descr == "HasProp") {
// don't need to do anything here because the classes
// automatically have everything set
} else {
Expand Down Expand Up @@ -1160,7 +1161,8 @@ void finalizeQueryFromDescription(
} else if (descr == "BondNull") {
query->setDataFunc(nullDataFun);
query->setMatchFunc(nullQueryFun);
} else if (descr == "BondAnd" || descr == "BondOr" || descr == "BondXor") {
} else if (descr == "BondAnd" || descr == "BondOr" || descr == "BondXor" ||
descr == "HasProp") {
// don't need to do anything here because the classes
// automatically have everything set
} else {
Expand Down
10 changes: 5 additions & 5 deletions Code/GraphMol/QueryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -822,14 +822,12 @@ class HasPropQuery : public Queries::EqualityQuery<int, TargetPtr, true> {

public:
HasPropQuery() : Queries::EqualityQuery<int, TargetPtr, true>(), propname() {
// default is to just do a number of rings query:
this->setDescription("AtomHasProp");
this->setDataFunc(0);
this->setDescription("HasProp");
this->setDataFunc(nullptr);
}
explicit HasPropQuery(std::string v)
: Queries::EqualityQuery<int, TargetPtr, true>(), propname(std::move(v)) {
// default is to just do a number of rings query:
this->setDescription("AtomHasProp");
this->setDescription("HasProp");
this->setDataFunc(nullptr);
}

Expand All @@ -848,6 +846,8 @@ class HasPropQuery : public Queries::EqualityQuery<int, TargetPtr, true> {
res->d_description = this->d_description;
return res;
}

const std::string &getPropName() const { return propname; }
};

typedef Queries::EqualityQuery<int, Atom const *, true> ATOM_PROP_QUERY;
Expand Down
13 changes: 13 additions & 0 deletions Code/GraphMol/catch_pickles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,16 @@ TEST_CASE("parsing old pickles with many features") {
CHECK(m2.getNumAtoms() == m->getNumAtoms());
CHECK(MolToCXSmiles(*m) == MolToCXSmiles(m2));
}

TEST_CASE("github #7675 : pickling HasProp queries") {
SECTION("basics") {
auto mol = "CC"_smarts;
REQUIRE(mol);
mol->getAtomWithIdx(0)->expandQuery(makeHasPropQuery<Atom>("foo"));
mol->getBondWithIdx(0)->expandQuery(makeHasPropQuery<Bond>("foo"));
std::string pkl;
MolPickler::pickleMol(*mol, pkl);
RWMol mol2(pkl);
REQUIRE(mol2.getAtomWithIdx(0)->hasQuery());
}
}
Loading

0 comments on commit 1790de8

Please sign in to comment.