Skip to content

Commit

Permalink
Merge pull request #177 from gkumbhat/fix_content_standalone_detection
Browse files Browse the repository at this point in the history
🐛 Fix content standalone endpoint response object
  • Loading branch information
gkumbhat authored Aug 23, 2024
2 parents de0aa8c + 8ac2ba3 commit 52bc871
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 8 deletions.
7 changes: 5 additions & 2 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use std::collections::HashMap;

use serde::{Deserialize, Serialize};

use crate::{clients::detector::ContextType, pb};
use crate::{
clients::detector::{ContentAnalysisResponse, ContextType},
pb,
};

/// Parameters relevant to each detector
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
Expand Down Expand Up @@ -351,7 +354,7 @@ impl TextContentDetectionHttpRequest {
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TextContentDetectionResult {
/// Detection results
pub detections: Vec<TokenClassificationResult>,
pub detections: Vec<ContentAnalysisResponse>,
}
/// Streaming classification result on text produced by a text generation model, containing
/// information from the original text generation output as well as the result of
Expand Down
97 changes: 91 additions & 6 deletions src/orchestrator/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use super::{
};
use crate::{
clients::detector::{
ContentAnalysisRequest, ContextDocsDetectionRequest, ContextType,
ContentAnalysisRequest, ContentAnalysisResponse, ContextDocsDetectionRequest, ContextType,
GenerationDetectionRequest,
},
models::{
Expand Down Expand Up @@ -221,19 +221,55 @@ impl Orchestrator {
request_id = ?task.request_id,
"handling text content detection task"
);

let ctx = self.ctx.clone();
let task_handle = tokio::spawn(async move {
let content = task.content.clone();
// No masking applied, so offset change is 0
let offset: usize = 0;
let text_with_offsets = [(offset, content)].to_vec();

let detectors = task.detectors.clone();

let chunker_ids = get_chunker_ids(&ctx, &detectors)?;
let chunks = chunk_task(&ctx, chunker_ids, text_with_offsets).await?;

// Call detectors
let detections = input_detection_task(&ctx, &detectors, content.clone(), None).await?;
debug!(?detections);
let detections = try_join_all(
task.detectors
.iter()
.map(|(detector_id, detector_params)| {
let ctx = ctx.clone();
let detector_id = detector_id.clone();
let detector_params = detector_params.clone();
let detector_config = ctx.config.detectors.get(&detector_id).unwrap();

let chunker_id = detector_config.chunker_id.as_str();

let default_threshold = detector_config.default_threshold;

let chunk = chunks.get(chunker_id).unwrap().clone();

async move {
detect_content(
ctx,
detector_id,
default_threshold,
detector_params,
chunk,
)
.await
}
})
.collect::<Vec<_>>(),
)
.await?
.into_iter()
.flatten()
.collect::<Vec<_>>();

// Send result with detections
Ok(TextContentDetectionResult {
detections: detections.unwrap_or(vec![]),
})
Ok(TextContentDetectionResult { detections })
});
match task_handle.await {
// Task completed successfully
Expand Down Expand Up @@ -519,6 +555,55 @@ pub async fn detect(
Ok::<Vec<TokenClassificationResult>, Error>(results)
}

/// Sends a request to a detector service and applies threshold.
/// TODO: Cleanup by removing duplicate code and merging it with above `detect` function
#[instrument(skip_all)]
pub async fn detect_content(
ctx: Arc<Context>,
detector_id: String,
default_threshold: f64,
detector_params: DetectorParams,
chunks: Vec<Chunk>,
) -> Result<Vec<ContentAnalysisResponse>, Error> {
let detector_id = detector_id.clone();
let threshold = detector_params.threshold().unwrap_or(default_threshold);
let contents: Vec<_> = chunks.iter().map(|chunk| chunk.text.clone()).collect();
let response = if contents.is_empty() {
// skip detector call as contents is empty
Vec::default()
} else {
let request = ContentAnalysisRequest::new(contents);
debug!(%detector_id, ?request, "sending detector request");
ctx.detector_client
.text_contents(&detector_id, request)
.await
.map_err(|error| {
debug!(%detector_id, ?error, "error received from detector");
Error::DetectorRequestFailed {
id: detector_id.clone(),
error,
}
})?
};
debug!(%detector_id, ?response, "received detector response");
if chunks.len() != response.len() {
return Err(Error::Other(format!(
"Detector {detector_id} did not return expected number of responses"
)));
}
let results = chunks
.into_iter()
.zip(response)
.flat_map(|(_chunk, response)| {
response
.into_iter()
.filter_map(|resp| (resp.score >= threshold).then_some(resp))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
Ok::<Vec<ContentAnalysisResponse>, Error>(results)
}

/// Calls a detector that implements the /api/v1/text/generation endpoint
pub async fn detect_for_generation(
ctx: Arc<Context>,
Expand Down

0 comments on commit 52bc871

Please sign in to comment.