Skip to content

Commit

Permalink
Remove pyarrow dep from datafusion. Add in PyScalarValue wrapper and …
Browse files Browse the repository at this point in the history
…rename DataFusionError to PyDataFusionError to be less confusing
  • Loading branch information
timsaucer committed Jan 20, 2025
1 parent 9650a82 commit f0d25a2
Show file tree
Hide file tree
Showing 25 changed files with 520 additions and 186 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ tokio = { version = "1.41", features = ["macros", "rt", "rt-multi-thread", "sync
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
pyo3-async-runtimes = { version = "0.22", features = ["tokio-runtime"]}
arrow = { version = "53", features = ["pyarrow"] }
datafusion = { version = "44.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
datafusion = { version = "44.0.0", features = ["avro", "unicode_expressions"] }
datafusion-substrait = { version = "44.0.0", optional = true }
datafusion-proto = { version = "44.0.0" }
datafusion-ffi = { version = "44.0.0" }
Expand Down
3 changes: 2 additions & 1 deletion python/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def test_err(df):
with pytest.raises(Exception) as e_info:
df["c"]

assert "Schema error: No field named c." in e_info.value.args[0]
for e in ["SchemaError", "FieldNotFound", 'name: "c"']:
assert e in e_info.value.args[0]

with pytest.raises(Exception) as e_info:
df[1]
Expand Down
8 changes: 5 additions & 3 deletions src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::sync::Arc;
use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;

use crate::errors::DataFusionError;
use crate::errors::PyDataFusionError;
use crate::utils::wait_for_future;
use datafusion::{
arrow::pyarrow::ToPyArrow,
Expand Down Expand Up @@ -97,10 +97,12 @@ impl PyDatabase {
}

fn table(&self, name: &str, py: Python) -> PyResult<PyTable> {
if let Some(table) = wait_for_future(py, self.database.table(name))? {
if let Some(table) =
wait_for_future(py, self.database.table(name)).map_err(PyDataFusionError::from)?
{
Ok(PyTable::new(table))
} else {
Err(DataFusionError::Common(format!("Table not found: {name}")).into())
Err(PyDataFusionError::Common(format!("Table not found: {name}")).into())
}
}

Expand Down
14 changes: 14 additions & 0 deletions src/common/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ use pyo3::{exceptions::PyValueError, prelude::*};

use crate::errors::py_datafusion_err;

#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
pub struct PyScalarValue(pub ScalarValue);

impl From<ScalarValue> for PyScalarValue {
fn from(value: ScalarValue) -> Self {
Self(value)
}
}
impl From<PyScalarValue> for ScalarValue {
fn from(value: PyScalarValue) -> Self {
value.0
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[pyclass(eq, eq_int, name = "RexType", module = "datafusion.common")]
pub enum RexType {
Expand Down
7 changes: 5 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use pyo3::types::*;
use datafusion::common::ScalarValue;
use datafusion::config::ConfigOptions;

use crate::errors::PyDataFusionError;

#[pyclass(name = "Config", module = "datafusion", subclass)]
#[derive(Clone)]
pub(crate) struct PyConfig {
Expand All @@ -40,7 +42,7 @@ impl PyConfig {
#[staticmethod]
pub fn from_env() -> PyResult<Self> {
Ok(Self {
config: ConfigOptions::from_env()?,
config: ConfigOptions::from_env().map_err(PyDataFusionError::from)?,
})
}

Expand All @@ -60,7 +62,8 @@ impl PyConfig {
let scalar_value = py_obj_to_scalar_value(py, value);
self.config
.set(key, scalar_value.to_string().as_str())
.map_err(|e| e.into())
.map_err(PyDataFusionError::from)
.map_err(PyErr::from)
}

/// Get all configuration options
Expand Down
87 changes: 52 additions & 35 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use pyo3::prelude::*;
use crate::catalog::{PyCatalog, PyTable};
use crate::dataframe::PyDataFrame;
use crate::dataset::Dataset;
use crate::errors::{py_datafusion_err, DataFusionError};
use crate::errors::{py_datafusion_err, PyDataFusionError};
use crate::expr::sort_expr::PySortExpr;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
Expand Down Expand Up @@ -288,7 +288,11 @@ impl PySessionContext {
} else {
RuntimeEnvBuilder::default()
};
let runtime = Arc::new(runtime_env_builder.build()?);
let runtime = Arc::new(
runtime_env_builder
.build()
.map_err(PyDataFusionError::from)?,
);
let session_state = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
Expand Down Expand Up @@ -359,19 +363,19 @@ impl PySessionContext {
.map(|e| e.into_iter().map(|f| f.into()).collect())
.collect(),
);
let table_path = ListingTableUrl::parse(path)?;
let table_path = ListingTableUrl::parse(path).map_err(PyDataFusionError::from)?;
let resolved_schema: SchemaRef = match schema {
Some(s) => Arc::new(s.0),
None => {
let state = self.ctx.state();
let schema = options.infer_schema(&state, &table_path);
wait_for_future(py, schema).map_err(DataFusionError::from)?
wait_for_future(py, schema).map_err(PyDataFusionError::from)?
}
};
let config = ListingTableConfig::new(table_path)
.with_listing_options(options)
.with_schema(resolved_schema);
let table = ListingTable::try_new(config)?;
let table = ListingTable::try_new(config).map_err(PyDataFusionError::from)?;
self.register_table(
name,
&PyTable {
Expand All @@ -384,7 +388,7 @@ impl PySessionContext {
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
pub fn sql(&mut self, query: &str, py: Python) -> PyResult<PyDataFrame> {
let result = self.ctx.sql(query);
let df = wait_for_future(py, result).map_err(DataFusionError::from)?;
let df = wait_for_future(py, result).map_err(PyDataFusionError::from)?;
Ok(PyDataFrame::new(df))
}

Expand All @@ -401,7 +405,7 @@ impl PySessionContext {
SQLOptions::new()
};
let result = self.ctx.sql_with_options(query, options);
let df = wait_for_future(py, result).map_err(DataFusionError::from)?;
let df = wait_for_future(py, result).map_err(PyDataFusionError::from)?;
Ok(PyDataFrame::new(df))
}

Expand All @@ -419,7 +423,7 @@ impl PySessionContext {
partitions.0[0][0].schema()
};

let table = MemTable::try_new(schema, partitions.0).map_err(DataFusionError::from)?;
let table = MemTable::try_new(schema, partitions.0).map_err(PyDataFusionError::from)?;

// generate a random (unique) name for this table if none is provided
// table name cannot start with numeric digit
Expand All @@ -435,9 +439,10 @@ impl PySessionContext {

self.ctx
.register_table(&*table_name, Arc::new(table))
.map_err(DataFusionError::from)?;
.map_err(PyDataFusionError::from)?;

let table = wait_for_future(py, self._table(&table_name)).map_err(DataFusionError::from)?;
let table =
wait_for_future(py, self._table(&table_name)).map_err(PyDataFusionError::from)?;

let df = PyDataFrame::new(table);
Ok(df)
Expand Down Expand Up @@ -503,7 +508,7 @@ impl PySessionContext {
let schema = stream_reader.schema().as_ref().to_owned();
let batches = stream_reader
.collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()
.map_err(DataFusionError::from)?;
.map_err(PyDataFusionError::from)?;

(schema, batches)
} else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
Expand Down Expand Up @@ -562,14 +567,14 @@ impl PySessionContext {
pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> {
self.ctx
.register_table(name, table.table())
.map_err(DataFusionError::from)?;
.map_err(PyDataFusionError::from)?;
Ok(())
}

pub fn deregister_table(&mut self, name: &str) -> PyResult<()> {
self.ctx
.deregister_table(name)
.map_err(DataFusionError::from)?;
.map_err(PyDataFusionError::from)?;
Ok(())
}

Expand All @@ -587,7 +592,10 @@ impl PySessionContext {
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
let provider: ForeignTableProvider = provider.into();

let _ = self.ctx.register_table(name, Arc::new(provider))?;
let _ = self
.ctx
.register_table(name, Arc::new(provider))
.map_err(PyDataFusionError::from)?;

Ok(())
} else {
Expand All @@ -603,10 +611,10 @@ impl PySessionContext {
partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
) -> PyResult<()> {
let schema = partitions.0[0][0].schema();
let table = MemTable::try_new(schema, partitions.0)?;
let table = MemTable::try_new(schema, partitions.0).map_err(PyDataFusionError::from)?;
self.ctx
.register_table(name, Arc::new(table))
.map_err(DataFusionError::from)?;
.map_err(PyDataFusionError::from)?;
Ok(())
}

Expand Down Expand Up @@ -642,7 +650,7 @@ impl PySessionContext {
.collect();

let result = self.ctx.register_parquet(name, path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;
wait_for_future(py, result).map_err(PyDataFusionError::from)?;
Ok(())
}

Expand Down Expand Up @@ -685,11 +693,11 @@ impl PySessionContext {
if path.is_instance_of::<PyList>() {
let paths = path.extract::<Vec<String>>()?;
let result = self.register_csv_from_multiple_paths(name, paths, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;
wait_for_future(py, result).map_err(PyDataFusionError::from)?;
} else {
let path = path.extract::<String>()?;
let result = self.ctx.register_csv(name, &path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;
wait_for_future(py, result).map_err(PyDataFusionError::from)?;
}

Ok(())
Expand Down Expand Up @@ -726,7 +734,7 @@ impl PySessionContext {
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.register_json(name, path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;
wait_for_future(py, result).map_err(PyDataFusionError::from)?;

Ok(())
}
Expand Down Expand Up @@ -756,7 +764,7 @@ impl PySessionContext {
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.register_avro(name, path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;
wait_for_future(py, result).map_err(PyDataFusionError::from)?;

Ok(())
}
Expand All @@ -772,7 +780,7 @@ impl PySessionContext {

self.ctx
.register_table(name, table)
.map_err(DataFusionError::from)?;
.map_err(PyDataFusionError::from)?;

Ok(())
}
Expand Down Expand Up @@ -825,11 +833,16 @@ impl PySessionContext {
}

pub fn table_exist(&self, name: &str) -> PyResult<bool> {
Ok(self.ctx.table_exist(name)?)
Ok(self
.ctx
.table_exist(name)
.map_err(PyDataFusionError::from)?)
}

pub fn empty_table(&self) -> PyResult<PyDataFrame> {
Ok(PyDataFrame::new(self.ctx.read_empty()?))
Ok(PyDataFrame::new(
self.ctx.read_empty().map_err(PyDataFusionError::from)?,
))
}

pub fn session_id(&self) -> String {
Expand Down Expand Up @@ -859,10 +872,10 @@ impl PySessionContext {
let df = if let Some(schema) = schema {
options.schema = Some(&schema.0);
let result = self.ctx.read_json(path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?
wait_for_future(py, result).map_err(PyDataFusionError::from)?
} else {
let result = self.ctx.read_json(path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?
wait_for_future(py, result).map_err(PyDataFusionError::from)?
};
Ok(PyDataFrame::new(df))
}
Expand Down Expand Up @@ -909,12 +922,14 @@ impl PySessionContext {
let paths = path.extract::<Vec<String>>()?;
let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
let result = self.ctx.read_csv(paths, options);
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
let df =
PyDataFrame::new(wait_for_future(py, result).map_err(PyDataFusionError::from)?);
Ok(df)
} else {
let path = path.extract::<String>()?;
let result = self.ctx.read_csv(path, options);
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
let df =
PyDataFrame::new(wait_for_future(py, result).map_err(PyDataFusionError::from)?);
Ok(df)
}
}
Expand Down Expand Up @@ -952,7 +967,7 @@ impl PySessionContext {
.collect();

let result = self.ctx.read_parquet(path, options);
let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?);
let df = PyDataFrame::new(wait_for_future(py, result).map_err(PyDataFusionError::from)?);
Ok(df)
}

Expand All @@ -972,10 +987,10 @@ impl PySessionContext {
let df = if let Some(schema) = schema {
options.schema = Some(&schema.0);
let read_future = self.ctx.read_avro(path, options);
wait_for_future(py, read_future).map_err(DataFusionError::from)?
wait_for_future(py, read_future).map_err(PyDataFusionError::from)?
} else {
let read_future = self.ctx.read_avro(path, options);
wait_for_future(py, read_future).map_err(DataFusionError::from)?
wait_for_future(py, read_future).map_err(PyDataFusionError::from)?
};
Ok(PyDataFrame::new(df))
}
Expand All @@ -984,7 +999,7 @@ impl PySessionContext {
let df = self
.ctx
.read_table(table.table())
.map_err(DataFusionError::from)?;
.map_err(PyDataFusionError::from)?;
Ok(PyDataFrame::new(df))
}

Expand Down Expand Up @@ -1019,7 +1034,9 @@ impl PySessionContext {
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
Ok(PyRecordBatchStream::new(stream?))
Ok(PyRecordBatchStream::new(
stream.map_err(PyDataFusionError::from)?,
))
}
}

Expand Down Expand Up @@ -1071,13 +1088,13 @@ impl PySessionContext {

pub fn convert_table_partition_cols(
table_partition_cols: Vec<(String, String)>,
) -> Result<Vec<(String, DataType)>, DataFusionError> {
) -> Result<Vec<(String, DataType)>, PyDataFusionError> {
table_partition_cols
.into_iter()
.map(|(name, ty)| match ty.as_str() {
"string" => Ok((name, DataType::Utf8)),
"int" => Ok((name, DataType::Int32)),
_ => Err(DataFusionError::Common(format!(
_ => Err(PyDataFusionError::Common(format!(
"Unsupported data type '{ty}' for partition column. Supported types are 'string' and 'int'"
))),
})
Expand Down
Loading

0 comments on commit f0d25a2

Please sign in to comment.