diff --git a/.github/workflows/test-pyodide.yml b/.github/workflows/test-pyodide.yml index 49e02cb2412a..ff76d7307208 100644 --- a/.github/workflows/test-pyodide.yml +++ b/.github/workflows/test-pyodide.yml @@ -39,7 +39,7 @@ jobs: - name: Disable incompatible features env: - FEATURES: parquet|async|json|extract_jsonpath|cloud|polars_cloud|tokio|clipboard|decompress|new_streaming + FEATURES: parquet|async|json|extract_jsonpath|catalog|cloud|polars_cloud|tokio|clipboard|decompress|new_streaming run: | sed -i 's/^ "json",$/ "serde_json",/' crates/polars-python/Cargo.toml sed -E -i "/^ \"(${FEATURES})\",$/d" crates/polars-python/Cargo.toml py-polars/Cargo.toml diff --git a/Cargo.lock b/Cargo.lock index dfb5050118ba..689f5fa9a32c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3197,6 +3197,7 @@ dependencies = [ "serde_json", "simd-json", "simdutf8", + "strum_macros", "tempfile", "tokio", "tokio-util", diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml index e06a7aad1b8a..c3c1f90a5608 100644 --- a/crates/polars-io/Cargo.toml +++ b/crates/polars-io/Cargo.toml @@ -46,6 +46,7 @@ serde = { workspace = true, features = ["rc"], optional = true } serde_json = { version = "1", optional = true } simd-json = { workspace = true, optional = true } simdutf8 = { workspace = true, optional = true } +strum_macros = { workspace = true, optional = true } tokio = { workspace = true, features = ["fs", "net", "rt-multi-thread", "time", "sync"], optional = true } tokio-util = { workspace = true, features = ["io", "io-util"], optional = true } url = { workspace = true, optional = true } @@ -59,6 +60,7 @@ home = "0.5.4" tempfile = "3" [features] +catalog = ["cloud", "serde", "reqwest", "futures", "strum_macros"] default = ["decompress"] # support for arrows json parsing json = [ diff --git a/crates/polars-io/src/catalog/mod.rs b/crates/polars-io/src/catalog/mod.rs new file mode 100644 index 000000000000..99f35e4cfba4 --- /dev/null +++ b/crates/polars-io/src/catalog/mod.rs @@ -0,0 +1,2 @@ +pub mod schema; +pub mod unity; diff --git a/crates/polars-io/src/catalog/schema.rs b/crates/polars-io/src/catalog/schema.rs new file mode 100644 index 000000000000..33f54e056f03 --- /dev/null +++ b/crates/polars-io/src/catalog/schema.rs @@ -0,0 +1,265 @@ +use polars_core::prelude::{DataType, Field}; +use polars_core::schema::{Schema, SchemaRef}; +use polars_error::{polars_bail, polars_err, PolarsResult}; +use polars_utils::pl_str::PlSmallStr; + +use super::unity::models::TableInfo; + +/// Returns `(schema, hive_schema)` +pub fn table_info_to_schemas( + table_info: &TableInfo, +) -> PolarsResult<(Option, Option)> { + let Some(columns) = table_info.columns.as_deref() else { + return Ok((None, None)); + }; + + let mut schema = Schema::default(); + let mut hive_schema = Schema::default(); + + for (i, col) in columns.iter().enumerate() { + let dtype = parse_type_str(&col.type_text)?; + + if let Some(position) = col.position { + if usize::try_from(position).unwrap() != i { + polars_bail!( + ComputeError: + "not yet supported: position was not ordered" + ) + } + } + + if let Some(i) = col.partition_index { + if usize::try_from(i).unwrap() != hive_schema.len() { + polars_bail!( + ComputeError: + "not yet supported: partition_index was not ordered" + ) + } + + hive_schema.extend([Field::new(col.name.as_str().into(), dtype)]); + } else { + schema.extend([Field::new(col.name.as_str().into(), dtype)]) + } + } + + Ok(( + Some(schema.into()), + Some(hive_schema) + .filter(|x| !x.is_empty()) + .map(|x| x.into()), + )) +} + +/// Parse a type string from a catalog API response. +/// +/// References: +/// * https://spark.apache.org/docs/latest/sql-ref-datatypes.html +/// * https://docs.databricks.com/api/workspace/tables/get +/// * https://docs.databricks.com/en/sql/language-manual/sql-ref-datatypes.html +/// +/// Note: `type_precision` and `type_scale` in the API response are defined as supplementary data to +/// the `type_text`, but from testing they aren't actually used - e.g. a decimal type would have a +/// `type_text` of `decimal(18, 2)` +fn parse_type_str(type_text: &str) -> PolarsResult { + use DataType::*; + + let dtype = match type_text { + "boolean" => Boolean, + + "byte" | "tinyint" => Int8, + "short" | "smallint" => Int16, + "int" | "integer" => Int32, + "long" | "bigint" => Int64, + + "float" | "real" => Float32, + "double" => Float64, + + "date" => Date, + "timestamp" | "timestamp_ltz" | "timestamp_ntz" => { + Datetime(polars_core::prelude::TimeUnit::Nanoseconds, None) + }, + + "string" => String, + "binary" => Binary, + + "null" | "void" => Null, + + v => { + if v.starts_with("decimal") { + // e.g. decimal(38,18) + (|| { + let (precision, scale) = v + .get(7..)? + .strip_prefix('(')? + .strip_suffix(')')? + .split_once(',')?; + let precision: usize = precision.parse().ok()?; + let scale: usize = scale.parse().ok()?; + + Some(DataType::Decimal(Some(precision), Some(scale))) + })() + .ok_or_else(|| { + polars_err!( + ComputeError: + "type format did not match decimal(int,int): {}", + v + ) + })? + } else if v.starts_with("array") { + // e.g. array + DataType::List(Box::new(parse_type_str(extract_angle_brackets_inner( + v, "array", + )?)?)) + } else if v.starts_with("struct") { + parse_struct_type_str(v)? + } else if v.starts_with("map") { + // e.g. map + let inner = extract_angle_brackets_inner(v, "map")?; + let (key_type_str, value_type_str) = split_comma_nesting_aware(inner); + DataType::List(Box::new(DataType::Struct(vec![ + Field::new( + PlSmallStr::from_static("key"), + parse_type_str(key_type_str)?, + ), + Field::new( + PlSmallStr::from_static("value"), + parse_type_str(value_type_str)?, + ), + ]))) + } else { + polars_bail!( + ComputeError: + "parse_type_str unknown type name: {}", + v + ) + } + }, + }; + + Ok(dtype) +} + +/// `array -> inner` +fn extract_angle_brackets_inner<'a>(value: &'a str, name: &'static str) -> PolarsResult<&'a str> { + let i = value.find('<'); + let j = value.rfind('>'); + + if i.is_none() || j.is_none() || i.unwrap().saturating_add(1) >= j.unwrap() { + polars_bail!( + ComputeError: + "type format did not match {}<...>: {}", + name, value + ) + } + + let i = i.unwrap(); + let j = j.unwrap(); + + let inner = value[i + 1..j].trim(); + + Ok(inner) +} + +/// `struct,effective_list:struct>` +fn parse_struct_type_str(value: &str) -> PolarsResult { + let orig_value = value; + let mut value = extract_angle_brackets_inner(value, "struct")?; + + let mut fields = vec![]; + + while !value.is_empty() { + let (field_str, new_value) = split_comma_nesting_aware(value); + value = new_value; + + let (name, dtype_str) = field_str.split_once(':').ok_or_else(|| { + polars_err!( + ComputeError: + "type format did not match struct: {}", + orig_value + ) + })?; + + let dtype = parse_type_str(dtype_str)?; + + fields.push(Field::new(name.into(), dtype)); + } + + Ok(DataType::Struct(fields)) +} + +/// `default:decimal(38,18),promotional:struct` -> +/// * 1: `default:decimal(38,18)` +/// * 2: `struct` +/// +/// If there are no splits, returns the full string and an empty string. +fn split_comma_nesting_aware(value: &str) -> (&str, &str) { + let mut bracket_level = 0usize; + let mut angle_bracket_level = 0usize; + + for (i, b) in value.as_bytes().iter().enumerate() { + match b { + b'(' => bracket_level += 1, + b')' => bracket_level = bracket_level.saturating_sub(1), + b'<' => angle_bracket_level += 1, + b'>' => angle_bracket_level = angle_bracket_level.saturating_sub(1), + b',' if bracket_level == 0 && angle_bracket_level == 0 => { + return (&value[..i], &value[1 + i..]) + }, + _ => {}, + } + } + + (value, &value[value.len()..]) +} + +#[cfg(test)] +mod tests { + #[test] + fn test_parse_type_str_nested_struct() { + use super::{parse_type_str, DataType, Field}; + + let type_str = "struct,effective_list:struct>"; + let dtype = parse_type_str(type_str).unwrap(); + + use DataType::*; + + assert_eq!( + dtype, + Struct(vec![ + Field::new("default".into(), Decimal(Some(38), Some(18))), + Field::new( + "promotional".into(), + Struct(vec![Field::new( + "default".into(), + Decimal(Some(38), Some(18)) + )]) + ), + Field::new( + "effective_list".into(), + Struct(vec![Field::new( + "default".into(), + Decimal(Some(38), Some(18)) + )]) + ) + ]) + ); + } + + #[test] + fn test_parse_type_str_map() { + use super::{parse_type_str, DataType, Field}; + + let type_str = "map,array>"; + let dtype = parse_type_str(type_str).unwrap(); + + use DataType::*; + + assert_eq!( + dtype, + List(Box::new(Struct(vec![ + Field::new("key".into(), List(Box::new(Int32))), + Field::new("value".into(), List(Box::new(Decimal(Some(18), Some(2))))) + ]))) + ); + } +} diff --git a/crates/polars-io/src/catalog/unity/client.rs b/crates/polars-io/src/catalog/unity/client.rs new file mode 100644 index 000000000000..53f0814a01c5 --- /dev/null +++ b/crates/polars-io/src/catalog/unity/client.rs @@ -0,0 +1,158 @@ +use polars_error::{polars_bail, to_compute_err, PolarsResult}; + +use super::models::{CatalogInfo, SchemaInfo, TableInfo}; +use super::utils::PageWalker; +use crate::impl_page_walk; +use crate::utils::decode_json_response; + +/// Unity catalog client. +pub struct CatalogClient { + workspace_url: String, + http_client: reqwest::Client, +} + +impl CatalogClient { + pub async fn list_catalogs(&self) -> PolarsResult> { + ListCatalogs(PageWalker::new(self.http_client.get(format!( + "{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/catalogs" + )))) + .read_all_pages() + .await + } + + pub async fn list_schemas(&self, catalog_name: &str) -> PolarsResult> { + ListSchemas(PageWalker::new( + self.http_client + .get(format!( + "{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/schemas" + )) + .query(&[("catalog_name", catalog_name)]), + )) + .read_all_pages() + .await + } + + pub async fn list_tables( + &self, + catalog_name: &str, + schema_name: &str, + ) -> PolarsResult> { + ListTables(PageWalker::new( + self.http_client + .get(format!( + "{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/tables" + )) + .query(&[("catalog_name", catalog_name), ("schema_name", schema_name)]), + )) + .read_all_pages() + .await + } + + pub async fn get_table_info( + &self, + catalog_name: &str, + schema_name: &str, + table_name: &str, + ) -> PolarsResult { + let full_table_name = format!( + "{}.{}.{}", + catalog_name.replace('/', "%2F"), + schema_name.replace('/', "%2F"), + table_name.replace('/', "%2F") + ); + + let bytes = async { + self.http_client + .get(format!( + "{}{}{}", + &self.workspace_url, "/api/2.1/unity-catalog/tables/", full_table_name + )) + .query(&[("full_name", full_table_name)]) + .send() + .await? + .bytes() + .await + } + .await + .map_err(to_compute_err)?; + + let out: TableInfo = decode_json_response(&bytes)?; + + Ok(out) + } +} + +pub struct CatalogClientBuilder { + workspace_url: Option, + bearer_token: Option, +} + +#[allow(clippy::derivable_impls)] +impl Default for CatalogClientBuilder { + fn default() -> Self { + Self { + workspace_url: None, + bearer_token: None, + } + } +} + +impl CatalogClientBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn with_workspace_url(mut self, workspace_url: impl Into) -> Self { + self.workspace_url = Some(workspace_url.into()); + self + } + + pub fn with_bearer_token(mut self, bearer_token: impl Into) -> Self { + self.bearer_token = Some(bearer_token.into()); + self + } + + pub fn build(self) -> PolarsResult { + let Some(workspace_url) = self.workspace_url else { + polars_bail!(ComputeError: "expected Some(_) for workspace_url") + }; + + Ok(CatalogClient { + workspace_url, + http_client: { + let builder = reqwest::ClientBuilder::new().user_agent("polars"); + + let builder = if let Some(bearer_token) = self.bearer_token { + use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT}; + + let mut headers = HeaderMap::new(); + + let mut auth_value = + HeaderValue::from_str(format!("Bearer {}", bearer_token).as_str()).unwrap(); + auth_value.set_sensitive(true); + + headers.insert(AUTHORIZATION, auth_value); + headers.insert(USER_AGENT, "polars".try_into().unwrap()); + + builder.default_headers(headers) + } else { + builder + }; + + builder.build().map_err(to_compute_err)? + }, + }) + } +} + +pub struct ListCatalogs(pub(crate) PageWalker); +impl_page_walk!(ListCatalogs, CatalogInfo, key_name = catalogs); + +pub struct ListSchemas(pub(crate) PageWalker); +impl_page_walk!(ListSchemas, SchemaInfo, key_name = schemas); + +pub struct ListTables(pub(crate) PageWalker); +impl_page_walk!(ListTables, TableInfo, key_name = tables); diff --git a/crates/polars-io/src/catalog/unity/mod.rs b/crates/polars-io/src/catalog/unity/mod.rs new file mode 100644 index 000000000000..5a9e0c799cdb --- /dev/null +++ b/crates/polars-io/src/catalog/unity/mod.rs @@ -0,0 +1,3 @@ +pub mod client; +pub mod models; +pub(crate) mod utils; diff --git a/crates/polars-io/src/catalog/unity/models.rs b/crates/polars-io/src/catalog/unity/models.rs new file mode 100644 index 000000000000..da9f604e27ee --- /dev/null +++ b/crates/polars-io/src/catalog/unity/models.rs @@ -0,0 +1,81 @@ +#[derive(Debug, serde::Deserialize)] +pub struct CatalogInfo { + pub name: String, + pub comment: Option, +} + +#[derive(Debug, serde::Deserialize)] +pub struct SchemaInfo { + pub name: String, + pub comment: Option, +} + +#[derive(Debug, serde::Deserialize)] +pub struct TableInfo { + pub name: String, + pub table_id: String, + pub table_type: TableType, + #[serde(default)] + pub comment: Option, + #[serde(default)] + pub storage_location: Option, + #[serde(default)] + pub data_source_format: Option, + #[serde(default)] + pub columns: Option>, +} + +#[derive(Debug, strum_macros::Display, serde::Deserialize)] +#[strum(serialize_all = "SCREAMING_SNAKE_CASE")] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum TableType { + Managed, + External, + View, + MaterializedView, + StreamingTable, + ManagedShallowClone, + Foreign, + ExternalShallowClone, +} + +#[derive(Debug, strum_macros::Display, serde::Deserialize)] +#[strum(serialize_all = "SCREAMING_SNAKE_CASE")] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum DataSourceFormat { + Delta, + Csv, + Json, + Avro, + Parquet, + Orc, + Text, + + // Databricks-specific + UnityCatalog, + Deltasharing, + DatabricksFormat, + MysqlFormat, + PostgresqlFormat, + RedshiftFormat, + SnowflakeFormat, + SqldwFormat, + SqlserverFormat, + SalesforceFormat, + BigqueryFormat, + NetsuiteFormat, + WorkdayRaasFormat, + HiveSerde, + HiveCustom, + VectorIndexFormat, +} + +#[derive(Debug, serde::Deserialize)] +pub struct ColumnInfo { + pub name: String, + pub type_text: String, + pub type_interval_type: Option, + pub position: Option, + pub comment: Option, + pub partition_index: Option, +} diff --git a/crates/polars-io/src/catalog/unity/utils.rs b/crates/polars-io/src/catalog/unity/utils.rs new file mode 100644 index 000000000000..56308f0e4466 --- /dev/null +++ b/crates/polars-io/src/catalog/unity/utils.rs @@ -0,0 +1,102 @@ +use bytes::Bytes; +use polars_error::{to_compute_err, PolarsResult}; +use reqwest::RequestBuilder; + +/// Support for traversing paginated response values that look like: +/// ```text +/// { +/// $key_name: [$T, $T, ...], +/// next_page_token: "token" or null, +/// } +/// ``` +#[macro_export] +macro_rules! impl_page_walk { + ($S:ty, $T:ty, key_name = $key_name:tt) => { + impl $S { + pub async fn next(&mut self) -> PolarsResult>> { + return self + .0 + .next(|bytes| { + let Response { + $key_name: out, + next_page_token, + } = decode_json_response(bytes)?; + + Ok((out, next_page_token)) + }) + .await; + + #[derive(serde::Deserialize)] + struct Response { + #[serde(default = "Vec::new")] + $key_name: Vec<$T>, + #[serde(default)] + next_page_token: Option, + } + } + + pub async fn read_all_pages(mut self) -> PolarsResult> { + let Some(mut out) = self.next().await? else { + return Ok(vec![]); + }; + + while let Some(v) = self.next().await? { + out.extend(v); + } + + Ok(out) + } + } + }; +} + +pub(crate) struct PageWalker { + request: RequestBuilder, + next_page_token: Option, + has_run: bool, +} + +impl PageWalker { + pub(crate) fn new(request: RequestBuilder) -> Self { + Self { + request, + next_page_token: None, + has_run: false, + } + } + + pub(crate) async fn next(&mut self, deserializer: F) -> PolarsResult> + where + F: Fn(&[u8]) -> PolarsResult<(T, Option)>, + { + let Some(resp_bytes) = self.next_bytes().await? else { + return Ok(None); + }; + + let (value, next_page_token) = deserializer(&resp_bytes)?; + self.next_page_token = next_page_token; + + Ok(Some(value)) + } + + pub(crate) async fn next_bytes(&mut self) -> PolarsResult> { + if self.has_run && self.next_page_token.is_none() { + return Ok(None); + } + + self.has_run = true; + + let request = self.request.try_clone().unwrap(); + + let request = if let Some(page_token) = self.next_page_token.take() { + request.query(&[("page_token", page_token)]) + } else { + request + }; + + async { request.send().await?.bytes().await } + .await + .map(Some) + .map_err(to_compute_err) + } +} diff --git a/crates/polars-io/src/lib.rs b/crates/polars-io/src/lib.rs index f3540f4e13fd..20c57f69b17e 100644 --- a/crates/polars-io/src/lib.rs +++ b/crates/polars-io/src/lib.rs @@ -5,6 +5,8 @@ extern crate core; #[cfg(feature = "avro")] pub mod avro; +#[cfg(feature = "catalog")] +pub mod catalog; pub mod cloud; #[cfg(any(feature = "csv", feature = "json"))] pub mod csv; diff --git a/crates/polars-io/src/path_utils/hugging_face.rs b/crates/polars-io/src/path_utils/hugging_face.rs index 8e6f8477360e..b76b011074d4 100644 --- a/crates/polars-io/src/path_utils/hugging_face.rs +++ b/crates/polars-io/src/path_utils/hugging_face.rs @@ -3,7 +3,7 @@ use std::collections::VecDeque; use std::path::PathBuf; -use polars_error::{polars_bail, polars_err, to_compute_err, PolarsResult}; +use polars_error::{polars_bail, to_compute_err, PolarsResult}; use crate::cloud::{ extract_prefix_expansion, try_build_http_header_map_from_items_slice, CloudConfig, @@ -12,6 +12,7 @@ use crate::cloud::{ use crate::path_utils::HiveIdxTracker; use crate::pl_async::with_concurrency_budget; use crate::prelude::URL_ENCODE_CHAR_SET; +use crate::utils::decode_json_response; #[derive(Debug, PartialEq)] struct HFPathParts { @@ -294,17 +295,12 @@ pub(super) async fn expand_paths_hf( client, }; - fn try_parse_api_response(bytes: &[u8]) -> PolarsResult> { - serde_json::from_slice::>(bytes).map_err( - |e| polars_err!(ComputeError: "failed to parse API response as JSON: error: {}, value: {}", e, std::str::from_utf8(bytes).unwrap()), - ) - } - if let Some(matcher) = expansion_matcher { while let Some(bytes) = gp.next().await { let bytes = bytes?; let bytes = bytes.as_ref(); - entries.extend(try_parse_api_response(bytes)?.into_iter().filter(|x| { + let response: Vec = decode_json_response(bytes)?; + entries.extend(response.into_iter().filter(|x| { !x.is_file() || (x.size > 0 && matcher.is_matching(x.path.as_str())) })); } @@ -312,11 +308,8 @@ pub(super) async fn expand_paths_hf( while let Some(bytes) = gp.next().await { let bytes = bytes?; let bytes = bytes.as_ref(); - entries.extend( - try_parse_api_response(bytes)? - .into_iter() - .filter(|x| !x.is_file() || x.size > 0), - ); + let response: Vec = decode_json_response(bytes)?; + entries.extend(response.into_iter().filter(|x| !x.is_file() || x.size > 0)); } } diff --git a/crates/polars-io/src/utils/other.rs b/crates/polars-io/src/utils/other.rs index ceec5dc46217..6496b0d54d68 100644 --- a/crates/polars-io/src/utils/other.rs +++ b/crates/polars-io/src/utils/other.rs @@ -202,6 +202,41 @@ pub fn materialize_projection( } } +/// Utility for decoding JSON that adds the response value to the error message if decoding fails. +/// This makes it much easier to debug errors from parsing network responses. +#[cfg(feature = "cloud")] +pub fn decode_json_response(bytes: &[u8]) -> PolarsResult +where + T: for<'de> serde::de::Deserialize<'de>, +{ + use polars_core::config; + use polars_error::to_compute_err; + + serde_json::from_slice(bytes) + .map_err(to_compute_err) + .map_err(|e| { + e.wrap_msg(|e| { + let maybe_truncated = if config::verbose() { + bytes + } else { + // Clamp the output on non-verbose + &bytes[..bytes.len().min(4096)] + }; + + format!( + "error decoding response: {}, response value: {}{}", + e, + String::from_utf8_lossy(maybe_truncated), + if maybe_truncated.len() != bytes.len() { + " ...(set POLARS_VERBOSE=1 to see full response)" + } else { + "" + } + ) + }) + }) +} + #[cfg(test)] mod tests { use super::FLOAT_RE; diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index ed172a178cb1..9ef1d38194cb 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -38,6 +38,7 @@ serde_json = { workspace = true } version_check = { workspace = true } [features] +catalog = ["polars-io/catalog"] nightly = ["polars-core/nightly", "polars-pipe?/nightly", "polars-plan/nightly"] streaming = ["polars-pipe", "polars-plan/streaming", "polars-ops/chunked_ids", "polars-expr/streaming"] new_streaming = ["polars-stream"] diff --git a/crates/polars-lazy/src/scan/catalog.rs b/crates/polars-lazy/src/scan/catalog.rs new file mode 100644 index 000000000000..29baa89c0463 --- /dev/null +++ b/crates/polars-lazy/src/scan/catalog.rs @@ -0,0 +1,55 @@ +use polars_core::error::{feature_gated, polars_bail, PolarsResult}; +use polars_io::catalog::schema::table_info_to_schemas; +use polars_io::catalog::unity::models::{DataSourceFormat, TableInfo}; +use polars_io::cloud::CloudOptions; + +use crate::frame::LazyFrame; + +impl LazyFrame { + pub fn scan_catalog_table( + table_info: &TableInfo, + cloud_options: Option, + ) -> PolarsResult { + let Some(data_source_format) = &table_info.data_source_format else { + polars_bail!(ComputeError: "scan_catalog_table requires Some(_) for data_source_format") + }; + + let Some(storage_location) = table_info.storage_location.as_deref() else { + polars_bail!(ComputeError: "scan_catalog_table requires Some(_) for storage_location") + }; + + match data_source_format { + DataSourceFormat::Parquet => feature_gated!("parquet", { + use polars_io::HiveOptions; + + use crate::frame::ScanArgsParquet; + let (schema, hive_schema) = table_info_to_schemas(table_info)?; + + let args = ScanArgsParquet { + schema, + cloud_options, + hive_options: HiveOptions { + schema: hive_schema, + ..Default::default() + }, + ..Default::default() + }; + + Self::scan_parquet(storage_location, args) + }), + DataSourceFormat::Csv => feature_gated!("csv", { + use crate::frame::{LazyCsvReader, LazyFileListReader}; + let (schema, _) = table_info_to_schemas(table_info)?; + + LazyCsvReader::new(storage_location) + .with_schema(schema) + .finish() + }), + v => polars_bail!( + ComputeError: + "not yet supported data_source_format: {:?}", + v + ), + } + } +} diff --git a/crates/polars-lazy/src/scan/mod.rs b/crates/polars-lazy/src/scan/mod.rs index b868bfe909f5..86792bd8b5c9 100644 --- a/crates/polars-lazy/src/scan/mod.rs +++ b/crates/polars-lazy/src/scan/mod.rs @@ -8,3 +8,6 @@ pub(super) mod ipc; pub(super) mod ndjson; #[cfg(feature = "parquet")] pub(super) mod parquet; + +#[cfg(feature = "catalog")] +mod catalog; diff --git a/crates/polars-python/Cargo.toml b/crates/polars-python/Cargo.toml index 14c918b44097..5376bf87388e 100644 --- a/crates/polars-python/Cargo.toml +++ b/crates/polars-python/Cargo.toml @@ -121,6 +121,7 @@ version_check = { workspace = true } [features] # Features below are only there to enable building a slim binary during development. avro = ["polars/avro"] +catalog = ["polars-lazy/catalog"] parquet = ["polars/parquet", "polars-parquet"] ipc = ["polars/ipc"] ipc_streaming = ["polars/ipc_streaming"] diff --git a/crates/polars-python/src/catalog/mod.rs b/crates/polars-python/src/catalog/mod.rs new file mode 100644 index 000000000000..d014d07832c5 --- /dev/null +++ b/crates/polars-python/src/catalog/mod.rs @@ -0,0 +1,243 @@ +use polars::prelude::LazyFrame; +use polars_io::catalog::unity::client::{CatalogClient, CatalogClientBuilder}; +use polars_io::catalog::unity::models::{CatalogInfo, ColumnInfo, SchemaInfo, TableInfo}; +use polars_io::cloud::credential_provider::PlCredentialProvider; +use polars_io::pl_async; +use pyo3::exceptions::PyValueError; +use pyo3::types::{PyAnyMethods, PyDict, PyList}; +use pyo3::{pyclass, pymethods, Bound, PyObject, PyResult, Python}; + +use crate::lazyframe::PyLazyFrame; +use crate::prelude::parse_cloud_options; +use crate::utils::to_py_err; + +macro_rules! pydict_insert_keys { + ($dict:expr, {$a:expr}) => { + $dict.set_item(stringify!($a), $a).unwrap(); + }; + + ($dict:expr, {$a:expr, $($args:expr),+}) => { + pydict_insert_keys!($dict, { $a }); + pydict_insert_keys!($dict, { $($args),+ }); + }; + + ($dict:expr, {$a:expr, $($args:expr),+,}) => { + pydict_insert_keys!($dict, {$a, $($args),+}); + }; +} + +#[pyclass] +pub struct PyCatalogClient(CatalogClient); + +#[pymethods] +impl PyCatalogClient { + #[pyo3(signature = (workspace_url, bearer_token))] + #[staticmethod] + pub fn new(workspace_url: String, bearer_token: Option) -> PyResult { + let builder = CatalogClientBuilder::new().with_workspace_url(workspace_url); + + let builder = if let Some(bearer_token) = bearer_token { + builder.with_bearer_token(bearer_token) + } else { + builder + }; + + builder.build().map(PyCatalogClient).map_err(to_py_err) + } + + pub fn list_catalogs(&self, py: Python) -> PyResult { + let v = py + .allow_threads(|| { + pl_async::get_runtime().block_on_potential_spawn(self.client().list_catalogs()) + }) + .map_err(to_py_err)?; + + PyList::new( + py, + v.into_iter().map(|CatalogInfo { name, comment }| { + let dict = PyDict::new(py); + + pydict_insert_keys!(dict, { + name, + comment, + }); + + dict + }), + ) + .map(|x| x.into()) + } + + #[pyo3(signature = (catalog_name))] + pub fn list_schemas(&self, py: Python, catalog_name: &str) -> PyResult { + let v = py + .allow_threads(|| { + pl_async::get_runtime() + .block_on_potential_spawn(self.client().list_schemas(catalog_name)) + }) + .map_err(to_py_err)?; + + PyList::new( + py, + v.into_iter().map(|SchemaInfo { name, comment }| { + let dict = PyDict::new(py); + + pydict_insert_keys!(dict, { + name, + comment, + }); + + dict + }), + ) + .map(|x| x.into()) + } + + #[pyo3(signature = (catalog_name, schema_name))] + pub fn list_tables( + &self, + py: Python, + catalog_name: &str, + schema_name: &str, + ) -> PyResult { + let v = py + .allow_threads(|| { + pl_async::get_runtime() + .block_on_potential_spawn(self.client().list_tables(catalog_name, schema_name)) + }) + .map_err(to_py_err)?; + + PyList::new( + py, + v.into_iter() + .map(|table_entry| table_entry_to_pydict(py, table_entry)), + ) + .map(|x| x.into()) + } + + #[pyo3(signature = (catalog_name, schema_name, table_name))] + pub fn get_table_info( + &self, + py: Python, + catalog_name: &str, + schema_name: &str, + table_name: &str, + ) -> PyResult { + let table_entry = py + .allow_threads(|| { + pl_async::get_runtime().block_on_potential_spawn(self.client().get_table_info( + catalog_name, + schema_name, + table_name, + )) + }) + .map_err(to_py_err)?; + + Ok(table_entry_to_pydict(py, table_entry).into()) + } + + #[pyo3(signature = (catalog_name, schema_name, table_name, cloud_options, credential_provider, retries))] + pub fn scan_table( + &self, + py: Python, + catalog_name: &str, + schema_name: &str, + table_name: &str, + cloud_options: Option>, + credential_provider: Option, + retries: usize, + ) -> PyResult { + let table_info = py + .allow_threads(|| { + pl_async::get_runtime().block_on_potential_spawn(self.client().get_table_info( + catalog_name, + schema_name, + table_name, + )) + }) + .map_err(to_py_err)?; + + let Some(storage_location) = table_info.storage_location.as_deref() else { + return Err(PyValueError::new_err( + "cannot scan catalog table: no storage_location found", + )); + }; + + let cloud_options = + parse_cloud_options(storage_location, cloud_options.unwrap_or_default())? + .with_max_retries(retries) + .with_credential_provider( + credential_provider.map(PlCredentialProvider::from_python_func_object), + ); + + Ok( + LazyFrame::scan_catalog_table(&table_info, Some(cloud_options)) + .map_err(to_py_err)? + .into(), + ) + } +} + +impl PyCatalogClient { + fn client(&self) -> &CatalogClient { + &self.0 + } +} + +fn table_entry_to_pydict(py: Python, table_entry: TableInfo) -> Bound<'_, PyDict> { + let TableInfo { + name, + comment, + table_id, + table_type, + storage_location, + data_source_format, + columns, + } = table_entry; + + let dict = PyDict::new(py); + + let columns = columns.map(|columns| { + columns + .into_iter() + .map( + |ColumnInfo { + name, + type_text, + type_interval_type, + position, + comment, + partition_index, + }| { + let dict = PyDict::new(py); + + pydict_insert_keys!(dict, { + name, + type_text, + type_interval_type, + position, + comment, + partition_index, + }); + + dict + }, + ) + .collect::>() + }); + + let data_source_format = data_source_format.map(|x| x.to_string()); + let table_type = table_type.to_string(); + + pydict_insert_keys!(dict, { + name, + comment, + table_id, + table_type, + storage_location, + data_source_format, + columns, + }); + + dict +} diff --git a/crates/polars-python/src/lib.rs b/crates/polars-python/src/lib.rs index 640e9d5d7785..4a2471189298 100644 --- a/crates/polars-python/src/lib.rs +++ b/crates/polars-python/src/lib.rs @@ -7,6 +7,8 @@ #[cfg(feature = "csv")] pub mod batched_csv; +#[cfg(feature = "catalog")] +pub mod catalog; #[cfg(feature = "polars_cloud")] pub mod cloud; pub mod conversion; diff --git a/crates/polars-python/src/utils.rs b/crates/polars-python/src/utils.rs index 703a95cfd74e..ffd2d175f885 100644 --- a/crates/polars-python/src/utils.rs +++ b/crates/polars-python/src/utils.rs @@ -1,3 +1,7 @@ +use pyo3::PyErr; + +use crate::error::PyPolarsErr; + // was redefined because I could not get feature flags activated? #[macro_export] macro_rules! apply_method_all_arrow_series2 { @@ -24,3 +28,9 @@ macro_rules! apply_method_all_arrow_series2 { } } } + +/// Boilerplate for `|e| PyPolarsErr::from(e).into()` +#[allow(unused)] +pub(crate) fn to_py_err>(e: E) -> PyErr { + e.into().into() +} diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index fdf7217a5d85..c99066bc24fa 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -35,6 +35,7 @@ sql = ["polars-python/sql"] trigonometry = ["polars-python/trigonometry"] parquet = ["polars-python/parquet"] ipc = ["polars-python/ipc"] +catalog = ["polars-python/catalog"] # Features passed through to the polars-python crate avro = ["polars-python/avro"] @@ -103,6 +104,7 @@ all = [ "trigonometry", "parquet", "ipc", + "catalog", "polars-python/all", "performant", ] diff --git a/py-polars/docs/source/reference/catalog.rst b/py-polars/docs/source/reference/catalog.rst new file mode 100644 index 000000000000..921ee66a25d0 --- /dev/null +++ b/py-polars/docs/source/reference/catalog.rst @@ -0,0 +1,20 @@ +======= +Catalog +======= +.. currentmodule:: polars + + +Unity Catalog +~~~~~~~~~~~~~ + +Interface with Unity catalogs. + +.. autosummary:: + :toctree: api/ + + Catalog + Catalog.list_catalogs + Catalog.list_schemas + Catalog.list_tables + Catalog.get_table_info + Catalog.scan_table diff --git a/py-polars/docs/source/reference/index.rst b/py-polars/docs/source/reference/index.rst index 1b5116eea4b5..fa8a64c3ad82 100644 --- a/py-polars/docs/source/reference/index.rst +++ b/py-polars/docs/source/reference/index.rst @@ -65,6 +65,13 @@ methods. All classes and functions exposed in the ``polars.*`` namespace are pub io + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + catalog + .. grid-item-card:: .. toctree:: diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 67b4413d2b61..7a502aaf0ebc 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -23,6 +23,7 @@ # TODO: remove need for importing wrap utils at top level from polars._utils.wrap import wrap_df, wrap_s # noqa: F401 +from polars.catalog import Catalog from polars.config import Config from polars.convert import ( from_arrow, @@ -278,6 +279,7 @@ "scan_ndjson", "scan_parquet", "scan_pyarrow_dataset", + "Catalog", # polars.io.cloud "CredentialProvider", "CredentialProviderAWS", diff --git a/py-polars/polars/catalog.py b/py-polars/polars/catalog.py new file mode 100644 index 000000000000..d57d01b783f8 --- /dev/null +++ b/py-polars/polars/catalog.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +import importlib +import os +from typing import TYPE_CHECKING, Any, Literal, TypedDict + +from polars._utils.wrap import wrap_ldf + +if TYPE_CHECKING: + from datetime import datetime + + from polars.io.cloud import CredentialProviderFunction + from polars.lazyframe import LazyFrame + + +class Catalog: + """ + Unity catalog client. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + """ + + def __init__( + self, + workspace_url: str, + *, + bearer_token: str | None = "auto", + ) -> None: + """ + Initialize a catalog client. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + + Parameters + ---------- + workspace_url + URL of the workspace, or alternatively the URL of the Unity catalog + API endpoint. + bearer_token + Bearer token to authenticate with. This can also be set to: + * "auto": Automatically retrieve bearer tokens from the environment. + * "databricks-sdk": Use the Databricks SDK to retrieve and use the + bearer token from the environment. + """ + from polars.polars import PyCatalogClient + + if bearer_token == "databricks-sdk" or ( + bearer_token == "auto" + # For security, in "auto" mode, only retrieve/use the token if: + # * We are running inside a Databricks environment + # * The `workspace_url` is pointing to Databricks + and "DATABRICKS_RUNTIME_VERSION" in os.environ + and ( + workspace_url.removeprefix("https://") + .split("/", 1)[0] + .endswith(".cloud.databricks.com") + ) + ): + bearer_token = self._get_databricks_token() + + if bearer_token == "auto": + bearer_token = None + + self._client = PyCatalogClient.new(workspace_url, bearer_token) + + def list_catalogs(self) -> list[CatalogInfo]: + """ + List the available catalogs. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + """ + return self._client.list_catalogs() + + def list_schemas(self, catalog_name: str) -> list[SchemaInfo]: + """ + List the available schemas under the specified catalog. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + + Parameters + ---------- + catalog_name + Name of the catalog. + """ + return self._client.list_schemas(catalog_name) + + def list_tables(self, catalog_name: str, schema_name: str) -> list[TableInfo]: + """ + List the available tables under the specified schema. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + + Parameters + ---------- + catalog_name + Name of the catalog. + schema_name + Name of the schema. + """ + return self._client.list_tables(catalog_name, schema_name) + + def get_table_info( + self, catalog_name: str, schema_name: str, table_name: str + ) -> TableInfo: + """ + Retrieve the metadata of the specified table. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + + Parameters + ---------- + catalog_name + Name of the catalog. + schema_name + Name of the schema. + table_name + Name of the table. + """ + return self._client.get_table_info(catalog_name, schema_name, table_name) + + def scan_table( + self, + catalog_name: str, + schema_name: str, + table_name: str, + *, + delta_table_version: int | str | datetime | None = None, + delta_table_options: dict[str, Any] | None = None, + storage_options: dict[str, Any] | None = None, + credential_provider: ( + CredentialProviderFunction | Literal["auto"] | None + ) = "auto", + retries: int = 2, + ) -> LazyFrame: + """ + Retrieve the metadata of the specified table. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + + Parameters + ---------- + catalog_name + Name of the catalog. + schema_name + Name of the schema. + table_name + Name of the table. + delta_table_version + Version of the table to scan (Deltalake only). + delta_table_options + Additional keyword arguments while reading a Deltalake table. + storage_options + Options that indicate how to connect to a cloud provider. + + The cloud providers currently supported are AWS, GCP, and Azure. + See supported keys here: + + * `aws `_ + * `gcp `_ + * `azure `_ + * Hugging Face (`hf://`): Accepts an API key under the `token` parameter: \ + `{'token': '...'}`, or by setting the `HF_TOKEN` environment variable. + + If `storage_options` is not provided, Polars will try to infer the + information from environment variables. + credential_provider + Provide a function that can be called to provide cloud storage + credentials. The function is expected to return a dictionary of + credential keys along with an optional credential expiry time. + + .. warning:: + This functionality is considered **unstable**. It may be changed + at any point without it being considered a breaking change. + retries + Number of retries if accessing a cloud instance fails. + + """ + table_info = self.get_table_info(catalog_name, schema_name, table_name) + + if (source := table_info.get("storage_location")) is None: + msg = "cannot scan catalog table: no storage_location found" + raise ValueError(msg) + + if (data_source_format := table_info.get("data_source_format")) is None: + msg = "cannot scan catalog table: no data_source_format found" + raise ValueError(msg) + + if data_source_format in ["DELTA", "DELTA_SHARING"]: + from polars.io.delta import scan_delta + + if credential_provider is not None and credential_provider != "auto": + msg = "credential_provider when scanning DELTA" + raise NotImplementedError(msg) + + return scan_delta( + source, + version=delta_table_version, + delta_table_options=delta_table_options, + storage_options=storage_options, + ) + + if delta_table_version is not None: + msg = ( + "cannot apply delta_table_version for table of type " + f"{data_source_format}" + ) + raise ValueError(msg) + + if delta_table_options is not None: + msg = ( + "cannot apply delta_table_options for table of type " + f"{data_source_format}" + ) + raise ValueError(msg) + + from polars.io.cloud.credential_provider import _maybe_init_credential_provider + + credential_provider = _maybe_init_credential_provider( + credential_provider, source, storage_options, "Catalog.scan_table" + ) + + if storage_options: + storage_options = list(storage_options.items()) # type: ignore[assignment] + else: + # Handle empty dict input + storage_options = None + + return wrap_ldf( + self._client.scan_table( + catalog_name, + schema_name, + table_name, + credential_provider=credential_provider, + cloud_options=storage_options, + retries=retries, + ) + ) + + @classmethod + def _get_databricks_token(cls) -> str: + cls._ensure_databricks_sdk_available() + + # We code like this to bypass linting + m = importlib.import_module("databricks.sdk.core").__dict__ + + return m["DefaultCredentials"]()(m["Config"]())()["Authorization"][7:] + + @staticmethod + def _ensure_databricks_sdk_available() -> None: + if importlib.util.find_spec("databricks.sdk") is None: + msg = "could not get Databricks token: databricks-sdk is not installed" + raise ImportError(msg) + + +class CatalogInfo(TypedDict): + """Information for a catalog within a metastore.""" + + name: str + comment: str | None + + +class SchemaInfo(TypedDict): + """Information for a schema within a catalog.""" + + name: str + comment: str | None + + +class TableInfo(TypedDict): + """Information for a catalog table.""" + + name: str + comment: str | None + table_id: str + table_type: TableType + storage_location: str | None + data_source_format: DataSourceFormat | None + columns: list[ColumnInfo] | None + + +class ColumnInfo(TypedDict): + """Information for a column within a catalog table.""" + + name: str + type_text: str + type_interval_type: str | None + position: int | None + comment: str | None + partition_index: int | None + + +TableType = Literal[ + "MANAGED", + "EXTERNAL", + "VIEW", + "MATERIALIZED_VIEW", + "STREAMING_TABLE", + "MANAGED_SHALLOW_CLONE", + "FOREIGN", + "EXTERNAL_SHALLOW_CLONE", +] + +DataSourceFormat = Literal[ + "DELTA", + "CSV", + "JSON", + "AVRO", + "PARQUET", + "ORC", + "TEXT", + "UNITY_CATALOG", + "DELTA_SHARING", + "DATABRICKS_FORMAT", + "REDSHIFT_FORMAT", + "SNOWFLAKE_FORMAT", + "SQLDW_FORMAT", + "SALESFORCE_FORMAT", + "BIGQUERY_FORMAT", + "NETSUITE_FORMAT", + "WORKDAY_RAAS_FORMAT", + "HIVE_SERDE", + "HIVE_CUSTOM", + "VECTOR_INDEX_FORMAT", +] diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index 381a56dd7153..879c03335597 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -10,6 +10,8 @@ mod memory; use allocator::create_allocator_capsule; #[cfg(feature = "csv")] use polars_python::batched_csv::PyBatchedCsv; +#[cfg(feature = "catalog")] +use polars_python::catalog::PyCatalogClient; #[cfg(feature = "polars_cloud")] use polars_python::cloud; use polars_python::dataframe::PyDataFrame; @@ -238,6 +240,8 @@ fn polars(py: Python, m: &Bound) -> PyResult<()> { #[cfg(feature = "clipboard")] m.add_wrapped(wrap_pyfunction!(functions::write_clipboard_string)) .unwrap(); + #[cfg(feature = "catalog")] + m.add_class::().unwrap(); // Functions - meta m.add_wrapped(wrap_pyfunction!(functions::get_index_type))