-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
288 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
## Quick Start | ||
|
||
```shell | ||
cargo run -r -F cuda --example owlv2 -- --device cuda:0 --dtype fp16 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
use anyhow::Result; | ||
use usls::{models::OWLv2, Annotator, DataLoader, Options}; | ||
|
||
#[derive(argh::FromArgs)] | ||
/// Example | ||
struct Args { | ||
/// dtype | ||
#[argh(option, default = "String::from(\"auto\")")] | ||
dtype: String, | ||
|
||
/// device | ||
#[argh(option, default = "String::from(\"cpu:0\")")] | ||
device: String, | ||
|
||
/// source image | ||
#[argh(option, default = "vec![String::from(\"./assets/bus.jpg\")]")] | ||
source: Vec<String>, | ||
|
||
/// open class names | ||
#[argh( | ||
option, | ||
default = "vec![ | ||
String::from(\"person\"), | ||
String::from(\"hand\"), | ||
String::from(\"shoes\"), | ||
String::from(\"bus\"), | ||
String::from(\"car\"), | ||
String::from(\"dog\"), | ||
String::from(\"cat\"), | ||
String::from(\"sign\"), | ||
String::from(\"tie\"), | ||
String::from(\"monitor\"), | ||
String::from(\"glasses\"), | ||
String::from(\"tree\"), | ||
String::from(\"head\"), | ||
]" | ||
)] | ||
labels: Vec<String>, | ||
} | ||
|
||
fn main() -> Result<()> { | ||
tracing_subscriber::fmt() | ||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) | ||
.with_timer(tracing_subscriber::fmt::time::ChronoLocal::rfc_3339()) | ||
.init(); | ||
let args: Args = argh::from_env(); | ||
|
||
// options | ||
let options = Options::owlv2_base_ensemble() | ||
// owlv2_base() | ||
.with_model_dtype(args.dtype.as_str().try_into()?) | ||
.with_model_device(args.device.as_str().try_into()?) | ||
.with_class_names(&args.labels.iter().map(|x| x.as_str()).collect::<Vec<_>>()) | ||
.commit()?; | ||
let mut model = OWLv2::new(options)?; | ||
|
||
// load | ||
let xs = DataLoader::try_read_batch(&args.source)?; | ||
|
||
// run | ||
let ys = model.forward(&xs)?; | ||
|
||
// annotate | ||
let annotator = Annotator::default() | ||
.with_bboxes_thickness(3) | ||
.with_saveout(model.spec()); | ||
annotator.annotate(&xs, &ys); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# OWLv2: Scaling Open-Vocabulary Object Detection | ||
|
||
## Official Repository | ||
|
||
The official repository can be found on: [Hugging Face](https://huggingface.co/google/owlv2-base-patch16-ensemble) | ||
|
||
## Example | ||
|
||
Refer to the [example](../../../examples/owlv2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/// Model configuration for `OWLv2` | ||
impl crate::Options { | ||
pub fn owlv2() -> Self { | ||
Self::default() | ||
.with_model_name("owlv2") | ||
.with_model_kind(crate::Kind::VisionLanguage) | ||
// 1st & 3rd: text | ||
.with_model_ixx(0, 0, (1, 1, 1).into()) // TODO | ||
.with_model_ixx(0, 1, 1.into()) | ||
.with_model_ixx(2, 0, (1, 1, 1).into()) | ||
.with_model_ixx(2, 1, 1.into()) | ||
.with_model_max_length(16) | ||
// 2nd: image | ||
.with_model_ixx(1, 0, (1, 1, 1).into()) | ||
.with_model_ixx(1, 1, 3.into()) | ||
.with_model_ixx(1, 2, 960.into()) | ||
.with_model_ixx(1, 3, 960.into()) | ||
.with_image_mean(&[0.48145466, 0.4578275, 0.40821073]) | ||
.with_image_std(&[0.26862954, 0.261_302_6, 0.275_777_1]) | ||
.with_resize_mode(crate::ResizeMode::FitAdaptive) | ||
.with_normalize(true) | ||
.with_class_confs(&[0.1]) | ||
.with_model_num_dry_run(0) | ||
} | ||
|
||
pub fn owlv2_base() -> Self { | ||
Self::owlv2().with_model_file("base-patch16.onnx") | ||
} | ||
|
||
pub fn owlv2_base_ensemble() -> Self { | ||
Self::owlv2().with_model_file("base-patch16-ensemble.onnx") | ||
} | ||
|
||
pub fn owlv2_base_ft() -> Self { | ||
Self::owlv2().with_model_file("base-patch16-ft.onnx") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
use aksr::Builder; | ||
use anyhow::Result; | ||
use image::DynamicImage; | ||
use ndarray::{s, Axis}; | ||
use rayon::prelude::*; | ||
|
||
use crate::{elapsed, Bbox, DynConf, Engine, Options, Processor, Ts, Xs, Ys, X, Y}; | ||
|
||
#[derive(Debug, Builder)] | ||
pub struct OWLv2 { | ||
engine: Engine, | ||
height: usize, | ||
width: usize, | ||
batch: usize, | ||
names: Vec<String>, | ||
names_with_prompt: Vec<String>, | ||
confs: DynConf, | ||
ts: Ts, | ||
processor: Processor, | ||
spec: String, | ||
input_ids: X, | ||
attention_mask: X, | ||
} | ||
|
||
impl OWLv2 { | ||
pub fn new(options: Options) -> Result<Self> { | ||
let engine = options.to_engine()?; | ||
let (batch, height, width, ts) = ( | ||
engine.batch().opt(), | ||
engine.try_height().unwrap_or(&960.into()).opt(), | ||
engine.try_width().unwrap_or(&960.into()).opt(), | ||
engine.ts.clone(), | ||
); | ||
let spec = engine.spec().to_owned(); | ||
let processor = options | ||
.to_processor()? | ||
.with_image_width(width as _) | ||
.with_image_height(height as _); | ||
let names: Vec<String> = options | ||
.class_names() | ||
.expect("No class names specified.") | ||
.iter() | ||
.map(|x| x.to_string()) | ||
.collect(); | ||
let names_with_prompt: Vec<String> = | ||
names.iter().map(|x| format!("a photo of {}", x)).collect(); | ||
let n = names.len(); | ||
let confs = DynConf::new(options.class_confs(), n); | ||
let input_ids: Vec<f32> = processor | ||
.encode_texts_ids( | ||
&names_with_prompt | ||
.iter() | ||
.map(|x| x.as_str()) | ||
.collect::<Vec<_>>(), | ||
false, | ||
)? | ||
.into_iter() | ||
.flatten() | ||
.collect(); | ||
let input_ids: X = ndarray::Array2::from_shape_vec((n, input_ids.len() / n), input_ids)? | ||
.into_dyn() | ||
.into(); | ||
let attention_mask = X::ones_like(&input_ids); | ||
|
||
Ok(Self { | ||
engine, | ||
height, | ||
width, | ||
batch, | ||
spec, | ||
names, | ||
names_with_prompt, | ||
confs, | ||
ts, | ||
processor, | ||
input_ids, | ||
attention_mask, | ||
}) | ||
} | ||
|
||
fn preprocess(&mut self, xs: &[DynamicImage]) -> Result<Xs> { | ||
let image_embeddings = self.processor.process_images(xs)?; | ||
let xs = Xs::from(vec![ | ||
self.input_ids.clone(), | ||
image_embeddings, | ||
self.attention_mask.clone(), | ||
]); | ||
|
||
Ok(xs) | ||
} | ||
|
||
fn inference(&mut self, xs: Xs) -> Result<Xs> { | ||
self.engine.run(xs) | ||
} | ||
|
||
pub fn forward(&mut self, xs: &[DynamicImage]) -> Result<Ys> { | ||
let ys = elapsed!("preprocess", self.ts, { self.preprocess(xs)? }); | ||
let ys = elapsed!("inference", self.ts, { self.inference(ys)? }); | ||
let ys = elapsed!("postprocess", self.ts, { self.postprocess(ys)? }); | ||
|
||
Ok(ys) | ||
} | ||
|
||
fn postprocess(&mut self, xs: Xs) -> Result<Ys> { | ||
let ys: Vec<Y> = xs[0] | ||
.axis_iter(Axis(0)) | ||
.into_par_iter() | ||
.zip(xs[1].axis_iter(Axis(0)).into_par_iter()) | ||
.enumerate() | ||
.filter_map(|(idx, (clss, bboxes))| { | ||
let (image_height, image_width) = self.processor.image0s_size[idx]; | ||
let ratio = image_height.max(image_width) as f32; | ||
let y_bboxes: Vec<Bbox> = clss | ||
.axis_iter(Axis(0)) | ||
.into_par_iter() | ||
.enumerate() | ||
.filter_map(|(i, clss_)| { | ||
let (class_id, &confidence) = clss_ | ||
.into_iter() | ||
.enumerate() | ||
.max_by(|a, b| a.1.total_cmp(b.1))?; | ||
|
||
let confidence = 1. / ((-confidence).exp() + 1.); | ||
if confidence < self.confs[class_id] { | ||
return None; | ||
} | ||
|
||
let bbox = bboxes.slice(s![i, ..]).mapv(|x| x * ratio); | ||
let (x, y, w, h) = ( | ||
(bbox[0] - bbox[2] / 2.).max(0.0f32), | ||
(bbox[1] - bbox[3] / 2.).max(0.0f32), | ||
bbox[2], | ||
bbox[3], | ||
); | ||
|
||
Some( | ||
Bbox::default() | ||
.with_xywh(x, y, w, h) | ||
.with_confidence(confidence) | ||
.with_id(class_id as isize) | ||
.with_name(&self.names[class_id]), | ||
) | ||
}) | ||
.collect(); | ||
|
||
Some(Y::default().with_bboxes(&y_bboxes)) | ||
}) | ||
.collect(); | ||
|
||
Ok(ys.into()) | ||
} | ||
|
||
pub fn summary(&mut self) { | ||
self.ts.summary(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
mod config; | ||
mod r#impl; | ||
|
||
pub use r#impl::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters