Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Account for SurrealDB Python API updates (handle both SurrealDB and AsyncSurrealDB classes) in read_database #20799

Merged
merged 2 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions py-polars/polars/io/database/_cursor_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING, Any

from polars.dependencies import import_optional
from polars.io.database._utils import _run_async

if TYPE_CHECKING:
Expand Down Expand Up @@ -71,17 +72,19 @@ def fetch_record_batches(


class SurrealDBCursorProxy:
"""Cursor proxy for SurrealDB connections (requires `surrealdb`)."""
"""Cursor proxy for both SurrealDB and AsyncSurrealDB connections."""

_cached_result: list[dict[str, Any]] | None = None

def __init__(self, client: Any) -> None:
self.client = client
surrealdb = import_optional("surrealdb")
self.is_async = isinstance(client, surrealdb.AsyncSurrealDB)
self.execute_options: dict[str, Any] = {}
self.client = client
self.query: str = None # type: ignore[assignment]

@staticmethod
async def _unpack_result(
async def _unpack_result_async(
result: Coroutine[Any, Any, list[dict[str, Any]]],
) -> Coroutine[Any, Any, list[dict[str, Any]]]:
"""Unpack the async query result."""
Expand All @@ -90,6 +93,16 @@ async def _unpack_result(
raise RuntimeError(response["result"])
return response["result"]

@staticmethod
def _unpack_result(
result: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Unpack the query result."""
response = result[0]
if response["status"] != "OK":
raise RuntimeError(response["result"])
return response["result"]

def close(self) -> None:
"""Close the cursor."""
# no-op; never close a user's Surreal session
Expand All @@ -103,23 +116,32 @@ def execute(self, query: str, **execute_options: Any) -> Self:

def fetchall(self) -> list[dict[str, Any]]:
"""Fetch all results (as a list of dictionaries)."""
return _run_async(
self._unpack_result(
return (
_run_async(
self._unpack_result_async(
result=self.client.query(
query=self.query,
variables=(self.execute_options or None),
),
)
)
if self.is_async
else self._unpack_result(
result=self.client.query(
sql=self.query,
vars=(self.execute_options or None),
query=self.query,
variables=(self.execute_options or None),
),
)
)

def fetchmany(self, size: int) -> list[dict[str, Any]]:
"""Fetch results in batches (simulated)."""
# first 'fetchmany' call acquires/caches the result
# first 'fetchmany' call acquires/caches the result object
if self._cached_result is None:
self._cached_result = self.fetchall()

# return batches of the cached result; remove from the cache as
# we go, so as not to hold on to additional copies when done
# return batches from the result, actively removing from the cache
# as we go, so as not to hold on to additional copies when done
result = self._cached_result[:size]
del self._cached_result[:size]
return result
15 changes: 8 additions & 7 deletions py-polars/polars/io/database/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,19 @@ def read_database(
... connection=async_engine,
... ) # doctest: +SKIP

Load data from an asynchronous SurrealDB client connection object; note that
both the WS (`Surreal`) and HTTP (`SurrealHTTP`) clients are supported:
Load data from an `AsyncSurrealDB` client connection object; note that both the "ws"
and "http" protocols are supported, as is the synchronous `SurrealDB` client. The
async loop can be run with standard `asyncio` or with `uvloop`:

>>> import asyncio
>>> import asyncio # (or uvloop)
>>> async def surreal_query_to_frame(query: str, url: str):
... async with Surreal(url) as client:
... async with AsyncSurrealDB(url) as client:
... await client.use(namespace="test", database="test")
... return pl.read_database(query=query, connection=client)
>>> df = asyncio.run(
... surreal_query_to_frame(
... query="SELECT * FROM test_data",
... url="ws://localhost:8000/rpc",
... query="SELECT * FROM test",
... url="http://localhost:8000",
... )
... ) # doctest: +SKIP

Expand All @@ -236,7 +237,7 @@ def read_database(
connection = ODBCCursorProxy(connection)
elif "://" in connection:
# otherwise looks like a mistaken call to read_database_uri
msg = "use of string URI is invalid here; call `read_database_uri` instead"
msg = "string URI is invalid here; call `read_database_uri` instead"
raise ValueError(msg)
else:
msg = "unable to identify string connection as valid ODBC (no driver)"
Expand Down
32 changes: 32 additions & 0 deletions py-polars/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import time
import tracemalloc
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, cast

import numpy as np
Expand All @@ -17,6 +18,7 @@

if TYPE_CHECKING:
from collections.abc import Generator
from types import ModuleType
from typing import Any

FixtureRequest = Any
Expand Down Expand Up @@ -260,3 +262,33 @@ def test_global_and_local(
yield
else:
yield


@contextmanager
def mock_module_import(
name: str, module: ModuleType, *, replace_if_exists: bool = False
) -> Generator[None, None, None]:
"""
Mock an optional module import for the duration of a context.

Parameters
----------
name
The name of the module to mock.
module
A ModuleType instance representing the mocked module.
replace_if_exists
Whether to replace the module if it already exists in `sys.modules` (defaults to
False, meaning that if the module is already imported, it will not be replaced).
"""
if (original := sys.modules.get(name, None)) is not None and not replace_if_exists:
yield
else:
sys.modules[name] = module
try:
yield
finally:
if original is not None:
sys.modules[name] = original
else:
del sys.modules[name]
39 changes: 24 additions & 15 deletions py-polars/tests/unit/io/database/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
from math import ceil
from types import ModuleType
from typing import TYPE_CHECKING, Any, overload

import pytest
Expand All @@ -11,6 +12,7 @@
import polars as pl
from polars._utils.various import parse_version
from polars.testing import assert_frame_equal
from tests.unit.conftest import mock_module_import

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -64,11 +66,17 @@ async def use(self, namespace: str, database: str) -> None:
pass

async def query(
self, sql: str, vars: dict[str, Any] | None = None
self, query: str, variables: dict[str, Any] | None = None
) -> list[dict[str, Any]]:
return [{"result": self._mock_data, "status": "OK", "time": "32.083µs"}]


class MockedSurrealModule(ModuleType):
"""Mock SurrealDB module; enables internal `isinstance` check for AsyncSurrealDB."""

AsyncSurrealDB = MockSurrealConnection


@pytest.mark.skipif(
parse_version(sqlalchemy.__version__) < (2, 0),
reason="SQLAlchemy 2.0+ required for async tests",
Expand Down Expand Up @@ -159,18 +167,19 @@ async def _surreal_query_as_frame(

@pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 4])
def test_surrealdb_fetchall(batch_size: int | None) -> None:
df_expected = pl.DataFrame(SURREAL_MOCK_DATA)
res = asyncio.run(
_surreal_query_as_frame(
url="ws://localhost:8000/rpc",
query="SELECT * FROM item",
batch_size=batch_size,
with mock_module_import("surrealdb", MockedSurrealModule("surrealdb")):
df_expected = pl.DataFrame(SURREAL_MOCK_DATA)
res = asyncio.run(
_surreal_query_as_frame(
url="ws://localhost:8000/rpc",
query="SELECT * FROM item",
batch_size=batch_size,
)
)
)
if batch_size:
frames = list(res) # type: ignore[call-overload]
n_mock_rows = len(SURREAL_MOCK_DATA)
assert len(frames) == ceil(n_mock_rows / batch_size)
assert_frame_equal(df_expected[:batch_size], frames[0])
else:
assert_frame_equal(df_expected, res) # type: ignore[arg-type]
if batch_size:
frames = list(res) # type: ignore[call-overload]
n_mock_rows = len(SURREAL_MOCK_DATA)
assert len(frames) == ceil(n_mock_rows / batch_size)
assert_frame_equal(df_expected[:batch_size], frames[0])
else:
assert_frame_equal(df_expected, res) # type: ignore[arg-type]
Loading