Skip to content

Commit

Permalink
Update for CRAN submission
Browse files Browse the repository at this point in the history
  • Loading branch information
daehwankim12 committed Oct 16, 2024
1 parent 140c562 commit 5b55336
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 66 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ BugReports: https://github.com/daehwankim12/fastmed/issues
License: MIT + file LICENSE
Imports:
Rcpp,
RcppEigen,
RcppParallel,
data.table
LinkingTo:
Expand All @@ -21,6 +20,7 @@ LinkingTo:
RcppParallel
Suggests:
rmarkdown,
withr,
testthat (>= 3.0.0)
Encoding: UTF-8
RoxygenNote: 7.3.2
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
YEAR: 2024
COPYRIGHT HOLDER: fastmed authors
COPYRIGHT HOLDER: Daehwan Kim
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Generated by roxygen2: do not edit by hand

export(mediation_analysis)
import(Rcpp)
import(data.table)
importFrom(Rcpp,evalCpp)
importFrom(RcppParallel,setThreadOptions)
useDynLib(fastmed, .registration = TRUE)
3 changes: 2 additions & 1 deletion R/mediation_analysis.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
#' )
#' }
#'
#' @import data.table
#' @import data.table Rcpp
#' @importFrom Rcpp evalCpp
#' @importFrom RcppParallel setThreadOptions
#' @useDynLib fastmed, .registration = TRUE
#' @export
Expand Down
Binary file modified src/fastmed.so
Binary file not shown.
77 changes: 54 additions & 23 deletions src/mediation_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,26 +82,26 @@ struct MediationWorker : public Worker {
const std::vector<std::string> mediator_cols;
const std::vector<std::string> outcome_cols;
const int nrep;
const std::string output_file;
const Rcpp::DataFrame combinations;
const std::string output_file;

// Constructor
MediationWorker(const NumericMatrix data,
const CharacterVector column_names,
const CharacterVector exposure_cols,
const CharacterVector mediator_cols,
const CharacterVector outcome_cols,
int nrep,
std::string output_file,
Rcpp::DataFrame combinations)
: data(data),
column_names(Rcpp::as<std::vector<std::string>>(column_names)),
exposure_cols(Rcpp::as<std::vector<std::string>>(exposure_cols)),
mediator_cols(Rcpp::as<std::vector<std::string>>(mediator_cols)),
outcome_cols(Rcpp::as<std::vector<std::string>>(outcome_cols)),
nrep(nrep),
output_file(output_file),
combinations(combinations) {}
MediationWorker(const NumericMatrix data_,
const CharacterVector column_names_,
const CharacterVector exposure_cols_,
const CharacterVector mediator_cols_,
const CharacterVector outcome_cols_,
int nrep_,
std::string output_file_,
Rcpp::DataFrame combinations_)
: data(data_),
column_names(Rcpp::as<std::vector<std::string>>(column_names_)),
exposure_cols(Rcpp::as<std::vector<std::string>>(exposure_cols_)),
mediator_cols(Rcpp::as<std::vector<std::string>>(mediator_cols_)),
outcome_cols(Rcpp::as<std::vector<std::string>>(outcome_cols_)),
nrep(nrep_),
output_file(output_file_),
combinations(combinations_) {}

// Function call operator that work for the specified range (begin/end)
void operator()(std::size_t begin, std::size_t end) {
Expand All @@ -120,6 +120,11 @@ struct MediationWorker : public Worker {
Rcpp::CharacterVector mediator_vec = combinations["mediator"];
Rcpp::CharacterVector outcome_vec = combinations["outcome"];

// Local buffer to accumulate results
std::vector<std::string> local_buffer;
const size_t buffer_size = 100;


for (std::size_t idx = begin; idx < end; ++idx) {
std::string exposure_col = Rcpp::as<std::string>(exposure_vec[idx]);
std::string mediator_col = Rcpp::as<std::string>(mediator_vec[idx]);
Expand Down Expand Up @@ -219,14 +224,40 @@ struct MediationWorker : public Worker {
<< percentile_cpp(tot_samples, 2.5) << "," << percentile_cpp(tot_samples, 97.5) << ","
<< p_value_cpp(tot_samples, 0.0) << "\n";

// Write results to file (thread-safe)
{
std::lock_guard<std::mutex> lock(file_mutex);
outfile << result.str();
}
// Add result to local buffer
local_buffer.push_back(result.str());

// If buffer is full, write to file
if (local_buffer.size() >= buffer_size) {
// Lock mutex and write buffer to file
{
std::lock_guard<std::mutex> lock(file_mutex);
std::ofstream outfile(output_file, std::ios::app);
if (!outfile.is_open()) {
throw std::runtime_error("Unable to open output file");
}
for (const auto& line : local_buffer) {
outfile << line;
}
outfile.close();
}
// Clear local buffer
local_buffer.clear();
}
}

outfile.close();
// After processing all combinations, write any remaining results in the buffer
if (!local_buffer.empty()) {
std::lock_guard<std::mutex> lock(file_mutex);
std::ofstream outfile(output_file, std::ios::app);
if (!outfile.is_open()) {
throw std::runtime_error("Unable to open output file");
}
for (const auto& line : local_buffer) {
outfile << line;
}
outfile.close();
}
} catch (const std::exception& e) {
Rcpp::stop(std::string("Error in worker: ") + e.what());
}
Expand Down
Binary file modified src/mediation_analysis.o
Binary file not shown.
101 changes: 61 additions & 40 deletions tests/testthat/test-mediation_analysis.R
Original file line number Diff line number Diff line change
@@ -1,102 +1,123 @@
# tests/testthat/test-mediation_analysis.R
test_that("mediation_analysis works correctly with default prefixes", {
# Create example data
set.seed(123)
test_data_default <- data.table(
test_data_default <- data.table::data.table(
CECUM1 = rnorm(50),
CECUM2 = rnorm(50),
SERUM1 = rnorm(50),
SERUM2 = rnorm(50),
CORTEX1 = rnorm(50),
CORTEX2 = rnorm(50)
)

# Temporary output CSV file
output_csv <- tempfile(fileext = ".csv")

# Use local_tempfile to create a temporary output CSV file
output_csv <- withr::local_tempfile(fileext = ".csv")
# Run mediation_analysis function
mediation_analysis(
data = test_data_default,
columns = list(exposure = c("CECUM"), mediator = c("SERUM"), outcome = c("CORTEX")),
nrep = 100, # Reduced for faster testing
output_file = output_csv,
num_threads = 2 # Adjust thread count for testing environment
output_file = output_csv, # Use the temporary file name
num_threads = 1 # Adjust thread count for testing environment
)

# Read results
results <- fread(output_csv)

results <- data.table::fread(output_csv)
# Expected result: 2 (CECUM) * 2 (SERUM) * 2 (CORTEX) = 8 combinations
expect_equal(nrow(results), 8)

# Remove temporary file
file.remove(output_csv)
})

test_that("mediation_analysis works correctly with custom prefixes", {
# Create example data
set.seed(123)
test_data_custom <- data.table(
test_data_custom <- data.table::data.table(
EXP1 = rnorm(50),
EXP2 = rnorm(50),
MED1 = rnorm(50),
MED2 = rnorm(50),
OUT1 = rnorm(50),
OUT2 = rnorm(50)
)

# Temporary output CSV file
output_csv <- tempfile(fileext = ".csv")

# Use local_tempfile to create a temporary output CSV file
output_csv <- withr::local_tempfile(fileext = ".csv")
# Run mediation_analysis function
mediation_analysis(
data = test_data_custom,
columns = list(exposure = c("EXP"), mediator = c("MED"), outcome = c("OUT")),
nrep = 100, # Reduced for faster testing
output_file = output_csv,
num_threads = 2 # Adjust thread count for testing environment
output_file = output_csv, # Use the temporary file name
num_threads = 1 # Adjust thread count for testing environment
)

# Read results
results <- fread(output_csv)

results <- data.table::fread(output_csv)
# Expected result: 2 (EXP) * 2 (MED) * 2 (OUT) = 8 combinations
expect_equal(nrow(results), 8)

# Remove temporary file
file.remove(output_csv)
})

test_that("mediation_analysis handles large datasets correctly", {
# Create large example data
set.seed(123)
test_data_large <- data.table(
test_data_large <- data.table::data.table(
CECUM1 = rnorm(1000),
CECUM2 = rnorm(1000),
SERUM1 = rnorm(1000),
SERUM2 = rnorm(1000),
CORTEX1 = rnorm(1000),
CORTEX2 = rnorm(1000)
)

# Temporary output CSV file
output_csv <- tempfile(fileext = ".csv")

# Use local_tempfile to create a temporary output CSV file
output_csv <- withr::local_tempfile(fileext = ".csv")
# Run mediation_analysis function
mediation_analysis(
data = test_data_large,
columns = list(exposure = c("CECUM"), mediator = c("SERUM"), outcome = c("CORTEX")),
nrep = 100, # Reduced for faster testing
output_file = output_csv,
num_threads = 2 # Adjust thread count for testing environment
output_file = output_csv, # Use the temporary file name
num_threads = 1 # Adjust thread count for testing environment
)

# Read results
results <- fread(output_csv)

results <- data.table::fread(output_csv)
# Expected result: 2 (CECUM) * 2 (SERUM) * 2 (CORTEX) = 8 combinations
expect_equal(nrow(results), 8)

# Remove temporary file
file.remove(output_csv)
})

test_that("mediation_analysis works correctly with multiple threads", {
# Create large example data
set.seed(123)
test_data_large <- data.table::data.table(
CECUM1 = rnorm(1000),
CECUM2 = rnorm(1000),
SERUM1 = rnorm(1000),
SERUM2 = rnorm(1000),
CORTEX1 = rnorm(1000),
CORTEX2 = rnorm(1000)
)

# Use local_tempfile to create a temporary output CSV file
output_csv <- withr::local_tempfile(fileext = ".csv")

# Run mediation_analysis function
mediation_analysis(
data = test_data_large,
columns = list(exposure = c("CECUM"), mediator = c("SERUM"), outcome = c("CORTEX")),
nrep = 100, # Reduced for faster testing
output_file = output_csv, # Use the temporary file name
num_threads = 4 # Adjust thread count for testing environment
)

# Read results
results <- data.table::fread(output_csv)

# Expected result: 2 (CECUM) * 2 (SERUM) * 2 (CORTEX) = 8 combinations
expect_equal(nrow(results), 8)
})

0 comments on commit 5b55336

Please sign in to comment.