diff --git a/crates/wasm-functions/Cargo.toml b/crates/wasm-functions/Cargo.toml new file mode 100644 index 0000000..713a1f9 --- /dev/null +++ b/crates/wasm-functions/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "wasm-functions" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +arrow = "51.0.0" +wasm-udfs = { version = "0.1.0", path = "../wasm-udfs" } + +[dev-dependencies] +wasm-bindgen-test = "0.3.43" diff --git a/crates/wasm-functions/src/lib.rs b/crates/wasm-functions/src/lib.rs new file mode 100644 index 0000000..8dd11ea --- /dev/null +++ b/crates/wasm-functions/src/lib.rs @@ -0,0 +1,84 @@ +use arrow::array::{Array, ArrayRef, Float64Array}; +use arrow::error::ArrowError; +use std::sync::Arc; +use wasm_udfs::*; + +// ```bash +// cargo install wasm-bindgen-cli +// ``` + +// ```bash +// cargo test --target wasm32-unknown-unknown +// ``` + +// expose function f1 as external function +// add required bindgen, and required serialization/deserialization +export_udf_function!(f1); +// function should return error +export_udf_function!(f_return_error); +// function should panic +// export_udf_function!(f_panic); +// function should return arrow error +export_udf_function!(f_return_arrow_error); + +/// standard datafusion udf ... kind of +/// should return ArrayRef or ArrowError +fn f1(args: &[ArrayRef]) -> Result { + assert_eq!(2, args.len()); + + let base = args[0] + .as_any() + .downcast_ref::() + .expect("cast 0 failed"); + let exponent = args[1] + .as_any() + .downcast_ref::() + .expect("cast 1 failed"); + + assert_eq!(exponent.len(), base.len()); + + let array = base + .iter() + .zip(exponent.iter()) + .map(|(base, exponent)| match (base, exponent) { + (Some(base), Some(exponent)) => Some(base.powf(exponent)), + _ => None, + }) + .collect::(); + + // TODO: do we need arc here? + // only reason to stay to keep api same + // like datafusion udf's + Ok(Arc::new(array)) +} +/// function returns String Error +fn f_return_error(_args: &[ArrayRef]) -> Result { + Err("wasm function returned error".to_string()) +} + +/// function returns error +fn f_return_arrow_error(_args: &[ArrayRef]) -> Result { + Err(ArrowError::DivideByZero) +} + +// fn f_panic(_args: &[ArrayRef]) -> Result { +// panic!("wasm function panicked") +// } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ArrayRef, Float64Array}; + + use std::sync::Arc; + + #[wasm_bindgen_test::wasm_bindgen_test] + fn test_f1() { + let a: ArrayRef = Arc::new(Float64Array::from(vec![2.1, 3.1, 4.1, 5.1])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])); + let args = vec![a, b]; + let result = f1(&args).unwrap(); + + assert_eq!(4, result.len()) + } +} diff --git a/crates/wasm-udfs/Cargo.toml b/crates/wasm-udfs/Cargo.toml new file mode 100644 index 0000000..867c110 --- /dev/null +++ b/crates/wasm-udfs/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "wasm-udfs" +version = "0.1.0" +edition = "2021" + +[dependencies] +arrow = "51.0.0" +paste = "1.0.15" +wasmedge-bindgen = "0.4.1" +wasmedge-bindgen-macro = "0.4.1" diff --git a/crates/wasm-udfs/src/lib.rs b/crates/wasm-udfs/src/lib.rs new file mode 100644 index 0000000..ed22397 --- /dev/null +++ b/crates/wasm-udfs/src/lib.rs @@ -0,0 +1,67 @@ +use arrow::{ + array::{Array, ArrayRef, RecordBatch}, + datatypes::{Field, Schema, SchemaRef}, +}; +pub use paste; +use std::sync::Arc; +pub use wasmedge_bindgen; +pub use wasmedge_bindgen_macro; + +/// packs slice of arrays to a batch +/// with schema generated from array types +pub fn pack_array(args: &[ArrayRef]) -> RecordBatch { + let fields = args + .iter() + .enumerate() + .map(|(i, f)| Field::new(format!("c{}", i), f.data_type().clone(), false)) + .collect::>(); + + let schema = Arc::new(Schema::new(fields)); + + RecordBatch::try_new(schema, args.to_vec()).unwrap() +} + +/// packs slice of arrays to a batch +/// with external schema +pub fn pack_array_with_schema(args: &[ArrayRef], schema: SchemaRef) -> RecordBatch { + RecordBatch::try_new(schema, args.to_vec()).unwrap() +} + +/// creates a arrow ipc blob +pub fn to_ipc(schema: &Schema, batch: RecordBatch) -> Vec { + let blob = vec![]; + let mut stream_writer = arrow::ipc::writer::StreamWriter::try_new(blob, schema).unwrap(); + stream_writer.write(&batch).unwrap(); + + stream_writer.into_inner().unwrap() +} + +/// creates arrow arrays from arrow ipc blob +pub fn from_ipc(payload: &[u8]) -> RecordBatch { + let mut batch = arrow::ipc::reader::StreamReader::try_new(payload, None).unwrap(); + batch.next().unwrap().unwrap() +} + +/// exports wasm function and performs all required +/// arrow ipc serialization/deserialization +/// +/// macro will create new function prefixed with `__wasm_udf_` +/// +// TODO: make this a proc macro maybe ? +#[macro_export] +macro_rules! export_udf_function { + ($name:ident) => { + paste::item! { + #[wasmedge_bindgen_macro::wasmedge_bindgen] + pub fn [<__wasm_udf_$name>](payload: Vec) -> Result,String> { + let args_batch = from_ipc(&payload); + let result = $name(args_batch.columns()); + // let batch = pack_array(&vec![result]); + // to_ipc(&batch.schema(), batch) + result.map(|result| pack_array(&vec![result])) + .map(|batch| to_ipc(&batch.schema(), batch)) + .map_err(|e| e.to_string()) + } + } + }; +} diff --git a/crates/wasmedge-factory/Cargo.toml b/crates/wasmedge-factory/Cargo.toml new file mode 100644 index 0000000..f80dfed --- /dev/null +++ b/crates/wasmedge-factory/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "wasmedge-factory" +version = "0.1.0" +edition = "2021" + +[dependencies] +arrow-udf-wasm = "0.2.2" +async-trait = "0.1.82" +datafusion = "38.0.0" +log = "0.4.22" +project-root = "0.2.2" +thiserror = "1.0.63" +tokio = "1.40.0" +wasm-udfs = { version = "0.1.0", path = "../wasm-udfs" } +# wasmedge-sdk = "0.13.2" +weak-table = "0.3.2" diff --git a/crates/wasmedge-factory/src/lib.rs b/crates/wasmedge-factory/src/lib.rs new file mode 100644 index 0000000..b7158bb --- /dev/null +++ b/crates/wasmedge-factory/src/lib.rs @@ -0,0 +1,452 @@ +use std::{ + path::Path, + sync::{Arc, Weak}, +}; + +use arrow_udf_wasm::Runtime; +use datafusion::{ + arrow::{ + array::{ArrayRef, Float64Array}, + datatypes::DataType, + }, + common::{cast::as_float64_array, exec_err}, + error::{DataFusionError, Result}, + execution::context::{FunctionFactory, RegisterFunction, SessionState}, + logical_expr::{ + create_udf, ColumnarValue, CreateFunction, DefinitionStatement, ScalarUDF, Volatility, + }, +}; +use thiserror::Error; +use tokio::sync::Mutex; +// use wasmedge_sdk::{config::ConfigBuilder, dock::VmDock, Module, VmBuilder}; +use weak_table::WeakValueHashMap; + +mod udf; + +// type ModuleCache = Arc>>>; + +fn test_udf() -> ScalarUDF { + // First, declare the actual implementation of the calculation + let pow = Arc::new(|args: &[ColumnarValue]| { + // in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to: + // 1. cast the values to the type we want + // 2. perform the computation for every element in the array (using a loop or SIMD) and construct the result + + // this is guaranteed by DataFusion based on the function's signature. + assert_eq!(args.len(), 2); + + // Expand the arguments to arrays (this is simple, but inefficient for + // single constant values). + let args = ColumnarValue::values_to_arrays(args)?; + + // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! + let base = as_float64_array(&args[0]).expect("cast failed"); + let exponent = as_float64_array(&args[1]).expect("cast failed"); + + // The array lengths is guaranteed by DataFusion. We assert here to make it obvious. + assert_eq!(exponent.len(), base.len()); + + // 2. perform the computation + let array = base + .iter() + .zip(exponent.iter()) + .map(|(base, exponent)| { + match (base, exponent) { + // in arrow, any value can be null. + // Here we decide to make our UDF to return null when either base or exponent is null. + (Some(base), Some(exponent)) => Some(base.powf(exponent)), + _ => None, + } + }) + .collect::(); + + // `Ok` because no error occurred during the calculation (we should add one if exponent was [0, 1[ and the base < 0 because that panics!) + // `Arc` because arrays are immutable, thread-safe, trait objects. + Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) + }); + + // Next: + // * give it a name so that it shows nicely when the plan is printed + // * declare what input it expects + // * declare its return type + let pow = create_udf( + "f1", + // expects two f64 + vec![DataType::Float64, DataType::Float64], + // returns f64 + Arc::new(DataType::Float64), + Volatility::Immutable, + pow, + ); + pow +} + +pub struct WasmFunctionFactory { + // note: + // https://github.com/WasmEdge/wasmedge-rust-sdk/issues/89 + // comments do not add up to VM interface, on top of it + // UDFs do not modify any state. leaving as it is for now + // may revert it later + // modules: ModuleCache, +} + +#[async_trait::async_trait] +impl FunctionFactory for WasmFunctionFactory { + async fn create( + &self, + _state: &SessionState, + statement: CreateFunction, + ) -> Result { + let return_type = statement.return_type.expect("return type expected"); + let argument_types = statement + .args + .map(|args| { + args.into_iter() + .map(|a| a.data_type) + .collect::>() + }) + .unwrap_or_default(); + let declared_name = statement.name; + let (module_name, method_name) = match &statement.params.as_ { + Some(DefinitionStatement::SingleQuotedDef(path)) => { + println!("Got create function path: {}", path); + Self::wasm_module_function(path)? + } + None => return exec_err!("wasm function not defined "), + Some(f) => return exec_err!("wasm function incorrect {:?} ", f), + }; + + // let rt = Runtime::new(binary); + + // let vm = self.wasm_model_cache_or_load(&module_name).await?; + // let f = crate::udf::WasmFunctionWrapper::new( + // vm, + // declared_name, + // method_name, + // argument_types, + // return_type, + // )?; + + let f = test_udf(); + + Ok(RegisterFunction::Scalar(Arc::new(ScalarUDF::from(f)))) + } +} + +impl Default for WasmFunctionFactory { + fn default() -> Self { + WasmFunctionFactory { + // modules: Arc::new(Mutex::new(WeakValueHashMap::new())), + } + } +} + +impl WasmFunctionFactory { + /// returns cached module or + /// loads, caches module and returns module + /// for given module path + // async fn wasm_model_cache_or_load( + // &self, + // wasm_module_path: &str, + // ) -> std::result::Result, WasmFunctionError> { + // // caching key is bit primitive, but good enough for now + // let mut modules = self.modules.lock().await; + // // lets assume creation of new module will not take too long + // // and lock will be kept for a very short period of time, + // // good enough for now + // match modules.get(wasm_module_path) { + // Some(module) => { + // log::debug!("return cached VM for wasm_module={}", wasm_module_path); + // Ok(module.clone()) + // } + // None => { + // log::debug!("no cached VM for wasm_module={}", wasm_module_path); + // let module = Self::wasm_model_load(wasm_module_path)?; + // modules.insert(wasm_module_path.to_string(), module.clone()); + // Ok(module) + // } + // } + // } + + fn wasm_module_function(s: &str) -> Result<(String, String)> { + match s.split('!').collect::>()[..] { + [module, method] if !module.is_empty() && !method.is_empty() => { + Ok((module.to_string(), method.to_string())) + } + _ => exec_err!("bad module/method format"), + } + } + + // fn wasm_model_load(wasm_module: &str) -> std::result::Result, WasmFunctionError> { + // log::debug!("producing new VM for wasm_module={}", wasm_module); + // let file = Path::new(&wasm_module); + // let module = if file.is_absolute() { + // Module::from_file(None, wasm_module)? + // } else { + // let mut project_root = project_root::get_project_root() + // .map_err(|e| WasmFunctionError::Execution(e.to_string()))?; + // project_root.push(file); + // Module::from_file(None, &project_root)? + // }; + // + // // default configuration will do for now + // let config = ConfigBuilder::default().build()?; + // + // let vm = VmBuilder::new() + // .with_config(config) + // .build()? + // .register_module(None, module)?; + // + // Ok(Arc::new(VmDock::new(vm))) + // } + // #[cfg(test)] + // fn module_cache(&self) -> ModuleCache { + // self.modules.clone() + // } +} + +// #[derive(Error, Debug)] +// pub enum WasmFunctionError { +// #[error("WasmEdge Error: {0}")] +// WasmEdgeError(#[from] Box), +// #[error("Execution Error: {0}")] +// Execution(String), +// } + +// impl From for DataFusionError { +// fn from(e: WasmFunctionError) -> Self { +// // will do for now +// DataFusionError::Execution(e.to_string()) +// } +// } + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use datafusion::{ + arrow::array::{ArrayRef, Float64Array, RecordBatch}, + assert_batches_eq, + execution::context::SessionContext, + }; + + use crate::WasmFunctionFactory; + + #[test] + fn test_module_function_split() { + let (module, method) = WasmFunctionFactory::wasm_module_function("module!method").unwrap(); + assert_eq!("module", module); + assert_eq!("method", method); + + assert!(WasmFunctionFactory::wasm_module_function("!method").is_err()); + } + #[tokio::test] + async fn should_handle_happy_path() -> datafusion::error::Result<()> { + let ctx = + SessionContext::new().with_function_factory(Arc::new(WasmFunctionFactory::default())); + + let a: ArrayRef = Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])); + let b: ArrayRef = Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.1])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?; + + ctx.register_batch("t", batch)?; + + let sql = r#" + CREATE FUNCTION f1(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f1' + "#; + + ctx.sql(sql).await?.show().await?; + + let result = ctx + .sql("select a, b, f1(a,b) from t") + .await? + .collect() + .await?; + let expected = vec![ + "+-----+-----+-------------------+", + "| a | b | f1(t.a,t.b) |", + "+-----+-----+-------------------+", + "| 2.0 | 2.0 | 4.0 |", + "| 3.0 | 3.0 | 27.0 |", + "| 4.0 | 4.0 | 256.0 |", + "| 5.0 | 5.1 | 3670.684197150057 |", + "+-----+-----+-------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[tokio::test] + async fn should_handle_error() -> datafusion::error::Result<()> { + let ctx = + SessionContext::new().with_function_factory(Arc::new(WasmFunctionFactory::default())); + + let sql = r#" + CREATE FUNCTION f2(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f_return_error' + "#; + + ctx.sql(sql).await?.show().await?; + + let result = ctx.sql("select f2(1.0,1.0)").await?.show().await; + + assert!(result.is_err()); + assert_eq!( + "Execution error: [Wasm Invocation] wasm function returned error", + result.err().unwrap().to_string() + ); + + Ok(()) + } + + #[tokio::test] + async fn should_handle_arrow_error() -> datafusion::error::Result<()> { + let ctx = + SessionContext::new().with_function_factory(Arc::new(WasmFunctionFactory::default())); + + let sql = r#" + CREATE FUNCTION f2(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f_return_arrow_error' + "#; + + ctx.sql(sql).await?.show().await?; + + let result = ctx.sql("select f2(1.0,1.0)").await?.show().await; + + assert!(result.is_err()); + assert_eq!( + "Execution error: [Wasm Invocation] Divide by zero error", + result.err().unwrap().to_string() + ); + + Ok(()) + } + + #[tokio::test] + #[ignore = "WasmEdge does not handle panic after latest change"] + async fn should_handle_panic() -> datafusion::error::Result<()> { + let ctx = + SessionContext::new().with_function_factory(Arc::new(WasmFunctionFactory::default())); + + let sql = r#" + CREATE FUNCTION f1(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f1' + "#; + // we register good function to verify that panich + // will not put vm to some unexpected state + ctx.sql(sql).await?.show().await?; + + let sql = r#" + CREATE FUNCTION f3(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f_panic' + "#; + + ctx.sql(sql).await?.show().await?; + + let result = ctx.sql("select f3(1.0,1.0)").await?.show().await; + + assert!(result.is_err()); + assert_eq!( + "Execution error: [Wasm Invocation Panic] unreachable", + result.err().unwrap().to_string() + ); + let result = ctx.sql("select f1(1.0,1.0)").await?.collect().await?; + let expected = vec![ + "+---------------------------+", + "| f1(Float64(1),Float64(1)) |", + "+---------------------------+", + "| 1.0 |", + "+---------------------------+", + ]; + + assert_batches_eq!(expected, &result); + Ok(()) + } + + #[tokio::test] + async fn should_create_drop_function() -> datafusion::error::Result<()> { + let function_factory = Arc::new(WasmFunctionFactory::default()); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); + + let sql = r#" + CREATE FUNCTION f1(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f1' + "#; + + ctx.sql(sql).await?.show().await?; + + let sql = r#" + CREATE FUNCTION f2(DOUBLE, DOUBLE) + RETURNS DOUBLE + LANGUAGE WASM + AS 'wasm_function/target/wasm32-unknown-unknown/debug/wasm_function.wasm!f_return_arrow_error' + "#; + + ctx.sql(sql).await?.show().await?; + + let result = ctx.sql("select f1(2.0,2.0)").await?.collect().await?; + let expected = vec![ + "+---------------------------+", + "| f1(Float64(2),Float64(2)) |", + "+---------------------------+", + "| 4.0 |", + "+---------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + // we should have one modules caching + // assert_eq!(1, function_factory.module_cache().lock().await.len()); + + let sql = r#" + DROP FUNCTION f1 + "#; + + ctx.sql(sql).await?.show().await?; + + let sql = r#" + DROP FUNCTION f2 + "#; + + ctx.sql(sql).await?.show().await?; + + // we should have none modules cached + // weak hashmap should drop VM after last function + // has been dropped. + // note, weak hash map is lazy to drop + // assert_eq!( + // 0, + // function_factory + // .module_cache() + // .lock() + // .await + // .keys() + // .collect::>() + // .len() + // ); + + Ok(()) + } +} + +// #[cfg(test)] +// #[ctor::ctor] +// fn init() { +// // Enable RUST_LOG logging configuration for test +// let _ = env_logger::builder().is_test(true).try_init(); +// } diff --git a/crates/wasmedge-factory/src/udf.rs b/crates/wasmedge-factory/src/udf.rs new file mode 100644 index 0000000..3acf023 --- /dev/null +++ b/crates/wasmedge-factory/src/udf.rs @@ -0,0 +1,122 @@ +use std::sync::Arc; + +use datafusion::{ + arrow::{ + array::ArrayRef, + datatypes::{DataType, Field, Schema, SchemaRef}, + }, + common::exec_err, + error::Result, + logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}, +}; +use wasm_udfs::{from_ipc, pack_array_with_schema, to_ipc}; +// use wasmedge_sdk::dock::{Param, VmDock}; + +// #[derive(Debug)] +// pub(crate) struct WasmFunctionWrapper { +// /// name which was used to in `CREATE FUNCTION` statement +// declared_function_name: String, +// /// wasm method to be called, can be found in `AS` part of the statement +// // it would be much better if we could cache method handle +// // but that is not currently supported by wasmedge sdk +// wasm_method: String, +// argument_schema: SchemaRef, +// // TODO: function signature should be extracted from `CREATE FUNCTION` statement +// signature: Signature, +// return_type: DataType, +// /// wasm VM which hosts module +// vm: Arc, +// } + +// impl WasmFunctionWrapper { +// pub(crate) fn new( +// vm: Arc, +// declared_function_name: String, +// wasm_method: String, +// argument_types: Vec, +// return_type: DataType, +// ) -> Result { +// let fields = argument_types +// .iter() +// .enumerate() +// .map(|(i, f)| Field::new(format!("c{}", i), f.clone(), false)) +// .collect::>(); +// +// // we cache the schema +// // as it will be used for every message +// // passed between rust and wasm (not sure if we can avoid that) +// let argument_schema = Arc::new(Schema::new(fields)); +// +// Ok(Self { +// // prefix is not really needed but it looks cool :) +// wasm_method: format!("__wasm_udf_{}", wasm_method), +// declared_function_name, +// signature: Signature::exact(argument_types, Volatility::Volatile), +// return_type, +// argument_schema, +// vm, +// }) +// } +// } + +// impl ScalarUDFImpl for WasmFunctionWrapper { +// fn as_any(&self) -> &dyn std::any::Any { +// self +// } +// +// fn name(&self) -> &str { +// &self.declared_function_name +// } +// +// fn signature(&self) -> &datafusion::logical_expr::Signature { +// &self.signature +// } +// +// fn return_type( +// &self, +// _arg_types: &[datafusion::arrow::datatypes::DataType], +// ) -> Result { +// Ok(self.return_type.clone()) +// } +// +// fn invoke( +// &self, +// args: &[datafusion::logical_expr::ColumnarValue], +// ) -> Result { +// let arrays = ColumnarValue::values_to_arrays(args)?; +// let batch = pack_array_with_schema(&arrays, self.argument_schema.clone()); +// +// let payload = to_ipc(&batch.schema(), batch); +// let params = vec![Param::VecU8(&payload)]; +// +// let call_result = match self.vm.run_func(&self.wasm_method, params) { +// Ok(result) => result, +// // if wasm function panics it should get to this error +// Err(e) => return exec_err!("[Wasm Invocation Panic] {}", e), +// }; +// +// match call_result { +// // function returned result +// // in our case we expect only single result +// // at position 0 +// Ok(mut res) => { +// // we should add errors to the protocol +// let response = res.pop().unwrap().downcast::>().unwrap(); +// let a = from_ipc(&response); +// // aso we expect single column as the result +// let result = a.column(0); +// Ok(ColumnarValue::from(result.clone() as ArrayRef)) +// } +// // function returned error +// Err(err) => { +// exec_err!("[Wasm Invocation] {}", err) +// } +// } +// } +// } +// +// impl Drop for WasmFunctionWrapper { +// fn drop(&mut self) { +// log::debug!("drop wasm function, name={}", self.name()) +// } +// }