diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 5578517ec359..0f45d51835f4 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -28,16 +28,18 @@ runs: - name: Install Build Dependencies shell: bash run: | - apt-get update - apt-get install -y protobuf-compiler + RETRY="ci/scripts/retry" + "${RETRY}" apt-get update + "${RETRY}" apt-get install -y protobuf-compiler - name: Setup Rust toolchain shell: bash # rustfmt is needed for the substrait build script run: | + RETRY="ci/scripts/retry" echo "Installing ${{ inputs.rust-version }}" - rustup toolchain install ${{ inputs.rust-version }} - rustup default ${{ inputs.rust-version }} - rustup component add rustfmt + "${RETRY}" rustup toolchain install ${{ inputs.rust-version }} + "${RETRY}" rustup default ${{ inputs.rust-version }} + "${RETRY}" rustup component add rustfmt - name: Configure rust runtime env uses: ./.github/actions/setup-rust-runtime - name: Fixup git permissions diff --git a/Cargo.toml b/Cargo.toml index b8bf83a5ab53..448607257ca1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,22 +70,22 @@ version = "42.0.0" ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } -arrow = { version = "53.0.0", features = [ +arrow = { version = "53.1.0", features = [ "prettyprint", ] } -arrow-array = { version = "53.0.0", default-features = false, features = [ +arrow-array = { version = "53.1.0", default-features = false, features = [ "chrono-tz", ] } -arrow-buffer = { version = "53.0.0", default-features = false } -arrow-flight = { version = "53.0.0", features = [ +arrow-buffer = { version = "53.1.0", default-features = false } +arrow-flight = { version = "53.1.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "53.0.0", default-features = false, features = [ +arrow-ipc = { version = "53.1.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "53.0.0", default-features = false } -arrow-schema = { version = "53.0.0", default-features = false } -arrow-string = { version = "53.0.0", default-features = false } +arrow-ord = { version = "53.1.0", default-features = false } +arrow-schema = { version = "53.1.0", default-features = false } +arrow-string = { version = "53.1.0", default-features = false } async-trait = "0.1.73" bigdecimal = "=0.4.1" bytes = "1.4" @@ -126,7 +126,7 @@ log = "^0.4" num_cpus = "1.13.0" object_store = { version = "0.11.0", default-features = false } parking_lot = "0.12" -parquet = { version = "53.0.0", default-features = false, features = [ +parquet = { version = "53.1.0", default-features = false, features = [ "arrow", "async", "object_store", diff --git a/ci/scripts/retry b/ci/scripts/retry new file mode 100755 index 000000000000..0569dea58c94 --- /dev/null +++ b/ci/scripts/retry @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +set -euo pipefail + +x() { + echo "+ $*" >&2 + "$@" +} + +max_retry_time_seconds=$(( 3 * 60 )) +retry_delay_seconds=10 + +END=$(( $(date +%s) + ${max_retry_time_seconds} )) + +while (( $(date +%s) < $END )); do + x "$@" && exit 0 + sleep "${retry_delay_seconds}" +done + +echo "$0: retrying [$*] timed out" >&2 +exit 1 diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 8c77ea8a2557..aa64e14fca8e 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5fb1d8e4442bd405fdfd1dacb42792696b0cf9cb15882e5d097b742a676d375" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ "gimli", ] @@ -173,9 +173,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45aef0d9cf9a039bf6cd1acc451b137aca819977b0928dece52bd92811b640ba" +checksum = "a9ba0d7248932f4e2a12fb37f0a2e3ec82b3bdedbac2a1dce186e036843b8f8c" dependencies = [ "arrow-arith", "arrow-array", @@ -194,9 +194,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03675e42d1560790f3524800e41403b40d0da1c793fe9528929fde06d8c7649a" +checksum = "d60afcdc004841a5c8d8da4f4fa22d64eb19c0c01ef4bcedd77f175a7cf6e38f" dependencies = [ "arrow-array", "arrow-buffer", @@ -209,9 +209,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd2bf348cf9f02a5975c5962c7fa6dee107a2009a7b41ac5fb1a027e12dc033f" +checksum = "7f16835e8599dbbb1659fd869d865254c4cf32c6c2bb60b6942ac9fc36bfa5da" dependencies = [ "ahash", "arrow-buffer", @@ -220,15 +220,15 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown", + "hashbrown 0.14.5", "num", ] [[package]] name = "arrow-buffer" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3092e37715f168976012ce52273c3989b5793b0db5f06cbaa246be25e5f0924d" +checksum = "1a1f34f0faae77da6b142db61deba2cb6d60167592b178be317b341440acba80" dependencies = [ "bytes", "half", @@ -237,9 +237,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ce1018bb710d502f9db06af026ed3561552e493e989a79d0d0f5d9cf267a785" +checksum = "450e4abb5775bca0740bec0bcf1b1a5ae07eff43bd625661c4436d8e8e4540c4" dependencies = [ "arrow-array", "arrow-buffer", @@ -258,9 +258,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd178575f45624d045e4ebee714e246a05d9652e41363ee3f57ec18cca97f740" +checksum = "d3a4e4d63830a341713e35d9a42452fbc6241d5f42fa5cf6a4681b8ad91370c4" dependencies = [ "arrow-array", "arrow-buffer", @@ -277,9 +277,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4ac0c4ee79150afe067dc4857154b3ee9c1cd52b5f40d59a77306d0ed18d65" +checksum = "2b1e618bbf714c7a9e8d97203c806734f012ff71ae3adc8ad1b075689f540634" dependencies = [ "arrow-buffer", "arrow-schema", @@ -289,9 +289,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb307482348a1267f91b0912e962cd53440e5de0f7fb24c5f7b10da70b38c94a" +checksum = "f98e983549259a2b97049af7edfb8f28b8911682040e99a94e4ceb1196bd65c2" dependencies = [ "arrow-array", "arrow-buffer", @@ -304,9 +304,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d24805ba326758effdd6f2cbdd482fcfab749544f21b134701add25b33f474e6" +checksum = "b198b9c6fcf086501730efbbcb483317b39330a116125af7bb06467d04b352a3" dependencies = [ "arrow-array", "arrow-buffer", @@ -324,9 +324,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "644046c479d80ae8ed02a7f1e1399072ea344ca6a7b0e293ab2d5d9ed924aa3b" +checksum = "2427f37b4459a4b9e533045abe87a5183a5e0995a3fc2c2fd45027ae2cc4ef3f" dependencies = [ "arrow-array", "arrow-buffer", @@ -339,9 +339,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a29791f8eb13b340ce35525b723f5f0df17ecb955599e11f65c2a94ab34e2efb" +checksum = "15959657d92e2261a7a323517640af87f5afd9fd8a6492e424ebee2203c567f6" dependencies = [ "ahash", "arrow-array", @@ -353,15 +353,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85320a3a2facf2b2822b57aa9d6d9d55edb8aee0b6b5d3b8df158e503d10858" +checksum = "fbf0388a18fd7f7f3fe3de01852d30f54ed5182f9004db700fbe3ba843ed2794" [[package]] name = "arrow-select" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cc7e6b582e23855fd1625ce46e51647aa440c20ea2e71b1d748e0839dd73cba" +checksum = "b83e5723d307a38bf00ecd2972cd078d1339c7fd3eb044f609958a9a24463f3a" dependencies = [ "ahash", "arrow-array", @@ -373,9 +373,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0775b6567c66e56ded19b87a954b6b1beffbdd784ef95a3a2b03f59570c1d230" +checksum = "7ab3db7c09dd826e74079661d84ed01ed06547cf75d52c2818ef776d0d852305" dependencies = [ "arrow-array", "arrow-buffer", @@ -406,9 +406,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fec134f64e2bc57411226dfc4e52dec859ddfc7e711fc5e07b612584f000e4aa" +checksum = "7e614738943d3f68c628ae3dbce7c3daffb196665f82f8c8ea6b65de73c79429" dependencies = [ "bzip2", "flate2", @@ -456,9 +456,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.5.7" +version = "1.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8191fb3091fa0561d1379ef80333c3c7191c6f0435d986e85821bcf7acbd1126" +checksum = "7198e6f03240fdceba36656d8be440297b6b82270325908c7381f37d826a74f6" dependencies = [ "aws-credential-types", "aws-runtime", @@ -523,9 +523,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.44.0" +version = "1.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b90cfe6504115e13c41d3ea90286ede5aa14da294f3fe077027a6e83850843c" +checksum = "e33ae899566f3d395cbf42858e433930682cc9c1889fa89318896082fef45efb" dependencies = [ "aws-credential-types", "aws-runtime", @@ -545,9 +545,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.45.0" +version = "1.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "167c0fad1f212952084137308359e8e4c4724d1c643038ce163f06de9662c1d0" +checksum = "f39c09e199ebd96b9f860b0fce4b6625f211e064ad7c8693b72ecf7ef03881e0" dependencies = [ "aws-credential-types", "aws-runtime", @@ -567,9 +567,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.44.0" +version = "1.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cb5f98188ec1435b68097daa2a37d74b9d17c9caa799466338a8d1544e71b9d" +checksum = "3d95f93a98130389eb6233b9d615249e543f6c24a68ca1f109af9ca5164a8765" dependencies = [ "aws-credential-types", "aws-runtime", @@ -917,9 +917,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.22" +version = "1.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9540e661f81799159abee814118cc139a2004b3a3aa3ea37724a1b66530b90e0" +checksum = "2e80e3b6a3ab07840e1cae9b0666a63970dc28e8ed5ffbcdacbfc760c281bfc1" dependencies = [ "jobserver", "libc", @@ -953,9 +953,9 @@ dependencies = [ [[package]] name = "chrono-tz" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93698b29de5e97ad0ae26447b344c482a7284c737d9ddc5f9e52b74a336671bb" +checksum = "cd6dd8046d00723a59a2f8c5f295c515b9bb9a331ee4f8f3d4dd49e428acd3b6" dependencies = [ "chrono", "chrono-tz-build", @@ -964,20 +964,19 @@ dependencies = [ [[package]] name = "chrono-tz-build" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c088aee841df9c3041febbb73934cfc39708749bf96dc827e3359cd39ef11b1" +checksum = "e94fea34d77a245229e7746bd2beb786cd2a896f306ff491fb8cecb3074b10a7" dependencies = [ "parse-zoneinfo", - "phf", "phf_codegen", ] [[package]] name = "clap" -version = "4.5.18" +version = "4.5.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0956a43b323ac1afaffc053ed5c4b7c1f1800bacd1683c353aabbb752515dd3" +checksum = "7be5744db7978a28d9df86a214130d106a89ce49644cbc4e3f0c22c3fba30615" dependencies = [ "clap_builder", "clap_derive", @@ -985,9 +984,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.18" +version = "4.5.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d72166dd41634086d5803a47eb71ae740e61d84709c36f3c34110173db3961b" +checksum = "a5fbc17d3ef8278f55b282b2a2e75ae6f6c7d4bb70ed3d0382375104bfafdb4b" dependencies = [ "anstream", "anstyle", @@ -1175,7 +1174,7 @@ checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" dependencies = [ "cfg-if", "crossbeam-utils", - "hashbrown", + "hashbrown 0.14.5", "lock_api", "once_cell", "parking_lot_core", @@ -1216,7 +1215,7 @@ dependencies = [ "futures", "glob", "half", - "hashbrown", + "hashbrown 0.14.5", "indexmap", "itertools", "log", @@ -1293,7 +1292,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "hashbrown", + "hashbrown 0.14.5", "instant", "libc", "num_cpus", @@ -1322,7 +1321,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "futures", - "hashbrown", + "hashbrown 0.14.5", "log", "object_store", "parking_lot", @@ -1375,7 +1374,7 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", - "hashbrown", + "hashbrown 0.14.5", "hex", "itertools", "log", @@ -1456,6 +1455,7 @@ name = "datafusion-functions-window-common" version = "42.0.0" dependencies = [ "datafusion-common", + "datafusion-physical-expr-common", ] [[package]] @@ -1468,7 +1468,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "hashbrown", + "hashbrown 0.14.5", "indexmap", "itertools", "log", @@ -1487,23 +1487,19 @@ dependencies = [ "arrow-ord", "arrow-schema", "arrow-string", - "base64 0.22.1", "chrono", "datafusion-common", - "datafusion-execution", "datafusion-expr", "datafusion-expr-common", "datafusion-functions-aggregate-common", "datafusion-physical-expr-common", "half", - "hashbrown", - "hex", + "hashbrown 0.14.5", "indexmap", "itertools", "log", "paste", "petgraph", - "regex", ] [[package]] @@ -1514,7 +1510,7 @@ dependencies = [ "arrow", "datafusion-common", "datafusion-expr-common", - "hashbrown", + "hashbrown 0.14.5", "rand", ] @@ -1522,9 +1518,11 @@ dependencies = [ name = "datafusion-physical-optimizer" version = "42.0.0" dependencies = [ + "arrow", "arrow-schema", "datafusion-common", "datafusion-execution", + "datafusion-expr-common", "datafusion-physical-expr", "datafusion-physical-plan", "itertools", @@ -1546,14 +1544,13 @@ dependencies = [ "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", - "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", "datafusion-functions-window-common", "datafusion-physical-expr", "datafusion-physical-expr-common", "futures", "half", - "hashbrown", + "hashbrown 0.14.5", "indexmap", "itertools", "log", @@ -1758,9 +1755,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -1773,9 +1770,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -1783,15 +1780,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -1800,15 +1797,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", @@ -1817,15 +1814,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-timer" @@ -1835,9 +1832,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -1874,9 +1871,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.31.0" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32085ea23f3234fc7846555e85283ba4de91e21016dc0455a16286d87a292d64" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" @@ -1943,6 +1940,12 @@ dependencies = [ "allocator-api2", ] +[[package]] +name = "hashbrown" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" + [[package]] name = "heck" version = "0.4.1" @@ -2043,9 +2046,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.9.4" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fcc0b4a115bf80b728eb8ea024ad5bd707b615bfed49e0665b6e0f86fd082d9" +checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" [[package]] name = "httpdate" @@ -2129,7 +2132,7 @@ dependencies = [ "http 1.1.0", "hyper 1.4.1", "hyper-util", - "rustls 0.23.13", + "rustls 0.23.14", "rustls-native-certs 0.8.0", "rustls-pki-types", "tokio", @@ -2191,12 +2194,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" +checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.0", ] [[package]] @@ -2219,9 +2222,9 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "ipnet" -version = "2.10.0" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" +checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" [[package]] name = "is_terminal_polyfill" @@ -2270,9 +2273,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "lexical-core" -version = "0.8.5" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +checksum = "0431c65b318a590c1de6b8fd6e72798c92291d27762d94c9e6c37ed7a73d8458" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -2283,9 +2286,9 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "0.8.5" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +checksum = "eb17a4bdb9b418051aa59d41d65b1c9be5affab314a872e5ad7f06231fb3b4e0" dependencies = [ "lexical-parse-integer", "lexical-util", @@ -2294,9 +2297,9 @@ dependencies = [ [[package]] name = "lexical-parse-integer" -version = "0.8.6" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +checksum = "5df98f4a4ab53bf8b175b363a34c7af608fe31f93cc1fb1bf07130622ca4ef61" dependencies = [ "lexical-util", "static_assertions", @@ -2304,18 +2307,18 @@ dependencies = [ [[package]] name = "lexical-util" -version = "0.8.5" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +checksum = "85314db53332e5c192b6bca611fb10c114a80d1b831ddac0af1e9be1b9232ca0" dependencies = [ "static_assertions", ] [[package]] name = "lexical-write-float" -version = "0.8.5" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +checksum = "6e7c3ad4e37db81c1cbe7cf34610340adc09c322871972f74877a712abc6c809" dependencies = [ "lexical-util", "lexical-write-integer", @@ -2324,9 +2327,9 @@ dependencies = [ [[package]] name = "lexical-write-integer" -version = "0.8.5" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +checksum = "eb89e9f6958b83258afa3deed90b5de9ef68eef090ad5086c791cd2345610162" dependencies = [ "lexical-util", "static_assertions", @@ -2358,7 +2361,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e0d73b369f386f1c44abd9c570d5318f55ccde816ff4b562fa452e5182863d" dependencies = [ "core2", - "hashbrown", + "hashbrown 0.14.5", "rle-decode-fast", ] @@ -2601,9 +2604,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.4" +version = "0.36.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" dependencies = [ "memchr", ] @@ -2629,7 +2632,7 @@ dependencies = [ "rand", "reqwest", "ring", - "rustls-pemfile 2.1.3", + "rustls-pemfile 2.2.0", "serde", "serde_json", "snafu", @@ -2641,9 +2644,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "openssl-probe" @@ -2697,9 +2700,9 @@ dependencies = [ [[package]] name = "parquet" -version = "53.0.0" +version = "53.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0fbf928021131daaa57d334ca8e3904fe9ae22f73c56244fc7db9b04eedc3d8" +checksum = "310c46a70a3ba90d98fec39fa2da6d9d731e544191da6fb56c9d199484d0dd3e" dependencies = [ "ahash", "arrow-array", @@ -2716,7 +2719,7 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown", + "hashbrown 0.14.5", "lz4_flex", "num", "num-bigint", @@ -2874,9 +2877,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" dependencies = [ "unicode-ident", ] @@ -2908,7 +2911,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.13", + "rustls 0.23.14", "socket2", "thiserror", "tokio", @@ -2925,7 +2928,7 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls 0.23.13", + "rustls 0.23.14", "slab", "thiserror", "tinyvec", @@ -2996,9 +2999,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.6" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "355ae415ccd3a04315d3f8246e86d67689ea74d88d915576e1589a351062a13b" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" dependencies = [ "bitflags 2.6.0", ] @@ -3016,9 +3019,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.6" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" dependencies = [ "aho-corasick", "memchr", @@ -3028,9 +3031,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", @@ -3045,9 +3048,9 @@ checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "relative-path" @@ -3057,9 +3060,9 @@ checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" [[package]] name = "reqwest" -version = "0.12.7" +version = "0.12.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63" +checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" dependencies = [ "base64 0.22.1", "bytes", @@ -3080,9 +3083,9 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.13", - "rustls-native-certs 0.7.3", - "rustls-pemfile 2.1.3", + "rustls 0.23.14", + "rustls-native-certs 0.8.0", + "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", "serde_json", @@ -3199,9 +3202,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.13" +version = "0.23.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" +checksum = "415d9944693cb90382053259f89fbb077ea730ad7273047ec63b19bc9b160ba8" dependencies = [ "once_cell", "ring", @@ -3223,19 +3226,6 @@ dependencies = [ "security-framework", ] -[[package]] -name = "rustls-native-certs" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" -dependencies = [ - "openssl-probe", - "rustls-pemfile 2.1.3", - "rustls-pki-types", - "schannel", - "security-framework", -] - [[package]] name = "rustls-native-certs" version = "0.8.0" @@ -3243,7 +3233,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" dependencies = [ "openssl-probe", - "rustls-pemfile 2.1.3", + "rustls-pemfile 2.2.0", "rustls-pki-types", "schannel", "security-framework", @@ -3260,11 +3250,10 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.3" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" dependencies = [ - "base64 0.22.1", "rustls-pki-types", ] @@ -3340,9 +3329,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.24" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b" +checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" dependencies = [ "windows-sys 0.59.0", ] @@ -3781,7 +3770,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.13", + "rustls 0.23.14", "rustls-pki-types", "tokio", ] @@ -3897,9 +3886,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" -version = "0.3.15" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" +checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" [[package]] name = "unicode-ident" diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index f430a87e190d..e2432abdc138 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -62,6 +62,7 @@ dashmap = { workspace = true } datafusion = { workspace = true, default-features = true, features = ["avro"] } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions-window-common = { workspace = true } datafusion-optimizer = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true, default-features = true } datafusion-proto = { workspace = true } diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index fd1b84070cf6..1c20e292f091 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -30,6 +30,7 @@ use datafusion_expr::function::WindowUDFFieldArgs; use datafusion_expr::{ PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl, }; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// This example shows how to use the full WindowUDFImpl API to implement a user /// defined window function. As in the `simple_udwf.rs` example, this struct implements @@ -74,7 +75,10 @@ impl WindowUDFImpl for SmoothItUdf { /// Create a `PartitionEvaluator` to evaluate this function on a new /// partition. - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { Ok(Box::new(MyPartitionEvaluator::new())) } diff --git a/datafusion-examples/examples/simplify_udwf_expression.rs b/datafusion-examples/examples/simplify_udwf_expression.rs index 1ff629eef196..d95f1147bc37 100644 --- a/datafusion-examples/examples/simplify_udwf_expression.rs +++ b/datafusion-examples/examples/simplify_udwf_expression.rs @@ -27,6 +27,7 @@ use datafusion_expr::{ expr::WindowFunction, simplify::SimplifyInfo, Expr, PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// This UDWF will show how to use the WindowUDFImpl::simplify() API #[derive(Debug, Clone)] @@ -60,7 +61,10 @@ impl WindowUDFImpl for SimplifySmoothItUdf { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { todo!() } diff --git a/datafusion/catalog/README.md b/datafusion/catalog/README.md new file mode 100644 index 000000000000..5b201e736fdc --- /dev/null +++ b/datafusion/catalog/README.md @@ -0,0 +1,26 @@ + + +# DataFusion Catalog + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that provides catalog management functionality, including catalogs, schemas, and tables. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 69cdf866cf98..9a1fe9bba267 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -406,33 +406,6 @@ impl DFSchema { } } - /// Check whether the column reference is ambiguous - pub fn check_ambiguous_name( - &self, - qualifier: Option<&TableReference>, - name: &str, - ) -> Result<()> { - let count = self - .iter() - .filter(|(field_q, f)| match (field_q, qualifier) { - (Some(q1), Some(q2)) => q1.resolved_eq(q2) && f.name() == name, - (None, None) => f.name() == name, - _ => false, - }) - .take(2) - .count(); - if count > 1 { - _schema_err!(SchemaError::AmbiguousReference { - field: Column { - relation: None, - name: name.to_string(), - }, - }) - } else { - Ok(()) - } - } - /// Find the qualified field with the given name pub fn qualified_field_with_name( &self, diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 3356a85fb6d4..f543dc5f0ca6 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -3577,9 +3577,8 @@ impl fmt::Display for ScalarValue { columns .iter() .zip(fields.iter()) - .enumerate() - .map(|(index, (column, field))| { - if nulls.is_some_and(|b| b.is_null(index)) { + .map(|(column, field)| { + if nulls.is_some_and(|b| b.is_null(0)) { format!("{}:NULL", field.name()) } else if let DataType::Struct(_) = field.data_type() { let sv = ScalarValue::Struct(Arc::new( @@ -3875,7 +3874,7 @@ mod tests { use arrow::compute::{is_null, kernels}; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_columns; - use arrow_buffer::Buffer; + use arrow_buffer::{Buffer, NullBuffer}; use arrow_schema::Fields; use chrono::NaiveDate; use rand::Rng; @@ -6589,6 +6588,43 @@ mod tests { assert_batches_eq!(&expected, &[batch]); } + #[test] + fn test_null_bug() { + let field_a = Field::new("a", DataType::Int32, true); + let field_b = Field::new("b", DataType::Int32, true); + let fields = Fields::from(vec![field_a, field_b]); + + let array_a = Arc::new(Int32Array::from_iter_values([1])); + let array_b = Arc::new(Int32Array::from_iter_values([2])); + let arrays: Vec = vec![array_a, array_b]; + + let mut not_nulls = BooleanBufferBuilder::new(1); + not_nulls.append(true); + let not_nulls = not_nulls.finish(); + let not_nulls = Some(NullBuffer::new(not_nulls)); + + let ar = unsafe { StructArray::new_unchecked(fields, arrays, not_nulls) }; + let s = ScalarValue::Struct(Arc::new(ar)); + + assert_eq!(s.to_string(), "{a:1,b:2}"); + assert_eq!(format!("{s:?}"), r#"Struct({a:1,b:2})"#); + + let ScalarValue::Struct(arr) = s else { + panic!("Expected struct"); + }; + + //verify compared to arrow display + let batch = RecordBatch::try_from_iter(vec![("s", arr as _)]).unwrap(); + let expected = [ + "+--------------+", + "| s |", + "+--------------+", + "| {a: 1, b: 2} |", + "+--------------+", + ]; + assert_batches_eq!(&expected, &[batch]); + } + #[test] fn test_struct_display_null() { let fields = vec![Field::new("a", DataType::Int32, false)]; diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 88300e3edd0e..b4d3251fd263 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -681,7 +681,7 @@ impl Transformed { } } - /// Create a `Transformed` with `transformed and [`TreeNodeRecursion::Continue`]. + /// Create a `Transformed` with `transformed` and [`TreeNodeRecursion::Continue`]. pub fn new_transformed(data: T, transformed: bool) -> Self { Self::new(data, transformed, TreeNodeRecursion::Continue) } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 01ba90ee5de8..28d0d136bd05 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -67,8 +67,6 @@ math_expressions = ["datafusion-functions/math_expressions"] parquet = ["datafusion-common/parquet", "dep:parquet"] pyarrow = ["datafusion-common/pyarrow", "parquet"] regex_expressions = [ - "datafusion-physical-expr/regex_expressions", - "datafusion-optimizer/regex_expressions", "datafusion-functions/regex_expressions", ] serde = ["arrow-schema/serde"] diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 92737b244a64..53cfe94ecab3 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -130,13 +130,14 @@ fn print_docs( .find(|f| f.get_name() == name || f.get_aliases().contains(&name)) .unwrap(); - let name = f.get_name(); let aliases = f.get_aliases(); let documentation = f.get_documentation(); // if this name is an alias we need to display what it's an alias of if aliases.contains(&name) { - let _ = write!(docs, "_Alias of [{name}](#{name})._"); + let fname = f.get_name(); + let _ = writeln!(docs, r#"### `{name}`"#); + let _ = writeln!(docs, "_Alias of [{fname}](#{fname})._"); continue; } @@ -183,10 +184,10 @@ fn print_docs( // next, aliases if !f.get_aliases().is_empty() { - let _ = write!(docs, "#### Aliases"); + let _ = writeln!(docs, "#### Aliases"); for alias in f.get_aliases() { - let _ = writeln!(docs, "- {alias}"); + let _ = writeln!(docs, "- {}", alias.replace("_", r#"\_"#)); } } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index f5867881da13..67e2a4780d06 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -535,9 +535,26 @@ impl DataFrame { group_expr: Vec, aggr_expr: Vec, ) -> Result { + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let aggr_expr_len = aggr_expr.len(); let plan = LogicalPlanBuilder::from(self.plan) .aggregate(group_expr, aggr_expr)? .build()?; + let plan = if is_grouping_set { + let grouping_id_pos = plan.schema().fields().len() - 1 - aggr_expr_len; + // For grouping sets we do a project to not expose the internal grouping id + let exprs = plan + .schema() + .columns() + .into_iter() + .enumerate() + .filter(|(idx, _)| *idx != grouping_id_pos) + .map(|(_, column)| Expr::Column(column)) + .collect::>(); + LogicalPlanBuilder::from(plan).project(exprs)?.build()? + } else { + plan + }; Ok(DataFrame { session_state: self.session_state, plan, diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index e821fa806fce..f235c3b628a0 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -771,7 +771,7 @@ mod tests { "c7: Int64", "c8: Int64", "c9: Int64", - "c10: Int64", + "c10: Utf8", "c11: Float64", "c12: Float64", "c13: Utf8" @@ -907,7 +907,7 @@ mod tests { Field::new("c7", DataType::Int64, true), Field::new("c8", DataType::Int64, true), Field::new("c9", DataType::Int64, true), - Field::new("c10", DataType::Int64, true), + Field::new("c10", DataType::Utf8, true), Field::new("c11", DataType::Float64, true), Field::new("c12", DataType::Float64, true), Field::new("c13", DataType::Utf8, true), diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 98ae0ce14bd7..8647b5df90be 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -20,6 +20,7 @@ use std::any::Any; use std::fmt; use std::fmt::Debug; +use std::ops::Range; use std::sync::Arc; use super::write::demux::start_demuxer_task; @@ -47,7 +48,7 @@ use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ - exec_err, internal_datafusion_err, not_impl_err, DataFusionError, GetExt, + internal_datafusion_err, not_impl_err, DataFusionError, GetExt, DEFAULT_PARQUET_EXTENSION, }; use datafusion_common_runtime::SpawnedTask; @@ -60,7 +61,7 @@ use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; -use bytes::{BufMut, BytesMut}; +use bytes::Bytes; use hashbrown::HashMap; use log::debug; use object_store::buffered::BufWriter; @@ -71,8 +72,7 @@ use parquet::arrow::arrow_writer::{ use parquet::arrow::{ arrow_to_parquet_schema, parquet_to_arrow_schema, AsyncArrowWriter, }; -use parquet::file::footer::{decode_footer, decode_metadata}; -use parquet::file::metadata::{ParquetMetaData, RowGroupMetaData}; +use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataReader, RowGroupMetaData}; use parquet::file::properties::WriterProperties; use parquet::file::writer::SerializedFileWriter; use parquet::format::FileMetaData; @@ -84,10 +84,13 @@ use crate::datasource::physical_plan::parquet::{ can_expr_be_pushed_down_with_schemas, ParquetExecBuilder, }; use datafusion_physical_expr_common::sort_expr::LexRequirement; -use futures::{StreamExt, TryStreamExt}; +use futures::future::BoxFuture; +use futures::{FutureExt, StreamExt, TryStreamExt}; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; use parquet::arrow::arrow_reader::statistics::StatisticsConverter; +use parquet::arrow::async_reader::MetadataFetch; +use parquet::errors::ParquetError; /// Initial writing buffer size. Note this is just a size hint for efficiency. It /// will grow beyond the set value if needed. @@ -441,6 +444,33 @@ impl FileFormat for ParquetFormat { } } +/// [`MetadataFetch`] adapter for reading bytes from an [`ObjectStore`] +struct ObjectStoreFetch<'a> { + store: &'a dyn ObjectStore, + meta: &'a ObjectMeta, +} + +impl<'a> ObjectStoreFetch<'a> { + fn new(store: &'a dyn ObjectStore, meta: &'a ObjectMeta) -> Self { + Self { store, meta } + } +} + +impl<'a> MetadataFetch for ObjectStoreFetch<'a> { + fn fetch( + &mut self, + range: Range, + ) -> BoxFuture<'_, Result> { + async { + self.store + .get_range(&self.meta.location, range) + .await + .map_err(ParquetError::from) + } + .boxed() + } +} + /// Fetches parquet metadata from ObjectStore for given object /// /// This component is a subject to **change** in near future and is exposed for low level integrations @@ -452,57 +482,14 @@ pub async fn fetch_parquet_metadata( meta: &ObjectMeta, size_hint: Option, ) -> Result { - if meta.size < 8 { - return exec_err!("file size of {} is less than footer", meta.size); - } - - // If a size hint is provided, read more than the minimum size - // to try and avoid a second fetch. - let footer_start = if let Some(size_hint) = size_hint { - meta.size.saturating_sub(size_hint) - } else { - meta.size - 8 - }; - - let suffix = store - .get_range(&meta.location, footer_start..meta.size) - .await?; - - let suffix_len = suffix.len(); - - let mut footer = [0; 8]; - footer.copy_from_slice(&suffix[suffix_len - 8..suffix_len]); - - let length = decode_footer(&footer)?; + let file_size = meta.size; + let fetch = ObjectStoreFetch::new(store, meta); - if meta.size < length + 8 { - return exec_err!( - "file size of {} is less than footer + metadata {}", - meta.size, - length + 8 - ); - } - - // Did not fetch the entire file metadata in the initial read, need to make a second request - if length > suffix_len - 8 { - let metadata_start = meta.size - length - 8; - let remaining_metadata = store - .get_range(&meta.location, metadata_start..footer_start) - .await?; - - let mut metadata = BytesMut::with_capacity(length); - - metadata.put(remaining_metadata.as_ref()); - metadata.put(&suffix[..suffix_len - 8]); - - Ok(decode_metadata(metadata.as_ref())?) - } else { - let metadata_start = meta.size - length - 8; - - Ok(decode_metadata( - &suffix[metadata_start - footer_start..suffix_len - 8], - )?) - } + ParquetMetaDataReader::new() + .with_prefetch_hint(size_hint) + .load_and_finish(fetch, file_size) + .await + .map_err(DataFusionError::from) } /// Read and parse the schema of the Parquet file at location `path` diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 6afb66cc7c02..743dd5896986 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -166,6 +166,33 @@ pub use writer::plan_to_parquet; /// [`RowFilter`]: parquet::arrow::arrow_reader::RowFilter /// [Parquet PageIndex]: https://github.com/apache/parquet-format/blob/master/PageIndex.md /// +/// # Example: rewriting `ParquetExec` +/// +/// You can modify a `ParquetExec` using [`ParquetExecBuilder`], for example +/// to change files or add a predicate. +/// +/// ```no_run +/// # use std::sync::Arc; +/// # use arrow::datatypes::Schema; +/// # use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +/// # use datafusion::datasource::listing::PartitionedFile; +/// # fn parquet_exec() -> ParquetExec { unimplemented!() } +/// // Split a single ParquetExec into multiple ParquetExecs, one for each file +/// let exec = parquet_exec(); +/// let existing_file_groups = &exec.base_config().file_groups; +/// let new_execs = existing_file_groups +/// .iter() +/// .map(|file_group| { +/// // create a new exec by copying the existing exec into a builder +/// let new_exec = exec.clone() +/// .into_builder() +/// .with_file_groups(vec![file_group.clone()]) +/// .build(); +/// new_exec +/// }) +/// .collect::>(); +/// ``` +/// /// # Implementing External Indexes /// /// It is possible to restrict the row groups and selections within those row @@ -257,6 +284,12 @@ pub struct ParquetExec { schema_adapter_factory: Option>, } +impl From for ParquetExecBuilder { + fn from(exec: ParquetExec) -> Self { + exec.into_builder() + } +} + /// [`ParquetExecBuilder`], builder for [`ParquetExec`]. /// /// See example on [`ParquetExec`]. @@ -291,6 +324,12 @@ impl ParquetExecBuilder { } } + /// Update the list of files groups to read + pub fn with_file_groups(mut self, file_groups: Vec>) -> Self { + self.file_scan_config.file_groups = file_groups; + self + } + /// Set the filter predicate when reading. /// /// See the "Predicate Pushdown" section of the [`ParquetExec`] documenation @@ -459,6 +498,34 @@ impl ParquetExec { ParquetExecBuilder::new(file_scan_config) } + /// Convert this `ParquetExec` into a builder for modification + pub fn into_builder(self) -> ParquetExecBuilder { + // list out fields so it is clear what is being dropped + // (note the fields which are dropped are re-created as part of calling + // `build` on the builder) + let Self { + base_config, + projected_statistics: _, + metrics: _, + predicate, + pruning_predicate: _, + page_pruning_predicate: _, + metadata_size_hint, + parquet_file_reader_factory, + cache: _, + table_parquet_options, + schema_adapter_factory, + } = self; + ParquetExecBuilder { + file_scan_config: base_config, + predicate, + metadata_size_hint, + table_parquet_options, + parquet_file_reader_factory, + schema_adapter_factory, + } + } + /// [`FileScanConfig`] that controls this scan (such as which files to read) pub fn base_config(&self) -> &FileScanConfig { &self.base_config @@ -479,9 +546,15 @@ impl ParquetExec { self.pruning_predicate.as_ref() } + /// return the optional file reader factory + pub fn parquet_file_reader_factory( + &self, + ) -> Option<&Arc> { + self.parquet_file_reader_factory.as_ref() + } + /// Optional user defined parquet file reader factory. /// - /// See documentation on [`ParquetExecBuilder::with_parquet_file_reader_factory`] pub fn with_parquet_file_reader_factory( mut self, parquet_file_reader_factory: Arc, @@ -490,6 +563,11 @@ impl ParquetExec { self } + /// return the optional schema adapter factory + pub fn schema_adapter_factory(&self) -> Option<&Arc> { + self.schema_adapter_factory.as_ref() + } + /// Optional schema adapter factory. /// /// See documentation on [`ParquetExecBuilder::with_schema_adapter_factory`] @@ -586,7 +664,14 @@ impl ParquetExec { ) } - fn with_file_groups(mut self, file_groups: Vec>) -> Self { + /// Updates the file groups to read and recalculates the output partitioning + /// + /// Note this function does not update statistics or other properties + /// that depend on the file groups. + fn with_file_groups_and_update_partitioning( + mut self, + file_groups: Vec>, + ) -> Self { self.base_config.file_groups = file_groups; // Changing file groups may invalidate output partitioning. Update it also let output_partitioning = Self::output_partitioning_helper(&self.base_config); @@ -679,7 +764,8 @@ impl ExecutionPlan for ParquetExec { let mut new_plan = self.clone(); if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { - new_plan = new_plan.with_file_groups(repartitioned_file_groups); + new_plan = new_plan + .with_file_groups_and_update_partitioning(repartitioned_file_groups); } Ok(Some(Arc::new(new_plan))) } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 78c70606bf68..cf2a157b04b6 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -692,10 +692,6 @@ impl DefaultPhysicalPlanner { physical_input_schema.clone(), )?); - // update group column indices based on partial aggregate plan evaluation - let final_group: Vec> = - initial_aggr.output_group_expr(); - let can_repartition = !groups.is_empty() && session_state.config().target_partitions() > 1 && session_state.config().repartition_aggregations(); @@ -716,13 +712,7 @@ impl DefaultPhysicalPlanner { AggregateMode::Final }; - let final_grouping_set = PhysicalGroupBy::new_single( - final_group - .iter() - .enumerate() - .map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone())) - .collect(), - ); + let final_grouping_set = initial_aggr.group_expr().as_final(); Arc::new(AggregateExec::try_new( next_partition_mode, @@ -2345,7 +2335,7 @@ mod tests { .expect("hash aggregate"); assert_eq!( "sum(aggregate_test_100.c3)", - final_hash_agg.schema().field(2).name() + final_hash_agg.schema().field(3).name() ); // we need access to the input to the partial aggregate so that other projects can // implement serde diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index cbd892672152..81a33361008f 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -37,14 +37,14 @@ mod simplification; fn test_octet_length() { #[rustfmt::skip] evaluate_expr_test( - octet_length(col("list")), + octet_length(col("id")), vec![ "+------+", "| expr |", "+------+", - "| 5 |", - "| 18 |", - "| 6 |", + "| 1 |", + "| 1 |", + "| 1 |", "+------+", ], ); diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 62e9be63983c..b0852501415e 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -44,6 +44,301 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use tokio::task::JoinSet; +use crate::fuzz_cases::aggregation_fuzzer::{ + AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig, +}; + +// ======================================================================== +// The new aggregation fuzz tests based on [`AggregationFuzzer`] +// ======================================================================== + +// TODO: write more test case to cover more `group by`s and `aggregation function`s +// TODO: maybe we can use macro to simply the case creating + +/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `no group by` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_prim_aggr_no_group() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ColumnDescr::new("a", DataType::Int32)]; + + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set: Vec::new(), + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(16) + .add_sql("SELECT sum(a) FROM fuzz_table") + .add_sql("SELECT sum(distinct a) FROM fuzz_table") + .add_sql("SELECT max(a) FROM fuzz_table") + .add_sql("SELECT min(a) FROM fuzz_table") + .add_sql("SELECT count(a) FROM fuzz_table") + .add_sql("SELECT count(distinct a) FROM fuzz_table") + .add_sql("SELECT avg(a) FROM fuzz_table") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `group by single int64` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_prim_aggr_group_by_single_int64() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ + ColumnDescr::new("a", DataType::Int32), + ColumnDescr::new("b", DataType::Int64), + ColumnDescr::new("c", DataType::Int64), + ]; + let sort_keys_set = vec![ + vec!["b".to_string()], + vec!["c".to_string(), "b".to_string()], + ]; + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set, + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(16) + .add_sql("SELECT b, sum(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, sum(distinct a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, avg(a) FROM fuzz_table GROUP BY b") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `group by single string` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_prim_aggr_group_by_single_string() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ + ColumnDescr::new("a", DataType::Int32), + ColumnDescr::new("b", DataType::Utf8), + ColumnDescr::new("c", DataType::Int64), + ]; + let sort_keys_set = vec![ + vec!["b".to_string()], + vec!["c".to_string(), "b".to_string()], + ]; + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set, + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(16) + .add_sql("SELECT b, sum(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, sum(distinct a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, avg(a) FROM fuzz_table GROUP BY b") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `group by string + int64` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_prim_aggr_group_by_mixed_string_int64() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ + ColumnDescr::new("a", DataType::Int32), + ColumnDescr::new("b", DataType::Utf8), + ColumnDescr::new("c", DataType::Int64), + ColumnDescr::new("d", DataType::Int32), + ]; + let sort_keys_set = vec![ + vec!["b".to_string(), "c".to_string()], + vec!["d".to_string(), "b".to_string(), "c".to_string()], + ]; + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set, + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(16) + .add_sql("SELECT b, c, sum(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, sum(distinct a) FROM fuzz_table GROUP BY b,c") + .add_sql("SELECT b, c, max(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, min(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, count(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, count(distinct a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, avg(a) FROM fuzz_table GROUP BY b, c") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `no group by` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_string_aggr_no_group() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ColumnDescr::new("a", DataType::Utf8)]; + + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set: Vec::new(), + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(8) + .add_sql("SELECT max(a) FROM fuzz_table") + .add_sql("SELECT min(a) FROM fuzz_table") + .add_sql("SELECT count(a) FROM fuzz_table") + .add_sql("SELECT count(distinct a) FROM fuzz_table") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `group by single int64` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_string_aggr_group_by_single_int64() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ + ColumnDescr::new("a", DataType::Utf8), + ColumnDescr::new("b", DataType::Int64), + ColumnDescr::new("c", DataType::Int64), + ]; + let sort_keys_set = vec![ + vec!["b".to_string()], + vec!["c".to_string(), "b".to_string()], + ]; + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set, + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(8) + .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `group by single string` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_string_aggr_group_by_single_string() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ + ColumnDescr::new("a", DataType::Utf8), + ColumnDescr::new("b", DataType::Utf8), + ColumnDescr::new("c", DataType::Int64), + ]; + let sort_keys_set = vec![ + vec!["b".to_string()], + vec!["c".to_string(), "b".to_string()], + ]; + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set, + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(16) + .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `group by string + int64` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_string_aggr_group_by_mixed_string_int64() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ + ColumnDescr::new("a", DataType::Utf8), + ColumnDescr::new("b", DataType::Utf8), + ColumnDescr::new("c", DataType::Int64), + ColumnDescr::new("d", DataType::Int32), + ]; + let sort_keys_set = vec![ + vec!["b".to_string(), "c".to_string()], + vec!["d".to_string(), "b".to_string(), "c".to_string()], + ]; + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set, + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(16) + .add_sql("SELECT b, c, max(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, min(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, count(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, count(distinct a) FROM fuzz_table GROUP BY b, c") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +// ======================================================================== +// The old aggregation fuzz tests +// ======================================================================== +/// Tracks if this stream is generating input or output /// Tests that streaming aggregate and batch (non streaming) aggregate produce /// same results #[tokio::test(flavor = "multi_thread")] @@ -311,6 +606,7 @@ async fn group_by_string_test( let actual = extract_result_counts(results); assert_eq!(expected, actual); } + async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { struct Visitor { expected_sort: bool, diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs new file mode 100644 index 000000000000..af454bee7ce8 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -0,0 +1,343 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{cmp, sync::Arc}; + +use datafusion::{ + datasource::MemTable, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_catalog::TableProvider; +use datafusion_common::error::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::col; +use rand::{thread_rng, Rng}; + +use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset; + +/// SessionContext generator +/// +/// During testing, `generate_baseline` will be called firstly to generate a standard [`SessionContext`], +/// and we will run `sql` on it to get the `expected result`. Then `generate` will be called some times to +/// generate some random [`SessionContext`]s, and we will run the same `sql` on them to get `actual results`. +/// Finally, we compare the `actual results` with `expected result`, the test only success while all they are +/// same with the expected. +/// +/// Following parameters of [`SessionContext`] used in query running will be generated randomly: +/// - `batch_size` +/// - `target_partitions` +/// - `skip_partial parameters` +/// - hint `sorted` or not +/// - `spilling` or not (TODO, I think a special `MemoryPool` may be needed +/// to support this) +/// +pub struct SessionContextGenerator { + /// Current testing dataset + dataset: Arc, + + /// Table name of the test table + table_name: String, + + /// Used in generate the random `batch_size` + /// + /// The generated `batch_size` is between (0, total_rows_num] + max_batch_size: usize, + + /// Candidate `SkipPartialParams` which will be picked randomly + candidate_skip_partial_params: Vec, + + /// The upper bound of the randomly generated target partitions, + /// and the lower bound will be 1 + max_target_partitions: usize, +} + +impl SessionContextGenerator { + pub fn new(dataset_ref: Arc, table_name: &str) -> Self { + let candidate_skip_partial_params = vec![ + SkipPartialParams::ensure_trigger(), + SkipPartialParams::ensure_not_trigger(), + ]; + + let max_batch_size = cmp::max(1, dataset_ref.total_rows_num); + let max_target_partitions = num_cpus::get(); + + Self { + dataset: dataset_ref, + table_name: table_name.to_string(), + max_batch_size, + candidate_skip_partial_params, + max_target_partitions, + } + } +} + +impl SessionContextGenerator { + /// Generate the `SessionContext` for the baseline run + pub fn generate_baseline(&self) -> Result { + let schema = self.dataset.batches[0].schema(); + let batches = self.dataset.batches.clone(); + let provider = MemTable::try_new(schema, vec![batches])?; + + // The baseline context should try best to disable all optimizations, + // and pursuing the rightness. + let batch_size = self.max_batch_size; + let target_partitions = 1; + let skip_partial_params = SkipPartialParams::ensure_not_trigger(); + + let builder = GeneratedSessionContextBuilder { + batch_size, + target_partitions, + skip_partial_params, + sort_hint: false, + table_name: self.table_name.clone(), + table_provider: Arc::new(provider), + }; + + builder.build() + } + + /// Randomly generate session context + pub fn generate(&self) -> Result { + let mut rng = thread_rng(); + let schema = self.dataset.batches[0].schema(); + let batches = self.dataset.batches.clone(); + let provider = MemTable::try_new(schema, vec![batches])?; + + // We will randomly generate following options: + // - `batch_size`, from range: [1, `total_rows_num`] + // - `target_partitions`, from range: [1, cpu_num] + // - `skip_partial`, trigger or not trigger currently for simplicity + // - `sorted`, if found a sorted dataset, will or will not push down this information + // - `spilling`(TODO) + let batch_size = rng.gen_range(1..=self.max_batch_size); + + let target_partitions = rng.gen_range(1..=self.max_target_partitions); + + let skip_partial_params_idx = + rng.gen_range(0..self.candidate_skip_partial_params.len()); + let skip_partial_params = + self.candidate_skip_partial_params[skip_partial_params_idx]; + + let (provider, sort_hint) = + if rng.gen_bool(0.5) && !self.dataset.sort_keys.is_empty() { + // Sort keys exist and random to push down + let sort_exprs = self + .dataset + .sort_keys + .iter() + .map(|key| col(key).sort(true, true)) + .collect::>(); + (provider.with_sort_order(vec![sort_exprs]), true) + } else { + (provider, false) + }; + + let builder = GeneratedSessionContextBuilder { + batch_size, + target_partitions, + sort_hint, + skip_partial_params, + table_name: self.table_name.clone(), + table_provider: Arc::new(provider), + }; + + builder.build() + } +} + +/// The generated [`SessionContext`] with its params +/// +/// Storing the generated `params` is necessary for +/// reporting the broken test case. +pub struct SessionContextWithParams { + pub ctx: SessionContext, + pub params: SessionContextParams, +} + +/// Collect the generated params, and build the [`SessionContext`] +struct GeneratedSessionContextBuilder { + batch_size: usize, + target_partitions: usize, + sort_hint: bool, + skip_partial_params: SkipPartialParams, + table_name: String, + table_provider: Arc, +} + +impl GeneratedSessionContextBuilder { + fn build(self) -> Result { + // Build session context + let mut session_config = SessionConfig::default(); + session_config = session_config.set( + "datafusion.execution.batch_size", + &ScalarValue::UInt64(Some(self.batch_size as u64)), + ); + session_config = session_config.set( + "datafusion.execution.target_partitions", + &ScalarValue::UInt64(Some(self.target_partitions as u64)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &ScalarValue::UInt64(Some(self.skip_partial_params.rows_threshold as u64)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", + &ScalarValue::Float64(Some(self.skip_partial_params.ratio_threshold)), + ); + + let ctx = SessionContext::new_with_config(session_config); + ctx.register_table(self.table_name, self.table_provider)?; + + let params = SessionContextParams { + batch_size: self.batch_size, + target_partitions: self.target_partitions, + sort_hint: self.sort_hint, + skip_partial_params: self.skip_partial_params, + }; + + Ok(SessionContextWithParams { ctx, params }) + } +} + +/// The generated params for [`SessionContext`] +#[derive(Debug)] +#[allow(dead_code)] +pub struct SessionContextParams { + batch_size: usize, + target_partitions: usize, + sort_hint: bool, + skip_partial_params: SkipPartialParams, +} + +/// Partial skipping parameters +#[derive(Debug, Clone, Copy)] +pub struct SkipPartialParams { + /// Related to `skip_partial_aggregation_probe_ratio_threshold` in `ExecutionOptions` + pub ratio_threshold: f64, + + /// Related to `skip_partial_aggregation_probe_rows_threshold` in `ExecutionOptions` + pub rows_threshold: usize, +} + +impl SkipPartialParams { + /// Generate `SkipPartialParams` ensuring to trigger partial skipping + pub fn ensure_trigger() -> Self { + Self { + ratio_threshold: 0.0, + rows_threshold: 0, + } + } + + /// Generate `SkipPartialParams` ensuring not to trigger partial skipping + pub fn ensure_not_trigger() -> Self { + Self { + ratio_threshold: 1.0, + rows_threshold: usize::MAX, + } + } +} + +#[cfg(test)] +mod test { + use arrow_array::{RecordBatch, StringArray, UInt32Array}; + use arrow_schema::{DataType, Field, Schema}; + + use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; + + use super::*; + + #[tokio::test] + async fn test_generated_context() { + // 1. Define a test dataset firstly + let a_col: StringArray = [ + Some("rust"), + Some("java"), + Some("cpp"), + Some("go"), + Some("go1"), + Some("python"), + Some("python1"), + Some("python2"), + ] + .into_iter() + .collect(); + // Sort by "b" + let b_col: UInt32Array = [ + Some(1), + Some(2), + Some(4), + Some(8), + Some(8), + Some(16), + Some(16), + Some(16), + ] + .into_iter() + .collect(); + let schema = Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::UInt32, true), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(a_col), Arc::new(b_col)], + ) + .unwrap(); + + // One row a group to create batches + let mut batches = Vec::with_capacity(batch.num_rows()); + for start in 0..batch.num_rows() { + let sub_batch = batch.slice(start, 1); + batches.push(sub_batch); + } + + let dataset = Dataset::new(batches, vec!["b".to_string()]); + + // 2. Generate baseline context, and some randomly session contexts. + // Run the same query on them, and all randoms' results should equal to baseline's + let ctx_generator = SessionContextGenerator::new(Arc::new(dataset), "fuzz_table"); + + let query = "select b, count(a) from fuzz_table group by b"; + let baseline_wrapped_ctx = ctx_generator.generate_baseline().unwrap(); + let mut random_wrapped_ctxs = Vec::with_capacity(8); + for _ in 0..8 { + let ctx = ctx_generator.generate().unwrap(); + random_wrapped_ctxs.push(ctx); + } + + let base_result = baseline_wrapped_ctx + .ctx + .sql(query) + .await + .unwrap() + .collect() + .await + .unwrap(); + + for wrapped_ctx in random_wrapped_ctxs { + let random_result = wrapped_ctx + .ctx + .sql(query) + .await + .unwrap() + .collect() + .await + .unwrap(); + check_equality_of_batches(&base_result, &random_result).unwrap(); + } + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs new file mode 100644 index 000000000000..9d45779295e7 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -0,0 +1,459 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; +use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; +use datafusion_physical_plan::sorts::sort::sort_batch; +use rand::{ + rngs::{StdRng, ThreadRng}, + thread_rng, Rng, SeedableRng, +}; +use test_utils::{ + array_gen::{PrimitiveArrayGenerator, StringArrayGenerator}, + stagger_batch, +}; + +/// Config for Data sets generator +/// +/// # Parameters +/// - `columns`, you just need to define `column name`s and `column data type`s +/// fot the test datasets, and then they will be randomly generated from generator +/// when you can `generate` function +/// +/// - `rows_num_range`, the rows num of the datasets will be randomly generated +/// among this range +/// +/// - `sort_keys`, if `sort_keys` are defined, when you can `generate`, the generator +/// will generate one `base dataset` firstly. Then the `base dataset` will be sorted +/// based on each `sort_key` respectively. And finally `len(sort_keys) + 1` datasets +/// will be returned +/// +#[derive(Debug, Clone)] +pub struct DatasetGeneratorConfig { + // Descriptions of columns in datasets, it's `required` + pub columns: Vec, + + // Rows num range of the generated datasets, it's `required` + pub rows_num_range: (usize, usize), + + // Sort keys used to generate the sorted data set, it's optional + pub sort_keys_set: Vec>, +} + +/// Dataset generator +/// +/// It will generate one random [`Dataset`]s when `generate` function is called. +/// +/// The generation logic in `generate`: +/// +/// - Randomly generate a base record from `batch_generator` firstly. +/// And `columns`, `rows_num_range` in `config`(detail can see `DataSetsGeneratorConfig`), +/// will be used in generation. +/// +/// - Sort the batch according to `sort_keys` in `config` to generator another +/// `len(sort_keys)` sorted batches. +/// +/// - Split each batch to multiple batches which each sub-batch in has the randomly `rows num`, +/// and this multiple batches will be used to create the `Dataset`. +/// +pub struct DatasetGenerator { + batch_generator: RecordBatchGenerator, + sort_keys_set: Vec>, +} + +impl DatasetGenerator { + pub fn new(config: DatasetGeneratorConfig) -> Self { + let batch_generator = RecordBatchGenerator::new( + config.rows_num_range.0, + config.rows_num_range.1, + config.columns, + ); + + Self { + batch_generator, + sort_keys_set: config.sort_keys_set, + } + } + + pub fn generate(&self) -> Result> { + let mut datasets = Vec::with_capacity(self.sort_keys_set.len() + 1); + + // Generate the base batch + let base_batch = self.batch_generator.generate()?; + let batches = stagger_batch(base_batch.clone()); + let dataset = Dataset::new(batches, Vec::new()); + datasets.push(dataset); + + // Generate the related sorted batches + let schema = base_batch.schema_ref(); + for sort_keys in self.sort_keys_set.clone() { + let sort_exprs = sort_keys + .iter() + .map(|key| { + let col_expr = col(key, schema)?; + Ok(PhysicalSortExpr::new_default(col_expr)) + }) + .collect::>>()?; + let sorted_batch = sort_batch(&base_batch, &sort_exprs, None)?; + + let batches = stagger_batch(sorted_batch); + let dataset = Dataset::new(batches, sort_keys); + datasets.push(dataset); + } + + Ok(datasets) + } +} + +/// Single test data set +#[derive(Debug)] +pub struct Dataset { + pub batches: Vec, + pub total_rows_num: usize, + pub sort_keys: Vec, +} + +impl Dataset { + pub fn new(batches: Vec, sort_keys: Vec) -> Self { + let total_rows_num = batches.iter().map(|batch| batch.num_rows()).sum::(); + + Self { + batches, + total_rows_num, + sort_keys, + } + } +} + +#[derive(Debug, Clone)] +pub struct ColumnDescr { + // Column name + name: String, + + // Data type of this column + column_type: DataType, +} + +impl ColumnDescr { + #[inline] + pub fn new(name: &str, column_type: DataType) -> Self { + Self { + name: name.to_string(), + column_type, + } + } +} + +/// Record batch generator +struct RecordBatchGenerator { + min_rows_nun: usize, + + max_rows_num: usize, + + columns: Vec, + + candidate_null_pcts: Vec, +} + +macro_rules! generate_string_array { + ($SELF:ident, $NUM_ROWS:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $OFFSET_TYPE:ty) => {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + let max_len = $BATCH_GEN_RNG.gen_range(1..50); + let num_distinct_strings = if $NUM_ROWS > 1 { + $BATCH_GEN_RNG.gen_range(1..$NUM_ROWS) + } else { + $NUM_ROWS + }; + + let mut generator = StringArrayGenerator { + max_len, + num_strings: $NUM_ROWS, + num_distinct_strings, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$OFFSET_TYPE>() + }}; +} + +macro_rules! generate_primitive_array { + ($SELF:ident, $NUM_ROWS:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $DATA_TYPE:ident) => { + paste::paste! {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + let num_distinct_primitives = if $NUM_ROWS > 1 { + $BATCH_GEN_RNG.gen_range(1..$NUM_ROWS) + } else { + $NUM_ROWS + }; + + let mut generator = PrimitiveArrayGenerator { + num_primitives: $NUM_ROWS, + num_distinct_primitives, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.[< gen_data_ $DATA_TYPE >]() + }}} +} + +impl RecordBatchGenerator { + fn new(min_rows_nun: usize, max_rows_num: usize, columns: Vec) -> Self { + let candidate_null_pcts = vec![0.0, 0.01, 0.1, 0.5]; + + Self { + min_rows_nun, + max_rows_num, + columns, + candidate_null_pcts, + } + } + + fn generate(&self) -> Result { + let mut rng = thread_rng(); + let num_rows = rng.gen_range(self.min_rows_nun..=self.max_rows_num); + let array_gen_rng = StdRng::from_seed(rng.gen()); + + // Build arrays + let mut arrays = Vec::with_capacity(self.columns.len()); + for col in self.columns.iter() { + let array = self.generate_array_of_type( + col.column_type.clone(), + num_rows, + &mut rng, + array_gen_rng.clone(), + ); + arrays.push(array); + } + + // Build schema + let fields = self + .columns + .iter() + .map(|col| Field::new(col.name.clone(), col.column_type.clone(), true)) + .collect::>(); + let schema = Arc::new(Schema::new(fields)); + + RecordBatch::try_new(schema, arrays).map_err(|e| arrow_datafusion_err!(e)) + } + + fn generate_array_of_type( + &self, + data_type: DataType, + num_rows: usize, + batch_gen_rng: &mut ThreadRng, + array_gen_rng: StdRng, + ) -> ArrayRef { + match data_type { + DataType::Int8 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + i8 + ) + } + DataType::Int16 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + i16 + ) + } + DataType::Int32 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + i32 + ) + } + DataType::Int64 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + i64 + ) + } + DataType::UInt8 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + u8 + ) + } + DataType::UInt16 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + u16 + ) + } + DataType::UInt32 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + u32 + ) + } + DataType::UInt64 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + u64 + ) + } + DataType::Float32 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + f32 + ) + } + DataType::Float64 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + f64 + ) + } + DataType::Utf8 => { + generate_string_array!(self, num_rows, batch_gen_rng, array_gen_rng, i32) + } + DataType::LargeUtf8 => { + generate_string_array!(self, num_rows, batch_gen_rng, array_gen_rng, i64) + } + _ => unreachable!(), + } + } +} + +#[cfg(test)] +mod test { + use arrow_array::UInt32Array; + + use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; + + use super::*; + + #[test] + fn test_generated_datasets() { + // The test datasets generation config + // We expect that after calling `generate` + // - Generate 2 datasets + // - They have 2 column "a" and "b", + // "a"'s type is `Utf8`, and "b"'s type is `UInt32` + // - One of them is unsorted, another is sorted by column "b" + // - Their rows num should be same and between [16, 32] + let config = DatasetGeneratorConfig { + columns: vec![ + ColumnDescr { + name: "a".to_string(), + column_type: DataType::Utf8, + }, + ColumnDescr { + name: "b".to_string(), + column_type: DataType::UInt32, + }, + ], + rows_num_range: (16, 32), + sort_keys_set: vec![vec!["b".to_string()]], + }; + + let gen = DatasetGenerator::new(config); + let datasets = gen.generate().unwrap(); + + // Should Generate 2 datasets + assert_eq!(datasets.len(), 2); + + // Should have 2 column "a" and "b", + // "a"'s type is `Utf8`, and "b"'s type is `UInt32` + let check_fields = |batch: &RecordBatch| { + assert_eq!(batch.num_columns(), 2); + let fields = batch.schema().fields().clone(); + assert_eq!(fields[0].name(), "a"); + assert_eq!(*fields[0].data_type(), DataType::Utf8); + assert_eq!(fields[1].name(), "b"); + assert_eq!(*fields[1].data_type(), DataType::UInt32); + }; + + let batch = &datasets[0].batches[0]; + check_fields(batch); + let batch = &datasets[1].batches[0]; + check_fields(batch); + + // One batches should be sort by "b" + let sorted_batches = &datasets[1].batches; + let b_vals = sorted_batches.iter().flat_map(|batch| { + let uint_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + uint_array.iter() + }); + let mut prev_b_val = u32::MIN; + for b_val in b_vals { + let b_val = b_val.unwrap_or(u32::MIN); + assert!(b_val >= prev_b_val); + prev_b_val = b_val; + } + + // Two batches should be same after sorting + check_equality_of_batches(&datasets[0].batches, &datasets[1].batches).unwrap(); + + // Rows num should between [16, 32] + let rows_num0 = datasets[0] + .batches + .iter() + .map(|batch| batch.num_rows()) + .sum::(); + let rows_num1 = datasets[1] + .batches + .iter() + .map(|batch| batch.num_rows()) + .sum::(); + assert_eq!(rows_num0, rows_num1); + assert!(rows_num0 >= 16); + assert!(rows_num0 <= 32); + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs new file mode 100644 index 000000000000..abb34048284d --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -0,0 +1,281 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::util::pretty::pretty_format_batches; +use arrow_array::RecordBatch; +use rand::{thread_rng, Rng}; +use tokio::task::JoinSet; + +use crate::fuzz_cases::aggregation_fuzzer::{ + check_equality_of_batches, + context_generator::{SessionContextGenerator, SessionContextWithParams}, + data_generator::{Dataset, DatasetGenerator, DatasetGeneratorConfig}, + run_sql, +}; + +/// Rounds to call `generate` of [`SessionContextGenerator`] +/// in [`AggregationFuzzer`], `ctx_gen_rounds` random [`SessionContext`] +/// will generated for each dataset for testing. +const CTX_GEN_ROUNDS: usize = 16; + +/// Aggregation fuzzer's builder +pub struct AggregationFuzzerBuilder { + /// See `candidate_sqls` in [`AggregationFuzzer`], no default, and required to set + candidate_sqls: Vec>, + + /// See `table_name` in [`AggregationFuzzer`], no default, and required to set + table_name: Option>, + + /// Used to generate `dataset_generator` in [`AggregationFuzzer`], + /// no default, and required to set + data_gen_config: Option, + + /// See `data_gen_rounds` in [`AggregationFuzzer`], default 16 + data_gen_rounds: usize, +} + +impl AggregationFuzzerBuilder { + fn new() -> Self { + Self { + candidate_sqls: Vec::new(), + table_name: None, + data_gen_config: None, + data_gen_rounds: 16, + } + } + + pub fn add_sql(mut self, sql: &str) -> Self { + self.candidate_sqls.push(Arc::from(sql)); + self + } + + pub fn table_name(mut self, table_name: &str) -> Self { + self.table_name = Some(Arc::from(table_name)); + self + } + + pub fn data_gen_config(mut self, data_gen_config: DatasetGeneratorConfig) -> Self { + self.data_gen_config = Some(data_gen_config); + self + } + + pub fn data_gen_rounds(mut self, data_gen_rounds: usize) -> Self { + self.data_gen_rounds = data_gen_rounds; + self + } + + pub fn build(self) -> AggregationFuzzer { + assert!(!self.candidate_sqls.is_empty()); + let candidate_sqls = self.candidate_sqls; + let table_name = self.table_name.expect("table_name is required"); + let data_gen_config = self.data_gen_config.expect("data_gen_config is required"); + let data_gen_rounds = self.data_gen_rounds; + + let dataset_generator = DatasetGenerator::new(data_gen_config); + + AggregationFuzzer { + candidate_sqls, + table_name, + dataset_generator, + data_gen_rounds, + } + } +} + +impl Default for AggregationFuzzerBuilder { + fn default() -> Self { + Self::new() + } +} + +/// AggregationFuzzer randomly generating multiple [`AggregationFuzzTestTask`], +/// and running them to check the correctness of the optimizations +/// (e.g. sorted, partial skipping, spilling...) +pub struct AggregationFuzzer { + /// Candidate test queries represented by sqls + candidate_sqls: Vec>, + + /// The queried table name + table_name: Arc, + + /// Dataset generator used to randomly generate datasets + dataset_generator: DatasetGenerator, + + /// Rounds to call `generate` of [`DatasetGenerator`], + /// len(sort_keys_set) + 1` datasets will be generated for testing. + /// + /// It is suggested to set value 2x or more bigger than num of + /// `candidate_sqls` for better test coverage. + data_gen_rounds: usize, +} + +/// Query group including the tested dataset and its sql query +struct QueryGroup { + dataset: Dataset, + sql: Arc, +} + +impl AggregationFuzzer { + pub async fn run(&self) { + let mut join_set = JoinSet::new(); + let mut rng = thread_rng(); + + // Loop to generate datasets and its query + for _ in 0..self.data_gen_rounds { + // Generate datasets first + let datasets = self + .dataset_generator + .generate() + .expect("should success to generate dataset"); + + // Then for each of them, we random select a test sql for it + let query_groups = datasets + .into_iter() + .map(|dataset| { + let sql_idx = rng.gen_range(0..self.candidate_sqls.len()); + let sql = self.candidate_sqls[sql_idx].clone(); + + QueryGroup { dataset, sql } + }) + .collect::>(); + + let tasks = self.generate_fuzz_tasks(query_groups).await; + for task in tasks { + join_set.spawn(async move { + task.run().await; + }); + } + } + + while let Some(join_handle) = join_set.join_next().await { + // propagate errors + join_handle.unwrap(); + } + } + + async fn generate_fuzz_tasks( + &self, + query_groups: Vec, + ) -> Vec { + let mut tasks = Vec::with_capacity(query_groups.len() * CTX_GEN_ROUNDS); + for QueryGroup { dataset, sql } in query_groups { + let dataset_ref = Arc::new(dataset); + let ctx_generator = + SessionContextGenerator::new(dataset_ref.clone(), &self.table_name); + + // Generate the baseline context, and get the baseline result firstly + let baseline_ctx_with_params = ctx_generator + .generate_baseline() + .expect("should success to generate baseline session context"); + let baseline_result = run_sql(&sql, &baseline_ctx_with_params.ctx) + .await + .expect("should success to run baseline sql"); + let baseline_result = Arc::new(baseline_result); + // Generate test tasks + for _ in 0..CTX_GEN_ROUNDS { + let ctx_with_params = ctx_generator + .generate() + .expect("should success to generate session context"); + let task = AggregationFuzzTestTask { + dataset_ref: dataset_ref.clone(), + expected_result: baseline_result.clone(), + sql: sql.clone(), + ctx_with_params, + }; + + tasks.push(task); + } + } + tasks + } +} + +/// One test task generated by [`AggregationFuzzer`] +/// +/// It includes: +/// - `expected_result`, the expected result generated by baseline [`SessionContext`] +/// (disable all possible optimizations for ensuring correctness). +/// +/// - `ctx`, a randomly generated [`SessionContext`], `sql` will be run +/// on it after, and check if the result is equal to expected. +/// +/// - `sql`, the selected test sql +/// +/// - `dataset_ref`, the input dataset, store it for error reported when found +/// the inconsistency between the one for `ctx` and `expected results`. +/// +struct AggregationFuzzTestTask { + /// Generated session context in current test case + ctx_with_params: SessionContextWithParams, + + /// Expected result in current test case + /// It is generate from `query` + `baseline session context` + expected_result: Arc>, + + /// The test query + /// Use sql to represent it currently. + sql: Arc, + + /// The test dataset for error reporting + dataset_ref: Arc, +} + +impl AggregationFuzzTestTask { + async fn run(&self) { + let task_result = run_sql(&self.sql, &self.ctx_with_params.ctx) + .await + .expect("should success to run sql"); + self.check_result(&task_result, &self.expected_result); + } + + // TODO: maybe we should persist the `expected_result` and `task_result`, + // because the readability is not so good if we just print it. + fn check_result(&self, task_result: &[RecordBatch], expected_result: &[RecordBatch]) { + let result = check_equality_of_batches(task_result, expected_result); + if let Err(e) = result { + // If we found inconsistent result, we print the test details for reproducing at first + println!( + "##### AggregationFuzzer error report ##### + ### Sql:\n{}\n\ + ### Schema:\n{}\n\ + ### Session context params:\n{:?}\n\ + ### Inconsistent row:\n\ + - row_idx:{}\n\ + - task_row:{}\n\ + - expected_row:{}\n\ + ### Task total result:\n{}\n\ + ### Expected total result:\n{}\n\ + ### Input:\n{}\n\ + ", + self.sql, + self.dataset_ref.batches[0].schema_ref(), + self.ctx_with_params.params, + e.row_idx, + e.lhs_row, + e.rhs_row, + pretty_format_batches(task_result).unwrap(), + pretty_format_batches(expected_result).unwrap(), + pretty_format_batches(&self.dataset_ref.batches).unwrap(), + ); + + // Then we just panic + panic!(); + } + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs new file mode 100644 index 000000000000..d93a5b7b9360 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::util::pretty::pretty_format_batches; +use arrow_array::RecordBatch; +use datafusion::prelude::SessionContext; +use datafusion_common::error::Result; + +mod context_generator; +mod data_generator; +mod fuzzer; + +pub use data_generator::{ColumnDescr, DatasetGeneratorConfig}; +pub use fuzzer::*; + +#[derive(Debug)] +pub(crate) struct InconsistentResult { + pub row_idx: usize, + pub lhs_row: String, + pub rhs_row: String, +} + +pub(crate) fn check_equality_of_batches( + lhs: &[RecordBatch], + rhs: &[RecordBatch], +) -> std::result::Result<(), InconsistentResult> { + let lhs_formatted_batches = pretty_format_batches(lhs).unwrap().to_string(); + let mut lhs_formatted_batches_sorted: Vec<&str> = + lhs_formatted_batches.trim().lines().collect(); + lhs_formatted_batches_sorted.sort_unstable(); + let rhs_formatted_batches = pretty_format_batches(rhs).unwrap().to_string(); + let mut rhs_formatted_batches_sorted: Vec<&str> = + rhs_formatted_batches.trim().lines().collect(); + rhs_formatted_batches_sorted.sort_unstable(); + + for (row_idx, (lhs_row, rhs_row)) in lhs_formatted_batches_sorted + .iter() + .zip(&rhs_formatted_batches_sorted) + .enumerate() + { + if lhs_row != rhs_row { + return Err(InconsistentResult { + row_idx, + lhs_row: lhs_row.to_string(), + rhs_row: rhs_row.to_string(), + }); + } + } + + Ok(()) +} + +pub(crate) async fn run_sql(sql: &str, ctx: &SessionContext) -> Result> { + ctx.sql(sql).await?.collect().await +} diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index c1cd41906023..49db0d31a8e9 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -21,7 +21,9 @@ mod join_fuzz; mod merge_fuzz; mod sort_fuzz; +mod aggregation_fuzzer; mod equivalence; + mod limit_fuzz; mod sort_preserving_repartition_fuzz; mod window_fuzz; diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index a6c2cf700cc4..b9881c9f23cf 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -45,6 +45,8 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use test_utils::add_empty_batches; use datafusion::functions_window::row_number::row_number_udwf; +use datafusion_functions_window::dense_rank::dense_rank_udwf; +use datafusion_functions_window::rank::rank_udwf; use hashbrown::HashMap; use rand::distributions::Alphanumeric; use rand::rngs::StdRng; @@ -224,9 +226,9 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Rank), + WindowFunctionDefinition::WindowUDF(rank_udwf()), // its name - "RANK", + "rank", // no argument vec![], // Expected causality, for None cases causality will be determined from window frame boundaries @@ -238,11 +240,9 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::DenseRank, - ), + WindowFunctionDefinition::WindowUDF(dense_rank_udwf()), // its name - "DENSE_RANK", + "dense_rank", // no argument vec![], // Expected causality, for None cases causality will be determined from window frame boundaries @@ -382,19 +382,12 @@ fn get_random_function( ); window_fn_map.insert( "rank", - ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Rank, - ), - vec![], - ), + (WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![]), ); window_fn_map.insert( "dense_rank", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::DenseRank, - ), + WindowFunctionDefinition::WindowUDF(dense_rank_udwf()), vec![], ), ); diff --git a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs deleted file mode 100644 index bbf4dcd2b799..000000000000 --- a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs +++ /dev/null @@ -1,325 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Tests for the physical optimizer - -use datafusion_common::config::ConfigOptions; -use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics; -use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::aggregates::AggregateExec; -use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::ExecutionPlan; -use std::sync::Arc; - -use datafusion::error::Result; -use datafusion::logical_expr::Operator; -use datafusion::prelude::SessionContext; -use datafusion::test_util::TestAggregate; -use datafusion_physical_plan::aggregates::PhysicalGroupBy; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; -use datafusion_physical_plan::common; -use datafusion_physical_plan::filter::FilterExec; -use datafusion_physical_plan::memory::MemoryExec; - -use arrow::array::Int32Array; -use arrow::datatypes::{DataType, Field, Schema}; -use arrow::record_batch::RecordBatch; -use datafusion_common::cast::as_int64_array; -use datafusion_physical_expr::expressions::{self, cast}; -use datafusion_physical_plan::aggregates::AggregateMode; - -/// Mock data using a MemoryExec which has an exact count statistic -fn mock_data() -> Result> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ])); - - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![Some(1), Some(2), None])), - Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])), - ], - )?; - - Ok(Arc::new(MemoryExec::try_new( - &[vec![batch]], - Arc::clone(&schema), - None, - )?)) -} - -/// Checks that the count optimization was applied and we still get the right result -async fn assert_count_optim_success( - plan: AggregateExec, - agg: TestAggregate, -) -> Result<()> { - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); - let plan: Arc = Arc::new(plan); - - let optimized = - AggregateStatistics::new().optimize(Arc::clone(&plan), state.config_options())?; - - // A ProjectionExec is a sign that the count optimization was applied - assert!(optimized.as_any().is::()); - - // run both the optimized and nonoptimized plan - let optimized_result = - common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?; - let nonoptimized_result = - common::collect(plan.execute(0, session_ctx.task_ctx())?).await?; - assert_eq!(optimized_result.len(), nonoptimized_result.len()); - - // and validate the results are the same and expected - assert_eq!(optimized_result.len(), 1); - check_batch(optimized_result.into_iter().next().unwrap(), &agg); - // check the non optimized one too to ensure types and names remain the same - assert_eq!(nonoptimized_result.len(), 1); - check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg); - - Ok(()) -} - -fn check_batch(batch: RecordBatch, agg: &TestAggregate) { - let schema = batch.schema(); - let fields = schema.fields(); - assert_eq!(fields.len(), 1); - - let field = &fields[0]; - assert_eq!(field.name(), agg.column_name()); - assert_eq!(field.data_type(), &DataType::Int64); - // note that nullabiolity differs - - assert_eq!( - as_int64_array(batch.column(0)).unwrap().values(), - &[agg.expected_count()] - ); -} - -#[tokio::test] -async fn test_count_partial_direct_child() -> Result<()> { - // basic test case with the aggregation applied on a source with exact statistics - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_count_partial_with_nulls_direct_child() -> Result<()> { - // basic test case with the aggregation applied on a source with exact statistics - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_column(&schema); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_count_partial_indirect_child() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - // We introduce an intermediate optimization step between the partial and final aggregtator - let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(coalesce), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_column(&schema); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - source, - Arc::clone(&schema), - )?; - - // We introduce an intermediate optimization step between the partial and final aggregtator - let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(coalesce), - Arc::clone(&schema), - )?; - - assert_count_optim_success(final_agg, agg).await?; - - Ok(()) -} - -#[tokio::test] -async fn test_count_inexact_stat() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - // adding a filter makes the statistics inexact - let filter = Arc::new(FilterExec::try_new( - expressions::binary( - expressions::col("a", &schema)?, - Operator::Gt, - cast(expressions::lit(1u32), &schema, DataType::Int32)?, - &schema, - )?, - source, - )?); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - filter, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - let conf = ConfigOptions::new(); - let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; - - // check that the original ExecutionPlan was not replaced - assert!(optimized.as_any().is::()); - - Ok(()) -} - -#[tokio::test] -async fn test_count_with_nulls_inexact_stat() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_column(&schema); - - // adding a filter makes the statistics inexact - let filter = Arc::new(FilterExec::try_new( - expressions::binary( - expressions::col("a", &schema)?, - Operator::Gt, - cast(expressions::lit(1u32), &schema, DataType::Int32)?, - &schema, - )?, - source, - )?); - - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - filter, - Arc::clone(&schema), - )?; - - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![agg.count_expr(&schema)], - vec![None], - Arc::new(partial_agg), - Arc::clone(&schema), - )?; - - let conf = ConfigOptions::new(); - let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; - - // check that the original ExecutionPlan was not replaced - assert!(optimized.as_any().is::()); - - Ok(()) -} diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index 4ec981bf2a74..c06783aa0277 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -mod aggregate_statistics; mod combine_partial_final_agg; mod limit_pushdown; mod limited_distinct_aggregation; diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index d96bb23953ae..3760328934bc 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -36,6 +36,7 @@ use datafusion_expr::{ PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// A query with a window function evaluated over the entire partition const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \ @@ -552,7 +553,10 @@ impl OddCounter { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state)))) } diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index dcd59acbd49e..d87ce1ebfed7 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -68,11 +68,35 @@ pub use pool::*; /// Note that a `MemoryPool` can be shared by concurrently executing plans, /// which can be used to control memory usage in a multi-tenant system. /// +/// # How MemoryPool works by example +/// +/// Scenario 1: +/// For `Filter` operator, `RecordBatch`es will stream through it, so it +/// don't have to keep track of memory usage through [`MemoryPool`]. +/// +/// Scenario 2: +/// For `CrossJoin` operator, if the input size gets larger, the intermediate +/// state will also grow. So `CrossJoin` operator will use [`MemoryPool`] to +/// limit the memory usage. +/// 2.1 `CrossJoin` operator has read a new batch, asked memory pool for +/// additional memory. Memory pool updates the usage and returns success. +/// 2.2 `CrossJoin` has read another batch, and tries to reserve more memory +/// again, memory pool does not have enough memory. Since `CrossJoin` operator +/// has not implemented spilling, it will stop execution and return an error. +/// +/// Scenario 3: +/// For `Aggregate` operator, its intermediate states will also accumulate as +/// the input size gets larger, but with spilling capability. When it tries to +/// reserve more memory from the memory pool, and the memory pool has already +/// reached the memory limit, it will return an error. Then, `Aggregate` +/// operator will spill the intermediate buffers to disk, and release memory +/// from the memory pool, and continue to retry memory reservation. +/// /// # Implementing `MemoryPool` /// /// You can implement a custom allocation policy by implementing the /// [`MemoryPool`] trait and configuring a `SessionContext` appropriately. -/// However, mDataFusion comes with the following simple memory pool implementations that +/// However, DataFusion comes with the following simple memory pool implementations that /// handle many common cases: /// /// * [`UnboundedMemoryPool`]: no memory limits (the default) diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 8e81c51d8460..d19d8e469688 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -90,6 +90,11 @@ impl EmitTo { /// faster for queries with many group values. See the [Aggregating Millions of /// Groups Fast blog] for more background. /// +/// [`NullState`] can help keep the state for groups that have not seen any +/// values and produce the correct output for those groups. +/// +/// [`NullState`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/struct.NullState.html +/// /// # Details /// Each group is assigned a `group_index` by the hash table and each /// accumulator manages the specific state, one per `group_index`. @@ -117,6 +122,11 @@ pub trait GroupsAccumulator: Send { /// /// Note that subsequent calls to update_batch may have larger /// total_num_groups as new groups are seen. + /// + /// See [`NullState`] to help keep the state for groups that have not seen any + /// values and produce the correct output for those groups. + /// + /// [`NullState`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/struct.NullState.html fn update_batch( &mut self, values: &[ArrayRef], diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 6424888c896a..e76453d91a8e 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -1753,7 +1753,7 @@ impl NullableInterval { } _ => Ok(Self::MaybeNull { values }), } - } else if op.is_comparison_operator() { + } else if op.supports_propagation() { Ok(Self::Null { datatype: DataType::Boolean, }) diff --git a/datafusion/expr-common/src/operator.rs b/datafusion/expr-common/src/operator.rs index e013b6fafa22..6ca0f04897ac 100644 --- a/datafusion/expr-common/src/operator.rs +++ b/datafusion/expr-common/src/operator.rs @@ -142,10 +142,11 @@ impl Operator { ) } - /// Return true if the operator is a comparison operator. + /// Return true if the comparison operator can be used in interval arithmetic and constraint + /// propagation /// - /// For example, 'Binary(a, >, b)' would be a comparison expression. - pub fn is_comparison_operator(&self) -> bool { + /// For example, 'Binary(a, >, b)' expression supports propagation. + pub fn supports_propagation(&self) -> bool { matches!( self, Operator::Eq @@ -163,6 +164,15 @@ impl Operator { ) } + /// Return true if the comparison operator can be used in interval arithmetic and constraint + /// propagation + /// + /// For example, 'Binary(a, >, b)' expression supports propagation. + #[deprecated(since = "43.0.0", note = "please use `supports_propagation` instead")] + pub fn is_comparison_operator(&self) -> bool { + self.supports_propagation() + } + /// Return true if the operator is a logic operator. /// /// For example, 'Binary(Binary(a, >, b), AND, Binary(a, <, b + 3))' would diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index d1553b3315e7..320e1303a21b 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -125,6 +125,11 @@ pub enum TypeSignature { /// Fixed number of arguments of numeric types. /// See to know which type is considered numeric Numeric(usize), + /// Fixed number of arguments of all the same string types. + /// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8. + /// Null is considerd as Utf8 by default + /// Dictionary with string value type is also handled. + String(usize), } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -190,8 +195,11 @@ impl TypeSignature { .collect::>() .join(", ")] } + TypeSignature::String(num) => { + vec![format!("String({num})")] + } TypeSignature::Numeric(num) => { - vec![format!("Numeric({})", num)] + vec![format!("Numeric({num})")] } TypeSignature::Exact(types) | TypeSignature::Coercible(types) => { vec![Self::join_types(types, ", ")] @@ -280,6 +288,14 @@ impl Signature { } } + /// A specified number of numeric arguments + pub fn string(arg_count: usize, volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::String(arg_count), + volatility, + } + } + /// An arbitrary number of arguments of any type. pub fn variadic_any(volatility: Volatility) -> Self { Self { diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index e66a9ae1ea98..e042dd5d3ac6 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -25,8 +25,8 @@ use crate::operator::Operator; use arrow::array::{new_empty_array, Array}; use arrow::compute::can_cast_types; use arrow::datatypes::{ - DataType, Field, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION, + DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result}; @@ -370,6 +370,8 @@ impl From<&DataType> for TypeCategory { /// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules /// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted /// decimal precision and scale when coercing decimal types. +/// +/// This function doesn't preserve correct field name and nullability for the struct type, we only care about data type. pub fn type_union_resolution(data_types: &[DataType]) -> Option { if data_types.is_empty() { return None; @@ -471,10 +473,56 @@ fn type_union_resolution_coercion( let new_value_type = type_union_resolution_coercion(value_type, other_type); new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t))) } + (DataType::List(lhs), DataType::List(rhs)) => { + let new_item_type = + type_union_resolution_coercion(lhs.data_type(), rhs.data_type()); + new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true)))) + } + (DataType::Struct(lhs), DataType::Struct(rhs)) => { + if lhs.len() != rhs.len() { + return None; + } + + // Search the field in the right hand side with the SAME field name + fn search_corresponding_coerced_type( + lhs_field: &FieldRef, + rhs: &Fields, + ) -> Option { + for rhs_field in rhs.iter() { + if lhs_field.name() == rhs_field.name() { + if let Some(t) = type_union_resolution_coercion( + lhs_field.data_type(), + rhs_field.data_type(), + ) { + return Some(t); + } else { + return None; + } + } + } + + None + } + + let types = lhs + .iter() + .map(|lhs_field| search_corresponding_coerced_type(lhs_field, rhs)) + .collect::>>()?; + + let fields = types + .into_iter() + .enumerate() + .map(|(i, datatype)| { + Arc::new(Field::new(format!("c{i}"), datatype, true)) + }) + .collect::>(); + Some(DataType::Struct(fields.into())) + } _ => { // numeric coercion is the same as comparison coercion, both find the narrowest type // that can accommodate both types binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) } @@ -507,22 +555,6 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { - if lhs_type == rhs_type { - // same type => equality is possible - return Some(lhs_type.clone()); - } - binary_numeric_coercion(lhs_type, rhs_type) - .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) - .or_else(|| string_coercion(lhs_type, rhs_type)) - .or_else(|| binary_coercion(lhs_type, rhs_type)) -} - /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation /// where one is numeric and one is `Utf8`/`LargeUtf8`. fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { @@ -969,7 +1001,7 @@ fn string_concat_internal_coercion( /// based on the observation that StringArray to StringViewArray is cheap but not vice versa. /// /// Between Utf8 and LargeUtf8, we coerce to LargeUtf8. -fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { +pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { // If Utf8View is in any side, we coerce to Utf8View. diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs index b136d6cacec8..117ff08253b6 100644 --- a/datafusion/expr/src/built_in_window_function.rs +++ b/datafusion/expr/src/built_in_window_function.rs @@ -40,12 +40,6 @@ impl fmt::Display for BuiltInWindowFunction { /// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum BuiltInWindowFunction { - /// rank of the current row with gaps; same as row_number of its first peer - Rank, - /// rank of the current row without gaps; this function counts peer groups - DenseRank, - /// relative rank of the current row: (rank - 1) / (total rows - 1) - PercentRank, /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) CumeDist, /// integer ranging from 1 to the argument value, dividing the partition as equally as possible @@ -72,9 +66,6 @@ impl BuiltInWindowFunction { pub fn name(&self) -> &str { use BuiltInWindowFunction::*; match self { - Rank => "RANK", - DenseRank => "DENSE_RANK", - PercentRank => "PERCENT_RANK", CumeDist => "CUME_DIST", Ntile => "NTILE", Lag => "LAG", @@ -90,9 +81,6 @@ impl FromStr for BuiltInWindowFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { Ok(match name.to_uppercase().as_str() { - "RANK" => BuiltInWindowFunction::Rank, - "DENSE_RANK" => BuiltInWindowFunction::DenseRank, - "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, "CUME_DIST" => BuiltInWindowFunction::CumeDist, "NTILE" => BuiltInWindowFunction::Ntile, "LAG" => BuiltInWindowFunction::Lag, @@ -127,12 +115,8 @@ impl BuiltInWindowFunction { })?; match self { - BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), - BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { - Ok(DataType::Float64) - } + BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), + BuiltInWindowFunction::CumeDist => Ok(DataType::Float64), BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead | BuiltInWindowFunction::FirstValue @@ -145,10 +129,7 @@ impl BuiltInWindowFunction { pub fn signature(&self) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match self { - BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::PercentRank - | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), + BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { Signature::one_of( vec![ diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 02a2edb98016..723433f57341 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2598,15 +2598,6 @@ mod test { Ok(()) } - #[test] - fn test_percent_rank_return_type() -> Result<()> { - let fun = find_df_window_func("percent_rank").unwrap(); - let observed = fun.return_type(&[], &[], "")?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - #[test] fn test_cume_dist_return_type() -> Result<()> { let fun = find_df_window_func("cume_dist").unwrap(); @@ -2628,9 +2619,6 @@ mod test { #[test] fn test_window_function_case_insensitive() -> Result<()> { let names = vec![ - "rank", - "dense_rank", - "percent_rank", "cume_dist", "ntile", "lag", diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 2975e36488dc..7fd4e64e0e62 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -27,8 +27,8 @@ use crate::function::{ }; use crate::{ conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, - AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, - Signature, Volatility, + AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, + ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; use crate::{ AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, @@ -39,6 +39,7 @@ use arrow::compute::kernels::cast_utils::{ use arrow::datatypes::{DataType, Field}; use datafusion_common::{plan_err, Column, Result, ScalarValue, TableReference}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::Debug; @@ -658,7 +659,10 @@ impl WindowUDFImpl for SimpleWindowUDF { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { (self.partition_evaluator_factory)() } @@ -697,7 +701,6 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { /// # use datafusion_expr::test::function_stub::count; /// # use sqlparser::ast::NullTreatment; /// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col}; -/// # use datafusion_expr::window_function::percent_rank; /// # // first_value is an aggregate function in another crate /// # fn first_value(_arg: Expr) -> Expr { /// unimplemented!() } @@ -717,6 +720,9 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { /// // Create a window expression for percent rank partitioned on column a /// // equivalent to: /// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)` +/// // percent_rank is an udwf function in another crate +/// # fn percent_rank() -> Expr { +/// unimplemented!() } /// let window = percent_rank() /// .partition_by(vec![col("a")]) /// .order_by(vec![col("b").sort(true, true)]) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index cc8ddf8ec8e8..da2a96327ce5 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -35,7 +35,6 @@ use crate::logical_plan::{ Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, Window, }; -use crate::type_coercion::binary::values_coercion; use crate::utils::{ can_hash, columnize_expr, compare_sort_expr, expr_to_columns, find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, @@ -53,6 +52,7 @@ use datafusion_common::{ plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; +use datafusion_expr_common::type_coercion::binary::type_union_resolution; use super::dml::InsertOp; use super::plan::{ColumnUnnestList, ColumnUnnestType}; @@ -209,7 +209,8 @@ impl LogicalPlanBuilder { } if let Some(prev_type) = common_type { // get common type of each column values. - let Some(new_type) = values_coercion(&data_type, &prev_type) else { + let data_types = vec![prev_type.clone(), data_type.clone()]; + let Some(new_type) = type_union_resolution(&data_types) else { return plan_err!("Inconsistent data type across values list at row {i} column {j}. Was {prev_type} but found {data_type}"); }; common_type = Some(new_type); diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 19e73140b75c..0292274e57ee 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -21,7 +21,7 @@ use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use super::dml::CopyTo; use super::DdlStatement; @@ -2965,6 +2965,15 @@ impl Aggregate { .into_iter() .map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into())) .collect::>(); + qualified_fields.push(( + None, + Field::new( + Self::INTERNAL_GROUPING_ID, + Self::grouping_id_type(qualified_fields.len()), + false, + ) + .into(), + )); } qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?); @@ -3016,9 +3025,19 @@ impl Aggregate { }) } + fn is_grouping_set(&self) -> bool { + matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)]) + } + /// Get the output expressions. fn output_expressions(&self) -> Result> { + static INTERNAL_ID_EXPR: OnceLock = OnceLock::new(); let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?; + if self.is_grouping_set() { + exprs.push(INTERNAL_ID_EXPR.get_or_init(|| { + Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID)) + })); + } exprs.extend(self.aggr_expr.iter()); debug_assert!(exprs.len() == self.schema.fields().len()); Ok(exprs) @@ -3030,6 +3049,41 @@ impl Aggregate { pub fn group_expr_len(&self) -> Result { grouping_set_expr_count(&self.group_expr) } + + /// Returns the data type of the grouping id. + /// The grouping ID value is a bitmask where each set bit + /// indicates that the corresponding grouping expression is + /// null + pub fn grouping_id_type(group_exprs: usize) -> DataType { + if group_exprs <= 8 { + DataType::UInt8 + } else if group_exprs <= 16 { + DataType::UInt16 + } else if group_exprs <= 32 { + DataType::UInt32 + } else { + DataType::UInt64 + } + } + + /// Internal column used when the aggregation is a grouping set. + /// + /// This column contains a bitmask where each bit represents a grouping + /// expression. The least significant bit corresponds to the rightmost + /// grouping expression. A bit value of 0 indicates that the corresponding + /// column is included in the grouping set, while a value of 1 means it is excluded. + /// + /// For example, for the grouping expressions CUBE(a, b), the grouping ID + /// column will have the following values: + /// 0b00: Both `a` and `b` are included + /// 0b01: `b` is excluded + /// 0b10: `a` is excluded + /// 0b11: Both `a` and `b` are excluded + /// + /// This internal column is necessary because excluded columns are replaced + /// with `NULL` values. To handle these cases correctly, we must distinguish + /// between an actual `NULL` value in a column and a column being excluded from the set. + pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id"; } // Manual implementation needed because of `schema` field. Comparison excludes this field. diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 8ac6ad372482..85f8e20ba4a5 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -26,8 +26,9 @@ use datafusion_common::{ utils::{coerced_fixed_size_list_to_list, list_ndims}, Result, }; -use datafusion_expr_common::signature::{ - ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD, +use datafusion_expr_common::{ + signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, + type_coercion::binary::string_coercion, }; use std::sync::Arc; @@ -167,6 +168,21 @@ pub fn data_types( try_coerce_types(valid_types, current_types, &signature.type_signature) } +fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { + if let TypeSignature::OneOf(signatures) = type_signature { + return signatures.iter().all(is_well_supported_signature); + } + + matches!( + type_signature, + TypeSignature::UserDefined + | TypeSignature::Numeric(_) + | TypeSignature::String(_) + | TypeSignature::Coercible(_) + | TypeSignature::Any(_) + ) +} + fn try_coerce_types( valid_types: Vec>, current_types: &[DataType], @@ -175,14 +191,7 @@ fn try_coerce_types( let mut valid_types = valid_types; // Well-supported signature that returns exact valid types. - if !valid_types.is_empty() - && matches!( - type_signature, - TypeSignature::UserDefined - | TypeSignature::Numeric(_) - | TypeSignature::Coercible(_) - ) - { + if !valid_types.is_empty() && is_well_supported_signature(type_signature) { // exact valid types assert_eq!(valid_types.len(), 1); let valid_types = valid_types.swap_remove(0); @@ -212,20 +221,37 @@ fn get_valid_types_with_scalar_udf( current_types: &[DataType], func: &ScalarUDF, ) -> Result>> { - let valid_types = match signature { + match signature { TypeSignature::UserDefined => match func.coerce_types(current_types) { - Ok(coerced_types) => vec![coerced_types], - Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), + Ok(coerced_types) => Ok(vec![coerced_types]), + Err(e) => exec_err!("User-defined coercion failed with {:?}", e), }, - TypeSignature::OneOf(signatures) => signatures - .iter() - .filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, func).ok()) - .flatten() - .collect::>(), - _ => get_valid_types(signature, current_types)?, - }; + TypeSignature::OneOf(signatures) => { + let mut res = vec![]; + let mut errors = vec![]; + for sig in signatures { + match get_valid_types_with_scalar_udf(sig, current_types, func) { + Ok(valid_types) => { + res.extend(valid_types); + } + Err(e) => { + errors.push(e.to_string()); + } + } + } - Ok(valid_types) + // Every signature failed, return the joined error + if res.is_empty() { + internal_err!( + "Failed to match any signature, errors: {}", + errors.join(",") + ) + } else { + Ok(res) + } + } + _ => get_valid_types(signature, current_types), + } } fn get_valid_types_with_aggregate_udf( @@ -374,6 +400,67 @@ fn get_valid_types( .iter() .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) .collect(), + TypeSignature::String(number) => { + if *number < 1 { + return plan_err!( + "The signature expected at least one argument but received {}", + current_types.len() + ); + } + if *number != current_types.len() { + return plan_err!( + "The signature expected {} arguments but received {}", + number, + current_types.len() + ); + } + + fn coercion_rule( + lhs_type: &DataType, + rhs_type: &DataType, + ) -> Result { + match (lhs_type, rhs_type) { + (DataType::Null, DataType::Null) => Ok(DataType::Utf8), + (DataType::Null, data_type) | (data_type, DataType::Null) => { + coercion_rule(data_type, &DataType::Utf8) + } + (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => { + coercion_rule(lhs, rhs) + } + (DataType::Dictionary(_, v), other) + | (other, DataType::Dictionary(_, v)) => coercion_rule(v, other), + _ => { + if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) { + Ok(coerced_type) + } else { + plan_err!( + "{} and {} are not coercible to a common string type", + lhs_type, + rhs_type + ) + } + } + } + } + + // Length checked above, safe to unwrap + let mut coerced_type = current_types.first().unwrap().to_owned(); + for t in current_types.iter().skip(1) { + coerced_type = coercion_rule(&coerced_type, t)?; + } + + fn base_type_or_default_type(data_type: &DataType) -> DataType { + if data_type.is_null() { + DataType::Utf8 + } else if let DataType::Dictionary(_, v) = data_type { + base_type_or_default_type(v) + } else { + data_type.to_owned() + } + } + + vec![vec![base_type_or_default_type(&coerced_type); *number]] + } TypeSignature::Numeric(number) => { if *number < 1 { return plan_err!( diff --git a/datafusion/expr/src/udf_docs.rs b/datafusion/expr/src/udf_docs.rs index 280910b87199..e8245588d945 100644 --- a/datafusion/expr/src/udf_docs.rs +++ b/datafusion/expr/src/udf_docs.rs @@ -131,6 +131,9 @@ impl DocumentationBuilder { self } + /// Adds documentation for a specific argument to the documentation. + /// + /// Arguments are displayed in the order they are added. pub fn with_argument( mut self, arg_name: impl Into, @@ -142,6 +145,27 @@ impl DocumentationBuilder { self } + /// Add a standard "expression" argument to the documentation + /// + /// This is similar to [`Self::with_argument`] except that a standard + /// description is appended to the end: `"Can be a constant, column, or + /// function, and any combination of arithmetic operators."` + /// + /// The argument is rendered like + /// + /// ```text + /// : + /// expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. + /// ``` + pub fn with_standard_argument( + self, + arg_name: impl Into, + expression_type: impl AsRef, + ) -> Self { + let expression_type = expression_type.as_ref(); + self.with_argument(arg_name, format!("{expression_type} expression to operate on. Can be a constant, column, or function, and any combination of operators.")) + } + pub fn with_related_udf(mut self, related_udf: impl Into) -> Self { let mut related = self.related_udfs.unwrap_or_default(); related.push(related_udf.into()); diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 6459e8f3f7d1..6d8f2be97e02 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -28,14 +28,14 @@ use std::{ use arrow::datatypes::{DataType, Field}; -use datafusion_common::{not_impl_err, Result}; -use datafusion_functions_window_common::field::WindowUDFFieldArgs; - use crate::expr::WindowFunction; use crate::{ function::WindowFunctionSimplification, Documentation, Expr, PartitionEvaluator, Signature, }; +use datafusion_common::{not_impl_err, Result}; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. @@ -150,8 +150,11 @@ impl WindowUDF { } /// Return a `PartitionEvaluator` for evaluating this window function - pub fn partition_evaluator_factory(&self) -> Result> { - self.inner.partition_evaluator() + pub fn partition_evaluator_factory( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + self.inner.partition_evaluator(partition_evaluator_args) } /// Returns the field of the final result of evaluating this window function. @@ -218,8 +221,9 @@ where /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt, Documentation}; /// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; -/// # use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; /// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; /// /// #[derive(Debug, Clone)] /// struct SmoothIt { @@ -254,7 +258,12 @@ where /// fn name(&self) -> &str { "smooth_it" } /// fn signature(&self) -> &Signature { &self.signature } /// // The actual implementation would smooth the window -/// fn partition_evaluator(&self) -> Result> { unimplemented!() } +/// fn partition_evaluator( +/// &self, +/// _partition_evaluator_args: PartitionEvaluatorArgs, +/// ) -> Result> { +/// unimplemented!() +/// } /// fn field(&self, field_args: WindowUDFFieldArgs) -> Result { /// if let Some(DataType::Int32) = field_args.get_input_type(0) { /// Ok(Field::new(field_args.name(), DataType::Int32, false)) @@ -294,7 +303,10 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn signature(&self) -> &Signature; /// Invoke the function, returning the [`PartitionEvaluator`] instance - fn partition_evaluator(&self) -> Result>; + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result>; /// Returns any aliases (alternate names) for this function. /// @@ -355,6 +367,10 @@ pub trait WindowUDFImpl: Debug + Send + Sync { } /// The [`Field`] of the final result of evaluating this window function. + /// + /// Call `field_args.name()` to get the fully qualified name for defining + /// the [`Field`]. For a complete example see the implementation in the + /// [Basic Example](WindowUDFImpl#basic-example) section. fn field(&self, field_args: WindowUDFFieldArgs) -> Result; /// Allows the window UDF to define a custom result ordering. @@ -464,8 +480,11 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { self.inner.signature() } - fn partition_evaluator(&self) -> Result> { - self.inner.partition_evaluator() + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + self.inner.partition_evaluator(partition_evaluator_args) } fn aliases(&self) -> &[String] { @@ -546,6 +565,7 @@ mod test { use datafusion_common::Result; use datafusion_expr_common::signature::{Signature, Volatility}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; + use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use std::any::Any; use std::cmp::Ordering; @@ -577,7 +597,10 @@ mod test { fn signature(&self) -> &Signature { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { unimplemented!() } fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { @@ -613,7 +636,10 @@ mod test { fn signature(&self) -> &Signature { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { unimplemented!() } fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 9bb53a1d04a0..02b36d0feab9 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -19,6 +19,7 @@ use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use std::sync::Arc; use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction}; @@ -60,7 +61,17 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result /// Count the number of distinct exprs in a list of group by expressions. If the /// first element is a `GroupingSet` expression then it must be the only expr. pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { - grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len()) + if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { + if group_expr.len() > 1 { + return plan_err!( + "Invalid group by expressions, GroupingSet must be the only expression" + ); + } + // Groupings sets have an additional interal column for the grouping id + Ok(grouping_set.distinct_expr().len() + 1) + } else { + grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len()) + } } /// The [power set] (or powerset) of a set S is the set of all subsets of S, \ @@ -754,6 +765,15 @@ pub fn find_base_plan(input: &LogicalPlan) -> &LogicalPlan { match input { LogicalPlan::Window(window) => find_base_plan(&window.input), LogicalPlan::Aggregate(agg) => find_base_plan(&agg.input), + // [SqlToRel::try_process_unnest] will convert Expr(Unnest(Expr)) to Projection/Unnest/Projection + // We should expand the wildcard expression based on the input plan of the inner Projection. + LogicalPlan::Unnest(unnest) => { + if let LogicalPlan::Projection(projection) = unnest.input.deref() { + find_base_plan(&projection.input) + } else { + input + } + } LogicalPlan::Filter(filter) => { if filter.having { // If a filter is used for a having clause, its input plan is an aggregation. diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index a80718147c3a..7ac6fb7d167c 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -19,27 +19,6 @@ use datafusion_common::ScalarValue; use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; -/// Create an expression to represent the `rank` window function -pub fn rank() -> Expr { - Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Rank, vec![])) -} - -/// Create an expression to represent the `dense_rank` window function -pub fn dense_rank() -> Expr { - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::DenseRank, - vec![], - )) -} - -/// Create an expression to represent the `percent_rank` window function -pub fn percent_rank() -> Expr { - Expr::WindowFunction(WindowFunction::new( - BuiltInWindowFunction::PercentRank, - vec![], - )) -} - /// Create an expression to represent the `cume_dist` window function pub fn cume_dist() -> Expr { Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::CumeDist, vec![])) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index fbbf4d303515..b03df0224089 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -23,6 +23,7 @@ pub mod bool_op; pub mod nulls; pub mod prim_op; +use arrow::array::new_empty_array; use arrow::{ array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}, compute, @@ -405,6 +406,18 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { ) -> Result> { let num_rows = values[0].len(); + // If there are no rows, return empty arrays + if num_rows == 0 { + // create empty accumulator to get the state types + let empty_state = (self.factory)()?.state()?; + let empty_arrays = empty_state + .into_iter() + .map(|state_val| new_empty_array(&state_val.data_type())) + .collect::>(); + + return Ok(empty_arrays); + } + // Each row has its respective group let mut results = vec![]; for row_idx in 0..num_rows { diff --git a/datafusion/functions-aggregate-common/src/tdigest.rs b/datafusion/functions-aggregate-common/src/tdigest.rs index 620a68e83ecd..e6723b54b372 100644 --- a/datafusion/functions-aggregate-common/src/tdigest.rs +++ b/datafusion/functions-aggregate-common/src/tdigest.rs @@ -644,7 +644,9 @@ impl TDigest { let max = cast_scalar_f64!(&state[3]); let min = cast_scalar_f64!(&state[4]); - assert!(max.total_cmp(&min).is_ge()); + if min.is_finite() && max.is_finite() { + assert!(max.total_cmp(&min).is_ge()); + } Self { max_size, diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index ce36e09bc25b..c5382c168f17 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -142,10 +142,7 @@ fn get_bit_and_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_GENERAL) .with_description("Computes the bitwise AND of all non-null input values.") .with_syntax_example("bit_and(expression)") - .with_argument( - "expression", - "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators.", - ) + .with_standard_argument("expression", "Integer") .build() .unwrap() }) @@ -159,10 +156,7 @@ fn get_bit_or_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_GENERAL) .with_description("Computes the bitwise OR of all non-null input values.") .with_syntax_example("bit_or(expression)") - .with_argument( - "expression", - "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators.", - ) + .with_standard_argument("expression", "Integer") .build() .unwrap() }) @@ -174,12 +168,11 @@ fn get_bit_xor_doc() -> &'static Documentation { BIT_XOR_DOC.get_or_init(|| { Documentation::builder() .with_doc_section(DOC_SECTION_GENERAL) - .with_description("Computes the bitwise exclusive OR of all non-null input values.") - .with_syntax_example("bit_xor(expression)") - .with_argument( - "expression", - "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators.", + .with_description( + "Computes the bitwise exclusive OR of all non-null input values.", ) + .with_syntax_example("bit_xor(expression)") + .with_standard_argument("expression", "Integer") .build() .unwrap() }) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 1ce1abe09ea8..7d2e8e66e2cd 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -17,21 +17,6 @@ //! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function //! [`Min`] and [`MinAccumulator`] accumulator for the `min` function -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - use arrow::array::{ ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array, @@ -1240,26 +1225,24 @@ impl Accumulator for SlidingMinAccumulator { } } -// -// Moving min and moving max -// The implementation is taken from https://github.com/spebern/moving_min_max/blob/master/src/lib.rs. - -// Keep track of the minimum or maximum value in a sliding window. -// -// `moving min max` provides one data structure for keeping track of the -// minimum value and one for keeping track of the maximum value in a sliding -// window. -// -// Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty, -// push to this stack all elements popped from first stack while updating their current min/max. Now pop from -// the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue, -// look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values. -// -// The complexity of the operations are -// - O(1) for getting the minimum/maximum -// - O(1) for push -// - amortized O(1) for pop - +/// Keep track of the minimum value in a sliding window. +/// +/// The implementation is taken from +/// +/// `moving min max` provides one data structure for keeping track of the +/// minimum value and one for keeping track of the maximum value in a sliding +/// window. +/// +/// Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty, +/// push to this stack all elements popped from first stack while updating their current min/max. Now pop from +/// the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue, +/// look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values. +/// +/// The complexity of the operations are +/// - O(1) for getting the minimum/maximum +/// - O(1) for push +/// - amortized O(1) for pop +/// /// ``` /// # use datafusion_functions_aggregate::min_max::MovingMin; /// let mut moving_min = MovingMin::::new(); @@ -1375,6 +1358,11 @@ impl MovingMin { self.len() == 0 } } + +/// Keep track of the maximum value in a sliding window. +/// +/// See [`MovingMin`] for more details. +/// /// ``` /// # use datafusion_functions_aggregate::min_max::MovingMax; /// let mut moving_max = MovingMax::::new(); diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 3648ec0d1312..49a30344c212 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -18,22 +18,24 @@ //! [`VarianceSample`]: variance sample aggregations. //! [`VariancePopulation`]: variance population aggregations. -use std::{fmt::Debug, sync::Arc}; - use arrow::{ array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, buffer::NullBuffer, compute::kernels::cast, datatypes::{DataType, Field}, }; +use std::sync::OnceLock; +use std::{fmt::Debug, sync::Arc}; use datafusion_common::{ downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, - Accumulator, AggregateUDFImpl, GroupsAccumulator, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, + Volatility, }; use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, @@ -135,6 +137,26 @@ impl AggregateUDFImpl for VarianceSample { ) -> Result> { Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_variance_sample_doc()) + } +} + +static VARIANCE_SAMPLE_DOC: OnceLock = OnceLock::new(); + +fn get_variance_sample_doc() -> &'static Documentation { + VARIANCE_SAMPLE_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the statistical sample variance of a set of numbers.", + ) + .with_syntax_example("var(expression)") + .with_standard_argument("expression", "Numeric") + .build() + .unwrap() + }) } pub struct VariancePopulation { @@ -222,6 +244,25 @@ impl AggregateUDFImpl for VariancePopulation { StatsType::Population, ))) } + fn documentation(&self) -> Option<&Documentation> { + Some(get_variance_population_doc()) + } +} + +static VARIANCE_POPULATION_DOC: OnceLock = OnceLock::new(); + +fn get_variance_population_doc() -> &'static Documentation { + VARIANCE_POPULATION_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the statistical population variance of a set of numbers.", + ) + .with_syntax_example("var_pop(expression)") + .with_standard_argument("expression", "Numeric") + .build() + .unwrap() + }) } /// An accumulator to compute variance diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index 79858041d3ca..cafa073f9191 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -17,6 +17,7 @@ //! [`ScalarUDFImpl`] definitions for `make_array` function. +use std::vec; use std::{any::Any, sync::Arc}; use arrow::array::{ArrayData, Capacities, MutableArrayData}; @@ -26,11 +27,12 @@ use arrow_array::{ use arrow_buffer::OffsetBuffer; use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, Field}; -use datafusion_common::internal_err; +use datafusion_common::{exec_err, internal_err}; use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result}; -use datafusion_expr::type_coercion::binary::comparison_coercion; +use datafusion_expr::binary::type_union_resolution; use datafusion_expr::TypeSignature; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use itertools::Itertools; use crate::utils::make_scalar_function; @@ -82,19 +84,12 @@ impl ScalarUDFImpl for MakeArray { match arg_types.len() { 0 => Ok(empty_array_type()), _ => { - let mut expr_type = DataType::Null; - for arg_type in arg_types { - if !arg_type.equals_datatype(&DataType::Null) { - expr_type = arg_type.clone(); - break; - } - } - - if expr_type.is_null() { - expr_type = DataType::Int64; - } - - Ok(List(Arc::new(Field::new("item", expr_type, true)))) + // At this point, all the type in array should be coerced to the same one + Ok(List(Arc::new(Field::new( + "item", + arg_types[0].to_owned(), + true, + )))) } } } @@ -112,23 +107,68 @@ impl ScalarUDFImpl for MakeArray { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let new_type = arg_types.iter().skip(1).try_fold( - arg_types.first().unwrap().clone(), - |acc, x| { - // The coerced types found by `comparison_coercion` are not guaranteed to be - // coercible for the arguments. `comparison_coercion` returns more loose - // types that can be coerced to both `acc` and `x` for comparison purpose. - // See `maybe_data_types` for the actual coercion. - let coerced_type = comparison_coercion(&acc, x); - if let Some(coerced_type) = coerced_type { - Ok(coerced_type) + if let Some(new_type) = type_union_resolution(arg_types) { + // TODO: Move the logic to type_union_resolution if this applies to other functions as well + // Handle struct where we only change the data type but preserve the field name and nullability. + // Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1" + let is_struct_and_has_same_key = are_all_struct_and_have_same_key(arg_types)?; + if is_struct_and_has_same_key { + let data_types: Vec<_> = if let DataType::Struct(fields) = &arg_types[0] { + fields.iter().map(|f| f.data_type().to_owned()).collect() } else { - internal_err!("Coercion from {acc:?} to {x:?} failed.") + return internal_err!("Struct type is checked is the previous function, so this should be unreachable"); + }; + + let mut final_struct_types = vec![]; + for s in arg_types { + let mut new_fields = vec![]; + if let DataType::Struct(fields) = s { + for (i, f) in fields.iter().enumerate() { + let field = Arc::unwrap_or_clone(Arc::clone(f)) + .with_data_type(data_types[i].to_owned()); + new_fields.push(Arc::new(field)); + } + } + final_struct_types.push(DataType::Struct(new_fields.into())) + } + return Ok(final_struct_types); + } + + if let DataType::FixedSizeList(field, _) = new_type { + Ok(vec![DataType::List(field); arg_types.len()]) + } else if new_type.is_null() { + Ok(vec![DataType::Int64; arg_types.len()]) + } else { + Ok(vec![new_type; arg_types.len()]) + } + } else { + plan_err!( + "Fail to find the valid type between {:?} for {}", + arg_types, + self.name() + ) + } + } +} + +fn are_all_struct_and_have_same_key(data_types: &[DataType]) -> Result { + let mut keys_string: Option = None; + for data_type in data_types { + if let DataType::Struct(fields) = data_type { + let keys = fields.iter().map(|f| f.name().to_owned()).join(","); + if let Some(ref k) = keys_string { + if *k != keys { + return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys); } - }, - )?; - Ok(vec![new_type; arg_types.len()]) + } else { + keys_string = Some(keys); + } + } else { + return Ok(false); + } } + + Ok(true) } // Empty array is a special case that is useful for many other array functions diff --git a/datafusion/functions-window-common/Cargo.toml b/datafusion/functions-window-common/Cargo.toml index 98b6f8c6dba5..b5df212b7d2a 100644 --- a/datafusion/functions-window-common/Cargo.toml +++ b/datafusion/functions-window-common/Cargo.toml @@ -39,3 +39,4 @@ path = "src/lib.rs" [dependencies] datafusion-common = { workspace = true } +datafusion-physical-expr-common = { workspace = true } diff --git a/datafusion/functions-window-common/src/lib.rs b/datafusion/functions-window-common/src/lib.rs index 2e4bcbbc83b9..53f9eb1c9ac6 100644 --- a/datafusion/functions-window-common/src/lib.rs +++ b/datafusion/functions-window-common/src/lib.rs @@ -19,3 +19,4 @@ //! //! [DataFusion]: pub mod field; +pub mod partition; diff --git a/datafusion/functions-window-common/src/partition.rs b/datafusion/functions-window-common/src/partition.rs new file mode 100644 index 000000000000..64786d2fe7c7 --- /dev/null +++ b/datafusion/functions-window-common/src/partition.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +/// Arguments passed to created user-defined window function state +/// during physical execution. +#[derive(Debug, Default)] +pub struct PartitionEvaluatorArgs<'a> { + /// The expressions passed as arguments to the user-defined window + /// function. + input_exprs: &'a [Arc], + /// The corresponding data types of expressions passed as arguments + /// to the user-defined window function. + input_types: &'a [DataType], + /// Set to `true` if the user-defined window function is reversed. + is_reversed: bool, + /// Set to `true` if `IGNORE NULLS` is specified. + ignore_nulls: bool, +} + +impl<'a> PartitionEvaluatorArgs<'a> { + /// Create an instance of [`PartitionEvaluatorArgs`]. + /// + /// # Arguments + /// + /// * `input_exprs` - The expressions passed as arguments + /// to the user-defined window function. + /// * `input_types` - The data types corresponding to the + /// arguments to the user-defined window function. + /// * `is_reversed` - Set to `true` if and only if the user-defined + /// window function is reversible and is reversed. + /// * `ignore_nulls` - Set to `true` when `IGNORE NULLS` is + /// specified. + /// + pub fn new( + input_exprs: &'a [Arc], + input_types: &'a [DataType], + is_reversed: bool, + ignore_nulls: bool, + ) -> Self { + Self { + input_exprs, + input_types, + is_reversed, + ignore_nulls, + } + } + + /// Returns the expressions passed as arguments to the user-defined + /// window function. + pub fn input_exprs(&self) -> &'a [Arc] { + self.input_exprs + } + + /// Returns the [`DataType`]s corresponding to the input expressions + /// to the user-defined window function. + pub fn input_types(&self) -> &'a [DataType] { + self.input_types + } + + /// Returns `true` when the user-defined window function is + /// reversed, otherwise returns `false`. + pub fn is_reversed(&self) -> bool { + self.is_reversed + } + + /// Returns `true` when `IGNORE NULLS` is specified, otherwise + /// returns `false`. + pub fn ignore_nulls(&self) -> bool { + self.ignore_nulls + } +} diff --git a/datafusion/functions-window/src/dense_rank.rs b/datafusion/functions-window/src/dense_rank.rs new file mode 100644 index 000000000000..c969a7c46fc6 --- /dev/null +++ b/datafusion/functions-window/src/dense_rank.rs @@ -0,0 +1,205 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `dense_rank` window function implementation + +use std::any::Any; +use std::fmt::Debug; +use std::iter; +use std::ops::Range; +use std::sync::Arc; + +use crate::define_udwf_and_expr; +use crate::rank::RankState; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::array::UInt64Array; +use datafusion_common::arrow::compute::SortOptions; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; +use datafusion_common::utils::get_row_at_idx; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; + +define_udwf_and_expr!( + DenseRank, + dense_rank, + "Returns rank of the current row without gaps. This function counts peer groups" +); + +/// dense_rank expression +#[derive(Debug)] +pub struct DenseRank { + signature: Signature, +} + +impl DenseRank { + /// Create a new `dense_rank` function + pub fn new() -> Self { + Self { + signature: Signature::any(0, Volatility::Immutable), + } + } +} + +impl Default for DenseRank { + fn default() -> Self { + Self::new() + } +} + +impl WindowUDFImpl for DenseRank { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dense_rank" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + Ok(Box::::default()) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false)) + } + + fn sort_options(&self) -> Option { + Some(SortOptions { + descending: false, + nulls_first: false, + }) + } +} + +/// State for the `dense_rank` built-in window function. +#[derive(Debug, Default)] +struct DenseRankEvaluator { + state: RankState, +} + +impl PartitionEvaluator for DenseRankEvaluator { + fn is_causal(&self) -> bool { + // The dense_rank function doesn't need "future" values to emit results: + true + } + + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &Range, + ) -> Result { + let row_idx = range.start; + // There is no argument, values are order by column values (where rank is calculated) + let range_columns = values; + let last_rank_data = get_row_at_idx(range_columns, row_idx)?; + let new_rank_encountered = + if let Some(state_last_rank_data) = &self.state.last_rank_data { + // if rank data changes, new rank is encountered + state_last_rank_data != &last_rank_data + } else { + // First rank seen + true + }; + + if new_rank_encountered { + self.state.last_rank_data = Some(last_rank_data); + self.state.last_rank_boundary += self.state.current_group_count; + self.state.current_group_count = 1; + self.state.n_rank += 1; + } else { + // data is still in the same rank + self.state.current_group_count += 1; + } + + Ok(ScalarValue::UInt64(Some(self.state.n_rank as u64))) + } + + fn evaluate_all_with_rank( + &self, + _num_rows: usize, + ranks_in_partition: &[Range], + ) -> Result { + let result = Arc::new(UInt64Array::from_iter_values( + ranks_in_partition + .iter() + .zip(1u64..) + .flat_map(|(range, rank)| { + let len = range.end - range.start; + iter::repeat(rank).take(len) + }), + )); + + Ok(result) + } + + fn supports_bounded_execution(&self) -> bool { + true + } + + fn include_rank(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::cast::as_uint64_array; + + fn test_with_rank(expr: &DenseRank, expected: Vec) -> Result<()> { + test_i32_result(expr, vec![0..2, 2..3, 3..6, 6..7, 7..8], expected) + } + + #[allow(clippy::single_range_in_vec_init)] + fn test_without_rank(expr: &DenseRank, expected: Vec) -> Result<()> { + test_i32_result(expr, vec![0..8], expected) + } + + fn test_i32_result( + expr: &DenseRank, + ranks: Vec>, + expected: Vec, + ) -> Result<()> { + let args = PartitionEvaluatorArgs::default(); + let result = expr + .partition_evaluator(args)? + .evaluate_all_with_rank(8, &ranks)?; + let result = as_uint64_array(&result)?; + let result = result.values(); + assert_eq!(expected, *result); + Ok(()) + } + + #[test] + fn test_dense_rank() -> Result<()> { + let r = DenseRank::default(); + test_without_rank(&r, vec![1; 8])?; + test_with_rank(&r, vec![1, 1, 2, 3, 3, 3, 4, 5])?; + Ok(()) + } +} diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs index 6e98bb091446..b72780990841 100644 --- a/datafusion/functions-window/src/lib.rs +++ b/datafusion/functions-window/src/lib.rs @@ -31,16 +31,27 @@ use datafusion_expr::WindowUDF; #[macro_use] pub mod macros; +pub mod dense_rank; +pub mod percent_rank; +pub mod rank; pub mod row_number; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::dense_rank::dense_rank; + pub use super::percent_rank::percent_rank; + pub use super::rank::rank; pub use super::row_number::row_number; } /// Returns all default window functions pub fn all_default_window_functions() -> Vec> { - vec![row_number::row_number_udwf()] + vec![ + row_number::row_number_udwf(), + rank::rank_udwf(), + dense_rank::dense_rank_udwf(), + percent_rank::percent_rank_udwf(), + ] } /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all( diff --git a/datafusion/functions-window/src/macros.rs b/datafusion/functions-window/src/macros.rs index 843d8ecb38cc..e934f883b101 100644 --- a/datafusion/functions-window/src/macros.rs +++ b/datafusion/functions-window/src/macros.rs @@ -45,6 +45,7 @@ /// # /// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; /// # use datafusion_functions_window::get_or_init_udwf; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// # /// /// Defines the `simple_udwf()` user-defined window function. /// get_or_init_udwf!( @@ -80,6 +81,7 @@ /// # } /// # fn partition_evaluator( /// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } @@ -145,6 +147,8 @@ macro_rules! get_or_init_udwf { /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # use datafusion_functions_window::{create_udwf_expr, get_or_init_udwf}; /// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// /// # get_or_init_udwf!( /// # RowNumber, /// # row_number, @@ -193,6 +197,7 @@ macro_rules! get_or_init_udwf { /// # } /// # fn partition_evaluator( /// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } @@ -216,6 +221,7 @@ macro_rules! get_or_init_udwf { /// # use datafusion_common::arrow::datatypes::Field; /// # use datafusion_common::ScalarValue; /// # use datafusion_expr::{col, lit}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// # /// # get_or_init_udwf!(Lead, lead, "user-defined window function"); /// # @@ -278,6 +284,7 @@ macro_rules! get_or_init_udwf { /// # } /// # fn partition_evaluator( /// # &self, +/// # partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } @@ -355,6 +362,7 @@ macro_rules! create_udwf_expr { /// # /// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; /// # use datafusion_functions_window::{define_udwf_and_expr, get_or_init_udwf, create_udwf_expr}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// # /// /// 1. Defines the `simple_udwf()` user-defined window function. /// /// @@ -397,6 +405,7 @@ macro_rules! create_udwf_expr { /// # } /// # fn partition_evaluator( /// # &self, +/// # partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } @@ -415,6 +424,7 @@ macro_rules! create_udwf_expr { /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; /// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// # /// /// 1. Defines the `row_number_udwf()` user-defined window function. /// /// @@ -459,6 +469,7 @@ macro_rules! create_udwf_expr { /// # } /// # fn partition_evaluator( /// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } @@ -484,6 +495,7 @@ macro_rules! create_udwf_expr { /// # use datafusion_common::arrow::datatypes::Field; /// # use datafusion_common::ScalarValue; /// # use datafusion_expr::{col, lit}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// # /// /// 1. Defines the `lead_udwf()` user-defined window function. /// /// @@ -543,6 +555,7 @@ macro_rules! create_udwf_expr { /// # } /// # fn partition_evaluator( /// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } @@ -570,6 +583,7 @@ macro_rules! create_udwf_expr { /// # use datafusion_common::arrow::datatypes::Field; /// # use datafusion_common::ScalarValue; /// # use datafusion_expr::{col, lit}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// # /// /// 1. Defines the `lead_udwf()` user-defined window function. /// /// @@ -630,6 +644,7 @@ macro_rules! create_udwf_expr { /// # } /// # fn partition_evaluator( /// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } diff --git a/datafusion/functions-window/src/percent_rank.rs b/datafusion/functions-window/src/percent_rank.rs new file mode 100644 index 000000000000..147959f69be9 --- /dev/null +++ b/datafusion/functions-window/src/percent_rank.rs @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `percent_rank` window function implementation + +use std::any::Any; +use std::fmt::Debug; +use std::iter; +use std::ops::Range; +use std::sync::Arc; + +use crate::define_udwf_and_expr; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::array::Float64Array; +use datafusion_common::arrow::compute::SortOptions; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; + +define_udwf_and_expr!( + PercentRank, + percent_rank, + "Returns the relative rank of the current row: (rank - 1) / (total rows - 1)" +); + +/// percent_rank expression +#[derive(Debug)] +pub struct PercentRank { + signature: Signature, +} + +impl PercentRank { + /// Create a new `percent_rank` function + pub fn new() -> Self { + Self { + signature: Signature::any(0, Volatility::Immutable), + } + } +} + +impl Default for PercentRank { + fn default() -> Self { + Self::new() + } +} + +impl WindowUDFImpl for PercentRank { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "percent_rank" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + Ok(Box::::default()) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, false)) + } + + fn sort_options(&self) -> Option { + Some(SortOptions { + descending: false, + nulls_first: false, + }) + } +} + +/// State for the `percent_rank` built-in window function. +#[derive(Debug, Default)] +struct PercentRankEvaluator {} + +impl PartitionEvaluator for PercentRankEvaluator { + fn is_causal(&self) -> bool { + // The percent_rank function doesn't need "future" values to emit results: + false + } + + fn evaluate( + &mut self, + _values: &[ArrayRef], + _range: &Range, + ) -> Result { + exec_err!("Can not execute PERCENT_RANK in a streaming fashion") + } + + fn evaluate_all_with_rank( + &self, + num_rows: usize, + ranks_in_partition: &[Range], + ) -> Result { + let denominator = num_rows as f64; + let result = + // Returns the relative rank of the current row, that is (rank - 1) / (total partition rows - 1). The value thus ranges from 0 to 1 inclusive. + Arc::new(Float64Array::from_iter_values( + ranks_in_partition + .iter() + .scan(0_u64, |acc, range| { + let len = range.end - range.start; + let value = (*acc as f64) / (denominator - 1.0).max(1.0); + let result = iter::repeat(value).take(len); + *acc += len as u64; + Some(result) + }) + .flatten(), + )); + + Ok(result) + } + + fn supports_bounded_execution(&self) -> bool { + false + } + + fn include_rank(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::cast::as_float64_array; + + fn test_f64_result( + expr: &PercentRank, + num_rows: usize, + ranks: Vec>, + expected: Vec, + ) -> Result<()> { + let args = PartitionEvaluatorArgs::default(); + let result = expr + .partition_evaluator(args)? + .evaluate_all_with_rank(num_rows, &ranks)?; + let result = as_float64_array(&result)?; + let result = result.values(); + assert_eq!(expected, *result); + Ok(()) + } + + #[test] + #[allow(clippy::single_range_in_vec_init)] + fn test_percent_rank() -> Result<()> { + let r = PercentRank::default(); + + // empty case + let expected = vec![0.0; 0]; + test_f64_result(&r, 0, vec![0..0; 0], expected)?; + + // singleton case + let expected = vec![0.0]; + test_f64_result(&r, 1, vec![0..1], expected)?; + + // uniform case + let expected = vec![0.0; 7]; + test_f64_result(&r, 7, vec![0..7], expected)?; + + // non-trivial case + let expected = vec![0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5]; + test_f64_result(&r, 7, vec![0..3, 3..7], expected)?; + + Ok(()) + } +} diff --git a/datafusion/functions-window/src/rank.rs b/datafusion/functions-window/src/rank.rs new file mode 100644 index 000000000000..c52dec9061ba --- /dev/null +++ b/datafusion/functions-window/src/rank.rs @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `rank` window function implementation + +use std::any::Any; +use std::fmt::Debug; +use std::iter; +use std::ops::Range; +use std::sync::Arc; + +use crate::define_udwf_and_expr; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::array::UInt64Array; +use datafusion_common::arrow::compute::SortOptions; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; +use datafusion_common::utils::get_row_at_idx; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; + +define_udwf_and_expr!( + Rank, + rank, + "Returns rank of the current row with gaps. Same as `row_number` of its first peer" +); + +/// rank expression +#[derive(Debug)] +pub struct Rank { + signature: Signature, +} + +impl Rank { + /// Create a new `rank` function + pub fn new() -> Self { + Self { + signature: Signature::any(0, Volatility::Immutable), + } + } +} + +impl Default for Rank { + fn default() -> Self { + Self::new() + } +} + +impl WindowUDFImpl for Rank { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "rank" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + Ok(Box::::default()) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false)) + } + + fn sort_options(&self) -> Option { + Some(SortOptions { + descending: false, + nulls_first: false, + }) + } +} + +/// State for the RANK(rank) built-in window function. +#[derive(Debug, Clone, Default)] +pub struct RankState { + /// The last values for rank as these values change, we increase n_rank + pub last_rank_data: Option>, + /// The index where last_rank_boundary is started + pub last_rank_boundary: usize, + /// Keep the number of entries in current rank + pub current_group_count: usize, + /// Rank number kept from the start + pub n_rank: usize, +} + +/// State for the `rank` built-in window function. +#[derive(Debug, Default)] +struct RankEvaluator { + state: RankState, +} + +impl PartitionEvaluator for RankEvaluator { + fn is_causal(&self) -> bool { + // The rank function doesn't need "future" values to emit results: + true + } + + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &Range, + ) -> Result { + let row_idx = range.start; + // There is no argument, values are order by column values (where rank is calculated) + let range_columns = values; + let last_rank_data = get_row_at_idx(range_columns, row_idx)?; + let new_rank_encountered = + if let Some(state_last_rank_data) = &self.state.last_rank_data { + // if rank data changes, new rank is encountered + state_last_rank_data != &last_rank_data + } else { + // First rank seen + true + }; + if new_rank_encountered { + self.state.last_rank_data = Some(last_rank_data); + self.state.last_rank_boundary += self.state.current_group_count; + self.state.current_group_count = 1; + self.state.n_rank += 1; + } else { + // data is still in the same rank + self.state.current_group_count += 1; + } + + Ok(ScalarValue::UInt64(Some( + self.state.last_rank_boundary as u64 + 1, + ))) + } + + fn evaluate_all_with_rank( + &self, + _num_rows: usize, + ranks_in_partition: &[Range], + ) -> Result { + let result = Arc::new(UInt64Array::from_iter_values( + ranks_in_partition + .iter() + .scan(1_u64, |acc, range| { + let len = range.end - range.start; + let result = iter::repeat(*acc).take(len); + *acc += len as u64; + Some(result) + }) + .flatten(), + )); + + Ok(result) + } + + fn supports_bounded_execution(&self) -> bool { + true + } + + fn include_rank(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::cast::as_uint64_array; + + fn test_with_rank(expr: &Rank, expected: Vec) -> Result<()> { + test_i32_result(expr, vec![0..2, 2..3, 3..6, 6..7, 7..8], expected) + } + + #[allow(clippy::single_range_in_vec_init)] + fn test_without_rank(expr: &Rank, expected: Vec) -> Result<()> { + test_i32_result(expr, vec![0..8], expected) + } + + fn test_i32_result( + expr: &Rank, + ranks: Vec>, + expected: Vec, + ) -> Result<()> { + let args = PartitionEvaluatorArgs::default(); + let result = expr + .partition_evaluator(args)? + .evaluate_all_with_rank(8, &ranks)?; + let result = as_uint64_array(&result)?; + let result = result.values(); + assert_eq!(expected, *result); + Ok(()) + } + + #[test] + fn test_rank() -> Result<()> { + let r = Rank::default(); + test_without_rank(&r, vec![1; 8])?; + test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?; + Ok(()) + } +} diff --git a/datafusion/functions-window/src/row_number.rs b/datafusion/functions-window/src/row_number.rs index c903f6778ae8..56af14fb84ae 100644 --- a/datafusion/functions-window/src/row_number.rs +++ b/datafusion/functions-window/src/row_number.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expression for `row_number` that can evaluated at runtime during query execution +//! `row_number` window function implementation use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::array::UInt64Array; @@ -28,6 +28,7 @@ use datafusion_expr::{ Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use field::WindowUDFFieldArgs; use std::any::Any; use std::fmt::Debug; @@ -89,7 +90,10 @@ impl WindowUDFImpl for RowNumber { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { Ok(Box::::default()) } @@ -162,7 +166,7 @@ mod tests { let num_rows = values.len(); let actual = RowNumber::default() - .partition_evaluator()? + .partition_evaluator(PartitionEvaluatorArgs::default())? .evaluate_all(&[values], num_rows)?; let actual = as_uint64_array(&actual)?; @@ -178,7 +182,7 @@ mod tests { let num_rows = values.len(); let actual = RowNumber::default() - .partition_evaluator()? + .partition_evaluator(PartitionEvaluatorArgs::default())? .evaluate_all(&[values], num_rows)?; let actual = as_uint64_array(&actual)?; diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 1c69f9c9b2f3..cf64c03766cb 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -107,5 +107,6 @@ pub fn functions() -> Vec> { get_field(), coalesce(), version(), + r#struct(), ] } diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index bdddbb81beab..5eea7dd48de8 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -57,6 +57,7 @@ fn struct_expr(args: &[ColumnarValue]) -> Result { #[derive(Debug)] pub struct StructFunc { signature: Signature, + aliases: Vec, } impl Default for StructFunc { @@ -69,6 +70,7 @@ impl StructFunc { pub fn new() -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec![String::from("row")], } } } @@ -81,6 +83,10 @@ impl ScalarUDFImpl for StructFunc { "struct" } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn signature(&self) -> &Signature { &self.signature } diff --git a/datafusion/functions/src/crypto/digest.rs b/datafusion/functions/src/crypto/digest.rs index c9dd3c1f56a2..9ec07b1cab53 100644 --- a/datafusion/functions/src/crypto/digest.rs +++ b/datafusion/functions/src/crypto/digest.rs @@ -19,10 +19,12 @@ use super::basic::{digest, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, TypeSignature::*, Volatility, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::*, Volatility, }; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct DigestFunc { @@ -69,4 +71,48 @@ impl ScalarUDFImpl for DigestFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { digest(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_digest_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_digest_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description( + "Computes the binary hash of an expression using the specified algorithm.", + ) + .with_syntax_example("digest(expression, algorithm)") + .with_sql_example( + r#"```sql +> select digest('foo', 'sha256'); ++------------------------------------------+ +| digest(Utf8("foo"), Utf8("sha256")) | ++------------------------------------------+ +| | ++------------------------------------------+ +```"#, + ) + .with_standard_argument( + "expression", "String") + .with_argument( + "algorithm", + "String expression specifying algorithm to use. Must be one of: + +- md5 +- sha224 +- sha256 +- sha384 +- sha512 +- blake2s +- blake2b +- blake3", + ) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index ccb6fbba80aa..f273c9d28c23 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -19,8 +19,12 @@ use crate::crypto::basic::md5; use arrow::datatypes::DataType; use datafusion_common::{plan_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct Md5Func { @@ -84,4 +88,32 @@ impl ScalarUDFImpl for Md5Func { fn invoke(&self, args: &[ColumnarValue]) -> Result { md5(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_md5_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_md5_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes an MD5 128-bit checksum for a string expression.") + .with_syntax_example("md5(expression)") + .with_sql_example( + r#"```sql +> select md5('foo'); ++-------------------------------------+ +| md5(Utf8("foo")) | ++-------------------------------------+ +| | ++-------------------------------------+ +```"#, + ) + .with_standard_argument("expression", "String") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/crypto/sha224.rs b/datafusion/functions/src/crypto/sha224.rs index df3045f22cf5..868c8cdc3558 100644 --- a/datafusion/functions/src/crypto/sha224.rs +++ b/datafusion/functions/src/crypto/sha224.rs @@ -58,8 +58,17 @@ fn get_sha224_doc() -> &'static Documentation { .with_doc_section(DOC_SECTION_HASHING) .with_description("Computes the SHA-224 hash of a binary string.") .with_syntax_example("sha224(expression)") - .with_argument("expression", - "String expression to operate on. Can be a constant, column, or function, and any combination of string operators.") + .with_sql_example( + r#"```sql +> select sha224('foo'); ++------------------------------------------+ +| sha224(Utf8("foo")) | ++------------------------------------------+ +| | ++------------------------------------------+ +```"#, + ) + .with_standard_argument("expression", "String") .build() .unwrap() }) diff --git a/datafusion/functions/src/crypto/sha256.rs b/datafusion/functions/src/crypto/sha256.rs index 0a3f3b26e431..99a470efbc1f 100644 --- a/datafusion/functions/src/crypto/sha256.rs +++ b/datafusion/functions/src/crypto/sha256.rs @@ -19,8 +19,12 @@ use super::basic::{sha256, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct SHA256Func { @@ -60,7 +64,36 @@ impl ScalarUDFImpl for SHA256Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { sha256(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_sha256_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_sha256_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes the SHA-256 hash of a binary string.") + .with_syntax_example("sha256(expression)") + .with_sql_example( + r#"```sql +> select sha256('foo'); ++--------------------------------------+ +| sha256(Utf8("foo")) | ++--------------------------------------+ +| | ++--------------------------------------+ +```"#, + ) + .with_standard_argument("expression", "String") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/crypto/sha384.rs b/datafusion/functions/src/crypto/sha384.rs index c3f7845ce7bd..afe2db7478f7 100644 --- a/datafusion/functions/src/crypto/sha384.rs +++ b/datafusion/functions/src/crypto/sha384.rs @@ -19,8 +19,12 @@ use super::basic::{sha384, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct SHA384Func { @@ -60,7 +64,36 @@ impl ScalarUDFImpl for SHA384Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { sha384(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_sha384_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_sha384_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes the SHA-384 hash of a binary string.") + .with_syntax_example("sha384(expression)") + .with_sql_example( + r#"```sql +> select sha384('foo'); ++-----------------------------------------+ +| sha384(Utf8("foo")) | ++-----------------------------------------+ +| | ++-----------------------------------------+ +```"#, + ) + .with_standard_argument("expression", "String") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/crypto/sha512.rs b/datafusion/functions/src/crypto/sha512.rs index dc3bfac9d8bd..c88579fd08ee 100644 --- a/datafusion/functions/src/crypto/sha512.rs +++ b/datafusion/functions/src/crypto/sha512.rs @@ -19,8 +19,12 @@ use super::basic::{sha512, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct SHA512Func { @@ -60,7 +64,36 @@ impl ScalarUDFImpl for SHA512Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { sha512(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_sha512_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_sha512_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes the SHA-512 hash of a binary string.") + .with_syntax_example("sha512(expression)") + .with_sql_example( + r#"```sql +> select sha512('foo'); ++-------------------------------------------+ +| sha512(Utf8("foo")) | ++-------------------------------------------+ +| | ++-------------------------------------------+ +```"#, + ) + .with_argument("expression", "String") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 176d7f8bbcbf..2803fd042b99 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -111,10 +111,7 @@ Note: `to_date` returns Date32, which represents its values as the number of day Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) "#) - .with_argument( - "expression", - "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators.", - ) + .with_standard_argument("expression", "String") .with_argument( "format_n", "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index bb680f3c67de..81be5552666d 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -92,9 +92,6 @@ pub mod macros; pub mod string; make_stub_package!(string, "string_expressions"); -#[cfg(feature = "string_expressions")] -mod regexp_common; - /// Core datafusion expressions /// Enabled via feature flag `core_expressions` #[cfg(feature = "core_expressions")] diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index e47818bc86a4..e850673ef8af 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -284,7 +284,7 @@ macro_rules! make_math_binary_udf { use arrow::datatypes::DataType; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; - use datafusion_expr::TypeSignature::*; + use datafusion_expr::TypeSignature; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] @@ -298,8 +298,8 @@ macro_rules! make_math_binary_udf { Self { signature: Signature::one_of( vec![ - Exact(vec![Float32, Float32]), - Exact(vec![Float64, Float64]), + TypeSignature::Exact(vec![Float32, Float32]), + TypeSignature::Exact(vec![Float64, Float64]), ], Volatility::Immutable, ), diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 889e3761d26c..07ff8e2166ff 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -57,10 +57,8 @@ fn get_log_doc() -> &'static Documentation { .with_description("Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number.") .with_syntax_example(r#"log(base, numeric_expression) log(numeric_expression)"#) - .with_argument("base", - "Base numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators.") - .with_argument("numeric_expression", - "Numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators.") + .with_standard_argument("base", "Base numeric") + .with_standard_argument("numeric_expression", "Numeric") .build() .unwrap() }) diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index 2bd704a7de2e..b02839b40bd9 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -19,10 +19,9 @@ use arrow::datatypes::DataType; use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, TypeSignature}; use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array}; -use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; @@ -43,7 +42,10 @@ impl IsNanFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Float32]), Exact(vec![Float64])], + vec![ + TypeSignature::Exact(vec![Float32]), + TypeSignature::Exact(vec![Float64]), + ], Volatility::Immutable, ), } diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 5b790fb56ddf..831f983d5916 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -25,10 +25,9 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDF}; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDF, TypeSignature}; use arrow::array::{ArrayRef, Float64Array, Int64Array}; -use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; @@ -52,7 +51,10 @@ impl PowerFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], + vec![ + TypeSignature::Exact(vec![Int64, Int64]), + TypeSignature::Exact(vec![Float64, Float64]), + ], Volatility::Immutable, ), aliases: vec![String::from("pow")], diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 2abb4a9376c5..a698913fff54 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -16,7 +16,7 @@ // under the License. //! Regx expressions -use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; +use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; use datafusion_common::exec_err; @@ -26,8 +26,7 @@ use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::{Arc, OnceLock}; @@ -67,10 +66,8 @@ SELECT regexp_like('aBc', '(b|d)', 'i'); ``` Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) "#) - .with_argument("str", - "String expression to operate on. Can be a constant, column, or function, and any combination of string operators.") - .with_argument("regexp", - "Regular expression to test against the string expression. Can be a constant, column, or function.") + .with_standard_argument("str", "String") + .with_standard_argument("regexp","Regular") .with_argument("flags", r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - **i**: case-insensitive: letters match both upper and lower case @@ -89,10 +86,10 @@ impl RegexpLikeFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), + TypeSignature::Exact(vec![Utf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), ], Volatility::Immutable, ), @@ -208,7 +205,8 @@ pub fn regexp_like(args: &[ArrayRef]) -> Result { 2 => { let values = as_generic_string_array::(&args[0])?; let regex = as_generic_string_array::(&args[1])?; - let array = regexp::regexp_is_match_utf8(values, regex, None) + let flags: Option<&GenericStringArray> = None; + let array = regexp::regexp_is_match(values, regex, flags) .map_err(|e| arrow_datafusion_err!(e))?; Ok(Arc::new(array) as ArrayRef) @@ -222,7 +220,7 @@ pub fn regexp_like(args: &[ArrayRef]) -> Result { return plan_err!("regexp_like() does not support the \"global\" option"); } - let array = regexp::regexp_is_match_utf8(values, regex, Some(flags)) + let array = regexp::regexp_is_match(values, regex, Some(flags)) .map_err(|e| arrow_datafusion_err!(e))?; Ok(Arc::new(array) as ArrayRef) diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index 498b591620ee..bfec97f92c36 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -26,8 +26,7 @@ use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; @@ -53,10 +52,10 @@ impl RegexpMatchFunc { // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8, Utf8)`. // If that fails, it proceeds to `(LargeUtf8, Utf8)`. // TODO: Native support Utf8View for regexp_match. - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), + TypeSignature::Exact(vec![Utf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), ], Volatility::Immutable, ), diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 3eb72a1fb5f5..bce8752af28b 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -33,7 +33,7 @@ use datafusion_common::{ }; use datafusion_expr::function::Hint; use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; +use datafusion_expr::TypeSignature; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use regex::Regex; use std::any::Any; @@ -56,10 +56,10 @@ impl RegexpReplaceFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![Utf8View, Utf8, Utf8]), - Exact(vec![Utf8, Utf8, Utf8, Utf8]), - Exact(vec![Utf8View, Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8, Utf8, Utf8]), ], Volatility::Immutable, ), diff --git a/datafusion/functions/src/regexp_common.rs b/datafusion/functions/src/regexp_common.rs deleted file mode 100644 index 748c1a294f97..000000000000 --- a/datafusion/functions/src/regexp_common.rs +++ /dev/null @@ -1,123 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Common utilities for implementing regex functions - -use crate::string::common::StringArrayType; - -use arrow::array::{Array, ArrayDataBuilder, BooleanArray}; -use arrow::datatypes::DataType; -use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; -use datafusion_common::DataFusionError; -use regex::Regex; - -use std::collections::HashMap; - -#[cfg(doc)] -use arrow::array::{LargeStringArray, StringArray, StringViewArray}; -/// Perform SQL `array ~ regex_array` operation on -/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. -/// -/// If `regex_array` element has an empty value, the corresponding result value is always true. -/// -/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, -/// which allow special search modes, such as case-insensitive and multi-line mode. -/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) -/// for more information. -/// -/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs]. -/// -/// Can remove when is implemented upstream -/// -/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37 -pub fn regexp_is_match_utf8<'a, S1, S2, S3>( - array: &'a S1, - regex_array: &'a S2, - flags_array: Option<&'a S3>, -) -> datafusion_common::Result -where - &'a S1: StringArrayType<'a>, - &'a S2: StringArrayType<'a>, - &'a S3: StringArrayType<'a>, -{ - if array.len() != regex_array.len() { - return Err(DataFusionError::Execution( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - - let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); - - let mut patterns: HashMap = HashMap::new(); - let mut result = BooleanBufferBuilder::new(array.len()); - - let complete_pattern = match flags_array { - Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( - |(pattern, flags)| { - pattern.map(|pattern| match flags { - Some(flag) => format!("(?{flag}){pattern}"), - None => pattern.to_string(), - }) - }, - )) as Box>>, - None => Box::new( - regex_array - .iter() - .map(|pattern| pattern.map(|pattern| pattern.to_string())), - ), - }; - - array - .iter() - .zip(complete_pattern) - .map(|(value, pattern)| { - match (value, pattern) { - (Some(_), Some(pattern)) if pattern == *"" => { - result.append(true); - } - (Some(value), Some(pattern)) => { - let existing_pattern = patterns.get(&pattern); - let re = match existing_pattern { - Some(re) => re, - None => { - let re = Regex::new(pattern.as_str()).map_err(|e| { - DataFusionError::Execution(format!( - "Regular expression did not compile: {e:?}" - )) - })?; - patterns.entry(pattern).or_insert(re) - } - }; - result.append(re.is_match(value)); - } - _ => result.append(false), - } - Ok(()) - }) - .collect::, DataFusionError>>()?; - - let data = unsafe { - ArrayDataBuilder::new(DataType::Boolean) - .len(array.len()) - .buffers(vec![result.into()]) - .nulls(nulls) - .build_unchecked() - }; - - Ok(BooleanArray::from(data)) -} diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index d01c6631e9dd..8d61661f97b8 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -26,24 +26,6 @@ use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::{Arc, OnceLock}; -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_ascii_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Returns the ASCII value of the first character in a string.") - .with_syntax_example("ascii(str)") - .with_argument( - "str", - "String expression to operate on. Can be a constant, column, or function that evaluates to or can be coerced to a Utf8, LargeUtf8 or a Utf8View.", - ) - .with_related_udf("chr") - .build() - .unwrap() - }) -} - #[derive(Debug)] pub struct AsciiFunc { signature: Signature, @@ -57,13 +39,8 @@ impl Default for AsciiFunc { impl AsciiFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8, Utf8View], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -96,6 +73,39 @@ impl ScalarUDFImpl for AsciiFunc { } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_ascii_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Returns the Unicode character code of the first character in a string.", + ) + .with_syntax_example("ascii(str)") + .with_sql_example( + r#"```sql +> select ascii('abc'); ++--------------------+ +| ascii(Utf8("abc")) | ++--------------------+ +| 97 | ++--------------------+ +> select ascii('🚀'); ++-------------------+ +| ascii(Utf8("🚀")) | ++-------------------+ +| 128640 | ++-------------------+ +```"#, + ) + .with_standard_argument("str", "String") + .with_related_udf("chr") + .build() + .unwrap() + }) +} + fn calculate_ascii<'a, V>(array: V) -> Result where V: ArrayAccessor, diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 65ec1a4a7734..7d162e7d411b 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -15,17 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::compute::kernels::length::bit_length; use arrow::datatypes::DataType; +use std::any::Any; +use std::sync::OnceLock; +use crate::utils::utf8_to_int_type; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::utf8_to_int_type; - #[derive(Debug)] pub struct BitLengthFunc { signature: Signature, @@ -39,13 +39,8 @@ impl Default for BitLengthFunc { impl BitLengthFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -88,4 +83,34 @@ impl ScalarUDFImpl for BitLengthFunc { }, } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_bit_length_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_bit_length_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the bit length of a string.") + .with_syntax_example("bit_length(str)") + .with_sql_example( + r#"```sql +> select bit_length('datafusion'); ++--------------------------------+ +| bit_length(Utf8("datafusion")) | ++--------------------------------+ +| 80 | ++--------------------------------+ +```"#, + ) + .with_standard_argument("str", "String") + .with_related_udf("length") + .with_related_udf("octet_length") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index 0e992ff27fd3..82b7599f0735 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -15,18 +15,18 @@ // specific language governing permissions and limitations // under the License. +use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; -use std::any::Any; - use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; - -use crate::string::common::*; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use std::any::Any; +use std::sync::OnceLock; /// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. /// btrim('xyxtrimyyx', 'xyz') = 'trim' @@ -49,18 +49,9 @@ impl Default for BTrimFunc { impl BTrimFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![ - // Planner attempts coercion to the target type starting with the most preferred candidate. - // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`. - // If that fails, it proceeds to `(Utf8, Utf8)`. - Exact(vec![Utf8View, Utf8View]), - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8View]), - Exact(vec![Utf8]), - ], + vec![TypeSignature::String(2), TypeSignature::String(1)], Volatility::Immutable, ), aliases: vec![String::from("trim")], @@ -109,6 +100,35 @@ impl ScalarUDFImpl for BTrimFunc { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_btrim_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_btrim_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string.") + .with_syntax_example("btrim(str[, trim_str])") + .with_sql_example(r#"```sql +> select btrim('__datafusion____', '_'); ++-------------------------------------------+ +| btrim(Utf8("__datafusion____"),Utf8("_")) | ++-------------------------------------------+ +| datafusion | ++-------------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_argument("trim_str", "String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._") + .with_related_udf("ltrim") + .with_related_udf("rtrim") + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index 4da7dc01594d..ae0900af37d3 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::ArrayRef; use arrow::array::StringArray; @@ -24,13 +24,13 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use arrow::datatypes::DataType::Utf8; +use crate::utils::make_scalar_function; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::make_scalar_function; - /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' pub fn chr(args: &[ArrayRef]) -> Result { @@ -99,4 +99,35 @@ impl ScalarUDFImpl for ChrFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(chr, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_chr_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_chr_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Returns the character with the specified ASCII or Unicode code value.", + ) + .with_syntax_example("chr(expression)") + .with_sql_example( + r#"```sql +> select chr(128640); ++--------------------+ +| chr(Int64(128640)) | ++--------------------+ +| 🚀 | ++--------------------+ +```"#, + ) + .with_standard_argument("expression", "String") + .with_related_udf("ascii") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 98f57efef90d..228fcd460c97 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -18,18 +18,18 @@ use arrow::array::{as_largestring_array, Array}; use arrow::datatypes::DataType; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; +use crate::string::common::*; +use crate::string::concat; use datafusion_common::cast::{as_string_array, as_string_view_array}; use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{lit, ColumnarValue, Expr, Volatility}; +use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; -use crate::string::concat; - #[derive(Debug)] pub struct ConcatFunc { signature: Signature, @@ -244,6 +244,36 @@ impl ScalarUDFImpl for ConcatFunc { ) -> Result { simplify_concat(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_concat_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_concat_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Concatenates multiple strings together.") + .with_syntax_example("concat(str[, ..., str_n])") + .with_sql_example( + r#"```sql +> select concat('data', 'f', 'us', 'ion'); ++-------------------------------------------------------+ +| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) | ++-------------------------------------------------------+ +| datafusion | ++-------------------------------------------------------+ +```"#, + ) + .with_standard_argument("str", "String") + .with_argument("str_n", "Subsequent string expressions to concatenate.") + .with_related_udf("concat_ws") + .build() + .unwrap() + }) } pub fn simplify_concat(args: Vec) -> Result { diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 1134c525cfca..a20cbf1a16f5 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -17,7 +17,7 @@ use arrow::array::{as_largestring_array, Array, StringArray}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::datatypes::DataType; @@ -27,8 +27,9 @@ use crate::string::concat_ws; use datafusion_common::cast::{as_string_array, as_string_view_array}; use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{lit, ColumnarValue, Expr, Volatility}; +use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; #[derive(Debug)] @@ -264,6 +265,45 @@ impl ScalarUDFImpl for ConcatWsFunc { _ => Ok(ExprSimplifyResult::Original(args)), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_concat_ws_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_concat_ws_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Concatenates multiple strings together with a specified separator.", + ) + .with_syntax_example("concat_ws(separator, str[, ..., str_n])") + .with_sql_example( + r#"```sql +> select concat_ws('_', 'data', 'fusion'); ++--------------------------------------------------+ +| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) | ++--------------------------------------------------+ +| data_fusion | ++--------------------------------------------------+ +```"#, + ) + .with_argument( + "separator", + "Separator to insert between concatenated strings.", + ) + .with_standard_argument("str", "String") + .with_standard_argument( + "str_n", + "Subsequent string expressions to concatenate.", + ) + .with_related_udf("concat") + .build() + .unwrap() + }) } fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index c319f80661c3..0f75731aa1c3 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -15,21 +15,20 @@ // specific language governing permissions and limitations // under the License. -use crate::regexp_common::regexp_is_match_utf8; use crate::utils::make_scalar_function; - use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringViewArray}; +use arrow::compute::regexp_is_match; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View}; use datafusion_common::exec_err; use datafusion_common::DataFusionError; use datafusion_common::Result; -use datafusion_expr::ScalarUDFImpl; -use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, Signature, Volatility}; - +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct ContainsFunc { @@ -44,22 +43,8 @@ impl Default for ContainsFunc { impl ContainsFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - Exact(vec![Utf8View, Utf8View]), - Exact(vec![Utf8View, Utf8]), - Exact(vec![Utf8View, LargeUtf8]), - Exact(vec![Utf8, Utf8View]), - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8View]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(2, Volatility::Immutable), } } } @@ -84,6 +69,37 @@ impl ScalarUDFImpl for ContainsFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(contains, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_contains_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_contains_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Return true if search_str is found within string (case-sensitive).", + ) + .with_syntax_example("contains(str, search_str)") + .with_sql_example( + r#"```sql +> select contains('the quick brown fox', 'row'); ++---------------------------------------------------+ +| contains(Utf8("the quick brown fox"),Utf8("row")) | ++---------------------------------------------------+ +| true | ++---------------------------------------------------+ +```"#, + ) + .with_standard_argument("str", "String") + .with_argument("search_str", "The string to search for in str.") + .build() + .unwrap() + }) } /// use regexp_is_match_utf8_scalar to do the calculation for contains @@ -92,41 +108,8 @@ pub fn contains(args: &[ArrayRef]) -> Result { (Utf8View, Utf8View) => { let mod_str = args[0].as_string_view(); let match_str = args[1].as_string_view(); - let res = regexp_is_match_utf8::< - StringViewArray, + let res = regexp_is_match::< StringViewArray, - GenericStringArray, - >(mod_str, match_str, None)?; - - Ok(Arc::new(res) as ArrayRef) - } - (Utf8View, Utf8) => { - let mod_str = args[0].as_string_view(); - let match_str = args[1].as_string::(); - let res = regexp_is_match_utf8::< - StringViewArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - - Ok(Arc::new(res) as ArrayRef) - } - (Utf8View, LargeUtf8) => { - let mod_str = args[0].as_string_view(); - let match_str = args[1].as_string::(); - let res = regexp_is_match_utf8::< - StringViewArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - - Ok(Arc::new(res) as ArrayRef) - } - (Utf8, Utf8View) => { - let mod_str = args[0].as_string::(); - let match_str = args[1].as_string_view(); - let res = regexp_is_match_utf8::< - GenericStringArray, StringViewArray, GenericStringArray, >(mod_str, match_str, None)?; @@ -136,7 +119,7 @@ pub fn contains(args: &[ArrayRef]) -> Result { (Utf8, Utf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match_utf8::< + let res = regexp_is_match::< GenericStringArray, GenericStringArray, GenericStringArray, @@ -144,43 +127,10 @@ pub fn contains(args: &[ArrayRef]) -> Result { Ok(Arc::new(res) as ArrayRef) } - (Utf8, LargeUtf8) => { - let mod_str = args[0].as_string::(); - let match_str = args[1].as_string::(); - let res = regexp_is_match_utf8::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - - Ok(Arc::new(res) as ArrayRef) - } - (LargeUtf8, Utf8View) => { - let mod_str = args[0].as_string::(); - let match_str = args[1].as_string_view(); - let res = regexp_is_match_utf8::< - GenericStringArray, - StringViewArray, - GenericStringArray, - >(mod_str, match_str, None)?; - - Ok(Arc::new(res) as ArrayRef) - } - (LargeUtf8, Utf8) => { - let mod_str = args[0].as_string::(); - let match_str = args[1].as_string::(); - let res = regexp_is_match_utf8::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - - Ok(Arc::new(res) as ArrayRef) - } (LargeUtf8, LargeUtf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match_utf8::< + let res = regexp_is_match::< GenericStringArray, GenericStringArray, GenericStringArray, @@ -193,95 +143,3 @@ pub fn contains(args: &[ArrayRef]) -> Result { } } } - -#[cfg(test)] -mod tests { - use crate::string::contains::ContainsFunc; - use crate::utils::test::test_function; - use arrow::array::Array; - use arrow::{array::BooleanArray, datatypes::DataType::Boolean}; - use datafusion_common::Result; - use datafusion_common::ScalarValue; - use datafusion_expr::ColumnarValue; - use datafusion_expr::ScalarUDFImpl; - #[test] - fn test_functions() -> Result<()> { - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("alph")), - ], - Ok(Some(true)), - bool, - Boolean, - BooleanArray - ); - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("dddddd")), - ], - Ok(Some(false)), - bool, - Boolean, - BooleanArray - ); - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("pha")), - ], - Ok(Some(true)), - bool, - Boolean, - BooleanArray - ); - - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "Apache" - )))), - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("pac")))), - ], - Ok(Some(true)), - bool, - Boolean, - BooleanArray - ); - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "Apache" - )))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ap")))), - ], - Ok(Some(false)), - bool, - Boolean, - BooleanArray - ); - test_function!( - ContainsFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "Apache" - )))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from( - "DataFusion" - )))), - ], - Ok(Some(false)), - bool, - Boolean, - BooleanArray - ); - - Ok(()) - } -} diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index 03a1795954d0..8c90cbc3b146 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -16,18 +16,17 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::ArrayRef; use arrow::datatypes::DataType; +use crate::utils::make_scalar_function; use datafusion_common::{internal_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::make_scalar_function; - #[derive(Debug)] pub struct EndsWithFunc { signature: Signature, @@ -42,17 +41,7 @@ impl Default for EndsWithFunc { impl EndsWithFunc { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - // Planner attempts coercion to the target type starting with the most preferred candidate. - // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`. - // If that fails, it proceeds to `(Utf8, Utf8)`. - Exact(vec![DataType::Utf8View, DataType::Utf8View]), - Exact(vec![DataType::Utf8, DataType::Utf8]), - Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(2, Volatility::Immutable), } } } @@ -84,6 +73,41 @@ impl ScalarUDFImpl for EndsWithFunc { } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_ends_with_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_ends_with_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Tests if a string ends with a substring.") + .with_syntax_example("ends_with(str, substr)") + .with_sql_example( + r#"```sql +> select ends_with('datafusion', 'soin'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("soin")) | ++--------------------------------------------+ +| false | ++--------------------------------------------+ +> select ends_with('datafusion', 'sion'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("sion")) | ++--------------------------------------------+ +| true | ++--------------------------------------------+ +```"#, + ) + .with_standard_argument("str", "String") + .with_argument("substr", "Substring to test for.") + .build() + .unwrap() + }) } /// Returns true if string ends with suffix. diff --git a/datafusion/functions/src/string/initcap.rs b/datafusion/functions/src/string/initcap.rs index 4e1eb213ef57..78c95b9a5e35 100644 --- a/datafusion/functions/src/string/initcap.rs +++ b/datafusion/functions/src/string/initcap.rs @@ -16,18 +16,18 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; use arrow::datatypes::DataType; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct InitcapFunc { signature: Signature, @@ -41,13 +41,8 @@ impl Default for InitcapFunc { impl InitcapFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8, Utf8View], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -79,6 +74,34 @@ impl ScalarUDFImpl for InitcapFunc { } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_initcap_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_initcap_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Capitalizes the first character in each word in the input string. Words are delimited by non-alphanumeric characters.") + .with_syntax_example("initcap(str)") + .with_sql_example(r#"```sql +> select initcap('apache datafusion'); ++------------------------------------+ +| initcap(Utf8("apache datafusion")) | ++------------------------------------+ +| Apache Datafusion | ++------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_related_udf("lower") + .with_related_udf("upper") + .build() + .unwrap() + }) } /// Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters. diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs index 430c402a50c5..558e71239f84 100644 --- a/datafusion/functions/src/string/levenshtein.rs +++ b/datafusion/functions/src/string/levenshtein.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; use arrow::datatypes::DataType; @@ -25,8 +25,8 @@ use crate::utils::{make_scalar_function, utf8_to_int_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::utils::datafusion_strsim; use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] @@ -43,14 +43,7 @@ impl Default for LevenshteinFunc { impl LevenshteinFunc { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - Exact(vec![DataType::Utf8View, DataType::Utf8View]), - Exact(vec![DataType::Utf8, DataType::Utf8]), - Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(2, Volatility::Immutable), } } } @@ -83,6 +76,33 @@ impl ScalarUDFImpl for LevenshteinFunc { } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_levenshtein_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_levenshtein_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the [`Levenshtein distance`](https://en.wikipedia.org/wiki/Levenshtein_distance) between the two given strings.") + .with_syntax_example("levenshtein(str1, str2)") + .with_sql_example(r#"```sql +> select levenshtein('kitten', 'sitting'); ++---------------------------------------------+ +| levenshtein(Utf8("kitten"),Utf8("sitting")) | ++---------------------------------------------+ +| 3 | ++---------------------------------------------+ +```"#) + .with_argument("str1", "String expression to compute Levenshtein distance with str2.") + .with_argument("str2", "String expression to compute Levenshtein distance with str1.") + .build() + .unwrap() + }) } ///Returns the Levenshtein distance between the two given strings. diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index ca324e69c0d2..f82b11ca9051 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; - -use datafusion_common::Result; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::OnceLock; use crate::string::common::to_lower; use crate::utils::utf8_to_str_type; +use datafusion_common::Result; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] pub struct LowerFunc { @@ -39,13 +39,8 @@ impl Default for LowerFunc { impl LowerFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8, Utf8View], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -70,8 +65,37 @@ impl ScalarUDFImpl for LowerFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { to_lower(args, "lower") } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_lower_doc()) + } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lower_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Converts a string to lower-case.") + .with_syntax_example("lower(str)") + .with_sql_example( + r#"```sql +> select lower('Ångström'); ++-------------------------+ +| lower(Utf8("Ångström")) | ++-------------------------+ +| ångström | ++-------------------------+ +```"#, + ) + .with_standard_argument("str", "String") + .with_related_udf("initcap") + .with_related_udf("upper") + .build() + .unwrap() + }) +} #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index 0ddb5a205bac..b64dcda7218e 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -15,20 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; +use std::any::Any; +use std::sync::OnceLock; +use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - /// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. /// ltrim('zzzytest', 'xyz') = 'test' fn ltrim(args: &[ArrayRef]) -> Result { @@ -49,18 +48,9 @@ impl Default for LtrimFunc { impl LtrimFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![ - // Planner attempts coercion to the target type starting with the most preferred candidate. - // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`. - // If that fails, it proceeds to `(Utf8, Utf8)`. - Exact(vec![Utf8View, Utf8View]), - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8View]), - Exact(vec![Utf8]), - ], + vec![TypeSignature::String(2), TypeSignature::String(1)], Volatility::Immutable, ), } @@ -104,6 +94,41 @@ impl ScalarUDFImpl for LtrimFunc { ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_ltrim_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_ltrim_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string.") + .with_syntax_example("ltrim(str[, trim_str])") + .with_sql_example(r#"```sql +> select ltrim(' datafusion '); ++-------------------------------+ +| ltrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select ltrim('___datafusion___', '_'); ++-------------------------------------------+ +| ltrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| datafusion___ | ++-------------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_argument("trim_str", "String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._") + .with_related_udf("btrim") + .with_related_udf("rtrim") + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index f792914d862e..04094396fadc 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -15,17 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::compute::kernels::length::length; use arrow::datatypes::DataType; +use std::any::Any; +use std::sync::OnceLock; +use crate::utils::utf8_to_int_type; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::utf8_to_int_type; - #[derive(Debug)] pub struct OctetLengthFunc { signature: Signature, @@ -39,13 +39,8 @@ impl Default for OctetLengthFunc { impl OctetLengthFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8, Utf8View], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -91,6 +86,36 @@ impl ScalarUDFImpl for OctetLengthFunc { }, } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_octet_length_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_octet_length_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the length of a string in bytes.") + .with_syntax_example("octet_length(str)") + .with_sql_example( + r#"```sql +> select octet_length('Ångström'); ++--------------------------------+ +| octet_length(Utf8("Ångström")) | ++--------------------------------+ +| 10 | ++--------------------------------+ +```"#, + ) + .with_standard_argument("str", "String") + .with_related_udf("bit_length") + .with_related_udf("length") + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index e285bd85b197..3b31bc360851 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -16,21 +16,20 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::DataType; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{ as_generic_string_array, as_int64_array, as_string_view_array, }; use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct OverlayFunc { signature: Signature, @@ -48,12 +47,12 @@ impl OverlayFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8View, Utf8View, Int64, Int64]), - Exact(vec![Utf8, Utf8, Int64, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), - Exact(vec![Utf8View, Utf8View, Int64]), - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), + TypeSignature::Exact(vec![Utf8View, Utf8View, Int64, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64, Int64]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + TypeSignature::Exact(vec![Utf8View, Utf8View, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]), ], Volatility::Immutable, ), @@ -87,6 +86,35 @@ impl ScalarUDFImpl for OverlayFunc { other => exec_err!("Unsupported data type {other:?} for function overlay"), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_overlay_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_overlay_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the string which is replaced by another string from the specified position and specified count length.") + .with_syntax_example("overlay(str PLACING substr FROM pos [FOR count])") + .with_sql_example(r#"```sql +> select overlay('Txxxxas' placing 'hom' from 2 for 4); ++--------------------------------------------------------+ +| overlay(Utf8("Txxxxas"),Utf8("hom"),Int64(2),Int64(4)) | ++--------------------------------------------------------+ +| Thomas | ++--------------------------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_argument("substr", "Substring to replace in str.") + .with_argument("pos", "The start position to start the replace in str.") + .with_argument("count", "The count of characters to be replaced from start position of str. If not specified, will use substr length instead.") + .build() + .unwrap() + }) } macro_rules! process_overlay { diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 20e4462784b8..fda9c7a13df6 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -16,24 +16,22 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; +use crate::string::common::StringArrayType; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, OffsetSizeTrait, StringViewArray, }; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Int64, LargeUtf8, Utf8, Utf8View}; - use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::StringArrayType; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct RepeatFunc { signature: Signature, @@ -53,9 +51,9 @@ impl RepeatFunc { // Planner attempts coercion to the target type starting with the most preferred candidate. // For example, given input `(Utf8View, Int64)`, it first tries coercing to `(Utf8View, Int64)`. // If that fails, it proceeds to `(Utf8, Int64)`. - Exact(vec![Utf8View, Int64]), - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), + TypeSignature::Exact(vec![Utf8View, Int64]), + TypeSignature::Exact(vec![Utf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, Int64]), ], Volatility::Immutable, ), @@ -83,6 +81,37 @@ impl ScalarUDFImpl for RepeatFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(repeat, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_repeat_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_repeat_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Returns a string with an input string repeated a specified number.", + ) + .with_syntax_example("repeat(str, n)") + .with_sql_example( + r#"```sql +> select repeat('data', 3); ++-------------------------------+ +| repeat(Utf8("data"),Int64(3)) | ++-------------------------------+ +| datadatadata | ++-------------------------------+ +```"#, + ) + .with_standard_argument("str", "String") + .with_argument("n", "Number of times to repeat the input string.") + .build() + .unwrap() + }) } /// Repeats string the specified number of times. diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 13fa3d55672d..612cd7276bab 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -16,19 +16,18 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; use arrow::datatypes::DataType; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct ReplaceFunc { signature: Signature, @@ -42,16 +41,8 @@ impl Default for ReplaceFunc { impl ReplaceFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - Exact(vec![Utf8View, Utf8View, Utf8View]), - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(3, Volatility::Immutable), } } } @@ -83,6 +74,34 @@ impl ScalarUDFImpl for ReplaceFunc { } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_replace_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_replace_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Replaces all occurrences of a specified substring in a string with a new substring.") + .with_syntax_example("replace(str, substr, replacement)") + .with_sql_example(r#"```sql +> select replace('ABabbaBA', 'ab', 'cd'); ++-------------------------------------------------+ +| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) | ++-------------------------------------------------+ +| ABcdbaBA | ++-------------------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_standard_argument("substr", "Substring expression to replace in the input string. Substring expression") + .with_standard_argument("replacement", "Replacement substring") + .build() + .unwrap() + }) } fn replace_view(args: &[ArrayRef]) -> Result { diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index a1aa5568babb..1a27502a2082 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -16,19 +16,18 @@ // under the License. use arrow::array::{ArrayRef, OffsetSizeTrait}; -use std::any::Any; - use arrow::datatypes::DataType; +use std::any::Any; +use std::sync::OnceLock; +use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - /// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. /// rtrim('testxxzx', 'xyz') = 'test' fn rtrim(args: &[ArrayRef]) -> Result { @@ -49,18 +48,9 @@ impl Default for RtrimFunc { impl RtrimFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![ - // Planner attempts coercion to the target type starting with the most preferred candidate. - // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`. - // If that fails, it proceeds to `(Utf8, Utf8)`. - Exact(vec![Utf8View, Utf8View]), - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8View]), - Exact(vec![Utf8]), - ], + vec![TypeSignature::String(2), TypeSignature::String(1)], Volatility::Immutable, ), } @@ -104,6 +94,41 @@ impl ScalarUDFImpl for RtrimFunc { ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_rtrim_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_rtrim_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string.") + .with_syntax_example("rtrim(str[, trim_str])") + .with_sql_example(r#"```sql +> select rtrim(' datafusion '); ++-------------------------------+ +| rtrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select rtrim('___datafusion___', '_'); ++-------------------------------------------+ +| rtrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| ___datafusion | ++-------------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_argument("trim_str", "String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._") + .with_related_udf("btrim") + .with_related_udf("ltrim") + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 8d292315a35a..2441798c38d4 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::utils::utf8_to_str_type; use arrow::array::{ ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, StringViewArray, }; @@ -23,13 +24,11 @@ use arrow::datatypes::DataType; use datafusion_common::cast::as_int64_array; use datafusion_common::ScalarValue; use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use std::any::Any; -use std::sync::Arc; - -use crate::utils::utf8_to_str_type; +use std::sync::{Arc, OnceLock}; use super::common::StringArrayType; @@ -50,15 +49,15 @@ impl SplitPartFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8View, Utf8View, Int64]), - Exact(vec![Utf8View, Utf8, Int64]), - Exact(vec![Utf8View, LargeUtf8, Int64]), - Exact(vec![Utf8, Utf8View, Int64]), - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, Utf8View, Int64]), - Exact(vec![LargeUtf8, Utf8, Int64]), - Exact(vec![Utf8, LargeUtf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), + TypeSignature::Exact(vec![Utf8View, Utf8View, Int64]), + TypeSignature::Exact(vec![Utf8View, Utf8, Int64]), + TypeSignature::Exact(vec![Utf8View, LargeUtf8, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8View, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, Utf8View, Int64]), + TypeSignature::Exact(vec![LargeUtf8, Utf8, Int64]), + TypeSignature::Exact(vec![Utf8, LargeUtf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]), ], Volatility::Immutable, ), @@ -178,6 +177,34 @@ impl ScalarUDFImpl for SplitPartFunc { result.map(ColumnarValue::Array) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_split_part_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_split_part_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Splits a string based on a specified delimiter and returns the substring in the specified position.") + .with_syntax_example("split_part(str, delimiter, pos)") + .with_sql_example(r#"```sql +> select split_part('1.2.3.4.5', '.', 3); ++--------------------------------------------------+ +| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) | ++--------------------------------------------------+ +| 3 | ++--------------------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_argument("delimiter", "String or character to split on.") + .with_argument("pos", "Position of the part to return.") + .build() + .unwrap() + }) } /// impl diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 8450697cbf30..713b642d5e91 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -16,18 +16,17 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::ArrayRef; use arrow::datatypes::DataType; +use crate::utils::make_scalar_function; use datafusion_common::{internal_err, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use crate::utils::make_scalar_function; - /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' pub fn starts_with(args: &[ArrayRef]) -> Result { @@ -49,17 +48,7 @@ impl Default for StartsWithFunc { impl StartsWithFunc { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - // Planner attempts coercion to the target type starting with the most preferred candidate. - // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`. - // If that fails, it proceeds to `(Utf8, Utf8)`. - Exact(vec![DataType::Utf8View, DataType::Utf8View]), - Exact(vec![DataType::Utf8, DataType::Utf8]), - Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(2, Volatility::Immutable), } } } @@ -89,6 +78,35 @@ impl ScalarUDFImpl for StartsWithFunc { _ => internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")?, } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_starts_with_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_starts_with_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Tests if a string starts with a substring.") + .with_syntax_example("starts_with(str, substr)") + .with_sql_example( + r#"```sql +> select starts_with('datafusion','data'); ++----------------------------------------------+ +| starts_with(Utf8("datafusion"),Utf8("data")) | ++----------------------------------------------+ +| true | ++----------------------------------------------+ +```"#, + ) + .with_standard_argument("str", "String") + .with_argument("substr", "Substring to test for.") + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 79aa9254f9b1..72cd4fbffa33 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -16,21 +16,21 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, }; +use crate::utils::make_scalar_function; use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use crate::utils::make_scalar_function; - /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' pub fn to_hex(args: &[ArrayRef]) -> Result @@ -110,6 +110,34 @@ impl ScalarUDFImpl for ToHexFunc { other => exec_err!("Unsupported data type {other:?} for function to_hex"), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_hex_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_hex_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Converts an integer to a hexadecimal string.") + .with_syntax_example("to_hex(int)") + .with_sql_example( + r#"```sql +> select to_hex(12345689); ++-------------------------+ +| to_hex(Int64(12345689)) | ++-------------------------+ +| bc6159 | ++-------------------------+ +```"#, + ) + .with_standard_argument("int", "Integer") + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 593e33ab6bb4..bfcb2a86994d 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -19,9 +19,11 @@ use crate::string::common::to_upper; use crate::utils::utf8_to_str_type; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct UpperFunc { @@ -36,13 +38,8 @@ impl Default for UpperFunc { impl UpperFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8, Utf8View], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -67,6 +64,36 @@ impl ScalarUDFImpl for UpperFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { to_upper(args, "upper") } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_upper_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_upper_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Converts a string to upper-case.") + .with_syntax_example("upper(str)") + .with_sql_example( + r#"```sql +> select upper('dataFusion'); ++---------------------------+ +| upper(Utf8("dataFusion")) | ++---------------------------+ +| DATAFUSION | ++---------------------------+ +```"#, + ) + .with_standard_argument("str", "String") + .with_related_udf("initcap") + .with_related_udf("lower") + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index 3ddc320fcec1..0fbdce16ccd1 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::GenericStringArray; use arrow::datatypes::DataType; @@ -24,7 +24,8 @@ use arrow::datatypes::DataType::Utf8; use uuid::Uuid; use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; #[derive(Debug)] @@ -74,4 +75,29 @@ impl ScalarUDFImpl for UuidFunc { let array = GenericStringArray::::from_iter_values(values); Ok(ColumnarValue::Array(Arc::new(array))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_uuid_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_uuid_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns [`UUID v4`](https://en.wikipedia.org/wiki/Universally_unique_identifier#Version_4_(random)) string value which is unique per row.") + .with_syntax_example("uuid()") + .with_sql_example(r#"```sql +> select uuid(); ++--------------------------------------+ +| uuid() | ++--------------------------------------+ +| 6ec17ef8-1934-41cc-8d59-d0c8f9eea1f0 | ++--------------------------------------+ +```"#) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index c9dc96b2a935..bfb60bfbe259 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -22,9 +22,12 @@ use arrow::array::{ }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct CharacterLengthFunc { @@ -76,6 +79,36 @@ impl ScalarUDFImpl for CharacterLengthFunc { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_character_length_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_character_length_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the number of characters in a string.") + .with_syntax_example("character_length(str)") + .with_sql_example( + r#"```sql +> select character_length('Ångström'); ++------------------------------------+ +| character_length(Utf8("Ångström")) | ++------------------------------------+ +| 8 | ++------------------------------------+ +```"#, + ) + .with_standard_argument("str", "String") + .with_related_udf("bit_length") + .with_related_udf("octet_length") + .build() + .unwrap() + }) } /// Returns number of characters in the string. diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index 41a2b9d9e72d..cad860e41088 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, @@ -24,11 +24,13 @@ use arrow::array::{ }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use crate::utils::{make_scalar_function, utf8_to_int_type}; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::{make_scalar_function, utf8_to_int_type}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct FindInSetFunc { @@ -77,6 +79,33 @@ impl ScalarUDFImpl for FindInSetFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(find_in_set, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_find_in_set_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_find_in_set_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.") + .with_syntax_example("find_in_set(str, strlist)") + .with_sql_example(r#"```sql +> select find_in_set('b', 'a,b,c,d'); ++----------------------------------------+ +| find_in_set(Utf8("b"),Utf8("a,b,c,d")) | ++----------------------------------------+ +| 2 | ++----------------------------------------+ +```"#) + .with_argument("str", "String expression to find in strlist.") + .with_argument("strlist", "A string list is a string composed of substrings separated by , characters.") + .build() + .unwrap() + }) } ///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index c49784948dd0..6610cfb25e79 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::cmp::Ordering; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, @@ -25,15 +25,17 @@ use arrow::array::{ }; use arrow::datatypes::DataType; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{ as_generic_string_array, as_int64_array, as_string_view_array, }; use datafusion_common::exec_err; use datafusion_common::Result; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct LeftFunc { @@ -91,6 +93,34 @@ impl ScalarUDFImpl for LeftFunc { ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_left_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_left_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns a specified number of characters from the left side of a string.") + .with_syntax_example("left(str, n)") + .with_sql_example(r#"```sql +> select left('datafusion', 4); ++-----------------------------------+ +| left(Utf8("datafusion"),Int64(4)) | ++-----------------------------------+ +| data | ++-----------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_argument("n", "Number of characters to return.") + .with_related_udf("right") + .build() + .unwrap() + }) } /// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index e102673c4253..48bd583720aa 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::fmt::Write; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, @@ -27,13 +27,15 @@ use arrow::datatypes::DataType; use unicode_segmentation::UnicodeSegmentation; use DataType::{LargeUtf8, Utf8, Utf8View}; +use crate::string::common::StringArrayType; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::string::common::StringArrayType; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct LPadFunc { @@ -95,6 +97,35 @@ impl ScalarUDFImpl for LPadFunc { other => exec_err!("Unsupported data type {other:?} for function lpad"), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_lpad_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lpad_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Pads the left side of a string with another string to a specified string length.") + .with_syntax_example("lpad(str, n[, padding_str])") + .with_sql_example(r#"```sql +> select lpad('Dolly', 10, 'hello'); ++---------------------------------------------+ +| lpad(Utf8("Dolly"),Int64(10),Utf8("hello")) | ++---------------------------------------------+ +| helloDolly | ++---------------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_argument("n", "String length to pad to.") + .with_argument("padding_str", "Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._") + .with_related_udf("rpad") + .build() + .unwrap() + }) } /// Extends the string to length 'length' by prepending the characters fill (a space by default). diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index da16d3ee3752..32872c28a613 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -16,19 +16,21 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait, }; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use DataType::{LargeUtf8, Utf8, Utf8View}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct ReverseFunc { signature: Signature, @@ -79,6 +81,34 @@ impl ScalarUDFImpl for ReverseFunc { } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_reverse_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_reverse_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Reverses the character order of a string.") + .with_syntax_example("reverse(str)") + .with_sql_example( + r#"```sql +> select reverse('datafusion'); ++-----------------------------+ +| reverse(Utf8("datafusion")) | ++-----------------------------+ +| noisufatad | ++-----------------------------+ +```"#, + ) + .with_standard_argument("str", "String") + .build() + .unwrap() + }) } /// Reverses the order of the characters in the string. diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index 9d542bb2c006..585611fe60e4 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::cmp::{max, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, @@ -31,8 +31,11 @@ use datafusion_common::cast::{ }; use datafusion_common::exec_err; use datafusion_common::Result; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct RightFunc { @@ -90,6 +93,34 @@ impl ScalarUDFImpl for RightFunc { ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_right_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_right_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns a specified number of characters from the right side of a string.") + .with_syntax_example("right(str, n)") + .with_sql_example(r#"```sql +> select right('datafusion', 6); ++------------------------------------+ +| right(Utf8("datafusion"),Int64(6)) | ++------------------------------------+ +| fusion | ++------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_argument("n", "Number of characters to return") + .with_related_udf("left") + .build() + .unwrap() + }) } /// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index ce221b44f42b..9ca65e229c0c 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -47,27 +47,6 @@ impl Default for RPadFunc { } } -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_rpad_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder() - .with_doc_section(DOC_SECTION_STRING) - .with_description("Pads the right side of a string with another string to a specified string length.") - .with_syntax_example("rpad(str, n[, padding_str])") - .with_argument( - "str", - "String expression to operate on. Can be a constant, column, or function, and any combination of string operators.", - ) - .with_argument("n", "String length to pad to.") - .with_argument("padding_str", - "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._") - .with_related_udf("lpad") - .build() - .unwrap() - }) -} - impl RPadFunc { pub fn new() -> Self { use DataType::*; @@ -143,6 +122,35 @@ impl ScalarUDFImpl for RPadFunc { } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_rpad_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Pads the right side of a string with another string to a specified string length.") + .with_syntax_example("rpad(str, n[, padding_str])") + .with_sql_example(r#"```sql +> select rpad('datafusion', 20, '_-'); ++-----------------------------------------------+ +| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) | ++-----------------------------------------------+ +| datafusion_-_-_-_-_- | ++-----------------------------------------------+ +```"#) + .with_standard_argument( + "str", + "String", + ) + .with_argument("n", "String length to pad to.") + .with_argument("padding_str", + "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._") + .with_related_udf("lpad") + .build() + .unwrap() + }) +} + pub fn rpad( args: &[ArrayRef], ) -> Result { diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 6da67c8a2798..660adc7578a5 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -16,16 +16,17 @@ // under the License. use std::any::Any; -use std::sync::Arc; - -use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; -use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use std::sync::{Arc, OnceLock}; use crate::string::common::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_int_type}; +use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; +use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct StrposFunc { @@ -41,20 +42,8 @@ impl Default for StrposFunc { impl StrposFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - Exact(vec![Utf8View, Utf8View]), - Exact(vec![Utf8View, Utf8]), - Exact(vec![Utf8View, LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(2, Volatility::Immutable), aliases: vec![String::from("instr"), String::from("position")], } } @@ -84,6 +73,33 @@ impl ScalarUDFImpl for StrposFunc { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_strpos_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_strpos_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.") + .with_syntax_example("strpos(str, substr)") + .with_sql_example(r#"```sql +> select strpos('datafusion', 'fus'); ++----------------------------------------+ +| strpos(Utf8("datafusion"),Utf8("fus")) | ++----------------------------------------+ +| 5 | ++----------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_argument("substr", "Substring expression to search for.") + .build() + .unwrap() + }) } fn strpos(args: &[ArrayRef]) -> Result { diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 205de0b30b9c..969969ef2f6f 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use crate::string::common::{make_and_append_view, StringArrayType}; use crate::utils::{make_scalar_function, utf8_to_str_type}; @@ -28,7 +28,10 @@ use arrow::datatypes::DataType; use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, plan_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct SubstrFunc { @@ -81,6 +84,13 @@ impl ScalarUDFImpl for SubstrFunc { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() < 2 || arg_types.len() > 3 { + return plan_err!( + "The {} function requires 2 or 3 arguments, but got {}.", + self.name(), + arg_types.len() + ); + } let first_data_type = match &arg_types[0] { DataType::Null => Ok(DataType::Utf8), DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(arg_types[0].clone()), @@ -138,6 +148,34 @@ impl ScalarUDFImpl for SubstrFunc { ]) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_substr_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_substr_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Extracts a substring of a specified number of characters from a specific starting position in a string.") + .with_syntax_example("substr(str, start_pos[, length])") + .with_sql_example(r#"```sql +> select substr('datafusion', 5, 3); ++----------------------------------------------+ +| substr(Utf8("datafusion"),Int64(5),Int64(3)) | ++----------------------------------------------+ +| fus | ++----------------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_argument("start_pos", "Character position to start the substring at. The first character in the string has a position of 1.") + .with_argument("length", "Number of characters to extract. If not specified, returns the rest of the string after the start position.") + .build() + .unwrap() + }) } /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index 6591ee26403a..436d554a49f7 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, @@ -24,11 +24,13 @@ use arrow::array::{ }; use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct SubstrIndexFunc { @@ -83,6 +85,42 @@ impl ScalarUDFImpl for SubstrIndexFunc { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_substr_index_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_substr_index_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description(r#"Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned."#) + .with_syntax_example("substr_index(str, delim, count)") + .with_sql_example(r#"```sql +> select substr_index('www.apache.org', '.', 1); ++---------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) | ++---------------------------------------------------------+ +| www | ++---------------------------------------------------------+ +> select substr_index('www.apache.org', '.', -1); ++----------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) | ++----------------------------------------------------------+ +| org | ++----------------------------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_argument("delim", "The string to find in str to split str.") + .with_argument("count", "The number of times to search for the delimiter. Can be either a positive or negative number.") + .build() + .unwrap() + }) } /// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index a42b9c6cb857..cbee9a6fe1f2 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait, @@ -27,8 +27,11 @@ use unicode_segmentation::UnicodeSegmentation; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct TranslateFunc { @@ -76,6 +79,34 @@ impl ScalarUDFImpl for TranslateFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(invoke_translate, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_translate_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_translate_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Translates characters in a string to specified translation characters.") + .with_syntax_example("translate(str, chars, translation)") + .with_sql_example(r#"```sql +> select translate('twice', 'wic', 'her'); ++--------------------------------------------------+ +| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) | ++--------------------------------------------------+ +| there | ++--------------------------------------------------+ +```"#) + .with_standard_argument("str", "String") + .with_argument("chars", "Characters to translate.") + .with_argument("translation", "Translation characters. Translation characters replace only characters at the same position in the **chars** string.") + .build() + .unwrap() + }) } fn invoke_translate(args: &[ArrayRef]) -> Result { diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 337a24ffae20..79a5bb24e918 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -35,10 +35,6 @@ workspace = true name = "datafusion_optimizer" path = "src/lib.rs" -[features] -default = ["regex_expressions"] -regex_expressions = ["datafusion-physical-expr/regex_expressions"] - [dependencies] arrow = { workspace = true } async-trait = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4dc34284c719..e5d280289342 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -456,7 +456,6 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { self.schema, &func, )?; - let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &func)?; Ok(Transformed::yes(Expr::ScalarFunction( ScalarFunction::new_udf(func, new_expr), ))) @@ -756,30 +755,6 @@ fn coerce_arguments_for_signature_with_aggregate_udf( .collect() } -fn coerce_arguments_for_fun( - expressions: Vec, - schema: &DFSchema, - fun: &Arc, -) -> Result> { - // Cast Fixedsizelist to List for array functions - if fun.name() == "make_array" { - expressions - .into_iter() - .map(|expr| { - let data_type = expr.get_type(schema).unwrap(); - if let DataType::FixedSizeList(field, _) = data_type { - let to_type = DataType::List(Arc::clone(&field)); - expr.cast_to(&to_type, schema) - } else { - Ok(expr) - } - }) - .collect() - } else { - Ok(expressions) - } -} - fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { // Given expressions like: // diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index a78a54a57123..e720188ae364 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -838,22 +838,18 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Or, right, }) if expr_contains(&right, &left, Or) => Transformed::yes(*right), - // A OR (A AND B) --> A (if B not null) + // A OR (A AND B) --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&right)? && is_op_with(And, &right, &left) => { - Transformed::yes(*left) - } - // (A AND B) OR A --> A (if B not null) + }) if is_op_with(And, &right, &left) => Transformed::yes(*left), + // (A AND B) OR A --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&left)? && is_op_with(And, &left, &right) => { - Transformed::yes(*right) - } + }) if is_op_with(And, &left, &right) => Transformed::yes(*right), // // Rules for AND @@ -911,22 +907,18 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: And, right, }) if expr_contains(&right, &left, And) => Transformed::yes(*right), - // A AND (A OR B) --> A (if B not null) + // A AND (A OR B) --> A Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&right)? && is_op_with(Or, &right, &left) => { - Transformed::yes(*left) - } - // (A OR B) AND A --> A (if B not null) + }) if is_op_with(Or, &right, &left) => Transformed::yes(*left), + // (A OR B) AND A --> A Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&left)? && is_op_with(Or, &left, &right) => { - Transformed::yes(*right) - } + }) if is_op_with(Or, &left, &right) => Transformed::yes(*right), // // Rules for Multiply @@ -1789,6 +1781,8 @@ fn inlist_except(mut l1: InList, l2: &InList) -> Result { #[cfg(test)] mod tests { + use crate::simplify_expressions::SimplifyContext; + use crate::test::test_table_scan_with_name; use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ function::{ @@ -1799,15 +1793,13 @@ mod tests { *, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; + use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use std::{ collections::HashMap, ops::{BitAnd, BitOr, BitXor}, sync::Arc, }; - use crate::simplify_expressions::SimplifyContext; - use crate::test::test_table_scan_with_name; - use super::*; // ------------------------------ @@ -2609,15 +2601,11 @@ mod tests { // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) let expr = or(l.clone(), r.clone()); - // no rewrites if c1 can be null - let expected = expr.clone(); + let expected = l.clone(); assert_eq!(simplify(expr), expected); // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) - let expr = or(l, r); - - // no rewrites if c1 can be null - let expected = expr.clone(); + let expr = or(r, l); assert_eq!(simplify(expr), expected); } @@ -2648,13 +2636,11 @@ mod tests { // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5 let expr = and(l.clone(), r.clone()); - // no rewrites if c1 can be null - let expected = expr.clone(); + let expected = l.clone(); assert_eq!(simplify(expr), expected); // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5 - let expr = and(l, r); - let expected = expr.clone(); + let expr = and(r, l); assert_eq!(simplify(expr), expected); } @@ -3223,7 +3209,7 @@ mod tests { )], Some(Box::new(col("c2").eq(lit(true)))), )))), - col("c2").or(col("c2").not().and(col("c2"))) // #1716 + col("c2") ); // CASE WHEN ISNULL(c2) THEN true ELSE c2 @@ -3910,7 +3896,10 @@ mod tests { } } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { unimplemented!("not needed for tests") } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 1c22c2a4375a..74251e5caad2 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -355,7 +355,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ + let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -373,7 +373,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ + let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -392,7 +392,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ + let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 22e3c0ddd076..31e21d08b569 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -146,7 +146,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }; is_supported_type(&left_type) && is_supported_type(&right_type) - && op.is_comparison_operator() + && op.supports_propagation() } => { match (left.as_mut(), right.as_mut()) { diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index f320ebcc06b5..03ac4769d9d9 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -104,8 +104,9 @@ impl ArrowBytesSet { /// `Binary`, and `LargeBinary`) values that can produce the set of keys on /// output as `GenericBinaryArray` without copies. /// -/// Equivalent to `HashSet` but with better performance for arrow -/// data. +/// Equivalent to `HashSet` but with better performance if you need +/// to emit the keys as an Arrow `StringArray` / `BinaryArray`. For other +/// purposes it is the same as a `HashMap` /// /// # Generic Arguments /// diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index bdcf7bbacc69..c6768a19d30e 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -88,8 +88,9 @@ impl ArrowBytesViewSet { /// values that can produce the set of keys on /// output as `GenericBinaryViewArray` without copies. /// -/// Equivalent to `HashSet` but with better performance for arrow -/// data. +/// Equivalent to `HashSet` but with better performance if you need +/// to emit the keys as an Arrow `StringViewArray` / `BinaryViewArray`. For other +/// purposes it is the same as a `HashMap` /// /// # Generic Arguments /// diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 704cb291335f..6c4bf156ce56 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -120,6 +120,13 @@ impl PhysicalSortExpr { } } +/// Access the PhysicalSortExpr as a PhysicalExpr +impl AsRef for PhysicalSortExpr { + fn as_ref(&self) -> &(dyn PhysicalExpr + 'static) { + self.expr.as_ref() + } +} + impl PartialEq for PhysicalSortExpr { fn eq(&self, other: &PhysicalSortExpr) -> bool { self.options == other.options && self.expr.eq(&other.expr) diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index c53f7a6c4771..4195e684381f 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -35,14 +35,6 @@ workspace = true name = "datafusion_physical_expr" path = "src/lib.rs" -[features] -default = [ - "regex_expressions", - "encoding_expressions", -] -encoding_expressions = ["base64", "hex"] -regex_expressions = ["regex"] - [dependencies] ahash = { workspace = true } arrow = { workspace = true } @@ -51,23 +43,19 @@ arrow-buffer = { workspace = true } arrow-ord = { workspace = true } arrow-schema = { workspace = true } arrow-string = { workspace = true } -base64 = { version = "0.22", optional = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } -datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } -hex = { version = "0.4", optional = true } indexmap = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "^1.0" petgraph = "0.6.2" -regex = { workspace = true, optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 00708b4540aa..c1851ddb22b5 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -30,7 +30,6 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::JoinType; use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; -#[derive(Debug, Clone)] /// A structure representing a expression known to be constant in a physical execution plan. /// /// The `ConstExpr` struct encapsulates an expression that is constant during the execution @@ -41,9 +40,10 @@ use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; /// /// - `expr`: Constant expression for a node in the physical plan. /// -/// - `across_partitions`: A boolean flag indicating whether the constant expression is -/// valid across partitions. If set to `true`, the constant expression has same value for all partitions. -/// If set to `false`, the constant expression may have different values for different partitions. +/// - `across_partitions`: A boolean flag indicating whether the constant +/// expression is the same across partitions. If set to `true`, the constant +/// expression has same value for all partitions. If set to `false`, the +/// constant expression may have different values for different partitions. /// /// # Example /// @@ -56,11 +56,22 @@ use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; /// // create a constant expression from a physical expression /// let const_expr = ConstExpr::from(col); /// ``` +#[derive(Debug, Clone)] pub struct ConstExpr { + /// The expression that is known to be constant (e.g. a `Column`) expr: Arc, + /// Does the constant have the same value across all partitions? See + /// struct docs for more details across_partitions: bool, } +impl PartialEq for ConstExpr { + fn eq(&self, other: &Self) -> bool { + self.across_partitions == other.across_partitions + && self.expr.eq(other.expr.as_any()) + } +} + impl ConstExpr { /// Create a new constant expression from a physical expression. /// @@ -74,11 +85,17 @@ impl ConstExpr { } } + /// Set the `across_partitions` flag + /// + /// See struct docs for more details pub fn with_across_partitions(mut self, across_partitions: bool) -> Self { self.across_partitions = across_partitions; self } + /// Is the expression the same across all partitions? + /// + /// See struct docs for more details pub fn across_partitions(&self) -> bool { self.across_partitions } @@ -101,6 +118,31 @@ impl ConstExpr { across_partitions: self.across_partitions, }) } + + /// Returns true if this constant expression is equal to the given expression + pub fn eq_expr(&self, other: impl AsRef) -> bool { + self.expr.eq(other.as_ref().as_any()) + } + + /// Returns a [`Display`]able list of `ConstExpr`. + pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ { + struct DisplayableList<'a>(&'a [ConstExpr]); + impl<'a> Display for DisplayableList<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let mut first = true; + for const_expr in self.0 { + if first { + first = false; + } else { + write!(f, ",")?; + } + write!(f, "{}", const_expr)?; + } + Ok(()) + } + } + DisplayableList(input) + } } /// Display implementation for `ConstExpr` diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 558565e523de..a0cc29685f77 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::fmt; use std::fmt::Display; use std::hash::{Hash, Hasher}; +use std::iter::Peekable; +use std::slice::Iter; use std::sync::Arc; use super::ordering::collapse_lex_ordering; @@ -279,6 +282,12 @@ impl EquivalenceProperties { self.with_constants(constants) } + /// Remove the specified constant + pub fn remove_constant(mut self, c: &ConstExpr) -> Self { + self.constants.retain(|existing| existing != c); + self + } + /// Track/register physical expressions with constant values. pub fn with_constants( mut self, @@ -701,7 +710,7 @@ impl EquivalenceProperties { /// c ASC: Node {None, HashSet{a ASC}} /// ``` fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { - let mut dependency_map = IndexMap::new(); + let mut dependency_map = DependencyMap::new(); for ordering in self.normalized_oeq_class().iter() { for (idx, sort_expr) in ordering.iter().enumerate() { let target_sort_expr = @@ -723,13 +732,11 @@ impl EquivalenceProperties { let dependency = idx.checked_sub(1).map(|a| &ordering[a]); // Add sort expressions that can be projected or referred to // by any of the projection expressions to the dependency map: - dependency_map - .entry(sort_expr.clone()) - .or_insert_with(|| DependencyNode { - target_sort_expr: target_sort_expr.clone(), - dependencies: IndexSet::new(), - }) - .insert_dependency(dependency); + dependency_map.insert( + sort_expr, + target_sort_expr.as_ref(), + dependency, + ); } if !is_projected { // If we can not project, stop constructing the dependency @@ -1120,15 +1127,7 @@ impl Display for EquivalenceProperties { write!(f, ", eq: {}", self.eq_group)?; } if !self.constants.is_empty() { - write!(f, ", const: [")?; - let mut iter = self.constants.iter(); - if let Some(c) = iter.next() { - write!(f, "{}", c)?; - } - for c in iter { - write!(f, ", {}", c)?; - } - write!(f, "]")?; + write!(f, ", const: [{}]", ConstExpr::format_list(&self.constants))?; } Ok(()) } @@ -1257,7 +1256,7 @@ fn referred_dependencies( // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: let mut expr_to_sort_exprs = IndexMap::::new(); for sort_expr in dependency_map - .keys() + .sort_exprs() .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) { let key = ExprWrapper(Arc::clone(&sort_expr.expr)); @@ -1270,10 +1269,16 @@ fn referred_dependencies( // Generate all valid dependencies for the source. For example, if the source // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. - expr_to_sort_exprs - .values() + let dependencies = expr_to_sort_exprs + .into_values() + .map(Dependencies::into_inner) + .collect::>(); + dependencies + .iter() .multi_cartesian_product() - .map(|referred_deps| referred_deps.into_iter().cloned().collect()) + .map(|referred_deps| { + Dependencies::new_from_iter(referred_deps.into_iter().cloned()) + }) .collect() } @@ -1295,21 +1300,32 @@ fn construct_prefix_orderings( relevant_sort_expr: &PhysicalSortExpr, dependency_map: &DependencyMap, ) -> Vec { - dependency_map[relevant_sort_expr] + let mut dep_enumerator = DependencyEnumerator::new(); + dependency_map + .get(relevant_sort_expr) + .expect("no relevant sort expr found") .dependencies .iter() - .flat_map(|dep| construct_orderings(dep, dependency_map)) + .flat_map(|dep| dep_enumerator.construct_orderings(dep, dependency_map)) .collect() } -/// Given a set of relevant dependencies (`relevant_deps`) and a map of dependencies -/// (`dependency_map`), this function generates all possible prefix orderings -/// based on the given dependencies. +/// Generates all possible orderings where dependencies are satisfied for the +/// current projection expression. +/// +/// # Examaple +/// If `dependences` is `a + b ASC` and the dependency map holds dependencies +/// * `a ASC` --> `[c ASC]` +/// * `b ASC` --> `[d DESC]`, +/// +/// This function generates these two sort orders +/// * `[c ASC, d DESC, a + b ASC]` +/// * `[d DESC, c ASC, a + b ASC]` /// /// # Parameters /// -/// * `dependencies` - A reference to the dependencies. -/// * `dependency_map` - A reference to the map of dependencies for expressions. +/// * `dependencies` - Set of relevant expressions. +/// * `dependency_map` - Map of dependencies for expressions that may appear in `dependencies` /// /// # Returns /// @@ -1335,11 +1351,6 @@ fn generate_dependency_orderings( return vec![vec![]]; } - // Generate all possible orderings where dependencies are satisfied for the - // current projection expression. For example, if expression is `a + b ASC`, - // and the dependency for `a ASC` is `[c ASC]`, the dependency for `b ASC` - // is `[d DESC]`, then we generate `[c ASC, d DESC, a + b ASC]` and - // `[d DESC, c ASC, a + b ASC]`. relevant_prefixes .into_iter() .multi_cartesian_product() @@ -1421,7 +1432,7 @@ struct DependencyNode { } impl DependencyNode { - // Insert dependency to the state (if exists). + /// Insert dependency to the state (if exists). fn insert_dependency(&mut self, dependency: Option<&PhysicalSortExpr>) { if let Some(dep) = dependency { self.dependencies.insert(dep.clone()); @@ -1429,46 +1440,229 @@ impl DependencyNode { } } -// Using `IndexMap` and `IndexSet` makes sure to generate consistent results across different executions for the same query. -// We could have used `HashSet`, `HashMap` in place of them without any loss of functionality. -// As an example, if existing orderings are `[a ASC, b ASC]`, `[c ASC]` for output ordering -// both `[a ASC, b ASC, c ASC]` and `[c ASC, a ASC, b ASC]` are valid (e.g. concatenated version of the alternative orderings). -// When using `HashSet`, `HashMap` it is not guaranteed to generate consistent result, among the possible 2 results in the example above. -type DependencyMap = IndexMap; -type Dependencies = IndexSet; - -/// This function recursively analyzes the dependencies of the given sort -/// expression within the given dependency map to construct lexicographical -/// orderings that include the sort expression and its dependencies. +impl Display for DependencyNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(target) = &self.target_sort_expr { + write!(f, "(target: {}, ", target)?; + } else { + write!(f, "(")?; + } + write!(f, "dependencies: [{}])", self.dependencies) + } +} + +/// Maps an expression --> DependencyNode /// -/// # Parameters +/// # Debugging / deplaying `DependencyMap` /// -/// - `referred_sort_expr`: A reference to the sort expression (`PhysicalSortExpr`) -/// for which lexicographical orderings satisfying its dependencies are to be -/// constructed. -/// - `dependency_map`: A reference to the `DependencyMap` that contains -/// dependencies for different `PhysicalSortExpr`s. +/// This structure implements `Display` to assist debugging. For example: /// -/// # Returns +/// ```text +/// DependencyMap: { +/// a@0 ASC --> (target: a@0 ASC, dependencies: [[]]) +/// b@1 ASC --> (target: b@1 ASC, dependencies: [[a@0 ASC, c@2 ASC]]) +/// c@2 ASC --> (target: c@2 ASC, dependencies: [[b@1 ASC, a@0 ASC]]) +/// d@3 ASC --> (target: d@3 ASC, dependencies: [[c@2 ASC, b@1 ASC]]) +/// } +/// ``` /// -/// A vector of lexicographical orderings (`Vec`) based on the given -/// sort expression and its dependencies. -fn construct_orderings( - referred_sort_expr: &PhysicalSortExpr, - dependency_map: &DependencyMap, -) -> Vec { - // We are sure that `referred_sort_expr` is inside `dependency_map`. - let node = &dependency_map[referred_sort_expr]; - // Since we work on intermediate nodes, we are sure `val.target_sort_expr` - // exists. - let target_sort_expr = node.target_sort_expr.clone().unwrap(); - if node.dependencies.is_empty() { - vec![vec![target_sort_expr]] - } else { +/// # Note on IndexMap Rationale +/// +/// Using `IndexMap` (which preserves insert order) to ensure consistent results +/// across different executions for the same query. We could have used +/// `HashSet`, `HashMap` in place of them without any loss of functionality. +/// +/// As an example, if existing orderings are +/// 1. `[a ASC, b ASC]` +/// 2. `[c ASC]` for +/// +/// Then both the following output orderings are valid +/// 1. `[a ASC, b ASC, c ASC]` +/// 2. `[c ASC, a ASC, b ASC]` +/// +/// (this are both valid as they are concatenated versions of the alternative +/// orderings). When using `HashSet`, `HashMap` it is not guaranteed to generate +/// consistent result, among the possible 2 results in the example above. +#[derive(Debug)] +struct DependencyMap { + inner: IndexMap, +} + +impl DependencyMap { + fn new() -> Self { + Self { + inner: IndexMap::new(), + } + } + + /// Insert a new dependency `sort_expr` --> `dependency` into the map. + /// + /// If `target_sort_expr` is none, a new entry is created with empty dependencies. + fn insert( + &mut self, + sort_expr: &PhysicalSortExpr, + target_sort_expr: Option<&PhysicalSortExpr>, + dependency: Option<&PhysicalSortExpr>, + ) { + self.inner + .entry(sort_expr.clone()) + .or_insert_with(|| DependencyNode { + target_sort_expr: target_sort_expr.cloned(), + dependencies: Dependencies::new(), + }) + .insert_dependency(dependency) + } + + /// Iterator over (sort_expr, DependencyNode) pairs + fn iter(&self) -> impl Iterator { + self.inner.iter() + } + + /// iterator over all sort exprs + fn sort_exprs(&self) -> impl Iterator { + self.inner.keys() + } + + /// Return the dependency node for the given sort expression, if any + fn get(&self, sort_expr: &PhysicalSortExpr) -> Option<&DependencyNode> { + self.inner.get(sort_expr) + } +} + +impl Display for DependencyMap { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "DependencyMap: {{")?; + for (sort_expr, node) in self.inner.iter() { + writeln!(f, " {sort_expr} --> {node}")?; + } + writeln!(f, "}}") + } +} + +/// A list of sort expressions that can be calculated from a known set of +/// dependencies. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +struct Dependencies { + inner: IndexSet, +} + +impl Display for Dependencies { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + let mut iter = self.inner.iter(); + if let Some(dep) = iter.next() { + write!(f, "{}", dep)?; + } + for dep in iter { + write!(f, ", {}", dep)?; + } + write!(f, "]") + } +} + +impl Dependencies { + /// Create a new empty `Dependencies` instance. + fn new() -> Self { + Self { + inner: IndexSet::new(), + } + } + + /// Create a new `Dependencies` from an iterator of `PhysicalSortExpr`. + fn new_from_iter(iter: impl IntoIterator) -> Self { + Self { + inner: iter.into_iter().collect(), + } + } + + /// Insert a new dependency into the set. + fn insert(&mut self, sort_expr: PhysicalSortExpr) { + self.inner.insert(sort_expr); + } + + /// Iterator over dependencies in the set + fn iter(&self) -> impl Iterator + Clone { + self.inner.iter() + } + + /// Return the inner set of dependencies + fn into_inner(self) -> IndexSet { + self.inner + } + + /// Returns true if there are no dependencies + fn is_empty(&self) -> bool { + self.inner.is_empty() + } +} + +/// Contains a mapping of all dependencies we have processed for each sort expr +struct DependencyEnumerator<'a> { + /// Maps `expr` --> `[exprs]` that have previously been processed + seen: IndexMap<&'a PhysicalSortExpr, IndexSet<&'a PhysicalSortExpr>>, +} + +impl<'a> DependencyEnumerator<'a> { + fn new() -> Self { + Self { + seen: IndexMap::new(), + } + } + + /// Insert a new dependency, + /// + /// returns false if the dependency was already in the map + /// returns true if the dependency was newly inserted + fn insert( + &mut self, + target: &'a PhysicalSortExpr, + dep: &'a PhysicalSortExpr, + ) -> bool { + self.seen.entry(target).or_default().insert(dep) + } + + /// This function recursively analyzes the dependencies of the given sort + /// expression within the given dependency map to construct lexicographical + /// orderings that include the sort expression and its dependencies. + /// + /// # Parameters + /// + /// - `referred_sort_expr`: A reference to the sort expression (`PhysicalSortExpr`) + /// for which lexicographical orderings satisfying its dependencies are to be + /// constructed. + /// - `dependency_map`: A reference to the `DependencyMap` that contains + /// dependencies for different `PhysicalSortExpr`s. + /// + /// # Returns + /// + /// A vector of lexicographical orderings (`Vec`) based on the given + /// sort expression and its dependencies. + fn construct_orderings( + &mut self, + referred_sort_expr: &'a PhysicalSortExpr, + dependency_map: &'a DependencyMap, + ) -> Vec { + let node = dependency_map + .get(referred_sort_expr) + .expect("`referred_sort_expr` should be inside `dependency_map`"); + // Since we work on intermediate nodes, we are sure `val.target_sort_expr` + // exists. + let target_sort_expr = node.target_sort_expr.as_ref().unwrap(); + // An empty dependency means the referred_sort_expr represents a global ordering. + // Return its projected version, which is the target_expression. + if node.dependencies.is_empty() { + return vec![vec![target_sort_expr.clone()]]; + }; + node.dependencies .iter() .flat_map(|dep| { - let mut orderings = construct_orderings(dep, dependency_map); + let mut orderings = if self.insert(target_sort_expr, dep) { + self.construct_orderings(dep, dependency_map) + } else { + vec![] + }; + for ordering in orderings.iter_mut() { ordering.push(target_sort_expr.clone()) } @@ -1611,58 +1805,62 @@ impl Hash for ExprWrapper { /// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties` /// of `lhs` and `rhs` according to the schema of `lhs`. +/// +/// Rules: The UnionExec does not interleave its inputs: instead it passes each +/// input partition from the children as its own output. +/// +/// Since the output equivalence properties are properties that are true for +/// *all* output partitions, that is the same as being true for all *input* +/// partitions fn calculate_union_binary( - lhs: EquivalenceProperties, + mut lhs: EquivalenceProperties, mut rhs: EquivalenceProperties, ) -> Result { - // TODO: In some cases, we should be able to preserve some equivalence - // classes. Add support for such cases. - // Harmonize the schema of the rhs with the schema of the lhs (which is the accumulator schema): if !rhs.schema.eq(&lhs.schema) { rhs = rhs.with_new_schema(Arc::clone(&lhs.schema))?; } - // First, calculate valid constants for the union. A quantity is constant - // after the union if it is constant in both sides. - let constants = lhs + // First, calculate valid constants for the union. An expression is constant + // at the output of the union if it is constant in both sides. + let constants: Vec<_> = lhs .constants() .iter() .filter(|const_expr| const_exprs_contains(rhs.constants(), const_expr.expr())) .map(|const_expr| { - // TODO: When both sides' constants are valid across partitions, - // the union's constant should also be valid if values are - // the same. However, we do not have the capability to - // check this yet. + // TODO: When both sides have a constant column, and the actual + // constant value is the same, then the output properties could + // reflect the constant is valid across all partitions. However we + // don't track the actual value that the ConstExpr takes on, so we + // can't determine that yet ConstExpr::new(Arc::clone(const_expr.expr())).with_across_partitions(false) }) .collect(); + // remove any constants that are shared in both outputs (avoid double counting them) + for c in &constants { + lhs = lhs.remove_constant(c); + rhs = rhs.remove_constant(c); + } + // Next, calculate valid orderings for the union by searching for prefixes // in both sides. - let mut orderings = vec![]; - for mut ordering in lhs.normalized_oeq_class().orderings { - // Progressively shorten the ordering to search for a satisfied prefix: - while !rhs.ordering_satisfy(&ordering) { - ordering.pop(); - } - // There is a non-trivial satisfied prefix, add it as a valid ordering: - if !ordering.is_empty() { - orderings.push(ordering); - } - } - for mut ordering in rhs.normalized_oeq_class().orderings { - // Progressively shorten the ordering to search for a satisfied prefix: - while !lhs.ordering_satisfy(&ordering) { - ordering.pop(); - } - // There is a non-trivial satisfied prefix, add it as a valid ordering: - if !ordering.is_empty() { - orderings.push(ordering); - } - } - let mut eq_properties = EquivalenceProperties::new(lhs.schema); - eq_properties.constants = constants; + let mut orderings = UnionEquivalentOrderingBuilder::new(); + orderings.add_satisfied_orderings( + lhs.normalized_oeq_class().orderings, + lhs.constants(), + &rhs, + ); + orderings.add_satisfied_orderings( + rhs.normalized_oeq_class().orderings, + rhs.constants(), + &lhs, + ); + let orderings = orderings.build(); + + let mut eq_properties = + EquivalenceProperties::new(lhs.schema).with_constants(constants); + eq_properties.add_new_orderings(orderings); Ok(eq_properties) } @@ -1695,6 +1893,206 @@ pub fn calculate_union( Ok(acc) } +#[derive(Debug)] +enum AddedOrdering { + /// The ordering was added to the in progress result + Yes, + /// The ordering was not added + No(LexOrdering), +} + +/// Builds valid output orderings of a `UnionExec` +#[derive(Debug)] +struct UnionEquivalentOrderingBuilder { + orderings: Vec, +} + +impl UnionEquivalentOrderingBuilder { + fn new() -> Self { + Self { orderings: vec![] } + } + + /// Add all orderings from `orderings` that satisfy `properties`, + /// potentially augmented with`constants`. + /// + /// Note: any column that is known to be constant can be inserted into the + /// ordering without changing its meaning + /// + /// For example: + /// * `orderings` contains `[a ASC, c ASC]` and `constants` contains `b` + /// * `properties` has required ordering `[a ASC, b ASC]` + /// + /// Then this will add `[a ASC, b ASC]` to the `orderings` list (as `a` was + /// in the sort order and `b` was a constant). + fn add_satisfied_orderings( + &mut self, + orderings: impl IntoIterator, + constants: &[ConstExpr], + properties: &EquivalenceProperties, + ) { + for mut ordering in orderings.into_iter() { + // Progressively shorten the ordering to search for a satisfied prefix: + loop { + match self.try_add_ordering(ordering, constants, properties) { + AddedOrdering::Yes => break, + AddedOrdering::No(o) => { + ordering = o; + ordering.pop(); + } + } + } + } + } + + /// Adds `ordering`, potentially augmented with constants, if it satisfies + /// the target `properties` properties. + /// + /// Returns + /// + /// * [`AddedOrdering::Yes`] if the ordering was added (either directly or + /// augmented), or was empty. + /// + /// * [`AddedOrdering::No`] if the ordering was not added + fn try_add_ordering( + &mut self, + ordering: LexOrdering, + constants: &[ConstExpr], + properties: &EquivalenceProperties, + ) -> AddedOrdering { + if ordering.is_empty() { + AddedOrdering::Yes + } else if constants.is_empty() && properties.ordering_satisfy(&ordering) { + // If the ordering satisfies the target properties, no need to + // augment it with constants. + self.orderings.push(ordering); + AddedOrdering::Yes + } else { + // Did not satisfy target properties, try and augment with constants + // to match the properties + if self.try_find_augmented_ordering(&ordering, constants, properties) { + AddedOrdering::Yes + } else { + AddedOrdering::No(ordering) + } + } + } + + /// Attempts to add `constants` to `ordering` to satisfy the properties. + /// + /// returns true if any orderings were added, false otherwise + fn try_find_augmented_ordering( + &mut self, + ordering: &LexOrdering, + constants: &[ConstExpr], + properties: &EquivalenceProperties, + ) -> bool { + // can't augment if there is nothing to augment with + if constants.is_empty() { + return false; + } + let start_num_orderings = self.orderings.len(); + + // for each equivalent ordering in properties, try and augment + // `ordering` it with the constants to match + for existing_ordering in &properties.oeq_class.orderings { + if let Some(augmented_ordering) = self.augment_ordering( + ordering, + constants, + existing_ordering, + &properties.constants, + ) { + if !augmented_ordering.is_empty() { + assert!(properties.ordering_satisfy(&augmented_ordering)); + self.orderings.push(augmented_ordering); + } + } + } + + self.orderings.len() > start_num_orderings + } + + /// Attempts to augment the ordering with constants to match the + /// `existing_ordering` + /// + /// Returns Some(ordering) if an augmented ordering was found, None otherwise + fn augment_ordering( + &mut self, + ordering: &LexOrdering, + constants: &[ConstExpr], + existing_ordering: &LexOrdering, + existing_constants: &[ConstExpr], + ) -> Option { + let mut augmented_ordering = vec![]; + let mut sort_expr_iter = ordering.iter().peekable(); + let mut existing_sort_expr_iter = existing_ordering.iter().peekable(); + + // walk in parallel down the two orderings, trying to match them up + while sort_expr_iter.peek().is_some() || existing_sort_expr_iter.peek().is_some() + { + // If the next expressions are equal, add the next match + // otherwise try and match with a constant + if let Some(expr) = + advance_if_match(&mut sort_expr_iter, &mut existing_sort_expr_iter) + { + augmented_ordering.push(expr); + } else if let Some(expr) = + advance_if_matches_constant(&mut sort_expr_iter, existing_constants) + { + augmented_ordering.push(expr); + } else if let Some(expr) = + advance_if_matches_constant(&mut existing_sort_expr_iter, constants) + { + augmented_ordering.push(expr); + } else { + // no match, can't continue the ordering, return what we have + break; + } + } + + Some(augmented_ordering) + } + + fn build(self) -> Vec { + self.orderings + } +} + +/// Advances two iterators in parallel +/// +/// If the next expressions are equal, the iterators are advanced and returns +/// the matched expression . +/// +/// Otherwise, the iterators are left unchanged and return `None` +fn advance_if_match( + iter1: &mut Peekable>, + iter2: &mut Peekable>, +) -> Option { + if matches!((iter1.peek(), iter2.peek()), (Some(expr1), Some(expr2)) if expr1.eq(expr2)) + { + iter1.next().unwrap(); + iter2.next().cloned() + } else { + None + } +} + +/// Advances the iterator with a constant +/// +/// If the next expression matches one of the constants, advances the iterator +/// returning the matched expression +/// +/// Otherwise, the iterator is left unchanged and returns `None` +fn advance_if_matches_constant( + iter: &mut Peekable>, + constants: &[ConstExpr], +) -> Option { + let expr = iter.peek()?; + let const_expr = constants.iter().find(|c| c.eq_expr(expr))?; + let found_expr = PhysicalSortExpr::new(Arc::clone(const_expr.expr()), expr.options); + iter.next(); + Some(found_expr) +} + #[cfg(test)] mod tests { use std::ops::Not; @@ -1760,6 +2158,51 @@ mod tests { Ok(()) } + #[test] + fn project_equivalence_properties_test_multi() -> Result<()> { + // test multiple input orderings with equivalence properties + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("d", DataType::Int64, true), + ])); + + let mut input_properties = EquivalenceProperties::new(Arc::clone(&input_schema)); + // add equivalent ordering [a, b, c, d] + input_properties.add_new_ordering(vec![ + parse_sort_expr("a", &input_schema), + parse_sort_expr("b", &input_schema), + parse_sort_expr("c", &input_schema), + parse_sort_expr("d", &input_schema), + ]); + + // add equivalent ordering [a, c, b, d] + input_properties.add_new_ordering(vec![ + parse_sort_expr("a", &input_schema), + parse_sort_expr("c", &input_schema), + parse_sort_expr("b", &input_schema), // NB b and c are swapped + parse_sort_expr("d", &input_schema), + ]); + + // simply project all the columns in order + let proj_exprs = vec![ + (col("a", &input_schema)?, "a".to_string()), + (col("b", &input_schema)?, "b".to_string()), + (col("c", &input_schema)?, "c".to_string()), + (col("d", &input_schema)?, "d".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let out_properties = input_properties.project(&projection_mapping, input_schema); + + assert_eq!( + out_properties.to_string(), + "order: [[a@0 ASC,c@2 ASC,b@1 ASC,d@3 ASC], [a@0 ASC,b@1 ASC,c@2 ASC,d@3 ASC]]" + ); + + Ok(()) + } + #[test] fn test_join_equivalence_properties() -> Result<()> { let schema = create_test_schema()?; @@ -2714,7 +3157,7 @@ mod tests { } #[test] - fn test_union_equivalence_properties_constants_1() { + fn test_union_equivalence_properties_constants_common_constants() { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -2738,10 +3181,9 @@ mod tests { } #[test] - fn test_union_equivalence_properties_constants_2() { + fn test_union_equivalence_properties_constants_prefix() { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) - // Meet ordering between [a ASC], [a ASC, b ASC] should be [a ASC] .with_child_sort_and_const_exprs( // First child: [a ASC], const [] vec![vec!["a"]], @@ -2763,10 +3205,9 @@ mod tests { } #[test] - fn test_union_equivalence_properties_constants_3() { + fn test_union_equivalence_properties_constants_asc_desc_mismatch() { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) - // Meet ordering between [a ASC], [a DESC] should be [] .with_child_sort_and_const_exprs( // First child: [a ASC], const [] vec![vec!["a"]], @@ -2788,7 +3229,7 @@ mod tests { } #[test] - fn test_union_equivalence_properties_constants_4() { + fn test_union_equivalence_properties_constants_different_schemas() { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); UnionEquivalenceTest::new(&schema) @@ -2805,11 +3246,10 @@ mod tests { &schema2, ) .with_expected_sort_and_const_exprs( - // Union orderings: - // should be [a ASC] + // Union orderings: [a ASC] // - // Where a, and a1 ath the same index for their corresponding - // schemas. + // Note that a, and a1 are at the same index for their + // corresponding schemas. vec![vec!["a"]], vec![], ) @@ -2817,9 +3257,7 @@ mod tests { } #[test] - #[ignore] - // ignored due to https://github.com/apache/datafusion/issues/12446 - fn test_union_equivalence_properties_constants() { + fn test_union_equivalence_properties_constants_fill_gaps() { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -2846,13 +3284,58 @@ mod tests { } #[test] - #[ignore] - // ignored due to https://github.com/apache/datafusion/issues/12446 - fn test_union_equivalence_properties_constants_desc() { + fn test_union_equivalence_properties_constants_no_fill_gaps() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [a ASC, c ASC], const [d] // some other constant + vec![vec!["a", "c"]], + vec!["d"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [b ASC, c ASC], const [a] + vec![vec!["b", "c"]], + vec!["a"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [[a]] (only a is constant) + vec![vec!["a"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_fill_some_gaps() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [c ASC], const [a, b] // some other constant + vec![vec!["c"]], + vec!["a", "b"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [a DESC, b], const [] + vec![vec!["a DESC", "b"]], + vec![], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [[a, b]] (can fill in the a/b with constants) + vec![vec!["a DESC", "b"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( - // NB `b DESC` in the second child // First child orderings: [a ASC, c ASC], const [b] vec![vec!["a", "c"]], vec!["b"], @@ -2876,9 +3359,7 @@ mod tests { } #[test] - #[ignore] - // ignored due to https://github.com/apache/datafusion/issues/12446 - fn test_union_equivalence_properties_constants_middle() { + fn test_union_equivalence_properties_constants_gap_fill_symmetric() { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -2895,8 +3376,8 @@ mod tests { ) .with_expected_sort_and_const_exprs( // Union orderings: - // [a, b, d] (c constant) - // [a, c, d] (b constant) + // [a, b, c, d] + // [a, c, b, d] vec![vec!["a", "c", "b", "d"], vec!["a", "b", "c", "d"]], vec![], ) @@ -2904,8 +3385,31 @@ mod tests { } #[test] - #[ignore] - // ignored due to https://github.com/apache/datafusion/issues/12446 + fn test_union_equivalence_properties_constants_gap_fill_and_common() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child: [a DESC, d ASC], const [b, c] + vec![vec!["a DESC", "d"]], + vec!["b", "c"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [a DESC, c ASC, d ASC], const [b] + vec![vec!["a DESC", "c", "d"]], + vec!["b"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: + // [a DESC, c, d] [b] + vec![vec!["a DESC", "c", "d"]], + vec!["b"], + ) + .run() + } + + #[test] fn test_union_equivalence_properties_constants_middle_desc() { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) @@ -3016,18 +3520,32 @@ mod tests { child_properties, expected_properties, } = self; + let expected_properties = expected_properties.expect("expected_properties not set"); - let actual_properties = - calculate_union(child_properties, Arc::clone(&output_schema)) - .expect("failed to calculate union equivalence properties"); - assert_eq_properties_same( - &actual_properties, - &expected_properties, - format!( - "expected: {expected_properties:?}\nactual: {actual_properties:?}" - ), - ); + + // try all permutations of the children + // as the code treats lhs and rhs differently + for child_properties in child_properties + .iter() + .cloned() + .permutations(child_properties.len()) + { + println!("--- permutation ---"); + for c in &child_properties { + println!("{c}"); + } + let actual_properties = + calculate_union(child_properties, Arc::clone(&output_schema)) + .expect("failed to calculate union equivalence properties"); + assert_eq_properties_same( + &actual_properties, + &expected_properties, + format!( + "expected: {expected_properties:?}\nactual: {actual_properties:?}" + ), + ); + } } /// Make equivalence properties for the specified columns named in orderings and constants diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 236b24dd4094..47b04a876b37 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -27,9 +27,7 @@ use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; use arrow::compute::kernels::cmp::*; -use arrow::compute::kernels::comparison::{ - regexp_is_match_utf8, regexp_is_match_utf8_scalar, -}; +use arrow::compute::kernels::comparison::{regexp_is_match, regexp_is_match_scalar}; use arrow::compute::kernels::concat_elements::concat_elements_utf8; use arrow::compute::{cast, ilike, like, nilike, nlike}; use arrow::datatypes::*; @@ -179,7 +177,7 @@ macro_rules! compute_utf8_flag_op { } else { None }; - let mut array = paste::expr! {[<$OP _utf8>]}(&ll, &rr, flag.as_ref())?; + let mut array = $OP(ll, rr, flag.as_ref())?; if $NOT { array = not(&array).unwrap(); } @@ -188,7 +186,9 @@ macro_rules! compute_utf8_flag_op { } macro_rules! binary_string_array_flag_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ + ($LEFT:ident, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ + // This macro is slightly different from binary_string_array_flag_op because, when comparing with a scalar value, + // the query can be optimized in such a way that operands will be dicts, so we need to support it here let result: Result> = match $LEFT.data_type() { DataType::Utf8View | DataType::Utf8 => { compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG) @@ -196,6 +196,27 @@ macro_rules! binary_string_array_flag_op_scalar { DataType::LargeUtf8 => { compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG) }, + DataType::Dictionary(_, _) => { + let values = $LEFT.as_any_dictionary().values(); + + match values.data_type() { + DataType::Utf8View | DataType::Utf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, StringArray, $NOT, $FLAG), + DataType::LargeUtf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG), + other => internal_err!( + "Data type {:?} not supported as a dictionary value type for binary_string_array_flag_op_scalar operation '{}' on string array", + other, stringify!($OP) + ), + }.map( + // downcast_dictionary_array duplicates code per possible key type, so we aim to do all prep work before + |evaluated_values| downcast_dictionary_array! { + $LEFT => { + let unpacked_dict = evaluated_values.take_iter($LEFT.keys().iter().map(|opt| opt.map(|v| v as _))).collect::(); + Arc::new(unpacked_dict) as _ + }, + _ => unreachable!(), + } + ) + }, other => internal_err!( "Data type {:?} not supported for binary_string_array_flag_op_scalar operation '{}' on string array", other, stringify!($OP) @@ -213,20 +234,32 @@ macro_rules! compute_utf8_flag_op_scalar { .downcast_ref::<$ARRAYTYPE>() .expect("compute_utf8_flag_op_scalar failed to downcast array"); - if let ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) = $RIGHT { - let flag = $FLAG.then_some("i"); - let mut array = - paste::expr! {[<$OP _utf8_scalar>]}(&ll, &string_value, flag)?; - if $NOT { - array = not(&array).unwrap(); - } - Ok(Arc::new(array)) - } else { - internal_err!( + let string_value = match $RIGHT { + ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value, + ScalarValue::Dictionary(_, value) => { + match *value { + ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value, + other => return internal_err!( + "compute_utf8_flag_op_scalar failed to cast dictionary value {} for operation '{}'", + other, stringify!($OP) + ) + } + }, + _ => return internal_err!( "compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'", $RIGHT, stringify!($OP) ) + + }; + + let flag = $FLAG.then_some("i"); + let mut array = + paste::expr! {[<$OP _scalar>]}(ll, &string_value, flag)?; + if $NOT { + array = not(&array).unwrap(); } + + Ok(Arc::new(array)) }}; } @@ -431,7 +464,7 @@ impl PhysicalExpr for BinaryExpr { // end-points of its children. Ok(Some(vec![])) } - } else if self.op.is_comparison_operator() { + } else if self.op.supports_propagation() { Ok( propagate_comparison(&self.op, interval, left_interval, right_interval)? .map(|(left, right)| vec![left, right]), diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 58559352d44c..cbab7d0c9d1f 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -73,7 +73,7 @@ impl PhysicalExpr for IsNotNullExpr { let arg = self.arg.evaluate(batch)?; match arg { ColumnarValue::Array(array) => { - let is_not_null = super::is_null::compute_is_not_null(array)?; + let is_not_null = arrow::compute::is_not_null(&array)?; Ok(ColumnarValue::Array(Arc::new(is_not_null))) } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 3cdb49bcab42..1c8597d3fdea 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -20,14 +20,10 @@ use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use arrow::compute; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use arrow_array::{Array, ArrayRef, BooleanArray, Int8Array, UnionArray}; -use arrow_buffer::{BooleanBuffer, ScalarBuffer}; -use arrow_ord::cmp; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; @@ -77,9 +73,9 @@ impl PhysicalExpr for IsNullExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let arg = self.arg.evaluate(batch)?; match arg { - ColumnarValue::Array(array) => { - Ok(ColumnarValue::Array(Arc::new(compute_is_null(array)?))) - } + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new( + arrow::compute::is_null(&array)?, + ))), ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( ScalarValue::Boolean(Some(scalar.is_null())), )), @@ -103,65 +99,6 @@ impl PhysicalExpr for IsNullExpr { } } -/// workaround , -/// this can be replaced with a direct call to `arrow::compute::is_null` once it's fixed. -pub(crate) fn compute_is_null(array: ArrayRef) -> Result { - if let Some(union_array) = array.as_any().downcast_ref::() { - if let Some(offsets) = union_array.offsets() { - dense_union_is_null(union_array, offsets) - } else { - sparse_union_is_null(union_array) - } - } else { - compute::is_null(array.as_ref()).map_err(Into::into) - } -} - -/// workaround , -/// this can be replaced with a direct call to `arrow::compute::is_not_null` once it's fixed. -pub(crate) fn compute_is_not_null(array: ArrayRef) -> Result { - if array.as_any().is::() { - compute::not(&compute_is_null(array)?).map_err(Into::into) - } else { - compute::is_not_null(array.as_ref()).map_err(Into::into) - } -} - -fn dense_union_is_null( - union_array: &UnionArray, - offsets: &ScalarBuffer, -) -> Result { - let child_arrays = (0..union_array.type_names().len()) - .map(|type_id| { - compute::is_null(&union_array.child(type_id as i8)).map_err(Into::into) - }) - .collect::>>()?; - - let buffer: BooleanBuffer = offsets - .iter() - .zip(union_array.type_ids()) - .map(|(offset, type_id)| child_arrays[*type_id as usize].value(*offset as usize)) - .collect(); - - Ok(BooleanArray::new(buffer, None)) -} - -fn sparse_union_is_null(union_array: &UnionArray) -> Result { - let type_ids = Int8Array::new(union_array.type_ids().clone(), None); - - let mut union_is_null = - BooleanArray::new(BooleanBuffer::new_unset(union_array.len()), None); - for type_id in 0..union_array.type_names().len() { - let type_id = type_id as i8; - let union_is_child = cmp::eq(&type_ids, &Int8Array::new_scalar(type_id))?; - let child = union_array.child(type_id); - let child_array_is_null = compute::is_null(&child)?; - let child_is_null = compute::and(&union_is_child, &child_array_is_null)?; - union_is_null = compute::or(&union_is_null, &child_is_null)?; - } - Ok(union_is_null) -} - impl PartialEq for IsNullExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) @@ -184,7 +121,7 @@ mod tests { array::{BooleanArray, StringArray}, datatypes::*, }; - use arrow_array::{Float64Array, Int32Array}; + use arrow_array::{Array, Float64Array, Int32Array, UnionArray}; use arrow_buffer::ScalarBuffer; use datafusion_common::cast::as_boolean_array; @@ -243,8 +180,7 @@ mod tests { let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); - let array_ref = Arc::new(array) as ArrayRef; - let result = compute_is_null(array_ref).unwrap(); + let result = arrow::compute::is_null(&array).unwrap(); let expected = &BooleanArray::from(vec![false, true, false, false, true, true, false]); @@ -272,8 +208,7 @@ mod tests { UnionArray::try_new(union_fields(), type_ids, Some(offsets), children) .unwrap(); - let array_ref = Arc::new(array) as ArrayRef; - let result = compute_is_null(array_ref).unwrap(); + let result = arrow::compute::is_null(&array).unwrap(); let expected = &BooleanArray::from(vec![false, true, false, true, false, true]); assert_eq!(expected, &result); diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 177fd799ae79..e07e11e43199 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -39,7 +39,6 @@ pub use crate::window::cume_dist::{cume_dist, CumeDist}; pub use crate::window::lead_lag::{lag, lead, WindowShift}; pub use crate::window::nth_value::NthValue; pub use crate::window::ntile::Ntile; -pub use crate::window::rank::{dense_rank, percent_rank, rank, Rank, RankType}; pub use crate::PhysicalSortExpr; pub use binary::{binary, similar_to, BinaryExpr}; diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 2aeb05333102..938bdac50f97 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -22,7 +22,6 @@ pub(crate) mod cume_dist; pub(crate) mod lead_lag; pub(crate) mod nth_value; pub(crate) mod ntile; -pub(crate) mod rank; mod sliding_aggregate; mod window_expr; diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs deleted file mode 100644 index fa3d4e487f14..000000000000 --- a/datafusion/physical-expr/src/window/rank.rs +++ /dev/null @@ -1,313 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expression for `rank`, `dense_rank`, and `percent_rank` that can evaluated -//! at runtime during query execution - -use crate::expressions::Column; -use crate::window::window_expr::RankState; -use crate::window::BuiltInWindowFunctionExpr; -use crate::{PhysicalExpr, PhysicalSortExpr}; - -use arrow::array::ArrayRef; -use arrow::array::{Float64Array, UInt64Array}; -use arrow::datatypes::{DataType, Field}; -use arrow_schema::{SchemaRef, SortOptions}; -use datafusion_common::utils::get_row_at_idx; -use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::PartitionEvaluator; - -use std::any::Any; -use std::iter; -use std::ops::Range; -use std::sync::Arc; - -/// Rank calculates the rank in the window function with order by -#[derive(Debug)] -pub struct Rank { - name: String, - rank_type: RankType, - /// Output data type - data_type: DataType, -} - -impl Rank { - /// Get rank_type of the rank in window function with order by - pub fn get_type(&self) -> RankType { - self.rank_type - } -} - -#[derive(Debug, Copy, Clone)] -pub enum RankType { - Basic, - Dense, - Percent, -} - -/// Create a rank window function -pub fn rank(name: String, data_type: &DataType) -> Rank { - Rank { - name, - rank_type: RankType::Basic, - data_type: data_type.clone(), - } -} - -/// Create a dense rank window function -pub fn dense_rank(name: String, data_type: &DataType) -> Rank { - Rank { - name, - rank_type: RankType::Dense, - data_type: data_type.clone(), - } -} - -/// Create a percent rank window function -pub fn percent_rank(name: String, data_type: &DataType) -> Rank { - Rank { - name, - rank_type: RankType::Percent, - data_type: data_type.clone(), - } -} - -impl BuiltInWindowFunctionExpr for Rank { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = false; - Ok(Field::new(self.name(), self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_evaluator(&self) -> Result> { - Ok(Box::new(RankEvaluator { - state: RankState::default(), - rank_type: self.rank_type, - })) - } - - fn get_result_ordering(&self, schema: &SchemaRef) -> Option { - // The built-in RANK window function (in all modes) introduces a new ordering: - schema.column_with_name(self.name()).map(|(idx, field)| { - let expr = Arc::new(Column::new(field.name(), idx)); - let options = SortOptions { - descending: false, - nulls_first: false, - }; // ASC, NULLS LAST - PhysicalSortExpr { expr, options } - }) - } -} - -#[derive(Debug)] -pub(crate) struct RankEvaluator { - state: RankState, - rank_type: RankType, -} - -impl PartitionEvaluator for RankEvaluator { - fn is_causal(&self) -> bool { - matches!(self.rank_type, RankType::Basic | RankType::Dense) - } - - /// Evaluates the window function inside the given range. - fn evaluate( - &mut self, - values: &[ArrayRef], - range: &Range, - ) -> Result { - let row_idx = range.start; - // There is no argument, values are order by column values (where rank is calculated) - let range_columns = values; - let last_rank_data = get_row_at_idx(range_columns, row_idx)?; - let new_rank_encountered = - if let Some(state_last_rank_data) = &self.state.last_rank_data { - // if rank data changes, new rank is encountered - state_last_rank_data != &last_rank_data - } else { - // First rank seen - true - }; - if new_rank_encountered { - self.state.last_rank_data = Some(last_rank_data); - self.state.last_rank_boundary += self.state.current_group_count; - self.state.current_group_count = 1; - self.state.n_rank += 1; - } else { - // data is still in the same rank - self.state.current_group_count += 1; - } - match self.rank_type { - RankType::Basic => Ok(ScalarValue::UInt64(Some( - self.state.last_rank_boundary as u64 + 1, - ))), - RankType::Dense => Ok(ScalarValue::UInt64(Some(self.state.n_rank as u64))), - RankType::Percent => { - exec_err!("Can not execute PERCENT_RANK in a streaming fashion") - } - } - } - - fn evaluate_all_with_rank( - &self, - num_rows: usize, - ranks_in_partition: &[Range], - ) -> Result { - // see https://www.postgresql.org/docs/current/functions-window.html - let result: ArrayRef = match self.rank_type { - RankType::Dense => Arc::new(UInt64Array::from_iter_values( - ranks_in_partition - .iter() - .zip(1u64..) - .flat_map(|(range, rank)| { - let len = range.end - range.start; - iter::repeat(rank).take(len) - }), - )), - RankType::Percent => { - // Returns the relative rank of the current row, that is (rank - 1) / (total partition rows - 1). The value thus ranges from 0 to 1 inclusive. - let denominator = num_rows as f64; - Arc::new(Float64Array::from_iter_values( - ranks_in_partition - .iter() - .scan(0_u64, |acc, range| { - let len = range.end - range.start; - let value = (*acc as f64) / (denominator - 1.0).max(1.0); - let result = iter::repeat(value).take(len); - *acc += len as u64; - Some(result) - }) - .flatten(), - )) - } - RankType::Basic => Arc::new(UInt64Array::from_iter_values( - ranks_in_partition - .iter() - .scan(1_u64, |acc, range| { - let len = range.end - range.start; - let result = iter::repeat(*acc).take(len); - *acc += len as u64; - Some(result) - }) - .flatten(), - )), - }; - Ok(result) - } - - fn supports_bounded_execution(&self) -> bool { - matches!(self.rank_type, RankType::Basic | RankType::Dense) - } - - fn include_rank(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::*; - use datafusion_common::cast::{as_float64_array, as_uint64_array}; - - fn test_with_rank(expr: &Rank, expected: Vec) -> Result<()> { - test_i32_result(expr, vec![0..2, 2..3, 3..6, 6..7, 7..8], expected) - } - - #[allow(clippy::single_range_in_vec_init)] - fn test_without_rank(expr: &Rank, expected: Vec) -> Result<()> { - test_i32_result(expr, vec![0..8], expected) - } - - fn test_f64_result( - expr: &Rank, - num_rows: usize, - ranks: Vec>, - expected: Vec, - ) -> Result<()> { - let result = expr - .create_evaluator()? - .evaluate_all_with_rank(num_rows, &ranks)?; - let result = as_float64_array(&result)?; - let result = result.values(); - assert_eq!(expected, *result); - Ok(()) - } - - fn test_i32_result( - expr: &Rank, - ranks: Vec>, - expected: Vec, - ) -> Result<()> { - let result = expr.create_evaluator()?.evaluate_all_with_rank(8, &ranks)?; - let result = as_uint64_array(&result)?; - let result = result.values(); - assert_eq!(expected, *result); - Ok(()) - } - - #[test] - fn test_dense_rank() -> Result<()> { - let r = dense_rank("arr".into(), &DataType::UInt64); - test_without_rank(&r, vec![1; 8])?; - test_with_rank(&r, vec![1, 1, 2, 3, 3, 3, 4, 5])?; - Ok(()) - } - - #[test] - fn test_rank() -> Result<()> { - let r = rank("arr".into(), &DataType::UInt64); - test_without_rank(&r, vec![1; 8])?; - test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?; - Ok(()) - } - - #[test] - #[allow(clippy::single_range_in_vec_init)] - fn test_percent_rank() -> Result<()> { - let r = percent_rank("arr".into(), &DataType::Float64); - - // empty case - let expected = vec![0.0; 0]; - test_f64_result(&r, 0, vec![0..0; 0], expected)?; - - // singleton case - let expected = vec![0.0]; - test_f64_result(&r, 1, vec![0..1], expected)?; - - // uniform case - let expected = vec![0.0; 7]; - test_f64_result(&r, 7, vec![0..7], expected)?; - - // non-trivial case - let expected = vec![0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5]; - test_f64_result(&r, 7, vec![0..3, 3..7], expected)?; - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 8f6f78df8cb8..46c46fab68c5 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -530,19 +530,6 @@ pub enum WindowFn { Aggregate(Box), } -/// State for the RANK(percent_rank, rank, dense_rank) built-in window function. -#[derive(Debug, Clone, Default)] -pub struct RankState { - /// The last values for rank as these values change, we increase n_rank - pub last_rank_data: Option>, - /// The index where last_rank_boundary is started - pub last_rank_boundary: usize, - /// Keep the number of entries in current rank - pub current_group_count: usize, - /// Rank number kept from the start - pub n_rank: usize, -} - /// Tag to differentiate special use cases of the NTH_VALUE built-in window function. #[derive(Debug, Copy, Clone)] pub enum NthValueKind { diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index acf3eee105d4..e7bf4a80fc45 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -32,9 +32,15 @@ rust-version = { workspace = true } workspace = true [dependencies] +arrow = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-execution = { workspace = true } +datafusion-expr-common = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } itertools = { workspace = true } + +[dev-dependencies] +datafusion-functions-aggregate = { workspace = true } +tokio = { workspace = true } diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index a11b498b955c..fd21362fd3eb 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -20,15 +20,15 @@ use std::sync::Arc; use datafusion_common::config::ConfigOptions; use datafusion_common::scalar::ScalarValue; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; use datafusion_physical_plan::{expressions, ExecutionPlan}; use crate::PhysicalOptimizerRule; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; /// Optimizer that uses available statistics for aggregate functions #[derive(Default, Debug)] @@ -146,3 +146,373 @@ fn take_optimizable_value_from_statistics( let value = agg_expr.fun().value_from_stats(statistics_args); value.map(|val| (val, agg_expr.name().to_string())) } + +#[cfg(test)] +mod tests { + use crate::aggregate_statistics::AggregateStatistics; + use crate::PhysicalOptimizerRule; + use datafusion_common::config::ConfigOptions; + use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; + use datafusion_execution::TaskContext; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_plan::aggregates::AggregateExec; + use datafusion_physical_plan::projection::ProjectionExec; + use datafusion_physical_plan::udaf::AggregateFunctionExpr; + use datafusion_physical_plan::ExecutionPlan; + use std::sync::Arc; + + use datafusion_common::Result; + use datafusion_expr_common::operator::Operator; + + use datafusion_physical_plan::aggregates::PhysicalGroupBy; + use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; + use datafusion_physical_plan::common; + use datafusion_physical_plan::filter::FilterExec; + use datafusion_physical_plan::memory::MemoryExec; + + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::cast::as_int64_array; + use datafusion_physical_expr::expressions::{self, cast}; + use datafusion_physical_plan::aggregates::AggregateMode; + + /// Describe the type of aggregate being tested + pub enum TestAggregate { + /// Testing COUNT(*) type aggregates + CountStar, + + /// Testing for COUNT(column) aggregate + ColumnA(Arc), + } + + impl TestAggregate { + /// Create a new COUNT(*) aggregate + pub fn new_count_star() -> Self { + Self::CountStar + } + + /// Create a new COUNT(column) aggregate + pub fn new_count_column(schema: &Arc) -> Self { + Self::ColumnA(Arc::clone(schema)) + } + + /// Return appropriate expr depending if COUNT is for col or table (*) + pub fn count_expr(&self, schema: &Schema) -> AggregateFunctionExpr { + AggregateExprBuilder::new(count_udaf(), vec![self.column()]) + .schema(Arc::new(schema.clone())) + .alias(self.column_name()) + .build() + .unwrap() + } + + /// what argument would this aggregate need in the plan? + fn column(&self) -> Arc { + match self { + Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION), + Self::ColumnA(s) => expressions::col("a", s).unwrap(), + } + } + + /// What name would this aggregate produce in a plan? + pub fn column_name(&self) -> &'static str { + match self { + Self::CountStar => "COUNT(*)", + Self::ColumnA(_) => "COUNT(a)", + } + } + + /// What is the expected count? + pub fn expected_count(&self) -> i64 { + match self { + TestAggregate::CountStar => 3, + TestAggregate::ColumnA(_) => 2, + } + } + } + + /// Mock data using a MemoryExec which has an exact count statistic + fn mock_data() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2), None])), + Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])), + ], + )?; + + Ok(Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?)) + } + + /// Checks that the count optimization was applied and we still get the right result + async fn assert_count_optim_success( + plan: AggregateExec, + agg: TestAggregate, + ) -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let plan: Arc = Arc::new(plan); + + let config = ConfigOptions::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::clone(&plan), &config)?; + + // A ProjectionExec is a sign that the count optimization was applied + assert!(optimized.as_any().is::()); + + // run both the optimized and nonoptimized plan + let optimized_result = + common::collect(optimized.execute(0, Arc::clone(&task_ctx))?).await?; + let nonoptimized_result = common::collect(plan.execute(0, task_ctx)?).await?; + assert_eq!(optimized_result.len(), nonoptimized_result.len()); + + // and validate the results are the same and expected + assert_eq!(optimized_result.len(), 1); + check_batch(optimized_result.into_iter().next().unwrap(), &agg); + // check the non optimized one too to ensure types and names remain the same + assert_eq!(nonoptimized_result.len(), 1); + check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg); + + Ok(()) + } + + fn check_batch(batch: RecordBatch, agg: &TestAggregate) { + let schema = batch.schema(); + let fields = schema.fields(); + assert_eq!(fields.len(), 1); + + let field = &fields[0]; + assert_eq!(field.name(), agg.column_name()); + assert_eq!(field.data_type(), &DataType::Int64); + // note that nullabiolity differs + + assert_eq!( + as_int64_array(batch.column(0)).unwrap().values(), + &[agg.expected_count()] + ); + } + + #[tokio::test] + async fn test_count_partial_direct_child() -> Result<()> { + // basic test case with the aggregation applied on a source with exact statistics + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + source, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_partial_with_nulls_direct_child() -> Result<()> { + // basic test case with the aggregation applied on a source with exact statistics + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + source, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_partial_indirect_child() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + source, + Arc::clone(&schema), + )?; + + // We introduce an intermediate optimization step between the partial and final aggregtator + let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + Arc::new(coalesce), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + source, + Arc::clone(&schema), + )?; + + // We introduce an intermediate optimization step between the partial and final aggregtator + let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + Arc::new(coalesce), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, agg).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_inexact_stat() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + // adding a filter makes the statistics inexact + let filter = Arc::new(FilterExec::try_new( + expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + cast(expressions::lit(1u32), &schema, DataType::Int32)?, + &schema, + )?, + source, + )?); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + filter, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + let conf = ConfigOptions::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; + + // check that the original ExecutionPlan was not replaced + assert!(optimized.as_any().is::()); + + Ok(()) + } + + #[tokio::test] + async fn test_count_with_nulls_inexact_stat() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); + + // adding a filter makes the statistics inexact + let filter = Arc::new(FilterExec::try_new( + expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + cast(expressions::lit(1u32), &schema, DataType::Int32)?, + &schema, + )?, + source, + )?); + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + filter, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![agg.count_expr(&schema)], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + let conf = ConfigOptions::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; + + // check that the original ExecutionPlan was not replaced + assert!(optimized.as_any().is::()); + + Ok(()) + } +} diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index 804dd165d335..c8a28ed0ec0b 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -19,21 +19,17 @@ use std::sync::Arc; -use datafusion_physical_plan::aggregates::AggregateExec; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion_physical_plan::filter::FilterExec; -use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::sorts::sort::SortExec; -use datafusion_physical_plan::ExecutionPlan; - -use arrow_schema::DataType; +use crate::PhysicalOptimizerRule; +use arrow::datatypes::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::PhysicalSortExpr; - -use crate::PhysicalOptimizerRule; +use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::execution_plan::CardinalityEffect; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::ExecutionPlan; use itertools::Itertools; /// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed @@ -48,12 +44,13 @@ impl TopKAggregation { fn transform_agg( aggr: &AggregateExec, - order: &PhysicalSortExpr, + order_by: &str, + order_desc: bool, limit: usize, ) -> Option> { // ensure the sort direction matches aggregate function let (field, desc) = aggr.get_minmax_desc()?; - if desc != order.options.descending { + if desc != order_desc { return None; } let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?; @@ -66,8 +63,7 @@ impl TopKAggregation { } // ensure the sort is on the same field as the aggregate output - let col = order.expr.as_any().downcast_ref::()?; - if col.name() != field.name() { + if order_by != field.name() { return None; } @@ -92,16 +88,11 @@ impl TopKAggregation { let child = children.into_iter().exactly_one().ok()?; let order = sort.properties().output_ordering()?; let order = order.iter().exactly_one().ok()?; + let order_desc = order.options.descending; + let order = order.expr.as_any().downcast_ref::()?; + let mut cur_col_name = order.name().to_string(); let limit = sort.fetch()?; - let is_cardinality_preserving = |plan: Arc| { - plan.as_any() - .downcast_ref::() - .is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - }; - let mut cardinality_preserved = true; let closure = |plan: Arc| { if !cardinality_preserved { @@ -109,14 +100,27 @@ impl TopKAggregation { } if let Some(aggr) = plan.as_any().downcast_ref::() { // either we run into an Aggregate and transform it - match Self::transform_agg(aggr, order, limit) { + match Self::transform_agg(aggr, &cur_col_name, order_desc, limit) { None => cardinality_preserved = false, Some(plan) => return Ok(Transformed::yes(plan)), } + } else if let Some(proj) = plan.as_any().downcast_ref::() { + // track renames due to successive projections + for (src_expr, proj_name) in proj.expr() { + let Some(src_col) = src_expr.as_any().downcast_ref::() else { + continue; + }; + if *proj_name == cur_col_name { + cur_col_name = src_col.name().to_string(); + } + } } else { - // or we continue down whitelisted nodes of other types - if !is_cardinality_preserving(Arc::clone(&plan)) { - cardinality_preserved = false; + // or we continue down through types that don't reduce cardinality + match plan.cardinality_effect() { + CardinalityEffect::Equal | CardinalityEffect::GreaterEqual => {} + CardinalityEffect::Unknown | CardinalityEffect::LowerEqual => { + cardinality_preserved = false; + } } } Ok(Transformed::no(plan)) diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index c3f1b7eb0e95..7fcd719539ec 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -51,7 +51,6 @@ datafusion-common = { workspace = true, default-features = true } datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-functions-aggregate = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } @@ -69,6 +68,7 @@ rand = { workspace = true } tokio = { workspace = true } [dev-dependencies] +datafusion-functions-aggregate = { workspace = true } rstest = { workspace = true } rstest_reuse = "0.7.0" tokio = { workspace = true, features = [ diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs index 15c93262968e..5d00f300e960 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs @@ -91,15 +91,17 @@ impl GroupColumn for PrimitiveGroupValueBuilder { fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { - // Perf: skip null check (by short circuit) if input is not ullable - let null_match = if NULLABLE { - self.nulls.is_null(lhs_row) == array.is_null(rhs_row) - } else { - true - }; + // Perf: skip null check (by short circuit) if input is not nullable + if NULLABLE { + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + // Otherwise, we need to check their values + } - null_match - && self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) + self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) } fn append_val(&mut self, array: &ArrayRef, row: usize) { @@ -211,9 +213,14 @@ where where B: ByteArrayType, { - let arr = array.as_bytes::(); - self.nulls.is_null(lhs_row) == arr.is_null(rhs_row) - && self.value(lhs_row) == (arr.value(rhs_row).as_ref() as &[u8]) + let array = array.as_bytes::(); + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + // Otherwise, we need to check their values + self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8]) } /// return the current value of the specified row irrespective of null @@ -369,13 +376,31 @@ where } } +/// Determines if the nullability of the existing and new input array can be used +/// to short-circuit the comparison of the two values. +/// +/// Returns `Some(result)` if the result of the comparison can be determined +/// from the nullness of the two values, and `None` if the comparison must be +/// done on the values themselves. +fn nulls_equal_to(lhs_null: bool, rhs_null: bool) -> Option { + match (lhs_null, rhs_null) { + (true, true) => Some(true), + (false, true) | (true, false) => Some(false), + _ => None, + } +} + #[cfg(test)] mod tests { use std::sync::Arc; - use arrow_array::{ArrayRef, StringArray}; + use arrow::datatypes::Int64Type; + use arrow_array::{ArrayRef, Int64Array, StringArray}; + use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; use datafusion_physical_expr::binary_map::OutputType; + use crate::aggregates::group_values::group_column::PrimitiveGroupValueBuilder; + use super::{ByteGroupValueBuilder, GroupColumn}; #[test] @@ -422,4 +447,136 @@ mod tests { ])) as ArrayRef; assert_eq!(&output, &array); } + + #[test] + fn test_nullable_primitive_equal_to() { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; values not equal + // - exist not null, input not null; values equal + + // Define PrimitiveGroupValueBuilder + let mut builder = PrimitiveGroupValueBuilder::::new(); + let builder_array = Arc::new(Int64Array::from(vec![ + None, + None, + None, + Some(1), + Some(2), + Some(3), + ])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + builder.append_val(&builder_array, 2); + builder.append_val(&builder_array, 3); + builder.append_val(&builder_array, 4); + builder.append_val(&builder_array, 5); + + // Define input array + let (_nulls, values, _) = + Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some(2) to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = Arc::new(Int64Array::new(values, Some(nulls))) as ArrayRef; + + // Check + assert!(!builder.equal_to(0, &input_array, 0)); + assert!(builder.equal_to(1, &input_array, 1)); + assert!(builder.equal_to(2, &input_array, 2)); + assert!(!builder.equal_to(3, &input_array, 3)); + assert!(!builder.equal_to(4, &input_array, 4)); + assert!(builder.equal_to(5, &input_array, 5)); + } + + #[test] + fn test_not_nullable_primitive_equal_to() { + // Will cover such cases: + // - values equal + // - values not equal + + // Define PrimitiveGroupValueBuilder + let mut builder = PrimitiveGroupValueBuilder::::new(); + let builder_array = + Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + + // Define input array + let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef; + + // Check + assert!(builder.equal_to(0, &input_array, 0)); + assert!(!builder.equal_to(1, &input_array, 1)); + } + + #[test] + fn test_byte_array_equal_to() { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; values not equal + // - exist not null, input not null; values equal + + // Define PrimitiveGroupValueBuilder + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + let builder_array = Arc::new(StringArray::from(vec![ + None, + None, + None, + Some("foo"), + Some("bar"), + Some("baz"), + ])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + builder.append_val(&builder_array, 2); + builder.append_val(&builder_array, 3); + builder.append_val(&builder_array, 4); + builder.append_val(&builder_array, 5); + + // Define input array + let (offsets, buffer, _nulls) = StringArray::from(vec![ + Some("foo"), + Some("bar"), + None, + None, + Some("foo"), + Some("baz"), + ]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some("bar") to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = + Arc::new(StringArray::new(offsets, buffer, Some(nulls))) as ArrayRef; + + // Check + assert!(!builder.equal_to(0, &input_array, 0)); + assert!(builder.equal_to(1, &input_array, 1)); + assert!(builder.equal_to(2, &input_array, 2)); + assert!(!builder.equal_to(3, &input_array, 3)); + assert!(!builder.equal_to(4, &input_array, 4)); + assert!(builder.equal_to(5, &input_array, 5)); + } } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 9466ff6dd459..d6f16fb0fdd3 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -36,10 +36,11 @@ use crate::{ use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_array::{UInt16Array, UInt32Array, UInt64Array, UInt8Array}; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_execution::TaskContext; -use datafusion_expr::Accumulator; +use datafusion_expr::{Accumulator, Aggregate}; use datafusion_physical_expr::{ equivalence::{collapse_lex_req, ProjectionMapping}, expressions::Column, @@ -47,6 +48,7 @@ use datafusion_physical_expr::{ PhysicalExpr, PhysicalSortRequirement, }; +use crate::execution_plan::CardinalityEffect; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use itertools::Itertools; @@ -211,13 +213,99 @@ impl PhysicalGroupBy { .collect() } + /// The number of expressions in the output schema. + fn num_output_exprs(&self) -> usize { + let mut num_exprs = self.expr.len(); + if !self.is_single() { + num_exprs += 1 + } + num_exprs + } + /// Return grouping expressions as they occur in the output schema. pub fn output_exprs(&self) -> Vec> { - self.expr - .iter() - .enumerate() - .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _) - .collect() + let num_output_exprs = self.num_output_exprs(); + let mut output_exprs = Vec::with_capacity(num_output_exprs); + output_exprs.extend( + self.expr + .iter() + .enumerate() + .take(num_output_exprs) + .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _), + ); + if !self.is_single() { + output_exprs.push(Arc::new(Column::new( + Aggregate::INTERNAL_GROUPING_ID, + self.expr.len(), + )) as _); + } + output_exprs + } + + /// Returns the number expression as grouping keys. + fn num_group_exprs(&self) -> usize { + if self.is_single() { + self.expr.len() + } else { + self.expr.len() + 1 + } + } + + /// Returns the fields that are used as the grouping keys. + fn group_fields(&self, input_schema: &Schema) -> Result> { + let mut fields = Vec::with_capacity(self.num_group_exprs()); + for ((expr, name), group_expr_nullable) in + self.expr.iter().zip(self.exprs_nullable().into_iter()) + { + fields.push( + Field::new( + name, + expr.data_type(input_schema)?, + group_expr_nullable || expr.nullable(input_schema)?, + ) + .with_metadata( + get_field_metadata(expr, input_schema).unwrap_or_default(), + ), + ); + } + if !self.is_single() { + fields.push(Field::new( + Aggregate::INTERNAL_GROUPING_ID, + Aggregate::grouping_id_type(self.expr.len()), + false, + )); + } + Ok(fields) + } + + /// Returns the output fields of the group by. + /// + /// This might be different from the `group_fields` that might contain internal expressions that + /// should not be part of the output schema. + fn output_fields(&self, input_schema: &Schema) -> Result> { + let mut fields = self.group_fields(input_schema)?; + fields.truncate(self.num_output_exprs()); + Ok(fields) + } + + /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial + /// aggregation. + pub fn as_final(&self) -> PhysicalGroupBy { + let expr: Vec<_> = + self.output_exprs() + .into_iter() + .zip( + self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once( + Aggregate::INTERNAL_GROUPING_ID.to_owned(), + )), + ) + .collect(); + let num_exprs = expr.len(); + Self { + expr, + null_expr: vec![], + groups: vec![vec![false; num_exprs]], + } } } @@ -321,13 +409,7 @@ impl AggregateExec { input: Arc, input_schema: SchemaRef, ) -> Result { - let schema = create_schema( - &input.schema(), - &group_by.expr, - &aggr_expr, - group_by.exprs_nullable(), - mode, - )?; + let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?; let schema = Arc::new(schema); AggregateExec::try_new_with_schema( @@ -785,29 +867,20 @@ impl ExecutionPlan for AggregateExec { } } } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual + } } fn create_schema( input_schema: &Schema, - group_expr: &[(Arc, String)], + group_by: &PhysicalGroupBy, aggr_expr: &[AggregateFunctionExpr], - group_expr_nullable: Vec, mode: AggregateMode, ) -> Result { - let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len()); - for (index, (expr, name)) in group_expr.iter().enumerate() { - fields.push( - Field::new( - name, - expr.data_type(input_schema)?, - // In cases where we have multiple grouping sets, we will use NULL expressions in - // order to align the grouping sets. So the field must be nullable even if the underlying - // schema field is not. - group_expr_nullable[index] || expr.nullable(input_schema)?, - ) - .with_metadata(get_field_metadata(expr, input_schema).unwrap_or_default()), - ) - } + let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len()); + fields.extend(group_by.output_fields(input_schema)?); match mode { AggregateMode::Partial => { @@ -833,9 +906,8 @@ fn create_schema( )) } -fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { - let group_fields = schema.fields()[0..group_count].to_vec(); - Arc::new(Schema::new(group_fields)) +fn group_schema(input_schema: &Schema, group_by: &PhysicalGroupBy) -> Result { + Ok(Arc::new(Schema::new(group_by.group_fields(input_schema)?))) } /// Determines the lexical ordering requirement for an aggregate expression. @@ -1142,6 +1214,27 @@ fn evaluate_optional( .collect() } +fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result { + if group.len() > 64 { + return not_impl_err!( + "Grouping sets with more than 64 columns are not supported" + ); + } + let group_id = group.iter().fold(0u64, |acc, &is_null| { + (acc << 1) | if is_null { 1 } else { 0 } + }); + let num_rows = batch.num_rows(); + if group.len() <= 8 { + Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows]))) + } else if group.len() <= 16 { + Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows]))) + } else if group.len() <= 32 { + Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows]))) + } else { + Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows]))) + } +} + /// Evaluate a group by expression against a `RecordBatch` /// /// Arguments: @@ -1174,23 +1267,24 @@ pub(crate) fn evaluate_group_by( }) .collect::>>()?; - Ok(group_by + group_by .groups .iter() .map(|group| { - group - .iter() - .enumerate() - .map(|(idx, is_null)| { - if *is_null { - Arc::clone(&null_exprs[idx]) - } else { - Arc::clone(&exprs[idx]) - } - }) - .collect() + let mut group_values = Vec::with_capacity(group_by.num_group_exprs()); + group_values.extend(group.iter().enumerate().map(|(idx, is_null)| { + if *is_null { + Arc::clone(&null_exprs[idx]) + } else { + Arc::clone(&exprs[idx]) + } + })); + if !group_by.is_single() { + group_values.push(group_id_array(group, batch)?); + } + Ok(group_values) }) - .collect()) + .collect() } #[cfg(test)] @@ -1348,21 +1442,21 @@ mod tests { ) -> Result<()> { let input_schema = input.schema(); - let grouping_set = PhysicalGroupBy { - expr: vec![ + let grouping_set = PhysicalGroupBy::new( + vec![ (col("a", &input_schema)?, "a".to_string()), (col("b", &input_schema)?, "b".to_string()), ], - null_expr: vec![ + vec![ (lit(ScalarValue::UInt32(None)), "a".to_string()), (lit(ScalarValue::Float64(None)), "b".to_string()), ], - groups: vec![ + vec![ vec![false, true], // (a, NULL) vec![true, false], // (NULL, b) vec![false, false], // (a,b) ], - }; + ); let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) .schema(Arc::clone(&input_schema)) @@ -1392,63 +1486,56 @@ mod tests { // In spill mode, we test with the limited memory, if the mem usage exceeds, // we trigger the early emit rule, which turns out the partial aggregate result. vec![ - "+---+-----+-----------------+", - "| a | b | COUNT(1)[count] |", - "+---+-----+-----------------+", - "| | 1.0 | 1 |", - "| | 1.0 | 1 |", - "| | 2.0 | 1 |", - "| | 2.0 | 1 |", - "| | 3.0 | 1 |", - "| | 3.0 | 1 |", - "| | 4.0 | 1 |", - "| | 4.0 | 1 |", - "| 2 | | 1 |", - "| 2 | | 1 |", - "| 2 | 1.0 | 1 |", - "| 2 | 1.0 | 1 |", - "| 3 | | 1 |", - "| 3 | | 2 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 1 |", - "| 4 | | 2 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+-----------------+", + "+---+-----+---------------+-----------------+", + "| a | b | __grouping_id | COUNT(1)[count] |", + "+---+-----+---------------+-----------------+", + "| | 1.0 | 2 | 1 |", + "| | 1.0 | 2 | 1 |", + "| | 2.0 | 2 | 1 |", + "| | 2.0 | 2 | 1 |", + "| | 3.0 | 2 | 1 |", + "| | 3.0 | 2 | 1 |", + "| | 4.0 | 2 | 1 |", + "| | 4.0 | 2 | 1 |", + "| 2 | | 1 | 1 |", + "| 2 | | 1 | 1 |", + "| 2 | 1.0 | 0 | 1 |", + "| 2 | 1.0 | 0 | 1 |", + "| 3 | | 1 | 1 |", + "| 3 | | 1 | 2 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 1 |", + "| 4 | | 1 | 2 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+-----------------+", ] } else { vec![ - "+---+-----+-----------------+", - "| a | b | COUNT(1)[count] |", - "+---+-----+-----------------+", - "| | 1.0 | 2 |", - "| | 2.0 | 2 |", - "| | 3.0 | 2 |", - "| | 4.0 | 2 |", - "| 2 | | 2 |", - "| 2 | 1.0 | 2 |", - "| 3 | | 3 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 3 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+-----------------+", + "+---+-----+---------------+-----------------+", + "| a | b | __grouping_id | COUNT(1)[count] |", + "+---+-----+---------------+-----------------+", + "| | 1.0 | 2 | 2 |", + "| | 2.0 | 2 | 2 |", + "| | 3.0 | 2 | 2 |", + "| | 4.0 | 2 | 2 |", + "| 2 | | 1 | 2 |", + "| 2 | 1.0 | 0 | 2 |", + "| 3 | | 1 | 3 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 3 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+-----------------+", ] }; assert_batches_sorted_eq!(expected, &result); - let groups = partial_aggregate.group_expr().expr().to_vec(); - let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); - let final_group: Vec<(Arc, String)> = groups - .iter() - .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone()))) - .collect::>()?; - - let final_grouping_set = PhysicalGroupBy::new_single(final_group); + let final_grouping_set = grouping_set.as_final(); let task_ctx = if spill { new_spill_ctx(4, 3160) @@ -1468,26 +1555,26 @@ mod tests { let result = common::collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let batch = concat_batches(&result[0].schema(), &result)?; - assert_eq!(batch.num_columns(), 3); + assert_eq!(batch.num_columns(), 4); assert_eq!(batch.num_rows(), 12); let expected = vec![ - "+---+-----+----------+", - "| a | b | COUNT(1) |", - "+---+-----+----------+", - "| | 1.0 | 2 |", - "| | 2.0 | 2 |", - "| | 3.0 | 2 |", - "| | 4.0 | 2 |", - "| 2 | | 2 |", - "| 2 | 1.0 | 2 |", - "| 3 | | 3 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 3 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+----------+", + "+---+-----+---------------+----------+", + "| a | b | __grouping_id | COUNT(1) |", + "+---+-----+---------------+----------+", + "| | 1.0 | 2 | 2 |", + "| | 2.0 | 2 | 2 |", + "| | 3.0 | 2 | 2 |", + "| | 4.0 | 2 | 2 |", + "| 2 | | 1 | 2 |", + "| 2 | 1.0 | 0 | 2 |", + "| 3 | | 1 | 3 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 3 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+----------+", ]; assert_batches_sorted_eq!(&expected, &result); @@ -1503,11 +1590,11 @@ mod tests { async fn check_aggregates(input: Arc, spill: bool) -> Result<()> { let input_schema = input.schema(); - let grouping_set = PhysicalGroupBy { - expr: vec![(col("a", &input_schema)?, "a".to_string())], - null_expr: vec![], - groups: vec![vec![false]], - }; + let grouping_set = PhysicalGroupBy::new( + vec![(col("a", &input_schema)?, "a".to_string())], + vec![], + vec![vec![false]], + ); let aggregates: Vec = vec![ @@ -1563,13 +1650,7 @@ mod tests { let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); - let final_group: Vec<(Arc, String)> = grouping_set - .expr - .iter() - .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone()))) - .collect::>()?; - - let final_grouping_set = PhysicalGroupBy::new_single(final_group); + let final_grouping_set = grouping_set.as_final(); let merged_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, @@ -1825,11 +1906,11 @@ mod tests { let task_ctx = Arc::new(task_ctx); let groups_none = PhysicalGroupBy::default(); - let groups_some = PhysicalGroupBy { - expr: vec![(col("a", &input_schema)?, "a".to_string())], - null_expr: vec![], - groups: vec![vec![false]], - }; + let groups_some = PhysicalGroupBy::new( + vec![(col("a", &input_schema)?, "a".to_string())], + vec![], + vec![vec![false]], + ); // something that allocates within the aggregator let aggregates_v0: Vec = @@ -2306,7 +2387,7 @@ mod tests { )?); let aggregate_exec = Arc::new(AggregateExec::try_new( - AggregateMode::Partial, + AggregateMode::Single, groups, aggregates.clone(), vec![None], @@ -2318,13 +2399,13 @@ mod tests { collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?; let expected = [ - "+-----+-----+-------+----------+", - "| a | b | const | 1[count] |", - "+-----+-----+-------+----------+", - "| | 0.0 | | 32768 |", - "| 0.0 | | | 32768 |", - "| | | 1 | 32768 |", - "+-----+-----+-------+----------+", + "+-----+-----+-------+---------------+-------+", + "| a | b | const | __grouping_id | 1 |", + "+-----+-----+-------+---------------+-------+", + "| | | 1 | 6 | 32768 |", + "| | 0.0 | | 5 | 32768 |", + "| 0.0 | | | 3 | 32768 |", + "+-----+-----+-------+---------------+-------+", ]; assert_batches_sorted_eq!(expected, &output); @@ -2638,30 +2719,30 @@ mod tests { .build()?, ]; - let grouping_set = PhysicalGroupBy { - expr: vec![ + let grouping_set = PhysicalGroupBy::new( + vec![ (col("a", &input_schema)?, "a".to_string()), (col("b", &input_schema)?, "b".to_string()), ], - null_expr: vec![ + vec![ (lit(ScalarValue::Float32(None)), "a".to_string()), (lit(ScalarValue::Float32(None)), "b".to_string()), ], - groups: vec![ + vec![ vec![false, true], // (a, NULL) vec![false, false], // (a,b) ], - }; + ); let aggr_schema = create_schema( &input_schema, - &grouping_set.expr, + &grouping_set, &aggr_expr, - grouping_set.exprs_nullable(), AggregateMode::Final, )?; let expected_schema = Schema::new(vec![ Field::new("a", DataType::Float32, false), Field::new("b", DataType::Float32, true), + Field::new("__grouping_id", DataType::UInt8, false), Field::new("COUNT(a)", DataType::Int64, false), ]); assert_eq!(aggr_schema, expected_schema); diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 9e4968f1123e..5121e6cc3b35 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -449,13 +449,13 @@ impl GroupedHashAggregateStream { let aggregate_arguments = aggregates::aggregate_expressions( &agg.aggr_expr, &agg.mode, - agg_group_by.expr.len(), + agg_group_by.num_group_exprs(), )?; // arguments for aggregating spilled data is the same as the one for final aggregation let merging_aggregate_arguments = aggregates::aggregate_expressions( &agg.aggr_expr, &AggregateMode::Final, - agg_group_by.expr.len(), + agg_group_by.num_group_exprs(), )?; let filter_expressions = match agg.mode { @@ -473,7 +473,7 @@ impl GroupedHashAggregateStream { .map(create_group_accumulator) .collect::>()?; - let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); + let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?; let spill_expr = group_schema .fields .into_iter() diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index 7caf5b8ab65a..e1a2f32d8a38 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -34,6 +34,7 @@ use datafusion_common::Result; use datafusion_execution::TaskContext; use crate::coalesce::{BatchCoalescer, CoalescerState}; +use crate::execution_plan::CardinalityEffect; use futures::ready; use futures::stream::{Stream, StreamExt}; @@ -199,6 +200,10 @@ impl ExecutionPlan for CoalesceBatchesExec { fn fetch(&self) -> Option { self.fetch } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } } /// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details. diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 486ae41901db..2ab6e3de1add 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -30,6 +30,7 @@ use super::{ use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; +use crate::execution_plan::CardinalityEffect; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; @@ -178,6 +179,10 @@ impl ExecutionPlan for CoalescePartitionsExec { fn supports_limit_pushdown(&self) -> bool { true } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } } #[cfg(test)] diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index b14021f4a99b..a89e265ad2f8 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -416,6 +416,11 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { fn fetch(&self) -> Option { None } + + /// Gets the effect on cardinality, if known + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Unknown + } } /// Extension trait provides an easy API to fetch various properties of @@ -898,6 +903,20 @@ pub fn get_plan_string(plan: &Arc) -> Vec { actual.iter().map(|elem| elem.to_string()).collect() } +/// Indicates the effect an execution plan operator will have on the cardinality +/// of its input stream +pub enum CardinalityEffect { + /// Unknown effect. This is the default + Unknown, + /// The operator is guaranteed to produce exactly one row for + /// each input row + Equal, + /// The operator may produce fewer output rows than it receives input rows + LowerEqual, + /// The operator may produce more output rows than it receives input rows + GreaterEqual, +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 417d2098b083..c39a91e251b7 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -48,6 +48,7 @@ use datafusion_physical_expr::{ analyze, split_conjunction, AnalysisContext, ConstExpr, ExprBoundaries, PhysicalExpr, }; +use crate::execution_plan::CardinalityEffect; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -372,6 +373,10 @@ impl ExecutionPlan for FilterExec { fn statistics(&self) -> Result { Self::statistics_helper(&self.input, self.predicate(), self.default_selectivity) } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual + } } /// This function ensures that all bounds in the `ExprBoundaries` vector are diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 48d648c89a35..74a45a7e4761 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -415,6 +415,12 @@ impl HashJoinExec { &self.join_type } + /// The schema after join. Please be careful when using this schema, + /// if there is a projection, the schema isn't the same as the output schema. + pub fn join_schema(&self) -> &SchemaRef { + &self.join_schema + } + /// The partitioning mode of this hash join pub fn partition_mode(&self) -> &PartitionMode { &self.mode diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 360e942226d2..a42e2da60587 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -34,6 +34,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; +use crate::execution_plan::CardinalityEffect; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -336,6 +337,10 @@ impl ExecutionPlan for LocalLimitExec { fn supports_limit_pushdown(&self) -> bool { true } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual + } } /// A Limit stream skips `skip` rows, and then fetch up to `fetch` rows. diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 4c889d1fc88c..49bf05964226 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -42,6 +42,7 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::Literal; +use crate::execution_plan::CardinalityEffect; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -233,6 +234,10 @@ impl ExecutionPlan for ProjectionExec { fn supports_limit_pushdown(&self) -> bool { true } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } } /// If e is a direct column reference, returns the field level diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index f0f198319ee3..d9368cf86d45 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -48,6 +48,7 @@ use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr}; +use crate::execution_plan::CardinalityEffect; use futures::stream::Stream; use futures::{FutureExt, StreamExt, TryStreamExt}; use hashbrown::HashMap; @@ -669,6 +670,10 @@ impl ExecutionPlan for RepartitionExec { fn statistics(&self) -> Result { self.input.statistics() } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } } impl RepartitionExec { diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 50f6f4a93097..5d86c2183b9e 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -55,6 +55,7 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::LexOrdering; use datafusion_physical_expr_common::sort_expr::PhysicalSortRequirement; +use crate::execution_plan::CardinalityEffect; use futures::{StreamExt, TryStreamExt}; use log::{debug, trace}; @@ -972,6 +973,14 @@ impl ExecutionPlan for SortExec { fn fetch(&self) -> Option { self.fetch } + + fn cardinality_effect(&self) -> CardinalityEffect { + if self.fetch.is_none() { + CardinalityEffect::Equal + } else { + CardinalityEffect::LowerEqual + } + } } #[cfg(test)] diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 6aafaad0ad77..6f7d95bf95f5 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -21,10 +21,7 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::{ - expressions::{ - cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, - PhysicalSortExpr, - }, + expressions::{cume_dist, lag, lead, Literal, NthValue, Ntile, PhysicalSortExpr}, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, }; @@ -52,6 +49,7 @@ mod window_agg_exec; pub use bounded_window_agg_exec::BoundedWindowAggExec; use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_physical_expr::expressions::Column; pub use datafusion_physical_expr::window::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowExpr, @@ -230,9 +228,6 @@ fn create_built_in_window_expr( let out_data_type: &DataType = input_schema.field_with_name(&name)?.data_type(); Ok(match fun { - BuiltInWindowFunction::Rank => Arc::new(rank(name, out_data_type)), - BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, out_data_type)), - BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, out_data_type)), BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, out_data_type)), BuiltInWindowFunction::Ntile => { let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { @@ -385,7 +380,13 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn create_evaluator(&self) -> Result> { - self.fun.partition_evaluator_factory() + self.fun + .partition_evaluator_factory(PartitionEvaluatorArgs::new( + &self.args, + &self.input_types, + self.is_reversed, + self.ignore_nulls, + )) } fn name(&self) -> &str { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index e36c91e7d004..803cb49919ee 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -508,9 +508,9 @@ message ScalarUDFExprNode { enum BuiltInWindowFunction { UNSPECIFIED = 0; // https://protobuf.dev/programming-guides/dos-donts/#unspecified-enum // ROW_NUMBER = 0; - RANK = 1; - DENSE_RANK = 2; - PERCENT_RANK = 3; + // RANK = 1; + // DENSE_RANK = 2; + // PERCENT_RANK = 3; CUME_DIST = 4; NTILE = 5; LAG = 6; @@ -739,7 +739,7 @@ message FileSinkConfig { datafusion_common.Schema output_schema = 4; repeated PartitionColumn table_partition_cols = 5; bool keep_partition_by_columns = 9; - InsertOp insert_op = 10; + InsertOp insert_op = 10; } enum InsertOp { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 004798b3ba93..c7d4c4561a1b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1662,9 +1662,6 @@ impl serde::Serialize for BuiltInWindowFunction { { let variant = match self { Self::Unspecified => "UNSPECIFIED", - Self::Rank => "RANK", - Self::DenseRank => "DENSE_RANK", - Self::PercentRank => "PERCENT_RANK", Self::CumeDist => "CUME_DIST", Self::Ntile => "NTILE", Self::Lag => "LAG", @@ -1684,9 +1681,6 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { { const FIELDS: &[&str] = &[ "UNSPECIFIED", - "RANK", - "DENSE_RANK", - "PERCENT_RANK", "CUME_DIST", "NTILE", "LAG", @@ -1735,9 +1729,6 @@ impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { { match value { "UNSPECIFIED" => Ok(BuiltInWindowFunction::Unspecified), - "RANK" => Ok(BuiltInWindowFunction::Rank), - "DENSE_RANK" => Ok(BuiltInWindowFunction::DenseRank), - "PERCENT_RANK" => Ok(BuiltInWindowFunction::PercentRank), "CUME_DIST" => Ok(BuiltInWindowFunction::CumeDist), "NTILE" => Ok(BuiltInWindowFunction::Ntile), "LAG" => Ok(BuiltInWindowFunction::Lag), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 436347330d92..8cba6f84f7eb 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1819,9 +1819,9 @@ pub enum BuiltInWindowFunction { /// Unspecified = 0, /// ROW_NUMBER = 0; - Rank = 1, - DenseRank = 2, - PercentRank = 3, + /// RANK = 1; + /// DENSE_RANK = 2; + /// PERCENT_RANK = 3; CumeDist = 4, Ntile = 5, Lag = 6, @@ -1838,9 +1838,6 @@ impl BuiltInWindowFunction { pub fn as_str_name(&self) -> &'static str { match self { Self::Unspecified => "UNSPECIFIED", - Self::Rank => "RANK", - Self::DenseRank => "DENSE_RANK", - Self::PercentRank => "PERCENT_RANK", Self::CumeDist => "CUME_DIST", Self::Ntile => "NTILE", Self::Lag => "LAG", @@ -1854,9 +1851,6 @@ impl BuiltInWindowFunction { pub fn from_str_name(value: &str) -> ::core::option::Option { match value { "UNSPECIFIED" => Some(Self::Unspecified), - "RANK" => Some(Self::Rank), - "DENSE_RANK" => Some(Self::DenseRank), - "PERCENT_RANK" => Some(Self::PercentRank), "CUME_DIST" => Some(Self::CumeDist), "NTILE" => Some(Self::Ntile), "LAG" => Some(Self::Lag), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 893255ccc8ce..32e1b68203ce 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -142,9 +142,6 @@ impl From for BuiltInWindowFunction { fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self { match built_in_function { protobuf::BuiltInWindowFunction::Unspecified => todo!(), - protobuf::BuiltInWindowFunction::Rank => Self::Rank, - protobuf::BuiltInWindowFunction::PercentRank => Self::PercentRank, - protobuf::BuiltInWindowFunction::DenseRank => Self::DenseRank, protobuf::BuiltInWindowFunction::Lag => Self::Lag, protobuf::BuiltInWindowFunction::Lead => Self::Lead, protobuf::BuiltInWindowFunction::FirstValue => Self::FirstValue, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 63d1a007c1e5..07823b422f71 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -119,11 +119,8 @@ impl From<&BuiltInWindowFunction> for protobuf::BuiltInWindowFunction { BuiltInWindowFunction::NthValue => Self::NthValue, BuiltInWindowFunction::Ntile => Self::Ntile, BuiltInWindowFunction::CumeDist => Self::CumeDist, - BuiltInWindowFunction::PercentRank => Self::PercentRank, - BuiltInWindowFunction::Rank => Self::Rank, BuiltInWindowFunction::Lag => Self::Lag, BuiltInWindowFunction::Lead => Self::Lead, - BuiltInWindowFunction::DenseRank => Self::DenseRank, } } } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 6f6065a1c284..85d4fe8a16d0 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -24,8 +24,8 @@ use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, InListExpr, IsNotNullExpr, - IsNullExpr, Literal, NegativeExpr, NotExpr, NthValue, Ntile, Rank, RankType, - TryCastExpr, WindowShift, + IsNullExpr, Literal, NegativeExpr, NotExpr, NthValue, Ntile, TryCastExpr, + WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -109,14 +109,7 @@ pub fn serialize_physical_window_expr( let expr = built_in_window_expr.get_built_in_func_expr(); let built_in_fn_expr = expr.as_any(); - let builtin_fn = if let Some(rank_expr) = built_in_fn_expr.downcast_ref::() - { - match rank_expr.get_type() { - RankType::Basic => protobuf::BuiltInWindowFunction::Rank, - RankType::Dense => protobuf::BuiltInWindowFunction::DenseRank, - RankType::Percent => protobuf::BuiltInWindowFunction::PercentRank, - } - } else if built_in_fn_expr.downcast_ref::().is_some() { + let builtin_fn = if built_in_fn_expr.downcast_ref::().is_some() { protobuf::BuiltInWindowFunction::CumeDist } else if let Some(ntile_expr) = built_in_fn_expr.downcast_ref::() { args.insert( diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index cd789e06dc3b..8c9b368598de 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -47,6 +47,9 @@ use datafusion::functions_aggregate::expr_fn::{ }; use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_nested::map::map; +use datafusion::functions_window::dense_rank::dense_rank; +use datafusion::functions_window::percent_rank::percent_rank; +use datafusion::functions_window::rank::{rank, rank_udwf}; use datafusion::functions_window::row_number::row_number; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -75,6 +78,7 @@ use datafusion_functions_aggregate::expr_fn::{ }; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, @@ -937,6 +941,9 @@ async fn roundtrip_expr_api() -> Result<()> { vec![lit(10), lit(20), lit(30)], ), row_number(), + rank(), + dense_rank(), + percent_rank(), nth_value(col("b"), 1, vec![]), nth_value( col("b"), @@ -2304,9 +2311,7 @@ fn roundtrip_window() { // 1. without window_frame let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::BuiltInWindowFunction( - datafusion_expr::BuiltInWindowFunction::Rank, - ), + WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) .partition_by(vec![col("col1")]) @@ -2317,9 +2322,7 @@ fn roundtrip_window() { // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::BuiltInWindowFunction( - datafusion_expr::BuiltInWindowFunction::Rank, - ), + WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) .partition_by(vec![col("col1")]) @@ -2336,9 +2339,7 @@ fn roundtrip_window() { ); let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::BuiltInWindowFunction( - datafusion_expr::BuiltInWindowFunction::Rank, - ), + WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) .partition_by(vec![col("col1")]) @@ -2459,7 +2460,10 @@ fn roundtrip_window() { &self.signature } - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { make_partition_evaluator() } diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 4c380f0b37a3..c288d6ca7067 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -98,8 +98,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - // Each recursive CTE consists from two parts in the logical plan: - // 1. A static term (the left hand side on the SQL, where the + // Each recursive CTE consists of two parts in the logical plan: + // 1. A static term (the left-hand side on the SQL, where the // referencing to the same CTE is not allowed) // // 2. A recursive term (the right hand side, and the recursive diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 20a772cdd088..619eadcf0fb8 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -237,7 +237,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - // user-defined function (UDF) should have precedence + // User-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args))); @@ -260,12 +260,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ); } - // then, window function + // Then, window function if let Some(WindowType::WindowSpec(window)) = over { let partition_by = window .partition_by .into_iter() - // ignore window spec PARTITION BY for scalar values + // Ignore window spec PARTITION BY for scalar values // as they do not change and thus do not generate new partitions .filter(|e| !matches!(e, sqlparser::ast::Expr::Value { .. },)) .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) @@ -383,7 +383,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, name: &str, ) -> Result { - // check udaf first + // Check udaf first let udaf = self.context_provider.get_aggregate_meta(name); // Use the builtin window function instead of the user-defined aggregate function if udaf.as_ref().is_some_and(|udaf| { @@ -434,7 +434,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }), FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(object_name)) => { let qualifier = self.object_name_to_table_reference(object_name)?; - // sanity check on qualifier with schema + // Sanity check on qualifier with schema let qualified_indices = schema.fields_indices_with_qualified(&qualifier); if qualified_indices.is_empty() { return plan_err!("Invalid qualifier {qualifier}"); diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index b2c8312709ef..e103f68fc927 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -115,9 +115,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let search_result = search_dfschema(&ids, schema); match search_result { - // found matching field with spare identifier(s) for nested field(s) in structure + // Found matching field with spare identifier(s) for nested field(s) in structure Some((field, qualifier, nested_names)) if !nested_names.is_empty() => { - // found matching field with spare identifier(s) for nested field(s) in structure + // Found matching field with spare identifier(s) for nested field(s) in structure for planner in self.context_provider.get_expr_planners() { if let Ok(planner_result) = planner.plan_compound_identifier( field, @@ -126,9 +126,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) { match planner_result { PlannerResult::Planned(expr) => { - // sanity check on column - schema - .check_ambiguous_name(qualifier, field.name())?; return Ok(expr); } PlannerResult::Original(_args) => {} @@ -137,23 +134,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } plan_err!("could not parse compound identifier from {ids:?}") } - // found matching field with no spare identifier(s) + // Found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { - // sanity check on column - schema.check_ambiguous_name(qualifier, field.name())?; Ok(Expr::Column(Column::from((qualifier, field)))) } None => { - // return default where use all identifiers to not have a nested field + // Return default where use all identifiers to not have a nested field // this len check is because at 5 identifiers will have to have a nested field if ids.len() == 5 { not_impl_err!("compound identifier: {ids:?}") } else { - // check the outer_query_schema and try to find a match + // Check the outer_query_schema and try to find a match if let Some(outer) = planner_context.outer_query_schema() { let search_result = search_dfschema(&ids, outer); match search_result { - // found matching field with spare identifier(s) for nested field(s) in structure + // Found matching field with spare identifier(s) for nested field(s) in structure Some((field, qualifier, nested_names)) if !nested_names.is_empty() => { @@ -163,15 +158,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Column::from((qualifier, field)).quoted_flat_name() ) } - // found matching field with no spare identifier(s) + // Found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { - // found an exact match on a qualified name in the outer plan schema, so this is an outer reference column + // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column Ok(Expr::OuterReferenceColumn( field.data_type().clone(), Column::from((qualifier, field)), )) } - // found no matching field, will return a default + // Found no matching field, will return a default None => { let s = &ids[0..ids.len()]; // safe unwrap as s can never be empty or exceed the bounds @@ -182,11 +177,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } else { let s = &ids[0..ids.len()]; - // safe unwrap as s can never be empty or exceed the bounds + // Safe unwrap as s can never be empty or exceed the bounds let (relation, column_name) = form_identifier(s).unwrap(); - // sanity check on column - schema - .check_ambiguous_name(relation.as_ref(), column_name)?; Ok(Expr::Column(Column::new(relation, column_name))) } } @@ -319,15 +311,15 @@ fn search_dfschema<'ids, 'schema>( fn generate_schema_search_terms( ids: &[String], ) -> impl Iterator, &String, &[String])> { - // take at most 4 identifiers to form a Column to search with + // Take at most 4 identifiers to form a Column to search with // - 1 for the column name // - 0 to 3 for the TableReference let bound = ids.len().min(4); - // search terms from most specific to least specific + // Search terms from most specific to least specific (0..bound).rev().map(|i| { let nested_names_index = i + 1; let qualifier_and_column = &ids[0..nested_names_index]; - // safe unwrap as qualifier_and_column can never be empty or exceed the bounds + // Safe unwrap as qualifier_and_column can never be empty or exceed the bounds let (relation, column_name) = form_identifier(qualifier_and_column).unwrap(); (relation, column_name, &ids[nested_names_index..]) }) @@ -339,7 +331,7 @@ mod test { #[test] // testing according to documentation of generate_schema_search_terms function - // where ensure generated search terms are in correct order with correct values + // where it ensures generated search terms are in correct order with correct values fn test_generate_schema_search_terms() -> Result<()> { type ExpectedItem = ( Option, diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index 6a3a4d6ccbb7..00289806876f 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -102,7 +102,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { expr_vec.push(Sort::new( expr, asc, - // when asc is true, by default nulls last to be consistent with postgres + // When asc is true, by default nulls last to be consistent with postgres // postgres rule: https://www.postgresql.org/docs/current/queries-order.html nulls_first.unwrap_or(!asc), )) diff --git a/datafusion/sql/src/expr/unary_op.rs b/datafusion/sql/src/expr/unary_op.rs index 2a341fb7c446..3c547050380d 100644 --- a/datafusion/sql/src/expr/unary_op.rs +++ b/datafusion/sql/src/expr/unary_op.rs @@ -37,7 +37,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } UnaryOperator::Minus => { match expr { - // optimization: if it's a number literal, we apply the negative operator + // Optimization: if it's a number literal, we apply the negative operator // here directly to calculate the new literal. SQLExpr::Value(Value::Number(n, _)) => { self.parse_sql_number(&n, true) @@ -45,7 +45,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Interval(interval) => { self.sql_interval_to_expr(true, interval) } - // not a literal, apply negative operator on expression + // Not a literal, apply negative operator on expression _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr( expr, schema, diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index be0909b58468..7dc15de7ad71 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -235,7 +235,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let value = interval_literal(*interval.value, negative)?; // leading_field really means the unit if specified - // for example, "month" in `INTERVAL '5' month` + // For example, "month" in `INTERVAL '5' month` let value = match interval.leading_field.as_ref() { Some(leading_field) => format!("{value} {leading_field}"), None => value, @@ -323,9 +323,9 @@ const fn try_decode_hex_char(c: u8) -> Option { fn parse_decimal_128(unsigned_number: &str, negative: bool) -> Result { // remove leading zeroes let trimmed = unsigned_number.trim_start_matches('0'); - // parse precision and scale, remove decimal point if exists + // Parse precision and scale, remove decimal point if exists let (precision, scale, replaced_str) = if trimmed == "." { - // special cases for numbers such as “0.”, “000.”, and so on. + // Special cases for numbers such as “0.”, “000.”, and so on. (1, 0, Cow::Borrowed("0")) } else if let Some(i) = trimmed.find('.') { ( @@ -334,7 +334,7 @@ fn parse_decimal_128(unsigned_number: &str, negative: bool) -> Result { Cow::Owned(trimmed.replace('.', "")), ) } else { - // no decimal point, keep as is + // No decimal point, keep as is (trimmed.len(), 0, Cow::Borrowed(trimmed)) }; @@ -344,7 +344,7 @@ fn parse_decimal_128(unsigned_number: &str, negative: bool) -> Result { ))) })?; - // check precision overflow + // Check precision overflow if precision as u8 > DECIMAL128_MAX_PRECISION { return Err(DataFusionError::from(ParserError(format!( "Cannot parse {replaced_str} as i128 when building decimal: precision overflow" diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 6d130647a49f..a68d8491856d 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -761,10 +761,10 @@ impl<'a> DFParser<'a> { // Note that mixing both names and definitions is not allowed let peeked = self.parser.peek_nth_token(2); if peeked == Token::Comma || peeked == Token::RParen { - // list of column names + // List of column names builder.table_partition_cols = Some(self.parse_partitions()?) } else { - // list of column defs + // List of column defs let (cols, cons) = self.parse_columns()?; builder.table_partition_cols = Some( cols.iter().map(|col| col.name.to_string()).collect(), @@ -850,7 +850,7 @@ impl<'a> DFParser<'a> { options.push((key, value)); let comma = self.parser.consume_token(&Token::Comma); if self.parser.consume_token(&Token::RParen) { - // allow a trailing comma, even though it's not in standard + // Allow a trailing comma, even though it's not in standard break; } else if !comma { return self.expected( diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index e8defedddf2c..66e360a9ade9 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -117,7 +117,7 @@ impl ValueNormalizer { /// CTEs, Views, subqueries and PREPARE statements. The states include /// Common Table Expression (CTE) provided with WITH clause and /// Parameter Data Types provided with PREPARE statement and the query schema of the -/// outer query plan +/// outer query plan. /// /// # Cloning /// @@ -166,12 +166,12 @@ impl PlannerContext { self } - // return a reference to the outer queries schema + // Return a reference to the outer query's schema pub fn outer_query_schema(&self) -> Option<&DFSchema> { self.outer_query_schema.as_ref().map(|s| s.as_ref()) } - /// sets the outer query schema, returning the existing one, if + /// Sets the outer query schema, returning the existing one, if /// any pub fn set_outer_query_schema( &mut self, @@ -181,12 +181,12 @@ impl PlannerContext { schema } - // return a clone of the outer FROM schema + // Return a clone of the outer FROM schema pub fn outer_from_schema(&self) -> Option> { self.outer_from_schema.clone() } - /// sets the outer FROM schema, returning the existing one, if any + /// Sets the outer FROM schema, returning the existing one, if any pub fn set_outer_from_schema( &mut self, mut schema: Option, @@ -195,7 +195,7 @@ impl PlannerContext { schema } - /// extends the FROM schema, returning the existing one, if any + /// Extends the FROM schema, returning the existing one, if any pub fn extend_outer_from_schema(&mut self, schema: &DFSchemaRef) -> Result<()> { match self.outer_from_schema.as_mut() { Some(from_schema) => Arc::make_mut(from_schema).merge(schema), @@ -209,7 +209,7 @@ impl PlannerContext { &self.prepare_param_data_types } - /// returns true if there is a Common Table Expression (CTE) / + /// Returns true if there is a Common Table Expression (CTE) / /// Subquery for the specified name pub fn contains_cte(&self, cte_name: &str) -> bool { self.ctes.contains_key(cte_name) @@ -387,12 +387,26 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { match sql_type { - SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) - | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_sql_type, _)) => { + SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) => { // Arrays may be multi-dimensional. let inner_data_type = self.convert_data_type(inner_sql_type)?; Ok(DataType::new_list(inner_data_type, true)) } + SQLDataType::Array(ArrayElemTypeDef::SquareBracket( + inner_sql_type, + maybe_array_size, + )) => { + let inner_data_type = self.convert_data_type(inner_sql_type)?; + if let Some(array_size) = maybe_array_size { + Ok(DataType::new_fixed_size_list( + inner_data_type, + *array_size as i32, + true, + )) + } else { + Ok(DataType::new_list(inner_data_type, true)) + } + } SQLDataType::Array(ArrayElemTypeDef::None) => { not_impl_err!("Arrays with unspecified type is not supported") } @@ -506,9 +520,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::CharVarying(_) | SQLDataType::CharacterLargeObject(_) | SQLDataType::CharLargeObject(_) - // precision is not supported + // Precision is not supported | SQLDataType::Timestamp(Some(_), _) - // precision is not supported + // Precision is not supported | SQLDataType::Time(Some(_), _) | SQLDataType::Dec(_) | SQLDataType::BigNumeric(_) @@ -572,7 +586,7 @@ pub fn object_name_to_table_reference( object_name: ObjectName, enable_normalization: bool, ) -> Result { - // use destructure to make it clear no fields on ObjectName are ignored + // Use destructure to make it clear no fields on ObjectName are ignored let ObjectName(idents) = object_name; idents_to_table_reference(idents, enable_normalization) } @@ -583,7 +597,7 @@ pub(crate) fn idents_to_table_reference( enable_normalization: bool, ) -> Result { struct IdentTaker(Vec); - /// take the next identifier from the back of idents, panic'ing if + /// Take the next identifier from the back of idents, panic'ing if /// there are none left impl IdentTaker { fn take(&mut self, enable_normalization: bool) -> String { diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 71328cfd018c..125259d2276f 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -205,7 +205,7 @@ fn convert_usize_with_check(n: i64, arg_name: &str) -> Result { /// Returns the order by expressions from the query. fn to_order_by_exprs(order_by: Option) -> Result> { let Some(OrderBy { exprs, interpolate }) = order_by else { - // if no order by, return an empty array + // If no order by, return an empty array. return Ok(vec![]); }; if let Some(_interpolate) = interpolate { diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index f8ebb04f3810..256cc58e71dc 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -70,7 +70,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .build()?; (plan, alias) } else { - // normalize name and alias + // Normalize name and alias let table_ref = self.object_name_to_table_reference(name)?; let table_name = table_ref.to_string(); let cte = planner_context.get_cte(&table_name); @@ -163,7 +163,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { subquery: TableFactor, planner_context: &mut PlannerContext, ) -> Result { - // At this point for a syntacitally valid query the outer_from_schema is + // At this point for a syntactically valid query the outer_from_schema is // guaranteed to be set, so the `.unwrap()` call will never panic. This // is the case because we only call this method for lateral table // factors, and those can never be the first factor in a FROM list. This diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index c93d9e6fc435..c029fe2a23d8 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -52,7 +52,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { order_by: Vec, planner_context: &mut PlannerContext, ) -> Result { - // check for unsupported syntax first + // Check for unsupported syntax first if !select.cluster_by.is_empty() { return not_impl_err!("CLUSTER BY"); } @@ -69,17 +69,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return not_impl_err!("SORT BY"); } - // process `from` clause + // Process `from` clause let plan = self.plan_from_tables(select.from, planner_context)?; let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); - // process `where` clause + // Process `where` clause let base_plan = self.plan_selection(select.selection, plan, planner_context)?; - // handle named windows before processing the projection expression + // Handle named windows before processing the projection expression check_conflicting_windows(&select.named_window)?; match_window_definitions(&mut select.projection, &select.named_window)?; - // process the SELECT expressions + // Process the SELECT expressions let select_exprs = self.prepare_select_exprs( &base_plan, select.projection, @@ -87,7 +87,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, )?; - // having and group by clause may reference aliases defined in select projection + // Having and group by clause may reference aliases defined in select projection let projected_plan = self.project(base_plan.clone(), select_exprs.clone())?; // Place the fields of the base plan at the front so that when there are references @@ -107,7 +107,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?; let order_by_rex = normalize_sorts(order_by_rex, &projected_plan)?; - // this alias map is resolved and looked up in both having exprs and group by exprs + // This alias map is resolved and looked up in both having exprs and group by exprs let alias_map = extract_aliases(&select_exprs); // Optionally the HAVING expression. @@ -159,7 +159,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, )?; - // aliases from the projection can conflict with same-named expressions in the input + // Aliases from the projection can conflict with same-named expressions in the input let mut alias_map = alias_map.clone(); for f in base_plan.schema().fields() { alias_map.remove(f.name()); @@ -192,7 +192,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect() }; - // process group by, aggregation or having + // Process group by, aggregation or having let (plan, mut select_exprs_post_aggr, having_expr_post_aggr) = if !group_by_exprs .is_empty() || !aggr_exprs.is_empty() @@ -219,7 +219,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; - // process window function + // Process window function let window_func_exprs = find_window_exprs(&select_exprs_post_aggr); let plan = if window_func_exprs.is_empty() { @@ -227,7 +227,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { let plan = LogicalPlanBuilder::window_plan(plan, window_func_exprs.clone())?; - // re-write the projection + // Re-write the projection select_exprs_post_aggr = select_exprs_post_aggr .iter() .map(|expr| rebase_expr(expr, &window_func_exprs, &plan)) @@ -236,10 +236,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; - // try process unnest expression or do the final projection + // Try processing unnest expression or do the final projection let plan = self.try_process_unnest(plan, select_exprs_post_aggr)?; - // process distinct clause + // Process distinct clause let plan = match select.distinct { None => Ok(plan), Some(Distinct::Distinct) => { @@ -304,7 +304,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Each expr in select_exprs can contains multiple unnest stage // The transformation happen bottom up, one at a time for each iteration - // Only exaust the loop if no more unnest transformation is found + // Only exhaust the loop if no more unnest transformation is found for i in 0.. { let mut unnest_columns = vec![]; // from which column used for projection, before the unnest happen @@ -390,7 +390,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .. } = agg; - // process unnest of group_by_exprs, and input of agg will be rewritten + // Process unnest of group_by_exprs, and input of agg will be rewritten // for example: // // ``` diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 656d72d07ba2..3111fab9a2ff 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -878,14 +878,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => None, }; - // at the moment functions can't be qualified `schema.name` + // At the moment functions can't be qualified `schema.name` let name = match &name.0[..] { [] => exec_err!("Function should have name")?, [n] => n.value.clone(), [..] => not_impl_err!("Qualified functions are not supported")?, }; // - // convert resulting expression to data fusion expression + // Convert resulting expression to data fusion expression // let arg_types = args.as_ref().map(|arg| { arg.iter().map(|t| t.data_type.clone()).collect::>() @@ -933,10 +933,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { func_desc, .. } => { - // according to postgresql documentation it can be only one function + // According to postgresql documentation it can be only one function // specified in drop statement if let Some(desc) = func_desc.first() { - // at the moment functions can't be qualified `schema.name` + // At the moment functions can't be qualified `schema.name` let name = match &desc.name.0[..] { [] => exec_err!("Function should have name")?, [n] => n.value.clone(), @@ -1028,7 +1028,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { filter: Option, ) -> Result { if self.has_table("information_schema", "tables") { - // we only support the basic "SHOW TABLES" + // We only support the basic "SHOW TABLES" // https://github.com/apache/datafusion/issues/3188 if db_name.is_some() || filter.is_some() || full || extended { plan_err!("Unsupported parameters to SHOW TABLES") @@ -1059,7 +1059,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } fn copy_to_plan(&self, statement: CopyToStatement) -> Result { - // determine if source is table or query and handle accordingly + // Determine if source is table or query and handle accordingly let copy_source = statement.source; let (input, input_schema, table_ref) = match copy_source { CopyToSource::Relation(object_name) => { @@ -1100,7 +1100,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .to_string(), ) }; - // try to infer file format from file extension + // Try to infer file format from file extension let extension: &str = &Path::new(&statement.target) .extension() .ok_or_else(e)? @@ -1406,11 +1406,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut variable_lower = variable.to_lowercase(); if variable_lower == "timezone" || variable_lower == "time.zone" { - // we could introduce alias in OptionDefinition if this string matching thing grows + // We could introduce alias in OptionDefinition if this string matching thing grows variable_lower = "datafusion.execution.time_zone".to_string(); } - // parse value string from Expr + // Parse value string from Expr let value_string = match &value[0] { SQLExpr::Identifier(i) => ident_to_string(i), SQLExpr::Value(v) => match crate::utils::value_to_string(v) { @@ -1419,7 +1419,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } Some(v) => v, }, - // for capture signed number e.g. +8, -8 + // For capture signed number e.g. +8, -8 SQLExpr::UnaryOp { op, expr } => match op { UnaryOperator::Plus => format!("+{expr}"), UnaryOperator::Minus => format!("-{expr}"), @@ -1614,10 +1614,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Get insert fields and target table's value indices // - // if value_indices[i] = Some(j), it means that the value of the i-th target table's column is + // If value_indices[i] = Some(j), it means that the value of the i-th target table's column is // derived from the j-th output of the source. // - // if value_indices[i] = None, it means that the value of the i-th target table's column is + // If value_indices[i] = None, it means that the value of the i-th target table's column is // not provided, and should be filled with a default value later. let (fields, value_indices) = if columns.is_empty() { // Empty means we're inserting into all columns of the table @@ -1749,7 +1749,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let table_ref = self.object_name_to_table_reference(sql_table_name)?; let _ = self.context_provider.get_table_source(table_ref)?; - // treat both FULL and EXTENDED as the same + // Treat both FULL and EXTENDED as the same let select_list = if full || extended { "*" } else { diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index d8a4fb254264..aef3b0dfabbc 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -18,12 +18,17 @@ use std::sync::Arc; use arrow_schema::TimeUnit; +use datafusion_expr::Expr; use regex::Regex; use sqlparser::{ - ast::{self, Ident, ObjectName, TimezoneInfo}, + ast::{self, Function, Ident, ObjectName, TimezoneInfo}, keywords::ALL_KEYWORDS, }; +use datafusion_common::Result; + +use super::{utils::date_part_to_sql, Unparser}; + /// `Dialect` to use for Unparsing /// /// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`) @@ -108,6 +113,18 @@ pub trait Dialect: Send + Sync { fn supports_column_alias_in_table_alias(&self) -> bool { true } + + /// Allows the dialect to override scalar function unparsing if the dialect has specific rules. + /// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is + /// a custom implementation for the function. + fn scalar_function_to_sql_overrides( + &self, + _unparser: &Unparser, + _func_name: &str, + _args: &[Expr], + ) -> Result> { + Ok(None) + } } /// `IntervalStyle` to use for unparsing @@ -145,7 +162,7 @@ impl Dialect for DefaultDialect { fn identifier_quote_style(&self, identifier: &str) -> Option { let identifier_regex = Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_]*$").unwrap(); let id_upper = identifier.to_uppercase(); - // special case ignore "ID", see https://github.com/sqlparser-rs/sqlparser-rs/issues/1382 + // Special case ignore "ID", see https://github.com/sqlparser-rs/sqlparser-rs/issues/1382 // ID is a keyword in ClickHouse, but we don't want to quote it when unparsing SQL here if (id_upper != "ID" && ALL_KEYWORDS.contains(&id_upper.as_str())) || !identifier_regex.is_match(identifier) @@ -171,6 +188,67 @@ impl Dialect for PostgreSqlDialect { fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { sqlparser::ast::DataType::DoublePrecision } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "round" { + return Ok(Some( + self.round_to_sql_enforce_numeric(unparser, func_name, args)?, + )); + } + + Ok(None) + } +} + +impl PostgreSqlDialect { + fn round_to_sql_enforce_numeric( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result { + let mut args = unparser.function_args_to_sql(args)?; + + // Enforce the first argument to be Numeric + if let Some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(expr))) = + args.first_mut() + { + if let ast::Expr::Cast { data_type, .. } = expr { + // Don't create an additional cast wrapper if we can update the existing one + *data_type = ast::DataType::Numeric(ast::ExactNumberInfo::None); + } else { + // Wrap the expression in a new cast + *expr = ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(expr.clone()), + data_type: ast::DataType::Numeric(ast::ExactNumberInfo::None), + format: None, + }; + } + } + + Ok(ast::Expr::Function(Function { + name: ast::ObjectName(vec![Ident { + value: func_name.to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + })) + } } pub struct MySqlDialect {} @@ -211,6 +289,19 @@ impl Dialect for MySqlDialect { ) -> ast::DataType { ast::DataType::Datetime(None) } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "date_part" { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + + Ok(None) + } } pub struct SqliteDialect {} @@ -231,6 +322,19 @@ impl Dialect for SqliteDialect { fn supports_column_alias_in_table_alias(&self) -> bool { false } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "date_part" { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + + Ok(None) + } } pub struct CustomDialect { @@ -273,7 +377,7 @@ impl Default for CustomDialect { } impl CustomDialect { - // create a CustomDialect + // Create a CustomDialect #[deprecated(note = "please use `CustomDialectBuilder` instead")] pub fn new(identifier_quote_style: Option) -> Self { Self { @@ -339,6 +443,19 @@ impl Dialect for CustomDialect { fn supports_column_alias_in_table_alias(&self) -> bool { self.supports_column_alias_in_table_alias } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "date_part" { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + + Ok(None) + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern @@ -424,7 +541,7 @@ impl CustomDialectBuilder { self } - /// Customize the dialect to supports `NULLS FIRST` in `ORDER BY` clauses + /// Customize the dialect to support `NULLS FIRST` in `ORDER BY` clauses pub fn with_supports_nulls_first_in_sort( mut self, supports_nulls_first_in_sort: bool, @@ -503,7 +620,7 @@ impl CustomDialectBuilder { self } - /// Customize the dialect to supports column aliases as part of alias table definition + /// Customize the dialect to support column aliases as part of alias table definition pub fn with_supports_column_alias_in_table_alias( mut self, supports_column_alias_in_table_alias: bool, diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index b924268a7657..b7491d1f88ce 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,16 +15,15 @@ // specific language governing permissions and limitations // under the License. -use datafusion_expr::ScalarUDF; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, Interval, - ObjectName, TimezoneInfo, UnaryOperator, + self, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, + TimezoneInfo, UnaryOperator, }; use std::sync::Arc; use std::vec; -use super::dialect::{DateFieldExtractStyle, IntervalStyle}; +use super::dialect::IntervalStyle; use super::Unparser; use arrow::datatypes::{Decimal128Type, Decimal256Type, DecimalType}; use arrow::util::display::array_value_to_string; @@ -83,7 +82,7 @@ pub fn sort_to_sql(sort: &Sort) -> Result { } const LOWEST: &BinaryOperator = &BinaryOperator::Or; -// closest precedence we have to IS operator is BitwiseAnd (any other) in PG docs +// Closest precedence we have to IS operator is BitwiseAnd (any other) in PG docs // (https://www.postgresql.org/docs/7.2/sql-precedence.html) const IS: &BinaryOperator = &BinaryOperator::BitwiseAnd; @@ -116,47 +115,14 @@ impl Unparser<'_> { Expr::ScalarFunction(ScalarFunction { func, args }) => { let func_name = func.name(); - if let Some(expr) = - self.scalar_function_to_sql_overrides(func_name, func, args) + if let Some(expr) = self + .dialect + .scalar_function_to_sql_overrides(self, func_name, args)? { return Ok(expr); } - let args = args - .iter() - .map(|e| { - if matches!( - e, - Expr::Wildcard { - qualifier: None, - .. - } - ) { - Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) - } else { - self.expr_to_sql_inner(e).map(|e| { - FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) - }) - } - }) - .collect::>>()?; - - Ok(ast::Expr::Function(Function { - name: ast::ObjectName(vec![Ident { - value: func_name.to_string(), - quote_style: None, - }]), - args: ast::FunctionArguments::List(ast::FunctionArgumentList { - duplicate_treatment: None, - args, - clauses: vec![], - }), - filter: None, - null_treatment: None, - over: None, - within_group: vec![], - parameters: ast::FunctionArguments::None, - })) + self.scalar_function_to_sql(func_name, args) } Expr::Between(Between { expr, @@ -508,6 +474,30 @@ impl Unparser<'_> { } } + pub fn scalar_function_to_sql( + &self, + func_name: &str, + args: &[Expr], + ) -> Result { + let args = self.function_args_to_sql(args)?; + Ok(ast::Expr::Function(Function { + name: ast::ObjectName(vec![Ident { + value: func_name.to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + })) + } + pub fn sort_to_sql(&self, sort: &Sort) -> Result { let Sort { expr, @@ -530,87 +520,6 @@ impl Unparser<'_> { }) } - fn scalar_function_to_sql_overrides( - &self, - func_name: &str, - _func: &Arc, - args: &[Expr], - ) -> Option { - if func_name.to_lowercase() == "date_part" { - match (self.dialect.date_field_extract_style(), args.len()) { - (DateFieldExtractStyle::Extract, 2) => { - let date_expr = self.expr_to_sql(&args[1]).ok()?; - - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &args[0] { - let field = match field.to_lowercase().as_str() { - "year" => ast::DateTimeField::Year, - "month" => ast::DateTimeField::Month, - "day" => ast::DateTimeField::Day, - "hour" => ast::DateTimeField::Hour, - "minute" => ast::DateTimeField::Minute, - "second" => ast::DateTimeField::Second, - _ => return None, - }; - - return Some(ast::Expr::Extract { - field, - expr: Box::new(date_expr), - syntax: ast::ExtractSyntax::From, - }); - } - } - (DateFieldExtractStyle::Strftime, 2) => { - let column = self.expr_to_sql(&args[1]).ok()?; - - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &args[0] { - let field = match field.to_lowercase().as_str() { - "year" => "%Y", - "month" => "%m", - "day" => "%d", - "hour" => "%H", - "minute" => "%M", - "second" => "%S", - _ => return None, - }; - - return Some(ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident { - value: "strftime".to_string(), - quote_style: None, - }]), - args: ast::FunctionArguments::List( - ast::FunctionArgumentList { - duplicate_treatment: None, - args: vec![ - ast::FunctionArg::Unnamed( - ast::FunctionArgExpr::Expr(ast::Expr::Value( - ast::Value::SingleQuotedString( - field.to_string(), - ), - )), - ), - ast::FunctionArg::Unnamed( - ast::FunctionArgExpr::Expr(column), - ), - ], - clauses: vec![], - }, - ), - filter: None, - null_treatment: None, - over: None, - within_group: vec![], - parameters: ast::FunctionArguments::None, - })); - } - } - _ => {} // no overrides for DateFieldExtractStyle::DatePart, because it's already a date_part - } - } - - None - } - fn ast_type_for_date64_in_cast(&self) -> ast::DataType { if self.dialect.use_timestamp_for_date64() { ast::DataType::Timestamp(None, ast::TimezoneInfo::None) @@ -665,7 +574,10 @@ impl Unparser<'_> { } } - fn function_args_to_sql(&self, args: &[Expr]) -> Result> { + pub(crate) fn function_args_to_sql( + &self, + args: &[Expr], + ) -> Result> { args.iter() .map(|e| { if matches!( @@ -786,7 +698,7 @@ impl Unparser<'_> { match expr { ast::Expr::Nested(_) | ast::Expr::Identifier(_) | ast::Expr::Value(_) => 100, ast::Expr::BinaryOp { op, .. } => self.sql_op_precedence(op), - // closest precedence we currently have to Between is PGLikeMatch + // Closest precedence we currently have to Between is PGLikeMatch // (https://www.postgresql.org/docs/7.2/sql-precedence.html) ast::Expr::Between { .. } => { self.sql_op_precedence(&ast::BinaryOperator::PGLikeMatch) @@ -1229,7 +1141,7 @@ impl Unparser<'_> { return Ok(ast::Expr::Interval(interval)); } - // calculate the best single interval to represent the provided days and microseconds + // Calculate the best single interval to represent the provided days and microseconds let microseconds = microseconds + (days as i64 * 24 * 60 * 60 * 1_000_000); @@ -1554,7 +1466,10 @@ mod tests { use datafusion_functions_aggregate::expr_fn::sum; use datafusion_functions_window::row_number::row_number_udwf; - use crate::unparser::dialect::{CustomDialect, CustomDialectBuilder}; + use crate::unparser::dialect::{ + CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, Dialect, + PostgreSqlDialect, + }; use super::*; @@ -2428,4 +2343,39 @@ mod tests { assert_eq!(actual, expected); } } + + #[test] + fn test_round_scalar_fn_to_expr() -> Result<()> { + let default_dialect: Arc = Arc::new( + CustomDialectBuilder::new() + .with_identifier_quote_style('"') + .build(), + ); + let postgres_dialect: Arc = Arc::new(PostgreSqlDialect {}); + + for (dialect, identifier) in + [(default_dialect, "DOUBLE"), (postgres_dialect, "NUMERIC")] + { + let unparser = Unparser::new(dialect.as_ref()); + let expr = Expr::ScalarFunction(ScalarFunction { + func: Arc::new(ScalarUDF::from( + datafusion_functions::math::round::RoundFunc::new(), + )), + args: vec![ + Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Float64, + }), + Expr::Literal(ScalarValue::Int64(Some(2))), + ], + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"round(CAST("a" AS {identifier}), 2)"#); + + assert_eq!(actual, expected); + } + Ok(()) + } } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 9b4eaca834f1..304a02f037e6 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -227,13 +227,13 @@ pub(super) fn subquery_alias_inner_query_and_columns( return (plan, vec![]); }; - // check if it's projection inside projection + // Check if it's projection inside projection let Some(inner_projection) = find_projection(outer_projections.input.as_ref()) else { return (plan, vec![]); }; let mut columns: Vec = vec![]; - // check if the inner projection and outer projection have a matching pattern like + // Check if the inner projection and outer projection have a matching pattern like // Projection: j1.j1_id AS id // Projection: j1.j1_id for (i, inner_expr) in inner_projection.expr.iter().enumerate() { @@ -241,7 +241,7 @@ pub(super) fn subquery_alias_inner_query_and_columns( return (plan, vec![]); }; - // inner projection schema fields store the projection name which is used in outer + // Inner projection schema fields store the projection name which is used in outer // projection expr let inner_expr_string = match inner_expr { Expr::Column(_) => inner_expr.to_string(), diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 0059aba25738..e8c4eca569b1 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -15,14 +15,19 @@ // specific language governing permissions and limitations // under the License. +use std::cmp::Ordering; + use datafusion_common::{ internal_err, tree_node::{Transformed, TreeNode}, - Column, DataFusionError, Result, + Column, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window, }; +use sqlparser::ast; + +use super::{dialect::DateFieldExtractStyle, Unparser}; /// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). @@ -166,10 +171,17 @@ fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result Ok(grouping_expr.into_iter().nth(index)), + Ordering::Equal => { + internal_err!( + "Tried to unproject column referring to internal grouping id" + ) + } + Ordering::Greater => { + Ok(agg.aggr_expr.get(index - grouping_expr.len() - 1)) + } + } } else { Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index)) } @@ -187,3 +199,80 @@ fn find_window_expr<'a>( .flat_map(|w| w.window_expr.iter()) .find(|expr| expr.schema_name().to_string() == column_name) } + +/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style. +pub(crate) fn date_part_to_sql( + unparser: &Unparser, + style: DateFieldExtractStyle, + date_part_args: &[Expr], +) -> Result> { + match (style, date_part_args.len()) { + (DateFieldExtractStyle::Extract, 2) => { + let date_expr = unparser.expr_to_sql(&date_part_args[1])?; + if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + let field = match field.to_lowercase().as_str() { + "year" => ast::DateTimeField::Year, + "month" => ast::DateTimeField::Month, + "day" => ast::DateTimeField::Day, + "hour" => ast::DateTimeField::Hour, + "minute" => ast::DateTimeField::Minute, + "second" => ast::DateTimeField::Second, + _ => return Ok(None), + }; + + return Ok(Some(ast::Expr::Extract { + field, + expr: Box::new(date_expr), + syntax: ast::ExtractSyntax::From, + })); + } + } + (DateFieldExtractStyle::Strftime, 2) => { + let column = unparser.expr_to_sql(&date_part_args[1])?; + + if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + let field = match field.to_lowercase().as_str() { + "year" => "%Y", + "month" => "%m", + "day" => "%d", + "hour" => "%H", + "minute" => "%M", + "second" => "%S", + _ => return Ok(None), + }; + + return Ok(Some(ast::Expr::Function(ast::Function { + name: ast::ObjectName(vec![ast::Ident { + value: "strftime".to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args: vec![ + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + ast::Expr::Value(ast::Value::SingleQuotedString( + field.to_string(), + )), + )), + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(column)), + ], + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + }))); + } + } + (DateFieldExtractStyle::DatePart, _) => { + return Ok(Some( + unparser.scalar_function_to_sql("date_part", date_part_args)?, + )); + } + _ => {} + }; + + Ok(None) +} diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 656e4b851aa8..787bc6634355 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -203,7 +203,7 @@ pub(crate) fn resolve_aliases_to_exprs( .data() } -/// given a slice of window expressions sharing the same sort key, find their common partition +/// Given a slice of window expressions sharing the same sort key, find their common partition /// keys. pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> { let all_partition_keys = window_exprs @@ -322,7 +322,7 @@ A full example of how the transformation works: struct RecursiveUnnestRewriter<'a> { input_schema: &'a DFSchemaRef, root_expr: &'a Expr, - // useful to detect which child expr is a part of/ not a part of unnest operation + // Useful to detect which child expr is a part of/ not a part of unnest operation top_most_unnest: Option, consecutive_unnest: Vec>, inner_projection_exprs: &'a mut Vec, @@ -399,14 +399,14 @@ impl<'a> RecursiveUnnestRewriter<'a> { expr_in_unnest.clone().alias(placeholder_name.clone()), ); - // let post_unnest_column = Column::from_name(post_unnest_name); + // Let post_unnest_column = Column::from_name(post_unnest_name); let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name); match self .columns_unnestings .iter_mut() .find(|(inner_col, _)| inner_col == &placeholder_column) { - // there is not unnesting done on this column yet + // There is not unnesting done on this column yet None => { self.columns_unnestings.push(( Column::from_name(placeholder_name.clone()), @@ -416,7 +416,7 @@ impl<'a> RecursiveUnnestRewriter<'a> { }]), )); } - // some unnesting(at some level) has been done on this column + // Some unnesting(at some level) has been done on this column // e.g select unnest(column3), unnest(unnest(column3)) Some((_, unnesting)) => match unnesting { ColumnUnnestType::List(list) => { @@ -512,7 +512,7 @@ impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { if traversing_unnest == self.top_most_unnest.as_ref().unwrap() { self.top_most_unnest = None; } - // find inside consecutive_unnest, the sequence of continous unnest exprs + // Find inside consecutive_unnest, the sequence of continous unnest exprs // Get the latest consecutive unnest exprs // and check if current upward traversal is the returning to the root expr @@ -619,7 +619,9 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( } = original_expr.clone().rewrite(&mut rewriter)?; if !transformed { - if matches!(&transformed_expr, Expr::Column(_)) { + if matches!(&transformed_expr, Expr::Column(_)) + || matches!(&transformed_expr, Expr::Wildcard { .. }) + { push_projection_dedupl(inner_projection_exprs, transformed_expr.clone()); Ok(vec![transformed_expr]) } else { @@ -698,7 +700,7 @@ mod tests { &mut inner_projection_exprs, &original_expr, )?; - // only the bottom most unnest exprs are transformed + // Only the bottom most unnest exprs are transformed assert_eq!( transformed_exprs, vec![col("unnest_placeholder(3d_col,depth=2)") @@ -717,7 +719,7 @@ mod tests { &unnest_placeholder_columns, ); - // still reference struct_col in original schema but with alias, + // Still reference struct_col in original schema but with alias, // to avoid colliding with the projection on the column itself if any assert_eq!( inner_projection_exprs, @@ -749,7 +751,7 @@ mod tests { ], &unnest_placeholder_columns, ); - // still reference struct_col in original schema but with alias, + // Still reference struct_col in original schema but with alias, // to avoid colliding with the projection on the column itself if any assert_eq!( inner_projection_exprs, @@ -814,7 +816,7 @@ mod tests { vec![("unnest_placeholder(struct_col)", "Struct")], &unnest_placeholder_columns, ); - // still reference struct_col in original schema but with alias, + // Still reference struct_col in original schema but with alias, // to avoid colliding with the projection on the column itself if any assert_eq!( inner_projection_exprs, @@ -839,7 +841,7 @@ mod tests { ], &unnest_placeholder_columns, ); - // only transform the unnest children + // Only transform the unnest children assert_eq!( transformed_exprs, vec![col("unnest_placeholder(array_col,depth=1)") @@ -847,8 +849,8 @@ mod tests { .add(lit(1i64))] ); - // keep appending to the current vector - // still reference array_col in original schema but with alias, + // Keep appending to the current vector + // Still reference array_col in original schema but with alias, // to avoid colliding with the projection on the column itself if any assert_eq!( inner_projection_exprs, @@ -858,7 +860,7 @@ mod tests { ] ); - // a nested structure struct[[]] + // A nested structure struct[[]] let schema = Schema::new(vec![ Field::new( "struct_col", // {array_col: [1,2,3]} diff --git a/datafusion/sql/src/values.rs b/datafusion/sql/src/values.rs index 9efb75bd60e4..cd33ddb3cfe7 100644 --- a/datafusion/sql/src/values.rs +++ b/datafusion/sql/src/values.rs @@ -31,7 +31,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { rows, } = values; - // values should not be based on any other schema + // Values should not be based on any other schema let schema = DFSchema::empty(); let values = rows .into_iter() diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index fe0e5f7283a4..47caeec78dc7 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -54,6 +54,7 @@ pub(crate) struct MockSessionState { scalar_functions: HashMap>, aggregate_functions: HashMap>, expr_planners: Vec>, + window_functions: HashMap>, pub config_options: ConfigOptions, } @@ -80,6 +81,12 @@ impl MockSessionState { ); self } + + pub fn with_window_function(mut self, window_function: Arc) -> Self { + self.window_functions + .insert(window_function.name().to_string(), window_function); + self + } } pub(crate) struct MockContextProvider { @@ -217,8 +224,8 @@ impl ContextProvider for MockContextProvider { unimplemented!() } - fn get_window_meta(&self, _name: &str) -> Option> { - None + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions.get(name).cloned() } fn options(&self) -> &ConfigOptions { diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 44b591fedef8..19f3d31321ce 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -48,6 +48,7 @@ use datafusion_functions_aggregate::{ min_max::min_udaf, }; use datafusion_functions_aggregate::{average::avg_udaf, grouping::grouping_udaf}; +use datafusion_functions_window::rank::rank_udwf; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; @@ -2633,6 +2634,7 @@ fn logical_plan_with_dialect_and_options( .with_aggregate_function(min_udaf()) .with_aggregate_function(max_udaf()) .with_aggregate_function(grouping_udaf()) + .with_window_function(rank_udwf()) .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let context = MockContextProvider { state }; @@ -3059,8 +3061,8 @@ fn rank_partition_grouping() { from person group by rollup(state, last_name)"; - let expected = "Projection: sum(person.age) AS total_sum, person.state, person.last_name, grouping(person.state) + grouping(person.last_name) AS x, RANK() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS the_rank\ - \n WindowAggr: windowExpr=[[RANK() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + let expected = "Projection: sum(person.age) AS total_sum, person.state, person.last_name, grouping(person.state) + grouping(person.last_name) AS x, rank() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS the_rank\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n Aggregate: groupBy=[[ROLLUP (person.state, person.last_name)]], aggr=[[sum(person.age), grouping(person.state), grouping(person.last_name)]]\ \n TableScan: person"; quick_test(sql, expected); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a78ade81eeba..54ce91b8e71b 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1385,6 +1385,24 @@ NaN statement ok DROP TABLE tmp_percentile_cont; +# Test for issue where approx_percentile_cont_with_weight + +statement ok +CREATE TABLE t1(v1 BOOL); + +statement ok +INSERT INTO t1 VALUES (TRUE); + +# ISSUE: https://github.com/apache/datafusion/issues/12716 +# This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' and returns 'inf' +query R +SELECT approx_percentile_cont_with_weight('NaN'::DOUBLE, 0, 0) FROM t1 WHERE t1.v1; +---- +Infinity + +statement ok +DROP TABLE t1; + # csv_query_cube_avg query TIR SELECT c1, c2, AVG(c3) FROM aggregate_test_100 GROUP BY CUBE (c1, c2) ORDER BY c1, c2 @@ -3520,6 +3538,18 @@ SELECT MIN(value), MAX(value) FROM integers_with_nulls ---- 1 5 +# grouping_sets with null values +query II rowsort +SELECT value, min(value) FROM integers_with_nulls GROUP BY CUBE(value) +---- +1 1 +3 3 +4 4 +5 5 +NULL 1 +NULL NULL + + statement ok DROP TABLE integers_with_nulls; @@ -4879,16 +4909,18 @@ query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; ---- logical_plan -01)Limit: skip=0, fetch=3 -02)--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] -03)----TableScan: aggregate_test_100 projection=[c2, c3] +01)Projection: aggregate_test_100.c2, aggregate_test_100.c3 +02)--Limit: skip=0, fetch=3 +03)----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] +04)------TableScan: aggregate_test_100 projection=[c2, c3] physical_plan -01)GlobalLimitExec: skip=0, fetch=3 -02)--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3] -03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] -05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true +01)ProjectionExec: expr=[c2@0 as c2, c3@1 as c3] +02)--GlobalLimitExec: skip=0, fetch=3 +03)----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3, __grouping_id@2 as __grouping_id], aggr=[], lim=[3] +04)------CoalescePartitionsExec +05)--------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true query II SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt b/datafusion/sqllogictest/test_files/aggregates_topk.slt index 2209edc5d1fc..a67fec695f6c 100644 --- a/datafusion/sqllogictest/test_files/aggregates_topk.slt +++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt @@ -53,6 +53,11 @@ physical_plan 07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces.timestamp)] 08)--------------MemoryExec: partitions=1, partition_sizes=[1] +query TI +select * from (select trace_id, MAX(timestamp) max_ts from traces t group by trace_id) where trace_id != 'b' order by max_ts desc limit 3; +---- +c 4 +a 1 query TI select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; @@ -89,6 +94,12 @@ c 1 2 statement ok set datafusion.optimizer.enable_topk_aggregation = true; +query TI +select * from (select trace_id, MAX(timestamp) max_ts from traces t group by trace_id) where max_ts != 3 order by max_ts desc limit 2; +---- +c 4 +a 1 + query TT explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; ---- diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index b7d60b50586d..69f62057c761 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -6595,7 +6595,7 @@ select make_array(1, 2.0, null, 3) query ? select make_array(1.0, '2', null) ---- -[1.0, 2, ] +[1.0, 2.0, ] ### FixedSizeListArray @@ -7097,6 +7097,19 @@ select array_has(a, 1) from values_all_empty; false false +# Test create table with fixed sized array +statement ok +create table fixed_size_col_table (a int[3]) as values ([1,2,3]), ([4,5,6]); + +query T +select arrow_typeof(a) from fixed_size_col_table; +---- +FixedSizeList(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 3) +FixedSizeList(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 3) + +statement error +create table varying_fixed_size_col_table (a int[3]) as values ([1,2,3]), ([4,5]); + ### Delete tables statement ok @@ -7272,3 +7285,7 @@ drop table test_create_array_table; statement ok drop table values_all_empty; + +statement ok +drop table fixed_size_col_table; + diff --git a/datafusion/sqllogictest/test_files/cse.slt b/datafusion/sqllogictest/test_files/cse.slt index 19b47fa50e41..9f0f654179e9 100644 --- a/datafusion/sqllogictest/test_files/cse.slt +++ b/datafusion/sqllogictest/test_files/cse.slt @@ -179,8 +179,8 @@ physical_plan # Surely only once but also conditionally evaluated expressions query TT EXPLAIN SELECT - (a = 1 OR random() = 0) AND a = 1 AS c1, - (a = 2 AND random() = 0) OR a = 2 AS c2, + (a = 1 OR random() = 0) AND a = 2 AS c1, + (a = 2 AND random() = 0) OR a = 1 AS c2, CASE WHEN a + 3 = 0 THEN a + 3 ELSE 0 END AS c3, CASE WHEN a + 4 = 0 THEN 0 WHEN a + 4 THEN 0 ELSE 0 END AS c4, CASE WHEN a + 5 = 0 THEN 0 WHEN random() = 0 THEN a + 5 ELSE 0 END AS c5, @@ -188,11 +188,11 @@ EXPLAIN SELECT FROM t1 ---- logical_plan -01)Projection: (__common_expr_1 OR random() = Float64(0)) AND __common_expr_1 AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_2 AS c2, CASE WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 ELSE Float64(0) END AS c3, CASE WHEN __common_expr_4 = Float64(0) THEN Int64(0) WHEN CAST(__common_expr_4 AS Boolean) THEN Int64(0) ELSE Int64(0) END AS c4, CASE WHEN __common_expr_5 = Float64(0) THEN Float64(0) WHEN random() = Float64(0) THEN __common_expr_5 ELSE Float64(0) END AS c5, CASE WHEN __common_expr_6 = Float64(0) THEN Float64(0) ELSE __common_expr_6 END AS c6 +01)Projection: (__common_expr_1 OR random() = Float64(0)) AND __common_expr_2 AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_1 AS c2, CASE WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 ELSE Float64(0) END AS c3, CASE WHEN __common_expr_4 = Float64(0) THEN Int64(0) WHEN CAST(__common_expr_4 AS Boolean) THEN Int64(0) ELSE Int64(0) END AS c4, CASE WHEN __common_expr_5 = Float64(0) THEN Float64(0) WHEN random() = Float64(0) THEN __common_expr_5 ELSE Float64(0) END AS c5, CASE WHEN __common_expr_6 = Float64(0) THEN Float64(0) ELSE __common_expr_6 END AS c6 02)--Projection: t1.a = Float64(1) AS __common_expr_1, t1.a = Float64(2) AS __common_expr_2, t1.a + Float64(3) AS __common_expr_3, t1.a + Float64(4) AS __common_expr_4, t1.a + Float64(5) AS __common_expr_5, t1.a + Float64(6) AS __common_expr_6 03)----TableScan: t1 projection=[a] physical_plan -01)ProjectionExec: expr=[(__common_expr_1@0 OR random() = 0) AND __common_expr_1@0 as c1, __common_expr_2@1 AND random() = 0 OR __common_expr_2@1 as c2, CASE WHEN __common_expr_3@2 = 0 THEN __common_expr_3@2 ELSE 0 END as c3, CASE WHEN __common_expr_4@3 = 0 THEN 0 WHEN CAST(__common_expr_4@3 AS Boolean) THEN 0 ELSE 0 END as c4, CASE WHEN __common_expr_5@4 = 0 THEN 0 WHEN random() = 0 THEN __common_expr_5@4 ELSE 0 END as c5, CASE WHEN __common_expr_6@5 = 0 THEN 0 ELSE __common_expr_6@5 END as c6] +01)ProjectionExec: expr=[(__common_expr_1@0 OR random() = 0) AND __common_expr_2@1 as c1, __common_expr_2@1 AND random() = 0 OR __common_expr_1@0 as c2, CASE WHEN __common_expr_3@2 = 0 THEN __common_expr_3@2 ELSE 0 END as c3, CASE WHEN __common_expr_4@3 = 0 THEN 0 WHEN CAST(__common_expr_4@3 AS Boolean) THEN 0 ELSE 0 END as c4, CASE WHEN __common_expr_5@4 = 0 THEN 0 WHEN random() = 0 THEN __common_expr_5@4 ELSE 0 END as c5, CASE WHEN __common_expr_6@5 = 0 THEN 0 ELSE __common_expr_6@5 END as c6] 02)--ProjectionExec: expr=[a@0 = 1 as __common_expr_1, a@0 = 2 as __common_expr_2, a@0 + 3 as __common_expr_3, a@0 + 4 as __common_expr_4, a@0 + 5 as __common_expr_5, a@0 + 6 as __common_expr_6] 03)----MemoryExec: partitions=1, partition_sizes=[0] @@ -217,8 +217,8 @@ physical_plan # Only conditionally evaluated expressions query TT EXPLAIN SELECT - (random() = 0 OR a = 1) AND a = 1 AS c1, - (random() = 0 AND a = 2) OR a = 2 AS c2, + (random() = 0 OR a = 1) AND a = 2 AS c1, + (random() = 0 AND a = 2) OR a = 1 AS c2, CASE WHEN random() = 0 THEN a + 3 ELSE a + 3 END AS c3, CASE WHEN random() = 0 THEN 0 WHEN a + 4 = 0 THEN a + 4 ELSE 0 END AS c4, CASE WHEN random() = 0 THEN 0 WHEN a + 5 = 0 THEN 0 ELSE a + 5 END AS c5, @@ -226,8 +226,8 @@ EXPLAIN SELECT FROM t1 ---- logical_plan -01)Projection: (random() = Float64(0) OR t1.a = Float64(1)) AND t1.a = Float64(1) AS c1, random() = Float64(0) AND t1.a = Float64(2) OR t1.a = Float64(2) AS c2, CASE WHEN random() = Float64(0) THEN t1.a + Float64(3) ELSE t1.a + Float64(3) END AS c3, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN t1.a + Float64(4) = Float64(0) THEN t1.a + Float64(4) ELSE Float64(0) END AS c4, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN t1.a + Float64(5) = Float64(0) THEN Float64(0) ELSE t1.a + Float64(5) END AS c5, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN random() = Float64(0) THEN t1.a + Float64(6) ELSE t1.a + Float64(6) END AS c6 +01)Projection: (random() = Float64(0) OR t1.a = Float64(1)) AND t1.a = Float64(2) AS c1, random() = Float64(0) AND t1.a = Float64(2) OR t1.a = Float64(1) AS c2, CASE WHEN random() = Float64(0) THEN t1.a + Float64(3) ELSE t1.a + Float64(3) END AS c3, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN t1.a + Float64(4) = Float64(0) THEN t1.a + Float64(4) ELSE Float64(0) END AS c4, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN t1.a + Float64(5) = Float64(0) THEN Float64(0) ELSE t1.a + Float64(5) END AS c5, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN random() = Float64(0) THEN t1.a + Float64(6) ELSE t1.a + Float64(6) END AS c6 02)--TableScan: t1 projection=[a] physical_plan -01)ProjectionExec: expr=[(random() = 0 OR a@0 = 1) AND a@0 = 1 as c1, random() = 0 AND a@0 = 2 OR a@0 = 2 as c2, CASE WHEN random() = 0 THEN a@0 + 3 ELSE a@0 + 3 END as c3, CASE WHEN random() = 0 THEN 0 WHEN a@0 + 4 = 0 THEN a@0 + 4 ELSE 0 END as c4, CASE WHEN random() = 0 THEN 0 WHEN a@0 + 5 = 0 THEN 0 ELSE a@0 + 5 END as c5, CASE WHEN random() = 0 THEN 0 WHEN random() = 0 THEN a@0 + 6 ELSE a@0 + 6 END as c6] +01)ProjectionExec: expr=[(random() = 0 OR a@0 = 1) AND a@0 = 2 as c1, random() = 0 AND a@0 = 2 OR a@0 = 1 as c2, CASE WHEN random() = 0 THEN a@0 + 3 ELSE a@0 + 3 END as c3, CASE WHEN random() = 0 THEN 0 WHEN a@0 + 4 = 0 THEN a@0 + 4 ELSE 0 END as c4, CASE WHEN random() = 0 THEN 0 WHEN a@0 + 5 = 0 THEN 0 ELSE a@0 + 5 END as c5, CASE WHEN random() = 0 THEN 0 WHEN random() = 0 THEN a@0 + 6 ELSE a@0 + 6 END as c6] 02)--MemoryExec: partitions=1, partition_sizes=[0] diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index be7fdac71b57..ce0947525344 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -128,5 +128,8 @@ from aggregate_test_100 order by c9 -statement error Inconsistent data type across values list at row 1 column 0. Was Int64 but found Utf8 +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'foo' to value of Int64 type create table foo as values (1), ('foo'); + +query error No function matches +select 1 group by substr(''); diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index f561fa9e9ac8..a80a0891e977 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5152,8 +5152,6 @@ drop table test_case_expr statement ok drop table t; -# TODO: Current grouping set result is not align with Postgres and DuckDB, we might want to change the result -# See https://github.com/apache/datafusion/issues/12570 # test multi group by for binary type with nulls statement ok create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, null), (null, 0xb), (null, 0xb); @@ -5162,11 +5160,14 @@ query I?I select a, b, count(*) from t group by grouping sets ((a, b), (a), (b)); ---- 1 0a 2 -2 NULL 2 -NULL 0b 4 +2 NULL 1 +NULL 0b 2 1 NULL 2 -NULL NULL 3 +2 NULL 1 +NULL NULL 2 NULL 0a 2 +NULL NULL 1 +NULL 0b 2 statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt.temp b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt.temp deleted file mode 100644 index 00e74a207b33..000000000000 --- a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt.temp +++ /dev/null @@ -1,26 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -########## -## Join Tests -########## - -# turn off repartition_joins -statement ok -set datafusion.optimizer.repartition_joins = false; - -include ./join.slt diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 45e1b51a09b4..726de75b5141 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -148,18 +148,17 @@ SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']); {[1, 2]: [a, b], [3, 4]: [b]} query ? -SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30); +SELECT MAKE_MAP('POST', 41, 'HEAD', 53, 'PATCH', 30); ---- -{POST: 41, HEAD: ab, PATCH: 30} +{POST: 41, HEAD: 53, PATCH: 30} + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'ab' to value of Int64 type +SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30); +# Map keys can not be NULL query error SELECT MAKE_MAP('POST', 41, 'HEAD', 33, null, 30); -query ? -SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30); ----- -{POST: 41, HEAD: ab, PATCH: 30} - query ? SELECT MAKE_MAP() ---- @@ -517,9 +516,12 @@ query error SELECT MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }[NULL]; query ? -SELECT MAP { 'a': 1, 2: 3 }; +SELECT MAP { 'a': 1, 'b': 3 }; ---- -{a: 1, 2: 3} +{a: 1, b: 3} + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type +SELECT MAP { 'a': 1, 2: 3 }; # TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key # query ? @@ -610,9 +612,12 @@ select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) # Tests for map_keys query ? -SELECT map_keys(MAP { 'a': 1, 2: 3 }); +SELECT map_keys(MAP { 'a': 1, 'b': 3 }); ---- -[a, 2] +[a, b] + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type +SELECT map_keys(MAP { 'a': 1, 2: 3 }); query ? SELECT map_keys(MAP {'a':1, 'b':2, 'c':3 }) FROM t; @@ -657,8 +662,11 @@ SELECT map_keys(column1) from map_array_table_1; # Tests for map_values -query ? +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type SELECT map_values(MAP { 'a': 1, 2: 3 }); + +query ? +SELECT map_values(MAP { 'a': 1, 'b': 3 }); ---- [1, 3] diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index f53363b6eb38..6cc7ee0403f2 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -958,6 +958,24 @@ drop table foo; statement ok drop table ambiguity_test; +## reproducer for https://github.com/apache/datafusion/issues/12446 +# Ensure union ordering calculations with constants can be optimized + +statement ok +create table t(a0 int, a int, b int, c int) as values (1, 2, 3, 4), (5, 6, 7, 8); + +# expect this query to run successfully, not error +query III +select * from (select c, a, NULL::int as a0 from t order by a, c) t1 +union all +select * from (select c, NULL::int as a, a0 from t order by a0, c) t2 +order by c, a, a0, b +limit 2; +---- +4 2 NULL +4 NULL 1 + + # Casting from numeric to string types breaks the ordering statement ok CREATE EXTERNAL TABLE ordered_table ( @@ -1189,3 +1207,48 @@ physical_plan 02)--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 03)----SortExec: TopK(fetch=1), expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], has_header=true + + +# Test: inputs into union with different orderings +query TT +explain select * from (select b, c, a, NULL::int as a0 from ordered_table order by a, c) t1 +union all +select * from (select b, c, NULL::int as a, a0 from ordered_table order by a0, c) t2 +order by d, c, a, a0, b +limit 2; +---- +logical_plan +01)Projection: t1.b, t1.c, t1.a, t1.a0 +02)--Sort: t1.d ASC NULLS LAST, t1.c ASC NULLS LAST, t1.a ASC NULLS LAST, t1.a0 ASC NULLS LAST, t1.b ASC NULLS LAST, fetch=2 +03)----Union +04)------SubqueryAlias: t1 +05)--------Projection: ordered_table.b, ordered_table.c, ordered_table.a, Int32(NULL) AS a0, ordered_table.d +06)----------TableScan: ordered_table projection=[a, b, c, d] +07)------SubqueryAlias: t2 +08)--------Projection: ordered_table.b, ordered_table.c, Int32(NULL) AS a, ordered_table.a0, ordered_table.d +09)----------TableScan: ordered_table projection=[a0, b, c, d] +physical_plan +01)ProjectionExec: expr=[b@0 as b, c@1 as c, a@2 as a, a0@3 as a0] +02)--SortPreservingMergeExec: [d@4 ASC NULLS LAST,c@1 ASC NULLS LAST,a@2 ASC NULLS LAST,a0@3 ASC NULLS LAST,b@0 ASC NULLS LAST], fetch=2 +03)----UnionExec +04)------SortExec: TopK(fetch=2), expr=[d@4 ASC NULLS LAST,c@1 ASC NULLS LAST,a@2 ASC NULLS LAST,b@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------ProjectionExec: expr=[b@1 as b, c@2 as c, a@0 as a, NULL as a0, d@3 as d] +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[c@2 ASC NULLS LAST], has_header=true +07)------SortExec: TopK(fetch=2), expr=[d@4 ASC NULLS LAST,c@1 ASC NULLS LAST,a0@3 ASC NULLS LAST,b@0 ASC NULLS LAST], preserve_partitioning=[false] +08)--------ProjectionExec: expr=[b@1 as b, c@2 as c, NULL as a, a0@0 as a0, d@3 as d] +09)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, b, c, d], output_ordering=[c@2 ASC NULLS LAST], has_header=true + +# Test: run the query from above +query IIII +select * from (select b, c, a, NULL::int as a0 from ordered_table order by a, c) t1 +union all +select * from (select b, c, NULL::int as a, a0 from ordered_table order by a0, c) t2 +order by d, c, a, a0, b +limit 2; +---- +0 0 0 NULL +0 0 NULL 1 + + +statement ok +drop table ordered_table; diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 4c86312f9e51..858e42106221 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -61,7 +61,7 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--FilterExec: column1@0 != 42 -03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..87], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:87..174], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:174..261], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:261..347]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..88], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:88..176], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:176..264], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:264..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # disable round robin repartitioning statement ok @@ -77,7 +77,7 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--FilterExec: column1@0 != 42 -03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..87], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:87..174], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:174..261], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:261..347]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..88], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:88..176], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:176..264], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:264..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # enable round robin repartitioning again statement ok @@ -102,7 +102,7 @@ physical_plan 02)--SortExec: expr=[column1@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=8192 04)------FilterExec: column1@0 != 42 -05)--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..172], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:172..338, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..178], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:178..347]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +05)--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..174], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:174..342, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..180], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:180..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] ## Read the files as though they are ordered @@ -138,7 +138,7 @@ physical_plan 01)SortPreservingMergeExec: [column1@0 ASC NULLS LAST] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----FilterExec: column1@0 != 42 -04)------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..169], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..173], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:173..347], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:169..338]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +04)------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..171], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..175], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:175..351], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:171..342]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # Cleanup statement ok diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 8820fffaeb47..0c2fa41e5bf8 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1906,12 +1906,8 @@ select position('' in '') ---- 1 - -query I +query error DataFusion error: Error during planning: Error during planning: Int64 and Int64 are not coercible to a common string select position(1 in 1) ----- -1 - query I select strpos('abc', 'c'); diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 5df5f313af3c..0fef56aeea5c 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -348,8 +348,11 @@ VALUES (1),() statement error DataFusion error: Error during planning: Inconsistent data length across values list: got 2 values in row 1 but expected 1 VALUES (1),(1,2) -statement error DataFusion error: Error during planning: Inconsistent data type across values list at row 1 column 0 +query I VALUES (1),('2') +---- +1 +2 query R VALUES (1),(2.0) @@ -357,8 +360,11 @@ VALUES (1),(2.0) 1 2 -statement error DataFusion error: Error during planning: Inconsistent data type across values list at row 1 column 1 +query II VALUES (1,2), (1,'2') +---- +1 2 +1 2 query IT VALUES (1,'a'),(NULL,'b'),(3,'c') diff --git a/datafusion/sqllogictest/test_files/string/init_data.slt.part b/datafusion/sqllogictest/test_files/string/init_data.slt.part index d99401f10d20..096e3bb3b330 100644 --- a/datafusion/sqllogictest/test_files/string/init_data.slt.part +++ b/datafusion/sqllogictest/test_files/string/init_data.slt.part @@ -30,4 +30,3 @@ statement ok create table test_substr_base ( col1 VARCHAR ) as values ('foo'), ('hello🌏世界'), ('💩'), ('ThisIsAVeryLongASCIIString'), (''), (NULL); - diff --git a/datafusion/sqllogictest/test_files/string/large_string.slt b/datafusion/sqllogictest/test_files/string/large_string.slt index 169c658e5ac1..8d8a5711bdb8 100644 --- a/datafusion/sqllogictest/test_files/string/large_string.slt +++ b/datafusion/sqllogictest/test_files/string/large_string.slt @@ -59,36 +59,6 @@ Xiangpeng datafusion数据融合 false true false true Raphael datafusionДатаФусион false false false false NULL NULL NULL NULL NULL NULL -# TODO: move it back to `string_query.slt.part` after fixing the issue -# https://github.com/apache/datafusion/issues/12618 -query BB -SELECT - ascii_1 ~* '^a.{3}e', - unicode_1 ~* '^d.*Фу' -FROM test_basic_operator; ----- -true false -false false -false true -NULL NULL - -# TODO: move it back to `string_query.slt.part` after fixing the issue -# see detail: https://github.com/apache/datafusion/issues/12670 -query IIIIII -SELECT - STRPOS(ascii_1, 'e'), - STRPOS(ascii_1, 'ang'), - STRPOS(ascii_1, NULL), - STRPOS(unicode_1, 'и'), - STRPOS(unicode_1, 'ион'), - STRPOS(unicode_1, NULL) -FROM test_basic_operator; ----- -5 0 NULL 0 0 NULL -7 3 NULL 0 0 NULL -6 0 NULL 18 18 NULL -NULL NULL NULL NULL NULL NULL - # # common test for string-like functions and operators # diff --git a/datafusion/sqllogictest/test_files/string/string.slt b/datafusion/sqllogictest/test_files/string/string.slt index f4e83966f78f..e84342abd3df 100644 --- a/datafusion/sqllogictest/test_files/string/string.slt +++ b/datafusion/sqllogictest/test_files/string/string.slt @@ -34,19 +34,6 @@ statement ok create table test_substr as select arrow_cast(col1, 'Utf8') as c1 from test_substr_base; -# TODO: move it back to `string_query.slt.part` after fixing the issue -# https://github.com/apache/datafusion/issues/12618 -query BB -SELECT - ascii_1 ~* '^a.{3}e', - unicode_1 ~* '^d.*Фу' -FROM test_basic_operator; ----- -true false -false false -false true -NULL NULL - # TODO: move it back to `string_query.slt.part` after fixing the issue # see detail: https://github.com/apache/datafusion/issues/12637 # Test pattern with wildcard characters @@ -63,23 +50,6 @@ Xiangpeng datafusion数据融合 false true false true Raphael datafusionДатаФусион false false false false NULL NULL NULL NULL NULL NULL -# TODO: move it back to `string_query.slt.part` after fixing the issue -# see detail: https://github.com/apache/datafusion/issues/12670 -query IIIIII -SELECT - STRPOS(ascii_1, 'e'), - STRPOS(ascii_1, 'ang'), - STRPOS(ascii_1, NULL), - STRPOS(unicode_1, 'и'), - STRPOS(unicode_1, 'ион'), - STRPOS(unicode_1, NULL) -FROM test_basic_operator; ----- -5 0 NULL 0 0 NULL -7 3 NULL 0 0 NULL -6 0 NULL 18 18 NULL -NULL NULL NULL NULL NULL NULL - # # common test for string-like functions and operators # diff --git a/datafusion/sqllogictest/test_files/string/string_query.slt.part b/datafusion/sqllogictest/test_files/string/string_query.slt.part index 3ba2b31bbab2..dc5626b7d573 100644 --- a/datafusion/sqllogictest/test_files/string/string_query.slt.part +++ b/datafusion/sqllogictest/test_files/string/string_query.slt.part @@ -642,18 +642,16 @@ true false false true NULL NULL -# TODO: DictionaryString does not support ~* operator. Enable this after fixing the issue -# see issue: https://github.com/apache/datafusion/issues/12618 -#query BB -#SELECT -# ascii_1 ~* '^a.{3}e', -# unicode_1 ~* '^d.*Фу' -#FROM test_basic_operator; -#---- -#true false -#false false -#false true -#NULL NULL +query BB +SELECT + ascii_1 ~* '^a.{3}e', + unicode_1 ~* '^d.*Фу' +FROM test_basic_operator; +---- +true false +false false +false true +NULL NULL query BB SELECT @@ -951,22 +949,20 @@ NULL NULL # Test STRPOS # -------------------------------------- -# TODO: DictionaryString does not support STRPOS. Enable this after fixing the issue -# see issue: https://github.com/apache/datafusion/issues/12670 -#query IIIIII -#SELECT -# STRPOS(ascii_1, 'e'), -# STRPOS(ascii_1, 'ang'), -# STRPOS(ascii_1, NULL), -# STRPOS(unicode_1, 'и'), -# STRPOS(unicode_1, 'ион'), -# STRPOS(unicode_1, NULL) -#FROM test_basic_operator; -#---- -#5 0 NULL 0 0 NULL -#7 3 NULL 0 0 NULL -#6 0 NULL 18 18 NULL -#NULL NULL NULL NULL NULL NULL +query IIIIII +SELECT + STRPOS(ascii_1, 'e'), + STRPOS(ascii_1, 'ang'), + STRPOS(ascii_1, NULL), + STRPOS(unicode_1, 'и'), + STRPOS(unicode_1, 'ион'), + STRPOS(unicode_1, NULL) +FROM test_basic_operator; +---- +5 0 NULL 0 0 NULL +7 3 NULL 0 0 NULL +6 0 NULL 18 18 NULL +NULL NULL NULL NULL NULL NULL # -------------------------------------- # Test SUBSTR_INDEX diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 4e7857ad804b..19bea3bf6bd0 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -37,36 +37,6 @@ select arrow_cast(col1, 'Utf8View') as c1 from test_substr_base; statement ok drop table test_source -# TODO: move it back to `string_query.slt.part` after fixing the issue -# https://github.com/apache/datafusion/issues/12618 -query BB -SELECT - ascii_1 ~* '^a.{3}e', - unicode_1 ~* '^d.*Фу' -FROM test_basic_operator; ----- -true false -false false -false true -NULL NULL - -# TODO: move it back to `string_query.slt.part` after fixing the issue -# see detail: https://github.com/apache/datafusion/issues/12670 -query IIIIII -SELECT - STRPOS(ascii_1, 'e'), - STRPOS(ascii_1, 'ang'), - STRPOS(ascii_1, NULL), - STRPOS(unicode_1, 'и'), - STRPOS(unicode_1, 'ион'), - STRPOS(unicode_1, NULL) -FROM test_basic_operator; ----- -5 0 NULL 0 0 NULL -7 3 NULL 0 0 NULL -6 0 NULL 18 18 NULL -NULL NULL NULL NULL NULL NULL - # # common test for string-like functions and operators # @@ -109,6 +79,21 @@ FROM test_source; statement ok drop table test_source +######## +## StringView Function test +######## + +query error DataFusion error: Arrow error: Compute error: bit_length not supported for Utf8View +select bit_length(column1_utf8view) from test; + +query T +select btrim(column1_large_utf8) from test; +---- +Andrew +Xiangpeng +Raphael +NULL + ######## ## StringView to Other Types column ######## @@ -316,9 +301,8 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: starts_with(__common_expr_1, test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(__common_expr_1, CAST(test.column2_large_utf8 AS Utf8View)) AS c4 -02)--Projection: CAST(test.column1_utf8 AS Utf8View) AS __common_expr_1, test.column1_utf8, test.column2_utf8, test.column2_large_utf8, test.column2_utf8view -03)----TableScan: test projection=[column1_utf8, column2_utf8, column2_large_utf8, column2_utf8view] +01)Projection: starts_with(CAST(test.column1_utf8 AS Utf8View), test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(CAST(test.column1_utf8 AS LargeUtf8), test.column2_large_utf8) AS c4 +02)--TableScan: test projection=[column1_utf8, column2_utf8, column2_large_utf8, column2_utf8view] query BBB SELECT @@ -608,7 +592,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: contains(test.column1_utf8view, Utf8("foo")) AS c1, contains(test.column1_utf8view, test.column2_utf8view) AS c2, contains(test.column1_utf8view, test.column2_large_utf8) AS c3, contains(test.column1_utf8, test.column2_utf8view) AS c4, contains(test.column1_utf8, test.column2_utf8) AS c5, contains(test.column1_utf8, test.column2_large_utf8) AS c6, contains(test.column1_large_utf8, test.column1_utf8view) AS c7, contains(test.column1_large_utf8, test.column2_utf8) AS c8, contains(test.column1_large_utf8, test.column2_large_utf8) AS c9 +01)Projection: contains(test.column1_utf8view, Utf8View("foo")) AS c1, contains(test.column1_utf8view, test.column2_utf8view) AS c2, contains(test.column1_utf8view, CAST(test.column2_large_utf8 AS Utf8View)) AS c3, contains(CAST(test.column1_utf8 AS Utf8View), test.column2_utf8view) AS c4, contains(test.column1_utf8, test.column2_utf8) AS c5, contains(CAST(test.column1_utf8 AS LargeUtf8), test.column2_large_utf8) AS c6, contains(CAST(test.column1_large_utf8 AS Utf8View), test.column1_utf8view) AS c7, contains(test.column1_large_utf8, CAST(test.column2_utf8 AS LargeUtf8)) AS c8, contains(test.column1_large_utf8, test.column2_large_utf8) AS c9 02)--TableScan: test projection=[column1_utf8, column2_utf8, column1_large_utf8, column2_large_utf8, column1_utf8view, column2_utf8view] ## Ensure no casts for ENDS_WITH @@ -852,7 +836,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: strpos(test.column1_utf8view, Utf8("f")) AS c, strpos(test.column1_utf8view, test.column2_utf8view) AS c2 +01)Projection: strpos(test.column1_utf8view, Utf8View("f")) AS c, strpos(test.column1_utf8view, test.column2_utf8view) AS c2 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for SUBSTR diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index d2e7160d0010..b76c78396aed 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -373,3 +373,159 @@ You reached the bottom! statement ok drop view complex_view; + +# struct with different keys r1 and r2 is not valid +statement ok +create table t(a struct, b struct) as values (struct('red', 1), struct('blue', 2.3)); + +# Expect same keys for struct type but got mismatched pair r1,c and r2,c +query error +select [a, b] from t; + +statement ok +drop table t; + +# struct with the same key +statement ok +create table t(a struct, b struct) as values (struct('red', 1), struct('blue', 2.3)); + +query T +select arrow_typeof([a, b]) from t; +---- +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +query ? +select [a, b] from t; +---- +[{r: red, c: 1}, {r: blue, c: 2}] + +statement ok +drop table t; + +# Test row alias + +query ? +select row('a', 'b'); +---- +{c0: a, c1: b} + +################################## +# Switch Dialect to DuckDB +################################## + +statement ok +set datafusion.sql_parser.dialect = 'DuckDB'; + +statement ok +CREATE TABLE struct_values ( + s1 struct(a int, b varchar), + s2 struct(a int, b varchar) +) AS VALUES + (row(1, 'red'), row(1, 'string1')), + (row(2, 'blue'), row(2, 'string2')), + (row(3, 'green'), row(3, 'string3')) +; + +statement ok +drop table struct_values; + +statement ok +create table t (c1 struct(r varchar, b int), c2 struct(r varchar, b float)) as values ( + row('red', 2), + row('blue', 2.3) +); + +query ?? +select * from t; +---- +{r: red, b: 2} {r: blue, b: 2.3} + +query T +select arrow_typeof(c1) from t; +---- +Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +query T +select arrow_typeof(c2) from t; +---- +Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; + +################################## +## Test Coalesce with Struct +################################## + +statement ok +CREATE TABLE t ( + s1 struct(a int, b varchar), + s2 struct(a float, b varchar) +) AS VALUES + (row(1, 'red'), row(1.1, 'string1')), + (row(2, 'blue'), row(2.2, 'string2')), + (row(3, 'green'), row(33.2, 'string3')) +; + +query ? +select coalesce(s1) from t; +---- +{a: 1, b: red} +{a: 2, b: blue} +{a: 3, b: green} + +# TODO: a's type should be float +query T +select arrow_typeof(coalesce(s1)) from t; +---- +Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; + +statement ok +CREATE TABLE t ( + s1 struct(a int, b varchar), + s2 struct(a float, b varchar) +) AS VALUES + (row(1, 'red'), row(1.1, 'string1')), + (null, row(2.2, 'string2')), + (row(3, 'green'), row(33.2, 'string3')) +; + +# TODO: second column should not be null +query ? +select coalesce(s1) from t; +---- +{a: 1, b: red} +NULL +{a: 3, b: green} + +statement ok +drop table t; + +# row() with incorrect order +statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'blue' to value of Float64 type +create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values + (row('red', 1), row(2.3, 'blue')), + (row('purple', 1), row('green', 2.3)); + +# out of order struct literal +# TODO: This query should not fail +statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type +create table t(a struct(r varchar, c int)) as values ({r: 'a', c: 1}), ({c: 2, r: 'b'}); + +################################## +## Test Array of Struct +################################## + +query ? +select [{r: 'a', c: 1}, {r: 'b', c: 2}]; +---- +[{r: a, c: 1}, {r: b, c: 2}] + +# Can't create a list of struct with different field types +query error +select [{r: 'a', c: 1}, {c: 2, r: 'b'}]; diff --git a/datafusion/sqllogictest/test_files/subquery_sort.slt b/datafusion/sqllogictest/test_files/subquery_sort.slt index 17affbc0acad..e4360a9269ca 100644 --- a/datafusion/sqllogictest/test_files/subquery_sort.slt +++ b/datafusion/sqllogictest/test_files/subquery_sort.slt @@ -93,14 +93,14 @@ logical_plan 02)--Sort: t2.c1 ASC NULLS LAST, t2.c3 ASC NULLS LAST, t2.c9 ASC NULLS LAST 03)----SubqueryAlias: t2 04)------Sort: sink_table.c1 ASC NULLS LAST, sink_table.c3 ASC NULLS LAST, fetch=2 -05)--------Projection: sink_table.c1, RANK() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS r, sink_table.c3, sink_table.c9 -06)----------WindowAggr: windowExpr=[[RANK() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +05)--------Projection: sink_table.c1, rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS r, sink_table.c3, sink_table.c9 +06)----------WindowAggr: windowExpr=[[rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 07)------------TableScan: sink_table projection=[c1, c3, c9] physical_plan 01)ProjectionExec: expr=[c1@0 as c1, r@1 as r] 02)--SortExec: TopK(fetch=2), expr=[c1@0 ASC NULLS LAST,c3@2 ASC NULLS LAST,c9@3 ASC NULLS LAST], preserve_partitioning=[false] -03)----ProjectionExec: expr=[c1@0 as c1, RANK() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as r, c3@1 as c3, c9@2 as c9] -04)------BoundedWindowAggExec: wdw=[RANK() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "RANK() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Utf8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----ProjectionExec: expr=[c1@0 as c1, rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as r, c3@1 as c3, c9@2 as c9] +04)------BoundedWindowAggExec: wdw=[rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Utf8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 05)--------SortExec: expr=[c1@0 DESC], preserve_partitioning=[false] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c3, c9], has_header=true diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 63ca74e9714c..b923e94fc819 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -33,7 +33,7 @@ AS VALUES statement ok CREATE TABLE nested_unnest_table AS VALUES - (struct('a', 'b', struct('c')), (struct('a', 'b', [10,20])), [struct('a', 'b')]), + (struct('a', 'b', struct('c')), (struct('a', 'b', [10,20])), [struct('a', 'b')]), (struct('d', 'e', struct('f')), (struct('x', 'y', [30,40, 50])), null) ; @@ -511,10 +511,19 @@ x y [30, 40, 50] query error DataFusion error: type_coercion\ncaused by\nThis feature is not implemented: Unnest should be rewritten to LogicalPlan::Unnest before type coercion select sum(unnest(generate_series(1,10))); -## TODO: support unnest as a child expr query error DataFusion error: Internal error: unnest on struct can only be applied at the root level of select expression select arrow_typeof(unnest(column5)) from unnest_table; +query T +select arrow_typeof(unnest(column1)) from unnest_table; +---- +Int64 +Int64 +Int64 +Int64 +Int64 +Int64 +Int64 ## unnest from a result of a logical plan with limit and offset query I @@ -524,10 +533,19 @@ select unnest(column1) from (select * from (values([1,2,3]), ([4,5,6])) limit 1 5 6 -## FIXME: https://github.com/apache/datafusion/issues/11198 query error DataFusion error: Error during planning: Projections require unique expression names but the expression "UNNEST\(unnest_table.column1\)" at position 0 and "UNNEST\(unnest_table.column1\)" at position 1 have the same name. Consider aliasing \("AS"\) one of them. select unnest(column1), unnest(column1) from unnest_table; +query II +select unnest(column1), unnest(column1) u1 from unnest_table; +---- +1 1 +2 2 +3 3 +4 4 +5 5 +6 6 +12 12 ## the same unnest expr is referened multiple times (unnest is the bottom-most expr) query ??II @@ -777,6 +795,61 @@ select unnest(unnest(column2)) c2, count(column3) from recursive_unnest_table gr [, 6] 1 NULL 1 -### TODO: group by unnest struct query error DataFusion error: Error during planning: Projection references non\-aggregate values select unnest(column1) c1 from nested_unnest_table group by c1.c0; + +# TODO: this query should work. see issue: https://github.com/apache/datafusion/issues/12794 +query error DataFusion error: Internal error: unnest on struct can only be applied at the root level of select expression +select unnest(column1) c1 from nested_unnest_table + +query II??I?? +select unnest(column5), * from unnest_table; +---- +1 2 [1, 2, 3] [7] 1 [13, 14] {c0: 1, c1: 2} +3 4 [4, 5] [8, 9, 10] 2 [15, 16] {c0: 3, c1: 4} +NULL NULL [6] [11, 12] 3 NULL NULL +7 8 [12] [, 42, ] NULL NULL {c0: 7, c1: 8} +NULL NULL NULL NULL 4 [17, 18] NULL + +query TT???? +select unnest(column1), * from nested_unnest_table +---- +a b {c0: c} {c0: a, c1: b, c2: {c0: c}} {c0: a, c1: b, c2: [10, 20]} [{c0: a, c1: b}] +d e {c0: f} {c0: d, c1: e, c2: {c0: f}} {c0: x, c1: y, c2: [30, 40, 50]} NULL + +query ????? +select unnest(unnest(column3)), * from recursive_unnest_table +---- +[1] [[1, 2]] {c0: [1], c1: a} [[[1], [2]], [[1, 1]]] [{c0: [1], c1: [[1, 2]]}] +[2] [[3], [4]] {c0: [2], c1: b} [[[3, 4], [5]], [[, 6], , [7, 8]]] [{c0: [2], c1: [[3], [4]]}] + +statement ok +CREATE TABLE join_table +AS VALUES + (1, 2, 3), + (2, 3, 4), + (4, 5, 6) +; + +query IIIII +select unnest(u.column5), j.* from unnest_table u join join_table j on u.column3 = j.column1 +---- +1 2 1 2 3 +3 4 2 3 4 +NULL NULL 4 5 6 + +query II?I? +select unnest(column5), * except (column5, column1) from unnest_table; +---- +1 2 [7] 1 [13, 14] +3 4 [8, 9, 10] 2 [15, 16] +NULL NULL [11, 12] 3 NULL +7 8 [, 42, ] NULL NULL +NULL NULL NULL 4 [17, 18] + +query III +select unnest(u.column5), j.* except(column2, column3) from unnest_table u join join_table j on u.column3 = j.column1 +---- +1 2 1 +3 4 2 +NULL NULL 4 diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index cb6c6a5ace76..40309a1f2de9 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -2636,14 +2636,14 @@ EXPLAIN SELECT ---- logical_plan 01)Sort: annotated_data_finite.ts DESC NULLS FIRST, fetch=5 -02)--Projection: annotated_data_finite.ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 -03)----WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +02)--Projection: annotated_data_finite.ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 +03)----WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] 04)------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] 05)--------TableScan: annotated_data_finite projection=[ts, inc_col] physical_plan 01)SortExec: TopK(fetch=5), expr=[ts@0 DESC], preserve_partitioning=[false] -02)--ProjectionExec: expr=[ts@0 as ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] -03)----BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +02)--ProjectionExec: expr=[ts@0 as ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] +03)----BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }], mode=[Sorted] 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index e6bfc67eda81..888480774996 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -34,6 +34,7 @@ use datafusion::logical_expr::{ ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, Values, }; use substrait::proto::expression::subquery::set_predicate::PredicateOp; +use substrait::proto::expression_reference::ExprType; use url::Url; use crate::extensions::Extensions; @@ -96,7 +97,7 @@ use substrait::proto::{ sort_field::{SortDirection, SortKind::*}, AggregateFunction, Expression, NamedStruct, Plan, Rel, Type, }; -use substrait::proto::{FunctionArgument, SortField}; +use substrait::proto::{ExtendedExpression, FunctionArgument, SortField}; // Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which // is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone @@ -251,6 +252,81 @@ pub async fn from_substrait_plan( } } +/// An ExprContainer is a container for a collection of expressions with a common input schema +/// +/// In addition, each expression is associated with a field, which defines the +/// expression's output. The data type and nullability of the field are calculated from the +/// expression and the input schema. However the names of the field (and its nested fields) are +/// derived from the Substrait message. +pub struct ExprContainer { + /// The input schema for the expressions + pub input_schema: DFSchemaRef, + /// The expressions + /// + /// Each item contains an expression and the field that defines the expected nullability and name of the expr's output + pub exprs: Vec<(Expr, Field)>, +} + +/// Convert Substrait ExtendedExpression to ExprContainer +/// +/// A Substrait ExtendedExpression message contains one or more expressions, +/// with names for the outputs, and an input schema. These pieces are all included +/// in the ExprContainer. +/// +/// This is a top-level message and can be used to send expressions (not plans) +/// between systems. This is often useful for scenarios like pushdown where filter +/// expressions need to be sent to remote systems. +pub async fn from_substrait_extended_expr( + ctx: &SessionContext, + extended_expr: &ExtendedExpression, +) -> Result { + // Register function extension + let extensions = Extensions::try_from(&extended_expr.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } + + let input_schema = DFSchemaRef::new(match &extended_expr.base_schema { + Some(base_schema) => from_substrait_named_struct(base_schema, &extensions), + None => { + plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message") + } + }?); + + // Parse expressions + let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len()); + for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() { + let scalar_expr = match &substrait_expr.expr_type { + Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr), + Some(ExprType::Measure(_)) => { + not_impl_err!("Measure expressions are not yet supported") + } + None => { + plan_err!("required property `expr_type` missing from Substrait ExpressionReference message") + } + }?; + let expr = + from_substrait_rex(ctx, scalar_expr, &input_schema, &extensions).await?; + let (output_type, expected_nullability) = + expr.data_type_and_nullable(&input_schema)?; + let output_field = Field::new("", output_type, expected_nullability); + let mut names_idx = 0; + let output_field = rename_field( + &output_field, + &substrait_expr.output_names, + expr_idx, + &mut names_idx, + /*rename_self=*/ true, + )?; + exprs.push((expr, output_field)); + } + + Ok(ExprContainer { + input_schema, + exprs, + }) +} + /// parse projection pub fn extract_projection( t: LogicalPlan, @@ -334,6 +410,68 @@ fn rename_expressions( .collect() } +fn rename_field( + field: &Field, + dfs_names: &Vec, + unnamed_field_suffix: usize, // If Substrait doesn't provide a name, we'll use this "c{unnamed_field_suffix}" + name_idx: &mut usize, // Index into dfs_names + rename_self: bool, // Some fields (e.g. list items) don't have names in Substrait and this will be false to keep old name +) -> Result { + let name = if rename_self { + next_struct_field_name(unnamed_field_suffix, dfs_names, name_idx)? + } else { + field.name().to_string() + }; + match field.data_type() { + DataType::Struct(children) => { + let children = children + .iter() + .enumerate() + .map(|(child_idx, f)| { + rename_field( + f.as_ref(), + dfs_names, + child_idx, + name_idx, + /*rename_self=*/ true, + ) + }) + .collect::>()?; + Ok(field + .to_owned() + .with_name(name) + .with_data_type(DataType::Struct(children))) + } + DataType::List(inner) => { + let renamed_inner = rename_field( + inner.as_ref(), + dfs_names, + 0, + name_idx, + /*rename_self=*/ false, + )?; + Ok(field + .to_owned() + .with_data_type(DataType::List(FieldRef::new(renamed_inner))) + .with_name(name)) + } + DataType::LargeList(inner) => { + let renamed_inner = rename_field( + inner.as_ref(), + dfs_names, + 0, + name_idx, + /*rename_self= */ false, + )?; + Ok(field + .to_owned() + .with_data_type(DataType::LargeList(FieldRef::new(renamed_inner))) + .with_name(name)) + } + _ => Ok(field.to_owned().with_name(name)), + } +} + /// Produce a version of the given schema with names matching the given list of names. /// Substrait doesn't deal with column (incl. nested struct field) names within the schema, /// but it does give us the list of expected names at the end of the plan, so we use this @@ -342,59 +480,20 @@ fn make_renamed_schema( schema: &DFSchemaRef, dfs_names: &Vec, ) -> Result { - fn rename_inner_fields( - dtype: &DataType, - dfs_names: &Vec, - name_idx: &mut usize, - ) -> Result { - match dtype { - DataType::Struct(fields) => { - let fields = fields - .iter() - .map(|f| { - let name = next_struct_field_name(0, dfs_names, name_idx)?; - Ok((**f).to_owned().with_name(name).with_data_type( - rename_inner_fields(f.data_type(), dfs_names, name_idx)?, - )) - }) - .collect::>()?; - Ok(DataType::Struct(fields)) - } - DataType::List(inner) => Ok(DataType::List(FieldRef::new( - (**inner).to_owned().with_data_type(rename_inner_fields( - inner.data_type(), - dfs_names, - name_idx, - )?), - ))), - DataType::LargeList(inner) => Ok(DataType::LargeList(FieldRef::new( - (**inner).to_owned().with_data_type(rename_inner_fields( - inner.data_type(), - dfs_names, - name_idx, - )?), - ))), - _ => Ok(dtype.to_owned()), - } - } - let mut name_idx = 0; let (qualifiers, fields): (_, Vec) = schema .iter() - .map(|(q, f)| { - let name = next_struct_field_name(0, dfs_names, &mut name_idx)?; - Ok(( - q.cloned(), - (**f) - .to_owned() - .with_name(name) - .with_data_type(rename_inner_fields( - f.data_type(), - dfs_names, - &mut name_idx, - )?), - )) + .enumerate() + .map(|(field_idx, (q, f))| { + let renamed_f = rename_field( + f.as_ref(), + dfs_names, + field_idx, + &mut name_idx, + /*rename_self=*/ true, + )?; + Ok((q.cloned(), renamed_f)) }) .collect::>>()? .into_iter() @@ -788,6 +887,17 @@ pub async fn from_substrait_rel( not_impl_err!("Union relation requires at least one input") } } + set_rel::SetOp::IntersectionPrimary => { + if set.inputs.len() == 2 { + LogicalPlanBuilder::intersect( + from_substrait_rel(ctx, &set.inputs[0], extensions).await?, + from_substrait_rel(ctx, &set.inputs[1], extensions).await?, + false, + ) + } else { + not_impl_err!("Primary Intersect relation with more than two inputs isn't supported") + } + } _ => not_impl_err!("Unsupported set operator: {set_op:?}"), }, Err(e) => not_impl_err!("Invalid set operation type {}: {e}", set.op), @@ -1681,14 +1791,14 @@ fn from_substrait_struct_type( } fn next_struct_field_name( - i: usize, + column_idx: usize, dfs_names: &[String], name_idx: &mut usize, ) -> Result { if dfs_names.is_empty() { // If names are not given, create dummy names // c0, c1, ... align with e.g. SqlToRel::create_named_struct - Ok(format!("c{i}")) + Ok(format!("c{column_idx}")) } else { let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| { substrait_datafusion_err!("Named schema must contain names for all fields") diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index fada827875b0..1165ce13d236 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. -use itertools::Itertools; use std::sync::Arc; +use substrait::proto::expression_reference::ExprType; use arrow_buffer::ToByteSlice; -use datafusion::arrow::datatypes::IntervalUnit; +use datafusion::arrow::datatypes::{Field, IntervalUnit}; use datafusion::logical_expr::{ CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, }; @@ -63,7 +63,9 @@ use substrait::proto::expression::window_function::BoundsType; use substrait::proto::read_rel::VirtualTable; use substrait::proto::rel_common::EmitKind; use substrait::proto::rel_common::EmitKind::Emit; -use substrait::proto::{rel_common, CrossRel, ExchangeRel, RelCommon}; +use substrait::proto::{ + rel_common, CrossRel, ExchangeRel, ExpressionReference, ExtendedExpression, RelCommon, +}; use substrait::{ proto::{ aggregate_function::AggregationInvocation, @@ -119,6 +121,56 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result Result> { + let mut extensions = Extensions::default(); + + let substrait_exprs = exprs + .iter() + .map(|(expr, field)| { + let substrait_expr = to_substrait_rex( + ctx, + expr, + schema, + /*col_ref_offset=*/ 0, + &mut extensions, + )?; + let mut output_names = Vec::new(); + flatten_names(field, false, &mut output_names)?; + Ok(ExpressionReference { + output_names, + expr_type: Some(ExprType::Expression(substrait_expr)), + }) + }) + .collect::>>()?; + let substrait_schema = to_substrait_named_struct(schema, &mut extensions)?; + + Ok(Box::new(ExtendedExpression { + advanced_extensions: None, + expected_type_urls: vec![], + extension_uris: vec![], + extensions: extensions.into(), + version: Some(version::version_with_producer("datafusion")), + referred_expr: substrait_exprs, + base_schema: Some(substrait_schema), + })) +} + /// Convert DataFusion LogicalPlan to Substrait Rel pub fn to_substrait_rel( plan: &LogicalPlan, @@ -580,50 +632,43 @@ fn create_project_remapping(expr_count: usize, input_field_count: usize) -> Emit Emit(rel_common::Emit { output_mapping }) } +// Substrait wants a list of all field names, including nested fields from structs, +// also from within e.g. lists and maps. However, it does not want the list and map field names +// themselves - only proper structs fields are considered to have useful names. +fn flatten_names(field: &Field, skip_self: bool, names: &mut Vec) -> Result<()> { + if !skip_self { + names.push(field.name().to_string()); + } + match field.data_type() { + DataType::Struct(fields) => { + for field in fields { + flatten_names(field, false, names)?; + } + Ok(()) + } + DataType::List(l) => flatten_names(l, true, names), + DataType::LargeList(l) => flatten_names(l, true, names), + DataType::Map(m, _) => match m.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + flatten_names(&key_and_value[0], true, names)?; + flatten_names(&key_and_value[1], true, names) + } + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + }, + _ => Ok(()), + }?; + Ok(()) +} + fn to_substrait_named_struct( schema: &DFSchemaRef, extensions: &mut Extensions, ) -> Result { - // Substrait wants a list of all field names, including nested fields from structs, - // also from within e.g. lists and maps. However, it does not want the list and map field names - // themselves - only proper structs fields are considered to have useful names. - fn names_dfs(dtype: &DataType) -> Result> { - match dtype { - DataType::Struct(fields) => { - let mut names = Vec::new(); - for field in fields { - names.push(field.name().to_string()); - names.extend(names_dfs(field.data_type())?); - } - Ok(names) - } - DataType::List(l) => names_dfs(l.data_type()), - DataType::LargeList(l) => names_dfs(l.data_type()), - DataType::Map(m, _) => match m.data_type() { - DataType::Struct(key_and_value) if key_and_value.len() == 2 => { - let key_names = - names_dfs(key_and_value.first().unwrap().data_type())?; - let value_names = - names_dfs(key_and_value.last().unwrap().data_type())?; - Ok([key_names, value_names].concat()) - } - _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), - }, - _ => Ok(Vec::new()), - } + let mut names = Vec::with_capacity(schema.fields().len()); + for field in schema.fields() { + flatten_names(field, false, &mut names)?; } - let names = schema - .fields() - .iter() - .map(|f| { - let mut names = vec![f.name().to_string()]; - names.extend(names_dfs(f.data_type())?); - Ok(names) - }) - .flatten_ok() - .collect::>()?; - let field_types = r#type::Struct { types: schema .fields() @@ -2178,14 +2223,16 @@ fn substrait_field_ref(index: usize) -> Result { mod test { use super::*; use crate::logical_plan::consumer::{ - from_substrait_literal_without_names, from_substrait_type_without_names, + from_substrait_extended_expr, from_substrait_literal_without_names, + from_substrait_named_struct, from_substrait_type_without_names, }; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::arrow::array::{ GenericListArray, Int64Builder, MapBuilder, StringBuilder, }; - use datafusion::arrow::datatypes::Field; + use datafusion::arrow::datatypes::{Field, Fields, Schema}; use datafusion::common::scalar::ScalarStructBuilder; + use datafusion::common::DFSchema; use std::collections::HashMap; #[test] @@ -2461,4 +2508,101 @@ mod test { Ok(()) } + + #[test] + fn named_struct_names() -> Result<()> { + let mut extensions = Extensions::default(); + let schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ + Field::new("int", DataType::Int32, true), + Field::new( + "struct", + DataType::Struct(Fields::from(vec![Field::new( + "inner", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + )])), + true, + ), + Field::new("trailer", DataType::Float64, true), + ]))?); + + let named_struct = to_substrait_named_struct(&schema, &mut extensions)?; + + // Struct field names should be flattened DFS style + // List field names should be omitted + assert_eq!( + named_struct.names, + vec!["int", "struct", "inner", "trailer"] + ); + + let roundtrip_schema = from_substrait_named_struct(&named_struct, &extensions)?; + assert_eq!(schema.as_ref(), &roundtrip_schema); + Ok(()) + } + + #[tokio::test] + async fn extended_expressions() -> Result<()> { + let ctx = SessionContext::new(); + + // One expression, empty input schema + let expr = Expr::Literal(ScalarValue::Int32(Some(42))); + let field = Field::new("out", DataType::Int32, false); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let substrait = + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &ctx)?; + let roundtrip_expr = from_substrait_extended_expr(&ctx, &substrait).await?; + + assert_eq!(roundtrip_expr.input_schema, empty_schema); + assert_eq!(roundtrip_expr.exprs.len(), 1); + + let (rt_expr, rt_field) = roundtrip_expr.exprs.first().unwrap(); + assert_eq!(rt_field, &field); + assert_eq!(rt_expr, &expr); + + // Multiple expressions, with column references + let expr1 = Expr::Column("c0".into()); + let expr2 = Expr::Column("c1".into()); + let out1 = Field::new("out1", DataType::Int32, true); + let out2 = Field::new("out2", DataType::Utf8, true); + let input_schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + ]))?); + + let substrait = to_substrait_extended_expr( + &[(&expr1, &out1), (&expr2, &out2)], + &input_schema, + &ctx, + )?; + let roundtrip_expr = from_substrait_extended_expr(&ctx, &substrait).await?; + + assert_eq!(roundtrip_expr.input_schema, input_schema); + assert_eq!(roundtrip_expr.exprs.len(), 2); + + let mut exprs = roundtrip_expr.exprs.into_iter(); + + let (rt_expr, rt_field) = exprs.next().unwrap(); + assert_eq!(rt_field, out1); + assert_eq!(rt_expr, expr1); + + let (rt_expr, rt_field) = exprs.next().unwrap(); + assert_eq!(rt_field, out2); + assert_eq!(rt_expr, expr2); + + Ok(()) + } + + #[tokio::test] + async fn invalid_extended_expression() { + let ctx = SessionContext::new(); + + // Not ok if input schema is missing field referenced by expr + let expr = Expr::Column("missing".into()); + let field = Field::new("out", DataType::Int32, false); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + + let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &ctx); + + assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); + } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 3b7d0fd29610..80caaafad6f6 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::utils::test::read_json; use datafusion::arrow::array::ArrayRef; use datafusion::physical_plan::Accumulator; use datafusion::scalar::ScalarValue; @@ -294,8 +295,9 @@ async fn aggregate_grouping_sets() -> Result<()> { async fn aggregate_grouping_rollup() -> Result<()> { assert_expected_plan( "SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)", - "Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\ - \n TableScan: data projection=[a, b, c, e]", + "Projection: data.a, data.c, data.e, avg(data.b)\ + \n Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\ + \n TableScan: data projection=[a, b, c, e]", true ).await } @@ -662,6 +664,17 @@ async fn simple_intersect() -> Result<()> { .await } +#[tokio::test] +async fn simple_intersect_consume() -> Result<()> { + let proto_plan = read_json("tests/testdata/test_plans/intersect.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT SELECT a FROM data2", + ) + .await +} + #[tokio::test] async fn simple_intersect_table_reuse() -> Result<()> { // Substrait does currently NOT maintain the alias of the tables. @@ -1110,6 +1123,21 @@ async fn assert_expected_plan( Ok(()) } +async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> { + let ctx = create_context().await?; + + let expected = ctx.sql(sql).await?.into_optimized_plan()?; + + let plan = from_substrait_plan(&ctx, &substrait_plan).await?; + let plan = ctx.state().optimize(&plan)?; + + let planstr = format!("{plan}"); + let expectedstr = format!("{expected}"); + assert_eq!(planstr, expectedstr); + + Ok(()) +} + async fn roundtrip_fill_na(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; diff --git a/datafusion/substrait/tests/cases/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs index da0898d222c4..72d685817d7d 100644 --- a/datafusion/substrait/tests/cases/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -117,8 +117,8 @@ mod tests { let datafusion_plan = df.into_optimized_plan()?; assert_eq!( format!("{}", datafusion_plan), - "Projection: data.b, RANK() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, data.c\ - \n WindowAggr: windowExpr=[[RANK() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + "Projection: data.b, rank() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, data.c\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: data projection=[a, b, c]", ); diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect.substrait.json new file mode 100644 index 000000000000..b9a2e4ad1403 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect.substrait.json @@ -0,0 +1,118 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_PRIMARY" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } +} \ No newline at end of file diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 46e157aecfd9..2440244d08c3 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -60,4 +60,5 @@ wasm-bindgen = "0.2.87" wasm-bindgen-futures = "0.4.40" [dev-dependencies] -wasm-bindgen-test = "0.3" +tokio = { workspace = true } +wasm-bindgen-test = "0.3.44" diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 0af07e176f0f..37512e8278a7 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -1095,9 +1095,9 @@ } }, "node_modules/cookie": { - "version": "0.6.0", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", - "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", + "version": "0.7.1", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.1.tgz", + "integrity": "sha512-6DnInpx7SJ2AK3+CTUE/ZM0vWTUboZCegxhC2xiIydHR9jNuTAASBrfEpHhiGOZw/nX51bHt6YQl8jsGo4y/0w==", "dev": true, "engines": { "node": ">= 0.6" @@ -1459,9 +1459,9 @@ } }, "node_modules/express": { - "version": "4.21.0", - "resolved": "https://registry.npmjs.org/express/-/express-4.21.0.tgz", - "integrity": "sha512-VqcNGcj/Id5ZT1LZ/cfihi3ttTn+NJmkli2eZADigjq29qTlWi/hAQ43t/VLPq8+UX06FCEx3ByOYet6ZFblng==", + "version": "4.21.1", + "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz", + "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==", "dev": true, "dependencies": { "accepts": "~1.3.8", @@ -1469,7 +1469,7 @@ "body-parser": "1.20.3", "content-disposition": "0.5.4", "content-type": "~1.0.4", - "cookie": "0.6.0", + "cookie": "0.7.1", "cookie-signature": "1.0.6", "debug": "2.6.9", "depd": "2.0.0", @@ -5247,9 +5247,9 @@ "dev": true }, "cookie": { - "version": "0.6.0", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", - "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", + "version": "0.7.1", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.1.tgz", + "integrity": "sha512-6DnInpx7SJ2AK3+CTUE/ZM0vWTUboZCegxhC2xiIydHR9jNuTAASBrfEpHhiGOZw/nX51bHt6YQl8jsGo4y/0w==", "dev": true }, "cookie-signature": { @@ -5524,9 +5524,9 @@ } }, "express": { - "version": "4.21.0", - "resolved": "https://registry.npmjs.org/express/-/express-4.21.0.tgz", - "integrity": "sha512-VqcNGcj/Id5ZT1LZ/cfihi3ttTn+NJmkli2eZADigjq29qTlWi/hAQ43t/VLPq8+UX06FCEx3ByOYet6ZFblng==", + "version": "4.21.1", + "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz", + "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==", "dev": true, "requires": { "accepts": "~1.3.8", @@ -5534,7 +5534,7 @@ "body-parser": "1.20.3", "content-disposition": "0.5.4", "content-type": "~1.0.4", - "cookie": "0.6.0", + "cookie": "0.7.1", "cookie-signature": "1.0.6", "debug": "2.6.9", "depd": "2.0.0", diff --git a/datafusion/wasmtest/src/lib.rs b/datafusion/wasmtest/src/lib.rs index 0f24449cbed3..085064d16d94 100644 --- a/datafusion/wasmtest/src/lib.rs +++ b/datafusion/wasmtest/src/lib.rs @@ -87,13 +87,14 @@ mod test { wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); - #[wasm_bindgen_test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + #[cfg_attr(not(target_arch = "wasm32"), allow(dead_code))] fn datafusion_test() { basic_exprs(); basic_parse(); } - #[wasm_bindgen_test] + #[wasm_bindgen_test(unsupported = tokio::test)] async fn basic_execute() { let sql = "SELECT 2 + 2;"; diff --git a/dev/update_function_docs.sh b/dev/update_function_docs.sh index a4236eefc8c8..f1f26c8b2f58 100755 --- a/dev/update_function_docs.sh +++ b/dev/update_function_docs.sh @@ -58,7 +58,11 @@ dev/update_function_docs.sh file for updating surrounding text. # Aggregate Functions (NEW) -This page is a WIP and will replace the Aggregate Functions page once completed. +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Aggregate Functions (old)](aggregate_functions.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 Aggregate functions operate on a set of values to compute a single result. EOF @@ -105,7 +109,12 @@ dev/update_function_docs.sh file for updating surrounding text. # Scalar Functions (NEW) -This page is a WIP and will replace the Scalar Functions page once completed. +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Scalar Functions (old)](aggregate_functions.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + EOF echo "Running CLI and inserting scalar function docs table" @@ -151,9 +160,16 @@ dev/update_function_docs.sh file for updating surrounding text. # Window Functions (NEW) -This page is a WIP and will replace the Window Functions page once completed. +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Window Functions (Old)](window_functions.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 -A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. This is comparable to the type of calculation that can be done with an aggregate function. However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result +A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. +This is comparable to the type of calculation that can be done with an aggregate function. +However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. +Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result Here is an example that shows how to compare each employee's salary with the average salary in his or her department: diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index edb0e1d0c9f0..1c25874c0806 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -21,12 +21,15 @@ Aggregate functions operate on a set of values to compute a single result. +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Aggregate Functions (new)](aggregate_functions_new.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + ## General - [avg](#avg) -- [bit_and](#bit_and) -- [bit_or](#bit_or) -- [bit_xor](#bit_xor) - [bool_and](#bool_and) - [bool_or](#bool_or) - [count](#count) @@ -56,45 +59,6 @@ avg(expression) - `mean` -### `bit_and` - -Computes the bitwise AND of all non-null input values. - -``` -bit_and(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `bit_or` - -Computes the bitwise OR of all non-null input values. - -``` -bit_or(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `bit_xor` - -Computes the bitwise exclusive OR of all non-null input values. - -``` -bit_xor(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - ### `bool_and` Returns true if all non-null input values are true, otherwise false. @@ -240,9 +204,6 @@ last_value(expression [ORDER BY expression]) - [stddev](#stddev) - [stddev_pop](#stddev_pop) - [stddev_samp](#stddev_samp) -- [var](#var) -- [var_pop](#var_pop) -- [var_samp](#var_samp) - [regr_avgx](#regr_avgx) - [regr_avgy](#regr_avgy) - [regr_count](#regr_count) @@ -349,45 +310,6 @@ stddev_samp(expression) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `var` - -Returns the statistical variance of a set of numbers. - -``` -var(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `var_pop` - -Returns the statistical population variance of a set of numbers. - -``` -var_pop(expression) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `var_samp` - -Returns the statistical sample variance of a set of numbers. - -``` -var_samp(expression) -``` - -#### Arguments - - **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. diff --git a/docs/source/user-guide/sql/aggregate_functions_new.md b/docs/source/user-guide/sql/aggregate_functions_new.md index 8303c50c2471..08cdb0a9867c 100644 --- a/docs/source/user-guide/sql/aggregate_functions_new.md +++ b/docs/source/user-guide/sql/aggregate_functions_new.md @@ -27,7 +27,11 @@ dev/update_function_docs.sh file for updating surrounding text. # Aggregate Functions (NEW) -This page is a WIP and will replace the Aggregate Functions page once completed. +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Aggregate Functions (old)](aggregate_functions.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 Aggregate functions operate on a set of values to compute a single result. @@ -36,6 +40,11 @@ Aggregate functions operate on a set of values to compute a single result. - [bit_and](#bit_and) - [bit_or](#bit_or) - [bit_xor](#bit_xor) +- [var](#var) +- [var_pop](#var_pop) +- [var_population](#var_population) +- [var_samp](#var_samp) +- [var_sample](#var_sample) ### `bit_and` @@ -47,7 +56,7 @@ bit_and(expression) #### Arguments -- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `bit_or` @@ -59,7 +68,7 @@ bit_or(expression) #### Arguments -- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `bit_xor` @@ -71,4 +80,49 @@ bit_xor(expression) #### Arguments -- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `var` + +Returns the statistical sample variance of a set of numbers. + +``` +var(expression) +``` + +#### Arguments + +- **expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Aliases + +- var_sample +- var_samp + +### `var_pop` + +Returns the statistical population variance of a set of numbers. + +``` +var_pop(expression) +``` + +#### Arguments + +- **expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Aliases + +- var_population + +### `var_population` + +_Alias of [var_pop](#var_pop)._ + +### `var_samp` + +_Alias of [var](#var)._ + +### `var_sample` + +_Alias of [var](#var)._ diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 469fb705b71f..480767389086 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -19,6 +19,14 @@ # Scalar Functions +Scalar functions operate on a single row at a time and return a single value. + +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Scalar Functions (new)](scalar_functions_new.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + ## Math Functions - [abs](#abs) @@ -42,7 +50,6 @@ - [iszero](#iszero) - [lcm](#lcm) - [ln](#ln) -- [log](#log) - [log10](#log10) - [log2](#log2) - [nanvl](#nanvl) @@ -339,23 +346,6 @@ ln(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. -### `log` - -Returns the base-x logarithm of a number. -Can either provide a specified base, or if omitted then takes the base-10 of a number. - -``` -log(base, numeric_expression) -log(numeric_expression) -``` - -#### Arguments - -- **base**: Base numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - ### `log10` Returns the base-10 logarithm of a number. @@ -404,897 +394,229 @@ Returns an approximate value of π. ``` pi() -``` - -### `power` - -Returns a base expression raised to the power of an exponent. - -``` -power(base, exponent) -``` - -#### Arguments - -- **base**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **exponent**: Exponent numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -#### Aliases - -- pow - -### `pow` - -_Alias of [power](#power)._ - -### `radians` - -Converts degrees to radians. - -``` -radians(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `random` - -Returns a random float value in the range [0, 1). -The random seed is unique to each row. - -``` -random() -``` - -### `round` - -Rounds a number to the nearest integer. - -``` -round(numeric_expression[, decimal_places]) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **decimal_places**: Optional. The number of decimal places to round to. - Defaults to 0. - -### `signum` - -Returns the sign of a number. -Negative numbers return `-1`. -Zero and positive numbers return `1`. - -``` -signum(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `sin` - -Returns the sine of a number. - -``` -sin(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `sinh` - -Returns the hyperbolic sine of a number. - -``` -sinh(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `sqrt` - -Returns the square root of a number. - -``` -sqrt(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `tan` - -Returns the tangent of a number. - -``` -tan(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `tanh` - -Returns the hyperbolic tangent of a number. - -``` -tanh(numeric_expression) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `trunc` - -Truncates a number to a whole number or truncated to the specified decimal places. - -``` -trunc(numeric_expression[, decimal_places]) -``` - -#### Arguments - -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -- **decimal_places**: Optional. The number of decimal places to - truncate to. Defaults to 0 (truncate to a whole number). If - `decimal_places` is a positive integer, truncates digits to the - right of the decimal point. If `decimal_places` is a negative - integer, replaces digits to the left of the decimal point with `0`. - -## Conditional Functions - -- [coalesce](#coalesce) -- [nullif](#nullif) -- [nvl](#nvl) -- [nvl2](#nvl2) -- [ifnull](#ifnull) - -### `coalesce` - -Returns the first of its arguments that is not _null_. -Returns _null_ if all arguments are _null_. -This function is often used to substitute a default value for _null_ values. - -``` -coalesce(expression1[, ..., expression_n]) -``` - -#### Arguments - -- **expression1, expression_n**: - Expression to use if previous expressions are _null_. - Can be a constant, column, or function, and any combination of arithmetic operators. - Pass as many expression arguments as necessary. - -### `nullif` - -Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_. -This can be used to perform the inverse operation of [`coalesce`](#coalesce). - -``` -nullif(expression1, expression2) -``` - -#### Arguments - -- **expression1**: Expression to compare and return if equal to expression2. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: Expression to compare to expression1. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `nvl` - -Returns _expression2_ if _expression1_ is NULL; otherwise it returns _expression1_. - -``` -nvl(expression1, expression2) -``` - -#### Arguments - -- **expression1**: return if expression1 not is NULL. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: return if expression1 is NULL. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `nvl2` - -Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_. - -``` -nvl2(expression1, expression2, expression3) -``` - -#### Arguments - -- **expression1**: conditional expression. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: return if expression1 is not NULL. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression3**: return if expression1 is NULL. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `ifnull` - -_Alias of [nvl](#nvl)._ - -## String Functions - -- [ascii](#ascii) -- [bit_length](#bit_length) -- [btrim](#btrim) -- [char_length](#char_length) -- [character_length](#character_length) -- [concat](#concat) -- [concat_ws](#concat_ws) -- [chr](#chr) -- [ends_with](#ends_with) -- [initcap](#initcap) -- [instr](#instr) -- [left](#left) -- [length](#length) -- [lower](#lower) -- [lpad](#lpad) -- [ltrim](#ltrim) -- [octet_length](#octet_length) -- [repeat](#repeat) -- [replace](#replace) -- [reverse](#reverse) -- [right](#right) -- [rpad](#rpad) -- [rtrim](#rtrim) -- [split_part](#split_part) -- [starts_with](#starts_with) -- [strpos](#strpos) -- [substr](#substr) -- [to_hex](#to_hex) -- [translate](#translate) -- [trim](#trim) -- [upper](#upper) -- [uuid](#uuid) -- [overlay](#overlay) -- [levenshtein](#levenshtein) -- [substr_index](#substr_index) -- [find_in_set](#find_in_set) -- [position](#position) -- [contains](#contains) - -### `ascii` - -Returns the ASCII value of the first character in a string. - -``` -ascii(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[chr](#chr) - -### `bit_length` - -Returns the bit length of a string. - -``` -bit_length(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[length](#length), -[octet_length](#octet_length) - -### `btrim` - -Trims the specified trim string from the start and end of a string. -If no trim string is provided, all whitespace is removed from the start and end -of the input string. - -``` -btrim(str[, trim_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **trim_str**: String expression to trim from the beginning and end of the input string. - Can be a constant, column, or function, and any combination of arithmetic operators. - _Default is whitespace characters._ - -**Related functions**: -[ltrim](#ltrim), -[rtrim](#rtrim) - -#### Aliases - -- trim - -### `char_length` - -_Alias of [length](#length)._ - -### `character_length` - -_Alias of [length](#length)._ - -### `concat` - -Concatenates multiple strings together. - -``` -concat(str[, ..., str_n]) -``` - -#### Arguments - -- **str**: String expression to concatenate. - Can be a constant, column, or function, and any combination of string operators. -- **str_n**: Subsequent string column or literal string to concatenate. - -**Related functions**: -[concat_ws](#concat_ws) - -### `concat_ws` - -Concatenates multiple strings together with a specified separator. - -``` -concat_ws(separator, str[, ..., str_n]) -``` - -#### Arguments - -- **separator**: Separator to insert between concatenated strings. -- **str**: String expression to concatenate. - Can be a constant, column, or function, and any combination of string operators. -- **str_n**: Subsequent string column or literal string to concatenate. - -**Related functions**: -[concat](#concat) - -### `chr` - -Returns the character with the specified ASCII or Unicode code value. - -``` -chr(expression) -``` - -#### Arguments - -- **expression**: Expression containing the ASCII or Unicode code value to operate on. - Can be a constant, column, or function, and any combination of arithmetic or - string operators. - -**Related functions**: -[ascii](#ascii) - -### `ends_with` - -Tests if a string ends with a substring. - -``` -ends_with(str, substr) -``` - -#### Arguments - -- **str**: String expression to test. - Can be a constant, column, or function, and any combination of string operators. -- **substr**: Substring to test for. - -### `initcap` - -Capitalizes the first character in each word in the input string. -Words are delimited by non-alphanumeric characters. - -``` -initcap(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[lower](#lower), -[upper](#upper) - -### `instr` - -_Alias of [strpos](#strpos)._ - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **substr**: Substring expression to search for. - Can be a constant, column, or function, and any combination of string operators. - -### `left` - -Returns a specified number of characters from the left side of a string. - -``` -left(str, n) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **n**: Number of characters to return. - -**Related functions**: -[right](#right) - -### `length` - -Returns the number of characters in a string. - -``` -length(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -#### Aliases - -- char_length -- character_length - -**Related functions**: -[bit_length](#bit_length), -[octet_length](#octet_length) - -### `lower` - -Converts a string to lower-case. - -``` -lower(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[initcap](#initcap), -[upper](#upper) - -### `lpad` - -Pads the left side of a string with another string to a specified string length. - -``` -lpad(str, n[, padding_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **n**: String length to pad to. -- **padding_str**: String expression to pad with. - Can be a constant, column, or function, and any combination of string operators. - _Default is a space._ - -**Related functions**: -[rpad](#rpad) - -### `ltrim` - -Trims the specified trim string from the beginning of a string. -If no trim string is provided, all whitespace is removed from the start -of the input string. - -``` -ltrim(str[, trim_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **trim_str**: String expression to trim from the beginning of the input string. - Can be a constant, column, or function, and any combination of arithmetic operators. - _Default is whitespace characters._ - -**Related functions**: -[btrim](#btrim), -[rtrim](#rtrim) - -### `octet_length` - -Returns the length of a string in bytes. - -``` -octet_length(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[bit_length](#bit_length), -[length](#length) - -### `repeat` - -Returns a string with an input string repeated a specified number. - -``` -repeat(str, n) -``` - -#### Arguments - -- **str**: String expression to repeat. - Can be a constant, column, or function, and any combination of string operators. -- **n**: Number of times to repeat the input string. - -### `replace` - -Replaces all occurrences of a specified substring in a string with a new substring. - -``` -replace(str, substr, replacement) -``` - -#### Arguments - -- **str**: String expression to repeat. - Can be a constant, column, or function, and any combination of string operators. -- **substr**: Substring expression to replace in the input string. - Can be a constant, column, or function, and any combination of string operators. -- **replacement**: Replacement substring expression. - Can be a constant, column, or function, and any combination of string operators. - -### `reverse` - -Reverses the character order of a string. - -``` -reverse(str) -``` - -#### Arguments - -- **str**: String expression to repeat. - Can be a constant, column, or function, and any combination of string operators. - -### `right` - -Returns a specified number of characters from the right side of a string. - -``` -right(str, n) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **n**: Number of characters to return. - -**Related functions**: -[left](#left) - -### `rpad` - -Pads the right side of a string with another string to a specified string length. - -``` -rpad(str, n[, padding_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **n**: String length to pad to. -- **padding_str**: String expression to pad with. - Can be a constant, column, or function, and any combination of string operators. - _Default is a space._ - -**Related functions**: -[lpad](#lpad) - -### `rtrim` - -Trims the specified trim string from the end of a string. -If no trim string is provided, all whitespace is removed from the end -of the input string. - -``` -rtrim(str[, trim_str]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **trim_str**: String expression to trim from the end of the input string. - Can be a constant, column, or function, and any combination of arithmetic operators. - _Default is whitespace characters._ - -**Related functions**: -[btrim](#btrim), -[ltrim](#ltrim) - -### `split_part` - -Splits a string based on a specified delimiter and returns the substring in the -specified position. - -``` -split_part(str, delimiter, pos) -``` - -#### Arguments - -- **str**: String expression to spit. - Can be a constant, column, or function, and any combination of string operators. -- **delimiter**: String or character to split on. -- **pos**: Position of the part to return. - -### `starts_with` - -Tests if a string starts with a substring. - -``` -starts_with(str, substr) -``` - -#### Arguments - -- **str**: String expression to test. - Can be a constant, column, or function, and any combination of string operators. -- **substr**: Substring to test for. +``` -### `strpos` +### `power` -Returns the starting position of a specified substring in a string. -Positions begin at 1. -If the substring does not exist in the string, the function returns 0. +Returns a base expression raised to the power of an exponent. ``` -strpos(str, substr) +power(base, exponent) ``` #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **substr**: Substring expression to search for. - Can be a constant, column, or function, and any combination of string operators. +- **base**: Numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **exponent**: Exponent numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. #### Aliases -- instr +- pow + +### `pow` + +_Alias of [power](#power)._ -### `substr` +### `radians` -Extracts a substring of a specified number of characters from a specific -starting position in a string. +Converts degrees to radians. ``` -substr(str, start_pos[, length]) +radians(numeric_expression) ``` #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **start_pos**: Character position to start the substring at. - The first character in the string has a position of 1. -- **length**: Number of characters to extract. - If not specified, returns the rest of the string after the start position. - -#### Aliases +- **numeric_expression**: Numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. -- substring +### `random` -### `substring` +Returns a random float value in the range [0, 1). +The random seed is unique to each row. -_Alias of [substr](#substr)._ +``` +random() +``` -### `translate` +### `round` -Translates characters in a string to specified translation characters. +Rounds a number to the nearest integer. ``` -translate(str, chars, translation) +round(numeric_expression[, decimal_places]) ``` -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **chars**: Characters to translate. -- **translation**: Translation characters. Translation characters replace only - characters at the same position in the **chars** string. +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **decimal_places**: Optional. The number of decimal places to round to. + Defaults to 0. -### `to_hex` +### `signum` -Converts an integer to a hexadecimal string. +Returns the sign of a number. +Negative numbers return `-1`. +Zero and positive numbers return `1`. ``` -to_hex(int) +signum(numeric_expression) ``` #### Arguments -- **int**: Integer expression to convert. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. -### `trim` - -_Alias of [btrim](#btrim)._ - -### `upper` +### `sin` -Converts a string to upper-case. +Returns the sine of a number. ``` -upper(str) +sin(numeric_expression) ``` #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[initcap](#initcap), -[lower](#lower) +- **numeric_expression**: Numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. -### `uuid` +### `sinh` -Returns UUID v4 string value which is unique per row. +Returns the hyperbolic sine of a number. ``` -uuid() +sinh(numeric_expression) ``` -### `overlay` +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `sqrt` -Returns the string which is replaced by another string from the specified position and specified count length. -For example, `overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas` +Returns the square root of a number. ``` -overlay(str PLACING substr FROM pos [FOR count]) +sqrt(numeric_expression) ``` #### Arguments -- **str**: String expression to operate on. -- **substr**: the string to replace part of str. -- **pos**: the start position to replace of str. -- **count**: the count of characters to be replaced from start position of str. If not specified, will use substr length instead. +- **numeric_expression**: Numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. -### `levenshtein` +### `tan` -Returns the Levenshtein distance between the two given strings. -For example, `levenshtein('kitten', 'sitting') = 3` +Returns the tangent of a number. ``` -levenshtein(str1, str2) +tan(numeric_expression) ``` #### Arguments -- **str1**: String expression to compute Levenshtein distance with str2. -- **str2**: String expression to compute Levenshtein distance with str1. +- **numeric_expression**: Numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. -### `substr_index` +### `tanh` -Returns the substring from str before count occurrences of the delimiter delim. -If count is positive, everything to the left of the final delimiter (counting from the left) is returned. -If count is negative, everything to the right of the final delimiter (counting from the right) is returned. -For example, `substr_index('www.apache.org', '.', 1) = www`, `substr_index('www.apache.org', '.', -1) = org` +Returns the hyperbolic tangent of a number. ``` -substr_index(str, delim, count) +tanh(numeric_expression) ``` #### Arguments -- **str**: String expression to operate on. -- **delim**: the string to find in str to split str. -- **count**: The number of times to search for the delimiter. Can be both a positive or negative number. +- **numeric_expression**: Numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. -### `find_in_set` +### `trunc` -Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings. -For example, `find_in_set('b', 'a,b,c,d') = 2` +Truncates a number to a whole number or truncated to the specified decimal places. ``` -find_in_set(str, strlist) +trunc(numeric_expression[, decimal_places]) ``` #### Arguments -- **str**: String expression to find in strlist. -- **strlist**: A string list is a string composed of substrings separated by , characters. +- **numeric_expression**: Numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + +- **decimal_places**: Optional. The number of decimal places to + truncate to. Defaults to 0 (truncate to a whole number). If + `decimal_places` is a positive integer, truncates digits to the + right of the decimal point. If `decimal_places` is a negative + integer, replaces digits to the left of the decimal point with `0`. -## Binary String Functions +## Conditional Functions -- [decode](#decode) -- [encode](#encode) +- [nullif](#nullif) +- [nvl](#nvl) +- [nvl2](#nvl2) +- [ifnull](#ifnull) -### `encode` +### `nullif` -Encode binary data into a textual representation. +Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_. +This can be used to perform the inverse operation of [`coalesce`](#coalesce). ``` -encode(expression, format) +nullif(expression1, expression2) ``` #### Arguments -- **expression**: Expression containing string or binary data +- **expression1**: Expression to compare and return if equal to expression2. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression2**: Expression to compare to expression1. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `nvl` + +Returns _expression2_ if _expression1_ is NULL; otherwise it returns _expression1_. + +``` +nvl(expression1, expression2) +``` -- **format**: Supported formats are: `base64`, `hex` +#### Arguments -**Related functions**: -[decode](#decode) +- **expression1**: return if expression1 not is NULL. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression2**: return if expression1 is NULL. + Can be a constant, column, or function, and any combination of arithmetic operators. -### `decode` +### `nvl2` -Decode binary data from textual representation in string. +Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_. ``` -decode(expression, format) +nvl2(expression1, expression2, expression3) ``` #### Arguments -- **expression**: Expression containing encoded string data +- **expression1**: conditional expression. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression2**: return if expression1 is not NULL. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression3**: return if expression1 is NULL. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `ifnull` + +_Alias of [nvl](#nvl)._ -- **format**: Same arguments as [encode](#encode) +## String Functions -**Related functions**: -[encode](#encode) +See the new documentation [`here`](https://datafusion.apache.org/user-guide/sql/scalar_functions_new.html) ## Regular Expression Functions @@ -1302,57 +624,12 @@ Apache DataFusion uses a [PCRE-like] regular expression [syntax] (minus support for several features including look-around and backreferences). The following regular expression functions are supported: -- [regexp_like](#regexp_like) - [regexp_match](#regexp_match) - [regexp_replace](#regexp_replace) [pcre-like]: https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions [syntax]: https://docs.rs/regex/latest/regex/#syntax -### `regexp_like` - -Returns true if a [regular expression] has at least one match in a string, -false otherwise. - -[regular expression]: https://docs.rs/regex/latest/regex/#syntax - -``` -regexp_like(str, regexp[, flags]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **regexp**: Regular expression to test against the string expression. - Can be a constant, column, or function. -- **flags**: Optional regular expression flags that control the behavior of the - regular expression. The following flags are supported: - - **i**: case-insensitive: letters match both upper and lower case - - **m**: multi-line mode: ^ and $ match begin/end of line - - **s**: allow . to match \n - - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - - **U**: swap the meaning of x* and x*? - -#### Example - -```sql -select regexp_like('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); -+--------------------------------------------------------+ -| regexp_like(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | -+--------------------------------------------------------+ -| true | -+--------------------------------------------------------+ -SELECT regexp_like('aBc', '(b|d)', 'i'); -+--------------------------------------------------+ -| regexp_like(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | -+--------------------------------------------------+ -| true | -+--------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) - ### `regexp_match` Returns a list of [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. @@ -1452,19 +729,6 @@ position(substr in origstr) - **substr**: The pattern string. - **origstr**: The model string. -### `contains` - -Return true if search_string is found within string (case-sensitive). - -``` -contains(string, search_string) -``` - -#### Arguments - -- **string**: The pattern string. -- **search_string**: The model string. - ## Time and Date Functions - [now](#now) @@ -1479,7 +743,6 @@ contains(string, search_string) - [today](#today) - [make_date](#make_date) - [to_char](#to_char) -- [to_date](#to_date) - [to_local_time](#to_local_time) - [to_timestamp](#to_timestamp) - [to_timestamp_millis](#to_timestamp_millis) @@ -1727,49 +990,6 @@ Additional examples can be found [here] - date_format -### `to_date` - -Converts a value to a date (`YYYY-MM-DD`). -Supports strings, integer and double types as input. -Strings are parsed as YYYY-MM-DD (e.g. '2023-07-20') if no [Chrono format]s are provided. -Integers and doubles are interpreted as days since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding date. - -Note: `to_date` returns Date32, which represents its values as the number of days since unix epoch(`1970-01-01`) stored as signed 32 bit value. The largest supported date value is `9999-12-31`. - -``` -to_date(expression[, ..., format_n]) -``` - -#### Arguments - -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. - -[chrono format]: https://docs.rs/chrono/latest/chrono/format/strftime/index.html - -#### Example - -``` -> select to_date('2023-01-31'); -+-----------------------------+ -| to_date(Utf8("2023-01-31")) | -+-----------------------------+ -| 2023-01-31 | -+-----------------------------+ -> select to_date('2023/01/31', '%Y-%m-%d', '%Y/%m/%d'); -+---------------------------------------------------------------+ -| to_date(Utf8("2023/01/31"),Utf8("%Y-%m-%d"),Utf8("%Y/%m/%d")) | -+---------------------------------------------------------------+ -| 2023-01-31 | -+---------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) - ### `to_local_time` Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or @@ -3846,104 +3066,6 @@ select map_values(map([100, 5], [42,43])); [42, 43] ``` -## Hashing Functions - -- [digest](#digest) -- [md5](#md5) -- [sha224](#sha224) -- [sha256](#sha256) -- [sha384](#sha384) -- [sha512](#sha512) - -### `digest` - -Computes the binary hash of an expression using the specified algorithm. - -``` -digest(expression, algorithm) -``` - -#### Arguments - -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **algorithm**: String expression specifying algorithm to use. - Must be one of: - - - md5 - - sha224 - - sha256 - - sha384 - - sha512 - - blake2s - - blake2b - - blake3 - -### `md5` - -Computes an MD5 128-bit checksum for a string expression. - -``` -md5(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -### `sha224` - -Computes the SHA-224 hash of a binary string. - -``` -sha224(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -### `sha256` - -Computes the SHA-256 hash of a binary string. - -``` -sha256(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -### `sha384` - -Computes the SHA-384 hash of a binary string. - -``` -sha384(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -### `sha512` - -Computes the SHA-512 hash of a binary string. - -``` -sha512(expression) -``` - -#### Arguments - -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - ## Other Functions - [arrow_cast](#arrow_cast) diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index ae2744c1650e..673c55f46b15 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -27,62 +27,692 @@ dev/update_function_docs.sh file for updating surrounding text. # Scalar Functions (NEW) -This page is a WIP and will replace the Scalar Functions page once completed. +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Scalar Functions (old)](aggregate_functions.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 ## Math Functions - [log](#log) -### `log` +### `log` + +Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number. + +``` +log(base, numeric_expression) +log(numeric_expression) +``` + +#### Arguments + +- **base**: Base numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +## Conditional Functions + +- [coalesce](#coalesce) + +### `coalesce` + +Returns the first of its arguments that is not _null_. Returns _null_ if all arguments are _null_. This function is often used to substitute a default value for _null_ values. + +``` +coalesce(expression1[, ..., expression_n]) +``` + +#### Arguments + +- **expression1, expression_n**: Expression to use if previous expressions are _null_. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary. + +## String Functions + +- [ascii](#ascii) +- [bit_length](#bit_length) +- [btrim](#btrim) +- [char_length](#char_length) +- [character_length](#character_length) +- [chr](#chr) +- [concat](#concat) +- [concat_ws](#concat_ws) +- [contains](#contains) +- [ends_with](#ends_with) +- [find_in_set](#find_in_set) +- [initcap](#initcap) +- [instr](#instr) +- [left](#left) +- [length](#length) +- [levenshtein](#levenshtein) +- [lower](#lower) +- [lpad](#lpad) +- [ltrim](#ltrim) +- [octet_length](#octet_length) +- [position](#position) +- [repeat](#repeat) +- [replace](#replace) +- [reverse](#reverse) +- [right](#right) +- [rpad](#rpad) +- [rtrim](#rtrim) +- [split_part](#split_part) +- [starts_with](#starts_with) +- [strpos](#strpos) +- [substr](#substr) +- [substr_index](#substr_index) +- [substring](#substring) +- [substring_index](#substring_index) +- [to_hex](#to_hex) +- [translate](#translate) +- [trim](#trim) +- [upper](#upper) +- [uuid](#uuid) + +### `ascii` + +Returns the Unicode character code of the first character in a string. + +``` +ascii(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select ascii('abc'); ++--------------------+ +| ascii(Utf8("abc")) | ++--------------------+ +| 97 | ++--------------------+ +> select ascii('🚀'); ++-------------------+ +| ascii(Utf8("🚀")) | ++-------------------+ +| 128640 | ++-------------------+ +``` + +**Related functions**: + +- [chr](#chr) + +### `bit_length` + +Returns the bit length of a string. + +``` +bit_length(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select bit_length('datafusion'); ++--------------------------------+ +| bit_length(Utf8("datafusion")) | ++--------------------------------+ +| 80 | ++--------------------------------+ +``` + +**Related functions**: + +- [length](#length) +- [octet_length](#octet_length) + +### `btrim` + +Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string. + +``` +btrim(str[, trim_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **trim_str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._ + +#### Example + +```sql +> select btrim('__datafusion____', '_'); ++-------------------------------------------+ +| btrim(Utf8("__datafusion____"),Utf8("_")) | ++-------------------------------------------+ +| datafusion | ++-------------------------------------------+ +``` + +#### Aliases + +- trim + +**Related functions**: + +- [ltrim](#ltrim) +- [rtrim](#rtrim) + +### `char_length` + +_Alias of [character_length](#character_length)._ + +### `character_length` + +Returns the number of characters in a string. + +``` +character_length(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select character_length('Ångström'); ++------------------------------------+ +| character_length(Utf8("Ångström")) | ++------------------------------------+ +| 8 | ++------------------------------------+ +``` + +#### Aliases + +- length +- char_length + +**Related functions**: + +- [bit_length](#bit_length) +- [octet_length](#octet_length) + +### `chr` + +Returns the character with the specified ASCII or Unicode code value. + +``` +chr(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select chr(128640); ++--------------------+ +| chr(Int64(128640)) | ++--------------------+ +| 🚀 | ++--------------------+ +``` + +**Related functions**: + +- [ascii](#ascii) + +### `concat` + +Concatenates multiple strings together. + +``` +concat(str[, ..., str_n]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **str_n**: Subsequent string expressions to concatenate. + +#### Example + +```sql +> select concat('data', 'f', 'us', 'ion'); ++-------------------------------------------------------+ +| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) | ++-------------------------------------------------------+ +| datafusion | ++-------------------------------------------------------+ +``` + +**Related functions**: + +- [concat_ws](#concat_ws) + +### `concat_ws` + +Concatenates multiple strings together with a specified separator. + +``` +concat_ws(separator, str[, ..., str_n]) +``` + +#### Arguments + +- **separator**: Separator to insert between concatenated strings. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **str_n**: Subsequent string expressions to concatenate. expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select concat_ws('_', 'data', 'fusion'); ++--------------------------------------------------+ +| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) | ++--------------------------------------------------+ +| data_fusion | ++--------------------------------------------------+ +``` + +**Related functions**: + +- [concat](#concat) + +### `contains` + +Return true if search_str is found within string (case-sensitive). + +``` +contains(str, search_str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **search_str**: The string to search for in str. + +#### Example + +```sql +> select contains('the quick brown fox', 'row'); ++---------------------------------------------------+ +| contains(Utf8("the quick brown fox"),Utf8("row")) | ++---------------------------------------------------+ +| true | ++---------------------------------------------------+ +``` + +### `ends_with` + +Tests if a string ends with a substring. + +``` +ends_with(str, substr) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring to test for. + +#### Example + +```sql +> select ends_with('datafusion', 'soin'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("soin")) | ++--------------------------------------------+ +| false | ++--------------------------------------------+ +> select ends_with('datafusion', 'sion'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("sion")) | ++--------------------------------------------+ +| true | ++--------------------------------------------+ +``` + +### `find_in_set` + +Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings. + +``` +find_in_set(str, strlist) +``` + +#### Arguments + +- **str**: String expression to find in strlist. +- **strlist**: A string list is a string composed of substrings separated by , characters. + +#### Example + +```sql +> select find_in_set('b', 'a,b,c,d'); ++----------------------------------------+ +| find_in_set(Utf8("b"),Utf8("a,b,c,d")) | ++----------------------------------------+ +| 2 | ++----------------------------------------+ +``` + +### `initcap` + +Capitalizes the first character in each word in the input string. Words are delimited by non-alphanumeric characters. + +``` +initcap(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select initcap('apache datafusion'); ++------------------------------------+ +| initcap(Utf8("apache datafusion")) | ++------------------------------------+ +| Apache Datafusion | ++------------------------------------+ +``` + +**Related functions**: + +- [lower](#lower) +- [upper](#upper) + +### `instr` + +_Alias of [strpos](#strpos)._ + +### `left` + +Returns a specified number of characters from the left side of a string. + +``` +left(str, n) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: Number of characters to return. + +#### Example + +```sql +> select left('datafusion', 4); ++-----------------------------------+ +| left(Utf8("datafusion"),Int64(4)) | ++-----------------------------------+ +| data | ++-----------------------------------+ +``` + +**Related functions**: + +- [right](#right) + +### `length` + +_Alias of [character_length](#character_length)._ + +### `levenshtein` + +Returns the [`Levenshtein distance`](https://en.wikipedia.org/wiki/Levenshtein_distance) between the two given strings. + +``` +levenshtein(str1, str2) +``` + +#### Arguments + +- **str1**: String expression to compute Levenshtein distance with str2. +- **str2**: String expression to compute Levenshtein distance with str1. + +#### Example + +```sql +> select levenshtein('kitten', 'sitting'); ++---------------------------------------------+ +| levenshtein(Utf8("kitten"),Utf8("sitting")) | ++---------------------------------------------+ +| 3 | ++---------------------------------------------+ +``` + +### `lower` + +Converts a string to lower-case. + +``` +lower(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select lower('Ångström'); ++-------------------------+ +| lower(Utf8("Ångström")) | ++-------------------------+ +| ångström | ++-------------------------+ +``` + +**Related functions**: + +- [initcap](#initcap) +- [upper](#upper) + +### `lpad` + +Pads the left side of a string with another string to a specified string length. + +``` +lpad(str, n[, padding_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: String length to pad to. +- **padding_str**: Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._ + +#### Example + +```sql +> select lpad('Dolly', 10, 'hello'); ++---------------------------------------------+ +| lpad(Utf8("Dolly"),Int64(10),Utf8("hello")) | ++---------------------------------------------+ +| helloDolly | ++---------------------------------------------+ +``` + +**Related functions**: + +- [rpad](#rpad) + +### `ltrim` + +Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string. + +``` +ltrim(str[, trim_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **trim_str**: String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._ + +#### Example + +```sql +> select ltrim(' datafusion '); ++-------------------------------+ +| ltrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select ltrim('___datafusion___', '_'); ++-------------------------------------------+ +| ltrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| datafusion___ | ++-------------------------------------------+ +``` + +**Related functions**: + +- [btrim](#btrim) +- [rtrim](#rtrim) + +### `octet_length` + +Returns the length of a string in bytes. + +``` +octet_length(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select octet_length('Ångström'); ++--------------------------------+ +| octet_length(Utf8("Ångström")) | ++--------------------------------+ +| 10 | ++--------------------------------+ +``` + +**Related functions**: + +- [bit_length](#bit_length) +- [length](#length) + +### `position` + +_Alias of [strpos](#strpos)._ + +### `repeat` + +Returns a string with an input string repeated a specified number. + +``` +repeat(str, n) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: Number of times to repeat the input string. + +#### Example + +```sql +> select repeat('data', 3); ++-------------------------------+ +| repeat(Utf8("data"),Int64(3)) | ++-------------------------------+ +| datadatadata | ++-------------------------------+ +``` + +### `replace` -Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number. +Replaces all occurrences of a specified substring in a string with a new substring. ``` -log(base, numeric_expression) -log(numeric_expression) +replace(str, substr, replacement) ``` #### Arguments -- **base**: Base numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. -- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring expression to replace in the input string. Substring expression expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **replacement**: Replacement substring expression to operate on. Can be a constant, column, or function, and any combination of operators. -## Conditional Functions +#### Example -- [coalesce](#coalesce) +```sql +> select replace('ABabbaBA', 'ab', 'cd'); ++-------------------------------------------------+ +| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) | ++-------------------------------------------------+ +| ABcdbaBA | ++-------------------------------------------------+ +``` -### `coalesce` +### `reverse` -Returns the first of its arguments that is not _null_. Returns _null_ if all arguments are _null_. This function is often used to substitute a default value for _null_ values. +Reverses the character order of a string. ``` -coalesce(expression1[, ..., expression_n]) +reverse(str) ``` #### Arguments -- **expression1, expression_n**: Expression to use if previous expressions are _null_. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -## String Functions +#### Example -- [ascii](#ascii) -- [rpad](#rpad) +```sql +> select reverse('datafusion'); ++-----------------------------+ +| reverse(Utf8("datafusion")) | ++-----------------------------+ +| noisufatad | ++-----------------------------+ +``` -### `ascii` +### `right` -Returns the ASCII value of the first character in a string. +Returns a specified number of characters from the right side of a string. ``` -ascii(str) +right(str, n) ``` #### Arguments -- **str**: String expression to operate on. Can be a constant, column, or function that evaluates to or can be coerced to a Utf8, LargeUtf8 or a Utf8View. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: Number of characters to return + +#### Example + +```sql +> select right('datafusion', 6); ++------------------------------------+ +| right(Utf8("datafusion"),Int64(6)) | ++------------------------------------+ +| fusion | ++------------------------------------+ +``` **Related functions**: -- [chr](#chr) +- [left](#left) ### `rpad` @@ -94,14 +724,311 @@ rpad(str, n[, padding_str]) #### Arguments -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **n**: String length to pad to. - **padding_str**: String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._ +#### Example + +```sql +> select rpad('datafusion', 20, '_-'); ++-----------------------------------------------+ +| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) | ++-----------------------------------------------+ +| datafusion_-_-_-_-_- | ++-----------------------------------------------+ +``` + **Related functions**: - [lpad](#lpad) +### `rtrim` + +Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string. + +``` +rtrim(str[, trim_str]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **trim_str**: String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._ + +#### Example + +```sql +> select rtrim(' datafusion '); ++-------------------------------+ +| rtrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select rtrim('___datafusion___', '_'); ++-------------------------------------------+ +| rtrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| ___datafusion | ++-------------------------------------------+ +``` + +**Related functions**: + +- [btrim](#btrim) +- [ltrim](#ltrim) + +### `split_part` + +Splits a string based on a specified delimiter and returns the substring in the specified position. + +``` +split_part(str, delimiter, pos) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **delimiter**: String or character to split on. +- **pos**: Position of the part to return. + +#### Example + +```sql +> select split_part('1.2.3.4.5', '.', 3); ++--------------------------------------------------+ +| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) | ++--------------------------------------------------+ +| 3 | ++--------------------------------------------------+ +``` + +### `starts_with` + +Tests if a string starts with a substring. + +``` +starts_with(str, substr) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring to test for. + +#### Example + +```sql +> select starts_with('datafusion','data'); ++----------------------------------------------+ +| starts_with(Utf8("datafusion"),Utf8("data")) | ++----------------------------------------------+ +| true | ++----------------------------------------------+ +``` + +### `strpos` + +Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0. + +``` +strpos(str, substr) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring expression to search for. + +#### Example + +```sql +> select strpos('datafusion', 'fus'); ++----------------------------------------+ +| strpos(Utf8("datafusion"),Utf8("fus")) | ++----------------------------------------+ +| 5 | ++----------------------------------------+ +``` + +#### Aliases + +- instr +- position + +### `substr` + +Extracts a substring of a specified number of characters from a specific starting position in a string. + +``` +substr(str, start_pos[, length]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **start_pos**: Character position to start the substring at. The first character in the string has a position of 1. +- **length**: Number of characters to extract. If not specified, returns the rest of the string after the start position. + +#### Example + +```sql +> select substr('datafusion', 5, 3); ++----------------------------------------------+ +| substr(Utf8("datafusion"),Int64(5),Int64(3)) | ++----------------------------------------------+ +| fus | ++----------------------------------------------+ +``` + +#### Aliases + +- substring + +### `substr_index` + +Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned. + +``` +substr_index(str, delim, count) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **delim**: The string to find in str to split str. +- **count**: The number of times to search for the delimiter. Can be either a positive or negative number. + +#### Example + +```sql +> select substr_index('www.apache.org', '.', 1); ++---------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) | ++---------------------------------------------------------+ +| www | ++---------------------------------------------------------+ +> select substr_index('www.apache.org', '.', -1); ++----------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) | ++----------------------------------------------------------+ +| org | ++----------------------------------------------------------+ +``` + +#### Aliases + +- substring_index + +### `substring` + +_Alias of [substr](#substr)._ + +### `substring_index` + +_Alias of [substr_index](#substr_index)._ + +### `to_hex` + +Converts an integer to a hexadecimal string. + +``` +to_hex(int) +``` + +#### Arguments + +- **int**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select to_hex(12345689); ++-------------------------+ +| to_hex(Int64(12345689)) | ++-------------------------+ +| bc6159 | ++-------------------------+ +``` + +### `translate` + +Translates characters in a string to specified translation characters. + +``` +translate(str, chars, translation) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **chars**: Characters to translate. +- **translation**: Translation characters. Translation characters replace only characters at the same position in the **chars** string. + +#### Example + +```sql +> select translate('twice', 'wic', 'her'); ++--------------------------------------------------+ +| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) | ++--------------------------------------------------+ +| there | ++--------------------------------------------------+ +``` + +### `trim` + +_Alias of [btrim](#btrim)._ + +### `upper` + +Converts a string to upper-case. + +``` +upper(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select upper('dataFusion'); ++---------------------------+ +| upper(Utf8("dataFusion")) | ++---------------------------+ +| DATAFUSION | ++---------------------------+ +``` + +**Related functions**: + +- [initcap](#initcap) +- [lower](#lower) + +### `uuid` + +Returns [`UUID v4`]() string value which is unique per row. + +``` +uuid() +``` + +#### Example + +```sql +> select uuid(); ++--------------------------------------+ +| uuid() | ++--------------------------------------+ +| 6ec17ef8-1934-41cc-8d59-d0c8f9eea1f0 | ++--------------------------------------+ +``` + ## Binary String Functions - [decode](#decode) @@ -160,8 +1087,8 @@ regexp_like(str, regexp[, flags]) #### Arguments -- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of string operators. -- **regexp**: Regular expression to test against the string expression. Can be a constant, column, or function. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. - **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - **i**: case-insensitive: letters match both upper and lower case - **m**: multi-line mode: ^ and $ match begin/end of line @@ -208,7 +1135,7 @@ to_date('2017-05-31', '%Y-%m-%d') #### Arguments -- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. @@ -234,7 +1161,67 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ## Hashing Functions +- [digest](#digest) +- [md5](#md5) - [sha224](#sha224) +- [sha256](#sha256) +- [sha384](#sha384) +- [sha512](#sha512) + +### `digest` + +Computes the binary hash of an expression using the specified algorithm. + +``` +digest(expression, algorithm) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **algorithm**: String expression specifying algorithm to use. Must be one of: +- md5 +- sha224 +- sha256 +- sha384 +- sha512 +- blake2s +- blake2b +- blake3 + +#### Example + +```sql +> select digest('foo', 'sha256'); ++------------------------------------------+ +| digest(Utf8("foo"), Utf8("sha256")) | ++------------------------------------------+ +| | ++------------------------------------------+ +``` + +### `md5` + +Computes an MD5 128-bit checksum for a string expression. + +``` +md5(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select md5('foo'); ++-------------------------------------+ +| md5(Utf8("foo")) | ++-------------------------------------+ +| | ++-------------------------------------+ +``` ### `sha224` @@ -246,4 +1233,84 @@ sha224(expression) #### Arguments -- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of string operators. +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select sha224('foo'); ++------------------------------------------+ +| sha224(Utf8("foo")) | ++------------------------------------------+ +| | ++------------------------------------------+ +``` + +### `sha256` + +Computes the SHA-256 hash of a binary string. + +``` +sha256(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select sha256('foo'); ++--------------------------------------+ +| sha256(Utf8("foo")) | ++--------------------------------------+ +| | ++--------------------------------------+ +``` + +### `sha384` + +Computes the SHA-384 hash of a binary string. + +``` +sha384(expression) +``` + +#### Arguments + +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select sha384('foo'); ++-----------------------------------------+ +| sha384(Utf8("foo")) | ++-----------------------------------------+ +| | ++-----------------------------------------+ +``` + +### `sha512` + +Computes the SHA-512 hash of a binary string. + +``` +sha512(expression) +``` + +#### Arguments + +- **expression**: String + +#### Example + +```sql +> select sha512('foo'); ++-------------------------------------------+ +| sha512(Utf8("foo")) | ++-------------------------------------------+ +| | ++-------------------------------------------+ +``` diff --git a/docs/source/user-guide/sql/window_functions.md b/docs/source/user-guide/sql/window_functions.md index 4d9d2557249f..6c0de711bc0c 100644 --- a/docs/source/user-guide/sql/window_functions.md +++ b/docs/source/user-guide/sql/window_functions.md @@ -19,7 +19,15 @@ # Window Functions -A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. This is comparable to the type of calculation that can be done with an aggregate function. However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result +A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. + +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Window Functions (new)](window_functions_new.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + +Window functions are comparable to the type of calculation that can be done with an aggregate function. However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result Here is an example that shows how to compare each employee's salary with the average salary in his or her department: @@ -140,19 +148,10 @@ All [aggregate functions](aggregate_functions.md) can be used as window function ## Ranking functions -- [row_number](#row_number) - [rank](#rank) - [dense_rank](#dense_rank) - [ntile](#ntile) -### `row_number` - -Number of the current row within its partition, counting from 1. - -```sql -row_number() -``` - ### `rank` Rank of the current row with gaps; same as row_number of its first peer. diff --git a/docs/source/user-guide/sql/window_functions_new.md b/docs/source/user-guide/sql/window_functions_new.md index 1ab6740a6f87..ee981911f5cb 100644 --- a/docs/source/user-guide/sql/window_functions_new.md +++ b/docs/source/user-guide/sql/window_functions_new.md @@ -27,9 +27,16 @@ dev/update_function_docs.sh file for updating surrounding text. # Window Functions (NEW) -This page is a WIP and will replace the Window Functions page once completed. +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Window Functions (Old)](window_functions.md) page for +the rest of the documentation. -A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. This is comparable to the type of calculation that can be done with an aggregate function. However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + +A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. +This is comparable to the type of calculation that can be done with an aggregate function. +However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. +Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result Here is an example that shows how to compare each employee's salary with the average salary in his or her department: diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 325a2cc2fcc4..414fa5569cfe 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -29,4 +29,5 @@ workspace = true arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } env_logger = { workspace = true } +paste = "1.0.15" rand = { workspace = true } diff --git a/test-utils/src/array_gen/mod.rs b/test-utils/src/array_gen/mod.rs new file mode 100644 index 000000000000..4a799ae737d7 --- /dev/null +++ b/test-utils/src/array_gen/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod primitive; +mod string; + +pub use primitive::PrimitiveArrayGenerator; +pub use string::StringArrayGenerator; diff --git a/test-utils/src/array_gen/primitive.rs b/test-utils/src/array_gen/primitive.rs new file mode 100644 index 000000000000..f70ebf6686d0 --- /dev/null +++ b/test-utils/src/array_gen/primitive.rs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, PrimitiveArray, UInt32Array}; +use arrow::datatypes::{ + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, +}; +use rand::rngs::StdRng; +use rand::Rng; + +/// Randomly generate primitive array +pub struct PrimitiveArrayGenerator { + /// the total number of strings in the output + pub num_primitives: usize, + /// The number of distinct strings in the columns + pub num_distinct_primitives: usize, + /// The percentage of nulls in the columns + pub null_pct: f64, + /// Random number generator + pub rng: StdRng, +} + +macro_rules! impl_gen_data { + ($NATIVE_TYPE:ty, $ARROW_TYPE:ident) => { + paste::paste! { + pub fn [< gen_data_ $NATIVE_TYPE >](&mut self) -> ArrayRef { + // table of strings from which to draw + let distinct_primitives: PrimitiveArray<$ARROW_TYPE> = (0..self.num_distinct_primitives) + .map(|_| Some(self.rng.gen::<$NATIVE_TYPE>())) + .collect(); + + // pick num_strings randomly from the distinct string table + let indicies: UInt32Array = (0..self.num_primitives) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_primitives > 1 { + let range = 1..(self.num_distinct_primitives as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + arrow::compute::take(&distinct_primitives, &indicies, options).unwrap() + } + } + }; +} + +// TODO: support generating more primitive arrays +impl PrimitiveArrayGenerator { + impl_gen_data!(i8, Int8Type); + impl_gen_data!(i16, Int16Type); + impl_gen_data!(i32, Int32Type); + impl_gen_data!(i64, Int64Type); + impl_gen_data!(u8, UInt8Type); + impl_gen_data!(u16, UInt16Type); + impl_gen_data!(u32, UInt32Type); + impl_gen_data!(u64, UInt64Type); + impl_gen_data!(f32, Float32Type); + impl_gen_data!(f64, Float64Type); +} diff --git a/test-utils/src/array_gen/string.rs b/test-utils/src/array_gen/string.rs new file mode 100644 index 000000000000..fbfa2bb941e0 --- /dev/null +++ b/test-utils/src/array_gen/string.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, UInt32Array}; +use rand::rngs::StdRng; +use rand::Rng; + +/// Randomly generate string arrays +pub struct StringArrayGenerator { + //// The maximum length of the strings + pub max_len: usize, + /// the total number of strings in the output + pub num_strings: usize, + /// The number of distinct strings in the columns + pub num_distinct_strings: usize, + /// The percentage of nulls in the columns + pub null_pct: f64, + /// Random number generator + pub rng: StdRng, +} + +impl StringArrayGenerator { + /// Creates a StringArray or LargeStringArray with random strings according + /// to the parameters of the BatchGenerator + pub fn gen_data(&mut self) -> ArrayRef { + // table of strings from which to draw + let distinct_strings: GenericStringArray = (0..self.num_distinct_strings) + .map(|_| Some(random_string(&mut self.rng, self.max_len))) + .collect(); + + // pick num_strings randomly from the distinct string table + let indicies: UInt32Array = (0..self.num_strings) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_strings > 1 { + let range = 1..(self.num_distinct_strings as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + arrow::compute::take(&distinct_strings, &indicies, options).unwrap() + } +} + +/// Return a string of random characters of length 1..=max_len +fn random_string(rng: &mut StdRng, max_len: usize) -> String { + // pick characters at random (not just ascii) + match max_len { + 0 => "".to_string(), + 1 => String::from(rng.gen::()), + _ => { + let len = rng.gen_range(1..=max_len); + rng.sample_iter::(rand::distributions::Standard) + .take(len) + .map(char::from) + .collect::() + } + } +} diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 3ddba2fec800..9db8920833ae 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -22,6 +22,7 @@ use datafusion_common::cast::as_int32_array; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; +pub mod array_gen; mod data_gen; mod string_gen; pub mod tpcds; diff --git a/test-utils/src/string_gen.rs b/test-utils/src/string_gen.rs index 530fc1535387..725eb22b85af 100644 --- a/test-utils/src/string_gen.rs +++ b/test-utils/src/string_gen.rs @@ -1,3 +1,4 @@ +use crate::array_gen::StringArrayGenerator; // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information @@ -14,27 +15,14 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -// -// use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, RecordBatch, UInt32Array}; + use crate::stagger_batch; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, UInt32Array}; use arrow::record_batch::RecordBatch; use rand::rngs::StdRng; use rand::{thread_rng, Rng, SeedableRng}; /// Randomly generate strings -pub struct StringBatchGenerator { - //// The maximum length of the strings - pub max_len: usize, - /// the total number of strings in the output - pub num_strings: usize, - /// The number of distinct strings in the columns - pub num_distinct_strings: usize, - /// The percentage of nulls in the columns - pub null_pct: f64, - /// Random number generator - pub rng: StdRng, -} +pub struct StringBatchGenerator(StringArrayGenerator); impl StringBatchGenerator { /// Make batches of random strings with a random length columns "a" and "b". @@ -44,8 +32,8 @@ impl StringBatchGenerator { pub fn make_input_batches(&mut self) -> Vec { // use a random number generator to pick a random sized output let batch = RecordBatch::try_from_iter(vec![ - ("a", self.gen_data::()), - ("b", self.gen_data::()), + ("a", self.0.gen_data::()), + ("b", self.0.gen_data::()), ]) .unwrap(); stagger_batch(batch) @@ -57,9 +45,9 @@ impl StringBatchGenerator { /// if large is true, the array is a LargeStringArray pub fn make_sorted_input_batches(&mut self, large: bool) -> Vec { let array = if large { - self.gen_data::() + self.0.gen_data::() } else { - self.gen_data::() + self.0.gen_data::() }; let array = arrow::compute::sort(&array, None).unwrap(); @@ -68,32 +56,6 @@ impl StringBatchGenerator { stagger_batch(batch) } - /// Creates a StringArray or LargeStringArray with random strings according - /// to the parameters of the BatchGenerator - fn gen_data(&mut self) -> ArrayRef { - // table of strings from which to draw - let distinct_strings: GenericStringArray = (0..self.num_distinct_strings) - .map(|_| Some(random_string(&mut self.rng, self.max_len))) - .collect(); - - // pick num_strings randomly from the distinct string table - let indicies: UInt32Array = (0..self.num_strings) - .map(|_| { - if self.rng.gen::() < self.null_pct { - None - } else if self.num_distinct_strings > 1 { - let range = 1..(self.num_distinct_strings as u32); - Some(self.rng.gen_range(range)) - } else { - Some(0) - } - }) - .collect(); - - let options = None; - arrow::compute::take(&distinct_strings, &indicies, options).unwrap() - } - /// Return an set of `BatchGenerator`s that cover a range of interesting /// cases pub fn interesting_cases() -> Vec { @@ -109,31 +71,15 @@ impl StringBatchGenerator { } else { num_strings }; - cases.push(StringBatchGenerator { + cases.push(StringBatchGenerator(StringArrayGenerator { max_len, num_strings, num_distinct_strings, null_pct, rng: StdRng::from_seed(rng.gen()), - }) + })) } } cases } } - -/// Return a string of random characters of length 1..=max_len -fn random_string(rng: &mut StdRng, max_len: usize) -> String { - // pick characters at random (not just ascii) - match max_len { - 0 => "".to_string(), - 1 => String::from(rng.gen::()), - _ => { - let len = rng.gen_range(1..=max_len); - rng.sample_iter::(rand::distributions::Standard) - .take(len) - .map(char::from) - .collect::() - } - } -}