Skip to content

Commit

Permalink
Update mediation_analysis.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
daehwankim12 committed Oct 17, 2024
1 parent 2caa5e1 commit 8ebeb5a
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 71 deletions.
Binary file modified src/fastmed.so
Binary file not shown.
180 changes: 109 additions & 71 deletions src/mediation_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
#include <fstream>
#include <iomanip>
#include <mutex>
#include <condition_variable>
#include <atomic>
#include <stdexcept>
#include <memory>
#include <sstream>

// [[Rcpp::depends(RcppEigen, RcppParallel)]]

Expand Down Expand Up @@ -70,60 +74,102 @@ VectorXd linear_regression(const MatrixXd& X, const VectorXd& y) {
return X.colPivHouseholderQr().solve(y);
}

// Mutex for thread-safe file writing
std::mutex file_mutex;
// Thread-safe file handler
class FileHandler {
private:
std::ofstream file;
std::mutex mutex;
std::condition_variable cv;
std::atomic<bool> is_writing{false};
std::atomic<int> active_writers{0};

public:
FileHandler(const std::string& filename) : file(filename, std::ios::app) {
if (!file.is_open()) {
throw std::runtime_error("Unable to open output file");
}
}

~FileHandler() {
waitForCompletion();
if (file.is_open()) {
file.close();
}
}

void write(const std::string& data) {
std::unique_lock<std::mutex> lock(mutex);
active_writers++;
is_writing = true;
file << data;
file.flush();
active_writers--;
is_writing = (active_writers > 0);
cv.notify_all();
}

void waitForCompletion() {
std::unique_lock<std::mutex> lock(mutex);
cv.wait(lock, [this] { return !is_writing; });
}
};

// Worker class for parallel processing
struct MediationWorker : public Worker {
// Inputs
const RMatrix<double> data;
class MediationWorker : public Worker {
private:
// Inputs (using shared_ptr for large data structures)
const std::shared_ptr<MatrixXd> data;
const std::vector<std::string> column_names;
const std::vector<std::string> exposure_cols;
const std::vector<std::string> mediator_cols;
const std::vector<std::string> outcome_cols;
const int nrep;
const std::string output_file;
const Rcpp::DataFrame combinations;

std::shared_ptr<FileHandler> file_handler;

// Local buffer
std::vector<std::string> local_buffer;
const size_t buffer_size = 1000;

public:
// Constructor
MediationWorker(const NumericMatrix data_,
MediationWorker(const std::shared_ptr<MatrixXd> data_,
const CharacterVector column_names_,
const CharacterVector exposure_cols_,
const CharacterVector mediator_cols_,
const CharacterVector outcome_cols_,
int nrep_,
std::string output_file_,
Rcpp::DataFrame combinations_)
Rcpp::DataFrame combinations_,
std::shared_ptr<FileHandler> file_handler_)
: data(data_),
column_names(Rcpp::as<std::vector<std::string>>(column_names_)),
exposure_cols(Rcpp::as<std::vector<std::string>>(exposure_cols_)),
mediator_cols(Rcpp::as<std::vector<std::string>>(mediator_cols_)),
outcome_cols(Rcpp::as<std::vector<std::string>>(outcome_cols_)),
nrep(nrep_),
output_file(output_file_),
combinations(combinations_) {}
combinations(combinations_),
file_handler(file_handler_) {
local_buffer.reserve(buffer_size);
}

// Function call operator that work for the specified range (begin/end)
void operator()(std::size_t begin, std::size_t end) {
try {
std::ofstream outfile(output_file, std::ios::app);
if (!outfile.is_open()) {
throw std::runtime_error("Unable to open output file");
}

// Random number generator
std::mt19937 gen(std::random_device{}());
std::uniform_int_distribution<> dis(0, data.nrow() - 1);
std::uniform_int_distribution<> dis(0, data->rows() - 1);

// Get column vectors from combinations DataFrame
Rcpp::CharacterVector exposure_vec = combinations["exposure"];
Rcpp::CharacterVector mediator_vec = combinations["mediator"];
Rcpp::CharacterVector outcome_vec = combinations["outcome"];

// Local buffer to accumulate results
std::vector<std::string> local_buffer;
const size_t buffer_size = 100;

// Pre-allocate vectors for bootstrap samples
std::vector<double> ind0_samples(nrep);
std::vector<double> ind1_samples(nrep);
std::vector<double> dir0_samples(nrep);
std::vector<double> dir1_samples(nrep);
std::vector<double> tot_samples(nrep);

for (std::size_t idx = begin; idx < end; ++idx) {
std::string exposure_col = Rcpp::as<std::string>(exposure_vec[idx]);
Expand All @@ -143,19 +189,16 @@ struct MediationWorker : public Worker {
}

// Extract data for current combination
VectorXd exposure(data.nrow()), mediator(data.nrow()), outcome(data.nrow());
for (int row = 0; row < data.nrow(); ++row) {
exposure(row) = data(row, exp_idx);
mediator(row) = data(row, med_idx);
outcome(row) = data(row, out_idx);
}
VectorXd exposure = data->col(exp_idx);
VectorXd mediator = data->col(med_idx);
VectorXd outcome = data->col(out_idx);

// Prepare design matrices
MatrixXd X_med(data.nrow(), 2);
MatrixXd X_med(data->rows(), 2);
X_med.col(0).setOnes();
X_med.col(1) = exposure;

MatrixXd X_out(data.nrow(), 3);
MatrixXd X_out(data->rows(), 3);
X_out.col(0).setOnes();
X_out.col(1) = mediator;
X_out.col(2) = exposure;
Expand All @@ -165,17 +208,10 @@ struct MediationWorker : public Worker {
VectorXd beta_out = linear_regression(X_out, outcome);

// Bootstrap
std::vector<double> ind0_samples, ind1_samples, dir0_samples, dir1_samples, tot_samples;
ind0_samples.reserve(nrep);
ind1_samples.reserve(nrep);
dir0_samples.reserve(nrep);
dir1_samples.reserve(nrep);
tot_samples.reserve(nrep);

for (int rep = 0; rep < nrep; ++rep) {
// Sample with replacement
VectorXd boot_exposure(data.nrow()), boot_mediator(data.nrow()), boot_outcome(data.nrow());
for (int m = 0; m < data.nrow(); ++m) {
VectorXd boot_exposure(data->rows()), boot_mediator(data->rows()), boot_outcome(data->rows());
for (int m = 0; m < data->rows(); ++m) {
int sample_idx = dis(gen);
boot_exposure[m] = exposure[sample_idx];
boot_mediator[m] = mediator[sample_idx];
Expand All @@ -194,17 +230,17 @@ struct MediationWorker : public Worker {
double y01 = y00 + beta_out_boot[2];
double y11 = y10 + beta_out_boot[2];

ind0_samples.push_back(y10 - y00);
ind1_samples.push_back(y11 - y01);
dir0_samples.push_back(y01 - y00);
dir1_samples.push_back(y11 - y10);
tot_samples.push_back(y11 - y00);
ind0_samples[rep] = y10 - y00;
ind1_samples[rep] = y11 - y01;
dir0_samples[rep] = y01 - y00;
dir1_samples[rep] = y11 - y10;
tot_samples[rep] = y11 - y00;
}

// Calculate statistics
std::string combination = exposure_col + "_" + mediator_col + "_" + outcome_col;

// Prepare result string
// Prepare result string using stringstream for efficiency
std::stringstream result;
result << std::fixed << std::setprecision(6);
result << combination << ","
Expand All @@ -229,39 +265,30 @@ struct MediationWorker : public Worker {

// If buffer is full, write to file
if (local_buffer.size() >= buffer_size) {
// Lock mutex and write buffer to file
{
std::lock_guard<std::mutex> lock(file_mutex);
std::ofstream outfile(output_file, std::ios::app);
if (!outfile.is_open()) {
throw std::runtime_error("Unable to open output file");
}
for (const auto& line : local_buffer) {
outfile << line;
}
outfile.close();
}
// Clear local buffer
local_buffer.clear();
flush_buffer();
}
}

// After processing all combinations, write any remaining results in the buffer
// After the loop, flush any remaining data
if (!local_buffer.empty()) {
std::lock_guard<std::mutex> lock(file_mutex);
std::ofstream outfile(output_file, std::ios::app);
if (!outfile.is_open()) {
throw std::runtime_error("Unable to open output file");
}
for (const auto& line : local_buffer) {
outfile << line;
}
outfile.close();
flush_buffer();
}
} catch (const std::exception& e) {
Rcpp::stop(std::string("Error in worker: ") + e.what());
}
}

private:
void flush_buffer() {
if (!local_buffer.empty()) {
std::string combined_buffer;
for (const auto& line : local_buffer) {
combined_buffer += line;
}
file_handler->write(combined_buffer);
local_buffer.clear();
}
}
};

// [[Rcpp::export]]
Expand All @@ -274,9 +301,20 @@ void mediation_analysis_cpp(NumericMatrix data,
int nrep,
std::string output_file) {
try {
MediationWorker worker(data, column_names, exposure_cols, mediator_cols, outcome_cols, nrep, output_file, combinations);
// Convert NumericMatrix to Eigen::MatrixXd and wrap in shared_ptr
std::shared_ptr<MatrixXd> data_eigen = std::make_shared<MatrixXd>(Rcpp::as<Eigen::Map<Eigen::MatrixXd>>(data));

// Create FileHandler
std::shared_ptr<FileHandler> file_handler = std::make_shared<FileHandler>(output_file);

// Create worker and run parallel processing
MediationWorker worker(data_eigen, column_names, exposure_cols, mediator_cols, outcome_cols, nrep, combinations, file_handler);
parallelFor(0, combinations.nrows(), worker);

// Ensure all writes are completed before function returns
file_handler->waitForCompletion();

} catch (const std::exception& e) {
Rcpp::stop(std::string("Error in mediation_analysis_cpp: ") + e.what());
}
}
}
Binary file modified src/mediation_analysis.o
Binary file not shown.

0 comments on commit 8ebeb5a

Please sign in to comment.