From fa2cc221e2c628a8ef26281baf9852cc62b9e1f4 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 14 Jan 2025 13:58:45 +0400 Subject: [PATCH] feat: Add SQL support for the `NORMALIZE` string function --- Cargo.lock | 3 -- crates/polars-sql/Cargo.toml | 3 -- crates/polars-sql/src/functions.rs | 41 +++++++++++++++++++++++- py-polars/tests/unit/sql/test_strings.py | 15 +++++++++ 4 files changed, 55 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 591eaae4e817..67f867188808 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3482,8 +3482,6 @@ name = "polars-sql" version = "0.45.1" dependencies = [ "hex", - "once_cell", - "polars-arrow", "polars-core", "polars-error", "polars-lazy", @@ -3493,7 +3491,6 @@ dependencies = [ "polars-utils", "rand", "serde", - "serde_json", "sqlparser", ] diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 27fff7d868fb..3cd2cafabb91 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -9,7 +9,6 @@ repository = { workspace = true } description = "SQL transpiler for Polars. Converts SQL to Polars logical plans" [dependencies] -arrow = { workspace = true } polars-core = { workspace = true, features = ["rows"] } polars-error = { workspace = true } polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "dtype-struct", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_normalize", "string_reverse", "strings", "timezones", "trigonometry"] } @@ -19,10 +18,8 @@ polars-time = { workspace = true } polars-utils = { workspace = true } hex = { workspace = true } -once_cell = { workspace = true } rand = { workspace = true } serde = { workspace = true } -serde_json = { workspace = true } sqlparser = { workspace = true } [dev-dependencies] diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index f8c7adde01ef..2712a44a2e27 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -8,6 +8,7 @@ use polars_core::prelude::{ use polars_lazy::dsl::Expr; #[cfg(feature = "list_eval")] use polars_lazy::dsl::ListNameSpaceExtension; +use polars_ops::chunked_array::UnicodeForm; use polars_plan::dsl::{coalesce, concat_str, len, max_horizontal, min_horizontal, when}; use polars_plan::plans::{typed_lit, LiteralValue}; use polars_plan::prelude::LiteralValue::Null; @@ -376,6 +377,13 @@ pub(crate) enum PolarsSQLFunctions { /// SELECT LTRIM(column_1) FROM df; /// ``` LTrim, + /// SQL 'normalize' function + /// Convert string to Unicode normalization form + /// (one of "NFC", "NFKC", "NFD", or "NFKD"). + /// ```sql + /// SELECT NORMALIZE(column_1, 'NFC') FROM df; + /// ``` + Normalize, /// SQL 'octet_length' function /// Returns the length of a given string in bytes. /// ```sql @@ -391,7 +399,7 @@ pub(crate) enum PolarsSQLFunctions { /// SQL 'replace' function /// Replace a given substring with another string. /// ```sql - /// SELECT REPLACE(column_1,'old','new') FROM df; + /// SELECT REPLACE(column_1, 'old', 'new') FROM df; /// ``` Replace, /// SQL 'reverse' function @@ -859,6 +867,7 @@ impl PolarsSQLFunctions { "left" => Self::Left, "lower" => Self::Lower, "ltrim" => Self::LTrim, + "normalize" => Self::Normalize, "octet_length" => Self::OctetLength, "strpos" => Self::StrPos, "regexp_like" => Self::RegexpLike, @@ -1152,6 +1161,36 @@ impl SQLFunctionVisitor<'_> { }, } }, + Normalize => { + let args = extract_args(function)?; + match args.len() { + 1 => self.visit_unary(|e| e.str().normalize(UnicodeForm::NFC)), + 2 => { + let form = if let FunctionArgExpr::Expr(SQLExpr::Identifier(Ident { + value: s, + quote_style: None, + span: _, + })) = args[1] + { + match s.to_uppercase().as_str() { + "NFC" => UnicodeForm::NFC, + "NFD" => UnicodeForm::NFD, + "NFKC" => UnicodeForm::NFKC, + "NFKD" => UnicodeForm::NFKD, + _ => { + polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", s) + }, + } + } else { + polars_bail!(SQLSyntax: "invalid 'form' for NORMALIZE (found {})", args[1]) + }; + self.try_visit_binary(|e, _form: Expr| Ok(e.str().normalize(form.clone()))) + }, + _ => { + polars_bail!(SQLSyntax: "NORMALIZE expects 1-2 arguments (found {})", args.len()) + }, + } + }, OctetLength => self.visit_unary(|e| e.str().len_bytes()), StrPos => { // // note: SQL is 1-indexed; returns zero if no match found diff --git a/py-polars/tests/unit/sql/test_strings.py b/py-polars/tests/unit/sql/test_strings.py index 0405a47e665e..63f16126a362 100644 --- a/py-polars/tests/unit/sql/test_strings.py +++ b/py-polars/tests/unit/sql/test_strings.py @@ -275,6 +275,21 @@ def test_string_like_multiline() -> None: assert df.sql(f"SELECT txt FROM self WHERE txt LIKE '{s}'").item() == s +@pytest.mark.parametrize("form", ["NFKC", "NFKD"]) +def test_string_normalize(form: str) -> None: + df = pl.DataFrame({"txt": ["Test", "𝕋𝕖𝕀π•₯", "π•Ώπ–Šπ–˜π–™", "π—§π—²π˜€π˜", "Ⓣⓔⓒⓣ"]}) # noqa: RUF001 + res = df.sql( + f""" + SELECT txt, NORMALIZE(txt,{form}) AS norm_txt + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "txt": ["Test", "𝕋𝕖𝕀π•₯", "π•Ώπ–Šπ–˜π–™", "π—§π—²π˜€π˜", "Ⓣⓔⓒⓣ"], # noqa: RUF001 + "norm_txt": ["Test", "Test", "Test", "Test", "Test"], + } + + def test_string_position() -> None: df = pl.Series( name="city",