diff --git a/DESCRIPTION b/DESCRIPTION index 990da96..f39bae7 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -12,7 +12,6 @@ BugReports: https://github.com/daehwankim12/fastmed/issues License: MIT + file LICENSE Imports: Rcpp, - RcppEigen, RcppParallel, data.table LinkingTo: @@ -21,6 +20,7 @@ LinkingTo: RcppParallel Suggests: rmarkdown, + withr, testthat (>= 3.0.0) Encoding: UTF-8 RoxygenNote: 7.3.2 diff --git a/LICENSE b/LICENSE index 7391ef9..de849e2 100644 --- a/LICENSE +++ b/LICENSE @@ -1,2 +1,2 @@ YEAR: 2024 -COPYRIGHT HOLDER: fastmed authors +COPYRIGHT HOLDER: Daehwan Kim diff --git a/NAMESPACE b/NAMESPACE index 776f160..d6ac0db 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,8 @@ # Generated by roxygen2: do not edit by hand export(mediation_analysis) +import(Rcpp) import(data.table) +importFrom(Rcpp,evalCpp) importFrom(RcppParallel,setThreadOptions) useDynLib(fastmed, .registration = TRUE) diff --git a/R/mediation_analysis.R b/R/mediation_analysis.R index 80f99a7..6dbd7d0 100644 --- a/R/mediation_analysis.R +++ b/R/mediation_analysis.R @@ -39,7 +39,8 @@ #' ) #' } #' -#' @import data.table +#' @import data.table Rcpp +#' @importFrom Rcpp evalCpp #' @importFrom RcppParallel setThreadOptions #' @useDynLib fastmed, .registration = TRUE #' @export diff --git a/src/fastmed.so b/src/fastmed.so index 2ca32b6..5f4e964 100755 Binary files a/src/fastmed.so and b/src/fastmed.so differ diff --git a/src/mediation_analysis.cpp b/src/mediation_analysis.cpp index 62862a4..96402c5 100644 --- a/src/mediation_analysis.cpp +++ b/src/mediation_analysis.cpp @@ -82,26 +82,26 @@ struct MediationWorker : public Worker { const std::vector mediator_cols; const std::vector outcome_cols; const int nrep; - const std::string output_file; const Rcpp::DataFrame combinations; + const std::string output_file; // Constructor - MediationWorker(const NumericMatrix data, - const CharacterVector column_names, - const CharacterVector exposure_cols, - const CharacterVector mediator_cols, - const CharacterVector outcome_cols, - int nrep, - std::string output_file, - Rcpp::DataFrame combinations) - : data(data), - column_names(Rcpp::as>(column_names)), - exposure_cols(Rcpp::as>(exposure_cols)), - mediator_cols(Rcpp::as>(mediator_cols)), - outcome_cols(Rcpp::as>(outcome_cols)), - nrep(nrep), - output_file(output_file), - combinations(combinations) {} + MediationWorker(const NumericMatrix data_, + const CharacterVector column_names_, + const CharacterVector exposure_cols_, + const CharacterVector mediator_cols_, + const CharacterVector outcome_cols_, + int nrep_, + std::string output_file_, + Rcpp::DataFrame combinations_) + : data(data_), + column_names(Rcpp::as>(column_names_)), + exposure_cols(Rcpp::as>(exposure_cols_)), + mediator_cols(Rcpp::as>(mediator_cols_)), + outcome_cols(Rcpp::as>(outcome_cols_)), + nrep(nrep_), + output_file(output_file_), + combinations(combinations_) {} // Function call operator that work for the specified range (begin/end) void operator()(std::size_t begin, std::size_t end) { @@ -120,6 +120,11 @@ struct MediationWorker : public Worker { Rcpp::CharacterVector mediator_vec = combinations["mediator"]; Rcpp::CharacterVector outcome_vec = combinations["outcome"]; + // Local buffer to accumulate results + std::vector local_buffer; + const size_t buffer_size = 100; + + for (std::size_t idx = begin; idx < end; ++idx) { std::string exposure_col = Rcpp::as(exposure_vec[idx]); std::string mediator_col = Rcpp::as(mediator_vec[idx]); @@ -219,14 +224,40 @@ struct MediationWorker : public Worker { << percentile_cpp(tot_samples, 2.5) << "," << percentile_cpp(tot_samples, 97.5) << "," << p_value_cpp(tot_samples, 0.0) << "\n"; - // Write results to file (thread-safe) - { - std::lock_guard lock(file_mutex); - outfile << result.str(); - } + // Add result to local buffer + local_buffer.push_back(result.str()); + + // If buffer is full, write to file + if (local_buffer.size() >= buffer_size) { + // Lock mutex and write buffer to file + { + std::lock_guard lock(file_mutex); + std::ofstream outfile(output_file, std::ios::app); + if (!outfile.is_open()) { + throw std::runtime_error("Unable to open output file"); + } + for (const auto& line : local_buffer) { + outfile << line; + } + outfile.close(); + } + // Clear local buffer + local_buffer.clear(); + } } - outfile.close(); + // After processing all combinations, write any remaining results in the buffer + if (!local_buffer.empty()) { + std::lock_guard lock(file_mutex); + std::ofstream outfile(output_file, std::ios::app); + if (!outfile.is_open()) { + throw std::runtime_error("Unable to open output file"); + } + for (const auto& line : local_buffer) { + outfile << line; + } + outfile.close(); + } } catch (const std::exception& e) { Rcpp::stop(std::string("Error in worker: ") + e.what()); } diff --git a/src/mediation_analysis.o b/src/mediation_analysis.o index d442492..117b760 100644 Binary files a/src/mediation_analysis.o and b/src/mediation_analysis.o differ diff --git a/tests/testthat/test-mediation_analysis.R b/tests/testthat/test-mediation_analysis.R index 4837968..4e35658 100644 --- a/tests/testthat/test-mediation_analysis.R +++ b/tests/testthat/test-mediation_analysis.R @@ -1,8 +1,7 @@ -# tests/testthat/test-mediation_analysis.R test_that("mediation_analysis works correctly with default prefixes", { # Create example data set.seed(123) - test_data_default <- data.table( + test_data_default <- data.table::data.table( CECUM1 = rnorm(50), CECUM2 = rnorm(50), SERUM1 = rnorm(50), @@ -10,33 +9,30 @@ test_that("mediation_analysis works correctly with default prefixes", { CORTEX1 = rnorm(50), CORTEX2 = rnorm(50) ) - - # Temporary output CSV file - output_csv <- tempfile(fileext = ".csv") - + + # Use local_tempfile to create a temporary output CSV file + output_csv <- withr::local_tempfile(fileext = ".csv") + # Run mediation_analysis function mediation_analysis( data = test_data_default, columns = list(exposure = c("CECUM"), mediator = c("SERUM"), outcome = c("CORTEX")), nrep = 100, # Reduced for faster testing - output_file = output_csv, - num_threads = 2 # Adjust thread count for testing environment + output_file = output_csv, # Use the temporary file name + num_threads = 1 # Adjust thread count for testing environment ) - + # Read results - results <- fread(output_csv) - + results <- data.table::fread(output_csv) + # Expected result: 2 (CECUM) * 2 (SERUM) * 2 (CORTEX) = 8 combinations expect_equal(nrow(results), 8) - - # Remove temporary file - file.remove(output_csv) }) test_that("mediation_analysis works correctly with custom prefixes", { # Create example data set.seed(123) - test_data_custom <- data.table( + test_data_custom <- data.table::data.table( EXP1 = rnorm(50), EXP2 = rnorm(50), MED1 = rnorm(50), @@ -44,33 +40,30 @@ test_that("mediation_analysis works correctly with custom prefixes", { OUT1 = rnorm(50), OUT2 = rnorm(50) ) - - # Temporary output CSV file - output_csv <- tempfile(fileext = ".csv") - + + # Use local_tempfile to create a temporary output CSV file + output_csv <- withr::local_tempfile(fileext = ".csv") + # Run mediation_analysis function mediation_analysis( data = test_data_custom, columns = list(exposure = c("EXP"), mediator = c("MED"), outcome = c("OUT")), nrep = 100, # Reduced for faster testing - output_file = output_csv, - num_threads = 2 # Adjust thread count for testing environment + output_file = output_csv, # Use the temporary file name + num_threads = 1 # Adjust thread count for testing environment ) - + # Read results - results <- fread(output_csv) - + results <- data.table::fread(output_csv) + # Expected result: 2 (EXP) * 2 (MED) * 2 (OUT) = 8 combinations expect_equal(nrow(results), 8) - - # Remove temporary file - file.remove(output_csv) }) test_that("mediation_analysis handles large datasets correctly", { # Create large example data set.seed(123) - test_data_large <- data.table( + test_data_large <- data.table::data.table( CECUM1 = rnorm(1000), CECUM2 = rnorm(1000), SERUM1 = rnorm(1000), @@ -78,25 +71,53 @@ test_that("mediation_analysis handles large datasets correctly", { CORTEX1 = rnorm(1000), CORTEX2 = rnorm(1000) ) - - # Temporary output CSV file - output_csv <- tempfile(fileext = ".csv") - + + # Use local_tempfile to create a temporary output CSV file + output_csv <- withr::local_tempfile(fileext = ".csv") + # Run mediation_analysis function mediation_analysis( data = test_data_large, columns = list(exposure = c("CECUM"), mediator = c("SERUM"), outcome = c("CORTEX")), nrep = 100, # Reduced for faster testing - output_file = output_csv, - num_threads = 2 # Adjust thread count for testing environment + output_file = output_csv, # Use the temporary file name + num_threads = 1 # Adjust thread count for testing environment ) - + # Read results - results <- fread(output_csv) - + results <- data.table::fread(output_csv) + # Expected result: 2 (CECUM) * 2 (SERUM) * 2 (CORTEX) = 8 combinations expect_equal(nrow(results), 8) - - # Remove temporary file - file.remove(output_csv) }) + +test_that("mediation_analysis works correctly with multiple threads", { + # Create large example data + set.seed(123) + test_data_large <- data.table::data.table( + CECUM1 = rnorm(1000), + CECUM2 = rnorm(1000), + SERUM1 = rnorm(1000), + SERUM2 = rnorm(1000), + CORTEX1 = rnorm(1000), + CORTEX2 = rnorm(1000) + ) + + # Use local_tempfile to create a temporary output CSV file + output_csv <- withr::local_tempfile(fileext = ".csv") + + # Run mediation_analysis function + mediation_analysis( + data = test_data_large, + columns = list(exposure = c("CECUM"), mediator = c("SERUM"), outcome = c("CORTEX")), + nrep = 100, # Reduced for faster testing + output_file = output_csv, # Use the temporary file name + num_threads = 4 # Adjust thread count for testing environment + ) + + # Read results + results <- data.table::fread(output_csv) + + # Expected result: 2 (CECUM) * 2 (SERUM) * 2 (CORTEX) = 8 combinations + expect_equal(nrow(results), 8) +}) \ No newline at end of file