Skip to content

Commit

Permalink
Remove need for async_thread
Browse files Browse the repository at this point in the history
  • Loading branch information
m1guelpf committed Dec 15, 2023
1 parent 0fc0954 commit 654a227
Show file tree
Hide file tree
Showing 19 changed files with 113 additions and 60 deletions.
3 changes: 1 addition & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 8 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ It's output should be interchangeable with [Replicate's own Cog](https://github.

## Highlights

- 📦 **Docker containers without the pain.** Writing your own `Dockerfile` can be a bewildering process. With Cog, you define your environment [inside your Cargo.toml](#how-it-works) and it generates a Docker image with all the best practices: Nvidia base images, efficient caching of dependencies, minimal image sizes, sensible environment variable defaults, and so on.
- 📦 **Docker containers without the pain.** Writing your own `Dockerfile` can be a bewildering process. With Cog, you define your environment [inside your Cargo.toml](#how-it-works) and it generates a Docker image with all the best practices: Nvidia base images, efficient caching of dependencies, minimal image sizes, sensible environment variable defaults, and so on.

- 🤬️ **No more CUDA hell.** Cog knows which CUDA/cuDNN/tch/tensorflow combos are compatible and will set it all up correctly for you.
- 🤬️ **No more CUDA hell.** Cog knows which CUDA/cuDNN/tch/tensorflow combos are compatible and will set it all up correctly for you.

- **Define the inputs and outputs for your model in Rust.** Then, Cog generates an OpenAPI schema and validates the inputs and outputs with JSONSchema.
-**Define the inputs and outputs for your model in Rust.** Then, Cog generates an OpenAPI schema and validates the inputs and outputs with JSONSchema.

- 🎁 **Automatic HTTP prediction server**: Your model's types are used to dynamically generate a RESTful HTTP API using [axum](https://github.com/tokio-rs/axum).
- 🎁 **Automatic HTTP prediction server**: Your model's types are used to dynamically generate a RESTful HTTP API using [axum](https://github.com/tokio-rs/axum).

- ☁️ **Cloud storage.** Files can be read and written directly to Amazon S3 and Google Cloud Storage. (Coming soon.)
- ☁️ **Cloud storage.** Files can be read and written directly to Amazon S3 and Google Cloud Storage. (Coming soon.)

- 🚀 **Ready for production.** Deploy your model anywhere that Docker images run. Your own infrastructure, or [Replicate](https://replicate.com).
- 🚀 **Ready for production.** Deploy your model anywhere that Docker images run. Your own infrastructure, or [Replicate](https://replicate.com).

## How it works

Expand All @@ -35,7 +35,6 @@ Define how predictions are run on your model on your `main.rs`:

```rust
use anyhow::Result;
use async_trait::async_trait;
use cog_rust::Cog;
use schemars::JsonSchema;
use std::collections::HashMap;
Expand All @@ -55,7 +54,6 @@ struct ResnetModel {
model: Box<dyn ModuleT + Send>,
}

#[async_trait]
impl Cog for ResnetModel {
type Request = ModelRequest;
type Response = HashMap<String, f64>;
Expand Down Expand Up @@ -118,8 +116,8 @@ As the non-Python ML ecosystem slowly flourishes (see [whisper.cpp](https://gith

## Prerequisites

- **macOS, Linux or Windows**. Cog works anywhere Rust works.
- **Docker**. Cog uses Docker to create a container for your model. You'll need to [install Docker](https://docs.docker.com/get-docker/) before you can run Cog.
- **macOS, Linux or Windows**. Cog works anywhere Rust works.
- **Docker**. Cog uses Docker to create a container for your model. You'll need to [install Docker](https://docs.docker.com/get-docker/) before you can run Cog.

## Install

Expand Down
2 changes: 1 addition & 1 deletion cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mime_guess = "2.0.4"
serde_json = "1.0.96"
webbrowser = "0.8.10"
cargo_metadata = "0.15.4"
cog-core = { path = "../core", version = "0.1.0" }
cog-core = { path = "../core", version = "0.2.0" }
clap = { version = "4.3.3", features = ["derive"] }
tokio = { version = "1.28.2", features = ["full"] }
reqwest = { version = "0.11.18", features = ["json"] }
Expand Down
4 changes: 2 additions & 2 deletions cli/src/commands/predict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ pub async fn handle(
let mut predictor = Predictor::new(image);

predictor.start().await;
predict_individual_inputs(&mut predictor, inputs, output).await;
predict_individual_inputs(&predictor, inputs, output).await;
}

async fn predict_individual_inputs(
predictor: &mut Predictor,
predictor: &Predictor,
inputs: Option<Vec<String>>,
mut output: Option<PathBuf>,
) {
Expand Down
4 changes: 2 additions & 2 deletions cli/src/docker/predictor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl Predictor {
serde_json::from_str::<SchemaObject>(
image
.as_array()
.and_then(|v| v.get(0))
.and_then(|v| v.first())
.and_then(|v| {
v.get("Config")
.and_then(Value::as_object)
Expand Down Expand Up @@ -143,7 +143,7 @@ impl Predictor {

let state = container
.as_array()
.and_then(|v| v.get(0))
.and_then(|v| v.first())
.and_then(Value::as_object)
.and_then(|v| v.get("State"))
.and_then(Value::as_object)
Expand Down
3 changes: 1 addition & 2 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "cog-core"
version = "0.1.1"
version = "0.2.0"
description = "Core types and traits for rust-cog, a Rust toolkit for machine learning."
readme = { workspace = true }
edition = { workspace = true }
Expand All @@ -15,7 +15,6 @@ anyhow = "1.0.71"
serde = "1.0.164"
thiserror = "1.0.40"
serde_json = "1.0.96"
async-trait = "0.1.68"
url = { version = "2.4.0", features = ["serde"] }
tokio = { version = "1.31.0", features = ["rt"] }
chrono = { version = "0.4.26", features = ["serde"] }
Expand Down
3 changes: 1 addition & 2 deletions core/src/http.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::collections::HashMap;

use chrono::{DateTime, Utc};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use url::Url;

#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
Expand Down
13 changes: 7 additions & 6 deletions core/src/spec.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use anyhow::Result;
use async_trait::async_trait;
use core::fmt::Debug;
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value;
use std::future::Future;

use crate::http::Request;

/// A Cog model
#[async_trait]
pub trait Cog: Sized + Send {
type Request: DeserializeOwned + JsonSchema + Send;
type Response: CogResponse + Debug + JsonSchema;
Expand All @@ -18,20 +17,22 @@ pub trait Cog: Sized + Send {
/// # Errors
///
/// Returns an error if setup fails.
async fn setup() -> Result<Self>;
fn setup() -> impl Future<Output = Result<Self>> + Send;

/// Run a prediction on the model
///
/// # Errors
///
/// Returns an error if the prediction fails.
fn predict(&self, input: Self::Request) -> Result<Self::Response>;
}

/// A response from a Cog model
#[async_trait]
pub trait CogResponse: Send {
/// Convert the response into a JSON value
async fn into_response(self, request: Request) -> Result<Value>;
fn into_response(self, request: Request) -> impl Future<Output = Result<Value>> + Send;
}

#[async_trait]
impl<T: Serialize + Send + 'static> CogResponse for T {
async fn into_response(self, _: Request) -> Result<Value> {
// We use spawn_blocking here to allow blocking code in serde Serialize impls (used in `Path`, for example).
Expand Down
1 change: 0 additions & 1 deletion examples/blur/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ serde = "1.0.163"
image = "0.24.6"
anyhow = "1.0.71"
schemars = "0.8.12"
async-trait = "0.1.68"
cog-rust = { path = "../../lib" }
tokio = { version = "1.28.2", features = ["full"] }

Expand Down
2 changes: 0 additions & 2 deletions examples/blur/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use anyhow::Result;
use async_trait::async_trait;
use cog_rust::{Cog, Path};
use schemars::JsonSchema;

Expand All @@ -13,7 +12,6 @@ struct ModelRequest {

struct BlurModel {}

#[async_trait]
impl Cog for BlurModel {
type Request = ModelRequest;
type Response = Path;
Expand Down
Loading

0 comments on commit 654a227

Please sign in to comment.