Skip to content

Commit

Permalink
Add callbacks to the multithreaded mol suppliers (rdkit#7763)
Browse files Browse the repository at this point in the history
* MultithreadedMolSuppliers can now be destroyed without being used.

This was previously not possible

* add callbacks to the multithreaded readers

* document the new functions

* switch to storing the queues in unique_ptrs

* only do those tests when in MT mode
  • Loading branch information
greglandrum authored Aug 29, 2024
1 parent 90c5334 commit 9dc263b
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 32 deletions.
5 changes: 5 additions & 0 deletions Code/GraphMol/FileParsers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,8 @@ rdkit_catch_test(v2MolSuppliers v2_suppliers_catch.cpp
rdkit_catch_test(v2FileParsersCatchTest v2_file_parsers_catch.cpp
LINK_LIBRARIES FileParsers)

if(RDK_TEST_MULTITHREADED AND RDK_BUILD_THREADSAFE_SSS)
rdkit_catch_test(multithreadedSupplierCatchTest multithreaded_supplier_catch.cpp
LINK_LIBRARIES FileParsers)
endif()

41 changes: 27 additions & 14 deletions Code/GraphMol/FileParsers/MultithreadedMolSupplier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,25 @@ MultithreadedMolSupplier::~MultithreadedMolSupplier() {
endThreads();
// destroy all objects in the input queue
d_inputQueue->clear();
// delete the pointer to the input queue
delete d_inputQueue;
std::tuple<RWMol*, std::string, unsigned int> r;
while (d_outputQueue->pop(r)) {
RWMol* m = std::get<0>(r);
delete m;
if (df_started) {
std::tuple<RWMol *, std::string, unsigned int> r;
while (d_outputQueue->pop(r)) {
RWMol *m = std::get<0>(r);
delete m;
}
}
// destroy all objects in the output queue
d_outputQueue->clear();
// delete the pointer to the output queue
delete d_outputQueue;
}

void MultithreadedMolSupplier::reader() {
std::string record;
unsigned int lineNum, index;
while (extractNextRecord(record, lineNum, index)) {
auto r = std::tuple<std::string, unsigned int, unsigned int>{
record, lineNum, index};
if (readCallback) {
record = readCallback(record, index);
}
auto r = std::make_tuple(record, lineNum, index);
d_inputQueue->push(r);
}
d_inputQueue->setDone();
Expand All @@ -47,12 +47,15 @@ void MultithreadedMolSupplier::writer() {
while (d_inputQueue->pop(r)) {
try {
auto mol = processMoleculeRecord(std::get<0>(r), std::get<1>(r));
auto temp = std::tuple<RWMol*, std::string, unsigned int>{
if (mol && writeCallback) {
writeCallback(*mol, std::get<0>(r), std::get<2>(r));
}
auto temp = std::tuple<RWMol *, std::string, unsigned int>{
mol, std::get<0>(r), std::get<2>(r)};
d_outputQueue->push(temp);
} catch (...) {
// fill the queue wih a null value
auto nullValue = std::tuple<RWMol*, std::string, unsigned int>{
auto nullValue = std::tuple<RWMol *, std::string, unsigned int>{
nullptr, std::get<0>(r), std::get<2>(r)};
d_outputQueue->push(nullValue);
}
Expand All @@ -65,19 +68,29 @@ void MultithreadedMolSupplier::writer() {
}

std::unique_ptr<RWMol> MultithreadedMolSupplier::next() {
std::tuple<RWMol*, std::string, unsigned int> r;
if (!df_started) {
df_started = true;
startThreads();
}
std::tuple<RWMol *, std::string, unsigned int> r;
if (d_outputQueue->pop(r)) {
d_lastItemText = std::get<1>(r);
d_lastRecordId = std::get<2>(r);
std::unique_ptr<RWMol> res{std::get<0>(r)};
if (res && nextCallback) {
nextCallback(*res, *this);
}
return res;
}
return nullptr;
}

void MultithreadedMolSupplier::endThreads() {
if (!df_started) {
return;
}
d_readerThread.join();
for (auto& thread : d_writerThreads) {
for (auto &thread : d_writerThreads) {
thread.join();
}
}
Expand Down
53 changes: 49 additions & 4 deletions Code/GraphMol/FileParsers/MultithreadedMolSupplier.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <RDGeneral/RDThreads.h>
#include <RDGeneral/StreamOps.h>

#include <functional>
#include <atomic>
#include <boost/tokenizer.hpp>

Expand All @@ -44,6 +45,7 @@ class RDKIT_FILEPARSERS_EXPORT MultithreadedMolSupplier : public MolSupplier {
~MultithreadedMolSupplier() override;
//! pop elements from the output queue
std::unique_ptr<RWMol> next() override;

//! returns true when all records have been read from the supplier
bool atEnd() override;

Expand All @@ -58,6 +60,40 @@ class RDKIT_FILEPARSERS_EXPORT MultithreadedMolSupplier : public MolSupplier {
//! returns the text block for the last extracted item
std::string getLastItemText() const;

//! sets the callback to be applied to molecules before they are returned by
///! the next() function
/*!
\param cb: a function that takes a reference to an RWMol and a const
reference to the MultithreadedMolSupplier. This can modify the molecule in
place
*/
template <typename T>
void setNextCallback(T cb) {
nextCallback = cb;
}
//! sets the callback to be applied to molecules after they are processed, but
///! before they are written to the output queue
/*!
\param cb: a function that takes a reference to an RWMol, a const reference
to the string record, and an unsigned int record id. This can modify the
molecule in place
*/
template <typename T>
void setWriteCallback(T cb) {
writeCallback = cb;
}
//! sets the callback to be applied to input text records before they are
///! added to the input queue
/*!
\param cb: a function that takes a const reference to the string record and
an unsigned int record id and returns the modified string record
*/
template <typename T>
void setReadCallback(T cb) {
readCallback = cb;
}

protected:
//! starts reader and writer threads
void startThreads();
Expand Down Expand Up @@ -92,16 +128,25 @@ class RDKIT_FILEPARSERS_EXPORT MultithreadedMolSupplier : public MolSupplier {
std::thread d_readerThread; //!< single reader thread

protected:
std::atomic<bool> df_started = false;
std::atomic<unsigned int> d_lastRecordId =
0; //!< stores last extracted record id
std::string d_lastItemText; //!< stores last extracted record
const unsigned int d_numReaderThread = 1; //!< number of reader thread

ConcurrentQueue<std::tuple<std::string, unsigned int, unsigned int>>
*d_inputQueue; //!< concurrent input queue
ConcurrentQueue<std::tuple<RWMol *, std::string, unsigned int>>
*d_outputQueue; //!< concurrent output queue
std::unique_ptr<
ConcurrentQueue<std::tuple<std::string, unsigned int, unsigned int>>>
d_inputQueue; //!< concurrent input queue
std::unique_ptr<
ConcurrentQueue<std::tuple<RWMol *, std::string, unsigned int>>>
d_outputQueue; //!< concurrent output queue
Parameters d_params;
std::function<void(RWMol &, const MultithreadedMolSupplier &)> nextCallback =
nullptr;
std::function<void(RWMol &, const std::string &, unsigned int)>
writeCallback = nullptr;
std::function<std::string(const std::string &, unsigned int)> readCallback =
nullptr;
};
} // namespace FileParsers
} // namespace v2
Expand Down
11 changes: 4 additions & 7 deletions Code/GraphMol/FileParsers/MultithreadedSDMolSupplier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ MultithreadedSDMolSupplier::MultithreadedSDMolSupplier(
dp_inStream = openAndCheckStream(fileName);
initFromSettings(true, params, parseParams);
POSTCONDITION(dp_inStream, "bad instream");
startThreads();
}

MultithreadedSDMolSupplier::MultithreadedSDMolSupplier(
Expand All @@ -31,13 +30,11 @@ MultithreadedSDMolSupplier::MultithreadedSDMolSupplier(
dp_inStream = inStream;
initFromSettings(takeOwnership, params, parseParams);
POSTCONDITION(dp_inStream, "bad instream");
startThreads();
}

MultithreadedSDMolSupplier::MultithreadedSDMolSupplier() {
dp_inStream = nullptr;
initFromSettings(false, d_params, d_parseParams);
startThreads();
}

void MultithreadedSDMolSupplier::initFromSettings(
Expand All @@ -47,12 +44,12 @@ void MultithreadedSDMolSupplier::initFromSettings(
d_params = params;
d_parseParams = parseParams;
d_params.numWriterThreads = getNumThreadsToUse(params.numWriterThreads);
d_inputQueue =
d_inputQueue.reset(
new ConcurrentQueue<std::tuple<std::string, unsigned int, unsigned int>>(
d_params.sizeInputQueue);
d_outputQueue =
d_params.sizeInputQueue));
d_outputQueue.reset(
new ConcurrentQueue<std::tuple<RWMol *, std::string, unsigned int>>(
d_params.sizeOutputQueue);
d_params.sizeOutputQueue));

df_end = false;
d_line = 0;
Expand Down
11 changes: 4 additions & 7 deletions Code/GraphMol/FileParsers/MultithreadedSmilesMolSupplier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ MultithreadedSmilesMolSupplier::MultithreadedSmilesMolSupplier(
CHECK_INVARIANT(!(dp_inStream->eof()), "early EOF");
// set df_takeOwnership = true
initFromSettings(true, params, parseParams);
startThreads();
POSTCONDITION(dp_inStream, "bad instream");
}

Expand All @@ -32,14 +31,12 @@ MultithreadedSmilesMolSupplier::MultithreadedSmilesMolSupplier(
CHECK_INVARIANT(!(inStream->eof()), "early EOF");
dp_inStream = inStream;
initFromSettings(takeOwnership, params, parseParams);
startThreads();
POSTCONDITION(dp_inStream, "bad instream");
}

MultithreadedSmilesMolSupplier::MultithreadedSmilesMolSupplier() {
dp_inStream = nullptr;
initFromSettings(true, d_params, d_parseParams);
startThreads();
}

MultithreadedSmilesMolSupplier::~MultithreadedSmilesMolSupplier() {
Expand All @@ -57,12 +54,12 @@ void MultithreadedSmilesMolSupplier::initFromSettings(
d_params = params;
d_parseParams = parseParams;
d_params.numWriterThreads = getNumThreadsToUse(d_params.numWriterThreads);
d_inputQueue =
d_inputQueue.reset(
new ConcurrentQueue<std::tuple<std::string, unsigned int, unsigned int>>(
d_params.sizeInputQueue);
d_outputQueue =
d_params.sizeInputQueue));
d_outputQueue.reset(
new ConcurrentQueue<std::tuple<RWMol *, std::string, unsigned int>>(
d_params.sizeOutputQueue);
d_params.sizeOutputQueue));
df_end = false;
d_line = -1;
}
Expand Down
Loading

0 comments on commit 9dc263b

Please sign in to comment.