Skip to content

Commit

Permalink
batch update
Browse files Browse the repository at this point in the history
  • Loading branch information
jaytimm committed Aug 9, 2024
1 parent 1e0e267 commit 6b52f24
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 16 deletions.
9 changes: 6 additions & 3 deletions R/hollr.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#' @param openai_organization The organization ID for the OpenAI API (default is NULL).
#' @param cores The number of cores to use for parallel processing (default is 1).
#' @param batch_size The number of batch_size to process (default is 1, only for local models).
#' @param extract_json A logical indicating whether to extract and clean JSON strings from the response (default is TRUE).
#' @return A data.table containing the generated text and metadata.
#' @examples
#' \dontrun{
Expand Down Expand Up @@ -49,7 +50,8 @@ hollr <- function(id,
openai_api_key = Sys.getenv("OPENAI_API_KEY"),
openai_organization = NULL,
cores = 1,
batch_size = 1) {
batch_size = 1,
extract_json = TRUE) {

# Determine if the model is OpenAI or local
is_openai_model <- grepl("gpt-3.5-turbo|gpt-4|gpt-4o|gpt-4o-mini", model, ignore.case = TRUE)
Expand All @@ -65,7 +67,8 @@ hollr <- function(id,
max_new_tokens = max_new_tokens,
max_length = max_length,
system_message = system_message,
batch_size = batch_size))
batch_size = batch_size,
extract_json = extract_json))
}

# Prepare data
Expand Down Expand Up @@ -108,7 +111,7 @@ hollr <- function(id,
force_json,
max_attempts)

cleaned_response <- gsub("^```json|```$", "", validation_result$response)
cleaned_response <- gsub("^\\s*\\[\\s*\\{.*\\}\\s*\\]\\s*$", "", validation_result$response)

list(id = row$id,
annotator_id = row$annotator_id,
Expand Down
73 changes: 62 additions & 11 deletions R/hollr_local_batches.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#' @param max_length The maximum length of the input prompt (default is NULL).
#' @param system_message The message provided by the system (default is '').
#' @param batch_size The number of messages to process in each batch (default is 10).
#' @param extract_json A logical indicating whether to extract and clean JSON strings from the response (default is TRUE).
#' @return A list containing the generated responses for each batch.
#' @examples
#' \dontrun{
Expand All @@ -31,7 +32,8 @@ hollr_local_batches <- function(id,
max_new_tokens = 100,
max_length = NULL,
system_message = '',
batch_size = 10) {
batch_size = 10,
extract_json = TRUE) {
# Prepare data
text_df <- data.table::data.table(id = rep(id, annotators),
annotator_id = .generate_random_ids(annotators),
Expand All @@ -50,13 +52,8 @@ hollr_local_batches <- function(id,
# Create batches
batches <- split(text_df, ceiling(seq_along(text_df$user_message) / batch_size))

# Initialize a list to collect responses
all_responses <- list()

# Process each batch
for (i in seq_along(batches)) {
batch <- batches[[i]]

# Define the function to process each batch
process_batch <- function(batch) {
# Calculate the input length and set max_length if not provided
max_input_length <- max(nchar(batch$user_message))
if (is.null(max_length)) {
Expand All @@ -70,9 +67,63 @@ hollr_local_batches <- function(id,
max_length,
max_new_tokens)

# Collect responses ensuring they are properly structured
all_responses <- c(all_responses, list(response))
return(response)
}

return(all_responses)
# Process each batch using pbapply::pblapply
all_responses <- pbapply::pblapply(batches, process_batch)

# Convert the list of responses to a data.table
responses_sens <- data.table::data.table(id = text_df$id,
annotator_id = text_df$annotator_id,
response = unlist(all_responses))

if (extract_json) {

# Apply the jaxsn_compaction function to each element in the response column
responses_sens[, cleaned_json := sapply(response, .jaxsn_compaction, USE.NAMES = FALSE)]
} else {
responses_sens[, cleaned_json := response]
}

return(responses_sens)
}


.jaxsn_compaction <- function(text) {
# Find the first occurrence of '{' or '['
start <- regexpr("\\{|\\[", text)

if (start == -1) {
return(NA) # Return NA if no brackets are found
}

# Determine the opening bracket
opening_bracket <- substring(text, start, start)
closing_bracket <- ifelse(opening_bracket == "{", "}", "]")

# Initialize a counter for nested brackets
level <- 0
end <- start

# Loop through the text starting from the first bracket
for (i in seq(from = start, to = nchar(text))) {
char <- substring(text, i, i)

if (char == opening_bracket) {
level <- level + 1
} else if (char == closing_bracket) {
level <- level - 1
if (level == 0) {
end <- i
break
}
}
}

# Extract and return the JSON substring
json_string <- substring(text, start, end)
json_string <- trimws(json_string)

return(json_string)
}
5 changes: 4 additions & 1 deletion man/hollr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/hollr_local_batches.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 6b52f24

Please sign in to comment.