diff --git a/src/fastmed.so b/src/fastmed.so index 9759846..02e3fe3 100755 Binary files a/src/fastmed.so and b/src/fastmed.so differ diff --git a/src/mediation_analysis.cpp b/src/mediation_analysis.cpp index d226a9c..ddb462a 100644 --- a/src/mediation_analysis.cpp +++ b/src/mediation_analysis.cpp @@ -10,7 +10,11 @@ #include #include #include +#include +#include #include +#include +#include // [[Rcpp::depends(RcppEigen, RcppParallel)]] @@ -70,60 +74,102 @@ VectorXd linear_regression(const MatrixXd& X, const VectorXd& y) { return X.colPivHouseholderQr().solve(y); } -// Mutex for thread-safe file writing -std::mutex file_mutex; +// Thread-safe file handler +class FileHandler { +private: + std::ofstream file; + std::mutex mutex; + std::condition_variable cv; + std::atomic is_writing{false}; + std::atomic active_writers{0}; + +public: + FileHandler(const std::string& filename) : file(filename, std::ios::app) { + if (!file.is_open()) { + throw std::runtime_error("Unable to open output file"); + } + } + + ~FileHandler() { + waitForCompletion(); + if (file.is_open()) { + file.close(); + } + } + + void write(const std::string& data) { + std::unique_lock lock(mutex); + active_writers++; + is_writing = true; + file << data; + file.flush(); + active_writers--; + is_writing = (active_writers > 0); + cv.notify_all(); + } + + void waitForCompletion() { + std::unique_lock lock(mutex); + cv.wait(lock, [this] { return !is_writing; }); + } +}; // Worker class for parallel processing -struct MediationWorker : public Worker { - // Inputs - const RMatrix data; +class MediationWorker : public Worker { +private: + // Inputs (using shared_ptr for large data structures) + const std::shared_ptr data; const std::vector column_names; const std::vector exposure_cols; const std::vector mediator_cols; const std::vector outcome_cols; const int nrep; - const std::string output_file; const Rcpp::DataFrame combinations; - + std::shared_ptr file_handler; + + // Local buffer + std::vector local_buffer; + const size_t buffer_size = 1000; + +public: // Constructor - MediationWorker(const NumericMatrix data_, + MediationWorker(const std::shared_ptr data_, const CharacterVector column_names_, const CharacterVector exposure_cols_, const CharacterVector mediator_cols_, const CharacterVector outcome_cols_, int nrep_, - std::string output_file_, - Rcpp::DataFrame combinations_) + Rcpp::DataFrame combinations_, + std::shared_ptr file_handler_) : data(data_), column_names(Rcpp::as>(column_names_)), exposure_cols(Rcpp::as>(exposure_cols_)), mediator_cols(Rcpp::as>(mediator_cols_)), outcome_cols(Rcpp::as>(outcome_cols_)), nrep(nrep_), - output_file(output_file_), - combinations(combinations_) {} + combinations(combinations_), + file_handler(file_handler_) { + local_buffer.reserve(buffer_size); + } // Function call operator that work for the specified range (begin/end) void operator()(std::size_t begin, std::size_t end) { try { - std::ofstream outfile(output_file, std::ios::app); - if (!outfile.is_open()) { - throw std::runtime_error("Unable to open output file"); - } - // Random number generator std::mt19937 gen(std::random_device{}()); - std::uniform_int_distribution<> dis(0, data.nrow() - 1); + std::uniform_int_distribution<> dis(0, data->rows() - 1); // Get column vectors from combinations DataFrame Rcpp::CharacterVector exposure_vec = combinations["exposure"]; Rcpp::CharacterVector mediator_vec = combinations["mediator"]; Rcpp::CharacterVector outcome_vec = combinations["outcome"]; - // Local buffer to accumulate results - std::vector local_buffer; - const size_t buffer_size = 100; - + // Pre-allocate vectors for bootstrap samples + std::vector ind0_samples(nrep); + std::vector ind1_samples(nrep); + std::vector dir0_samples(nrep); + std::vector dir1_samples(nrep); + std::vector tot_samples(nrep); for (std::size_t idx = begin; idx < end; ++idx) { std::string exposure_col = Rcpp::as(exposure_vec[idx]); @@ -143,19 +189,16 @@ struct MediationWorker : public Worker { } // Extract data for current combination - VectorXd exposure(data.nrow()), mediator(data.nrow()), outcome(data.nrow()); - for (int row = 0; row < data.nrow(); ++row) { - exposure(row) = data(row, exp_idx); - mediator(row) = data(row, med_idx); - outcome(row) = data(row, out_idx); - } + VectorXd exposure = data->col(exp_idx); + VectorXd mediator = data->col(med_idx); + VectorXd outcome = data->col(out_idx); // Prepare design matrices - MatrixXd X_med(data.nrow(), 2); + MatrixXd X_med(data->rows(), 2); X_med.col(0).setOnes(); X_med.col(1) = exposure; - MatrixXd X_out(data.nrow(), 3); + MatrixXd X_out(data->rows(), 3); X_out.col(0).setOnes(); X_out.col(1) = mediator; X_out.col(2) = exposure; @@ -165,17 +208,10 @@ struct MediationWorker : public Worker { VectorXd beta_out = linear_regression(X_out, outcome); // Bootstrap - std::vector ind0_samples, ind1_samples, dir0_samples, dir1_samples, tot_samples; - ind0_samples.reserve(nrep); - ind1_samples.reserve(nrep); - dir0_samples.reserve(nrep); - dir1_samples.reserve(nrep); - tot_samples.reserve(nrep); - for (int rep = 0; rep < nrep; ++rep) { // Sample with replacement - VectorXd boot_exposure(data.nrow()), boot_mediator(data.nrow()), boot_outcome(data.nrow()); - for (int m = 0; m < data.nrow(); ++m) { + VectorXd boot_exposure(data->rows()), boot_mediator(data->rows()), boot_outcome(data->rows()); + for (int m = 0; m < data->rows(); ++m) { int sample_idx = dis(gen); boot_exposure[m] = exposure[sample_idx]; boot_mediator[m] = mediator[sample_idx]; @@ -194,17 +230,17 @@ struct MediationWorker : public Worker { double y01 = y00 + beta_out_boot[2]; double y11 = y10 + beta_out_boot[2]; - ind0_samples.push_back(y10 - y00); - ind1_samples.push_back(y11 - y01); - dir0_samples.push_back(y01 - y00); - dir1_samples.push_back(y11 - y10); - tot_samples.push_back(y11 - y00); + ind0_samples[rep] = y10 - y00; + ind1_samples[rep] = y11 - y01; + dir0_samples[rep] = y01 - y00; + dir1_samples[rep] = y11 - y10; + tot_samples[rep] = y11 - y00; } // Calculate statistics std::string combination = exposure_col + "_" + mediator_col + "_" + outcome_col; - // Prepare result string + // Prepare result string using stringstream for efficiency std::stringstream result; result << std::fixed << std::setprecision(6); result << combination << "," @@ -229,39 +265,30 @@ struct MediationWorker : public Worker { // If buffer is full, write to file if (local_buffer.size() >= buffer_size) { - // Lock mutex and write buffer to file - { - std::lock_guard lock(file_mutex); - std::ofstream outfile(output_file, std::ios::app); - if (!outfile.is_open()) { - throw std::runtime_error("Unable to open output file"); - } - for (const auto& line : local_buffer) { - outfile << line; - } - outfile.close(); - } - // Clear local buffer - local_buffer.clear(); + flush_buffer(); } } - // After processing all combinations, write any remaining results in the buffer + // After the loop, flush any remaining data if (!local_buffer.empty()) { - std::lock_guard lock(file_mutex); - std::ofstream outfile(output_file, std::ios::app); - if (!outfile.is_open()) { - throw std::runtime_error("Unable to open output file"); - } - for (const auto& line : local_buffer) { - outfile << line; - } - outfile.close(); + flush_buffer(); } } catch (const std::exception& e) { Rcpp::stop(std::string("Error in worker: ") + e.what()); } } + +private: + void flush_buffer() { + if (!local_buffer.empty()) { + std::string combined_buffer; + for (const auto& line : local_buffer) { + combined_buffer += line; + } + file_handler->write(combined_buffer); + local_buffer.clear(); + } + } }; // [[Rcpp::export]] @@ -274,9 +301,20 @@ void mediation_analysis_cpp(NumericMatrix data, int nrep, std::string output_file) { try { - MediationWorker worker(data, column_names, exposure_cols, mediator_cols, outcome_cols, nrep, output_file, combinations); + // Convert NumericMatrix to Eigen::MatrixXd and wrap in shared_ptr + std::shared_ptr data_eigen = std::make_shared(Rcpp::as>(data)); + + // Create FileHandler + std::shared_ptr file_handler = std::make_shared(output_file); + + // Create worker and run parallel processing + MediationWorker worker(data_eigen, column_names, exposure_cols, mediator_cols, outcome_cols, nrep, combinations, file_handler); parallelFor(0, combinations.nrows(), worker); + + // Ensure all writes are completed before function returns + file_handler->waitForCompletion(); + } catch (const std::exception& e) { Rcpp::stop(std::string("Error in mediation_analysis_cpp: ") + e.what()); } -} +} \ No newline at end of file diff --git a/src/mediation_analysis.o b/src/mediation_analysis.o index 2293707..55182cb 100644 Binary files a/src/mediation_analysis.o and b/src/mediation_analysis.o differ