Skip to content

Commit

Permalink
new MongoDB driver using Motor
Browse files Browse the repository at this point in the history
  • Loading branch information
phenobarbital committed Jan 10, 2025
1 parent 11583c4 commit 9a0a288
Show file tree
Hide file tree
Showing 23 changed files with 740 additions and 249 deletions.
2 changes: 1 addition & 1 deletion asyncdb/drivers/aioch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class aioch(SQLDriver):

_provider: str = "clickhouse"
_syntax: str = "sql"
_dsn: str = "{database}"
_dsn_template: str = "{database}"
_test_query: str = "SELECT version()"

def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None:
Expand Down
76 changes: 55 additions & 21 deletions asyncdb/drivers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pandas as pd
from google.cloud import storage
from google.cloud import bigquery as bq
from google.cloud.exceptions import Conflict
from google.cloud.exceptions import Conflict, NotFound
from google.cloud.bigquery import LoadJobConfig, SourceFormat
from google.oauth2 import service_account
from .sql import SQLDriver
Expand All @@ -25,6 +25,7 @@ class bigquery(SQLDriver, ModelBackend):
_provider = "bigquery"
_syntax = "sql"
_test_query = "SELECT 1"
_dsn_template: str = ""

def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None:
self._credentials = params.get("credentials", None)
Expand Down Expand Up @@ -157,33 +158,67 @@ async def create_table(self, dataset_id, table_id, schema):
self._logger.info(f"Created table {table.project}.{table.dataset_id}.{table.table_id}")
return table
except Conflict:
self._logger.warning(f"Table {table.project}.{table.dataset_id}.{table.table_id} already exists")
self._logger.warning(
f"Table {table.project}.{table.dataset_id}.{table.table_id} already exists"
)
return table
except Exception as e:
raise DriverError(f"BigQuery: Error creating table: {e}")
raise DriverError(f"BigQuery: Error creating table: {e}") from e

async def truncate_table(self, table_id: str, dataset_id: str):
"""
Truncate a BigQuery table by overwriting with an empty table.
Truncate a BigQuery table by overwriting it with an empty table.
Parameters:
dataset_id (str): The ID of the dataset containing the table.
table_id (str): The ID of the table to truncate.
Raises:
DriverError: If there is an issue truncating the table.
"""
if not self._connection:
await self.connection()

# Construct a reference to the dataset
dataset_ref = bq.DatasetReference(self._connection.project, dataset_id)
table_ref = dataset_ref.table(table_id)
table = self._connection.get_table(table_ref) # API request to fetch the table schema
try:
# Reference to the dataset and table
dataset_ref = self._connection.dataset(dataset_id)
table_ref = dataset_ref.table(table_id)

# Ensure the table exists
try:
table = self._connection.get_table(table_ref)
except NotFound:
raise DriverError(
f"BigQuery: Table `{dataset_id}.{table_id}` does not exist."
)

# Create an empty table with the same schema
job_config = bq.QueryJobConfig(destination=table_ref)
job_config.write_disposition = bq.WriteDisposition.WRITE_TRUNCATE
# Configure the query job to overwrite the table
job_config = bq.QueryJobConfig(
destination=table_ref,
write_disposition=bq.WriteDisposition.WRITE_TRUNCATE,
allow_large_results=True
)

try:
job = self._connection.query(f"SELECT * FROM `{table_ref}` WHERE FALSE", job_config=job_config)
job.result() # Wait for the job to finish
self._logger.info(f"Truncated table {dataset_id}.{table_id}")
# Execute a query that selects no rows, effectively truncating the table
query = f"SELECT * FROM `{self._project_id}.{dataset_id}.{table_id}` WHERE FALSE"

self._logger.debug(f"Truncating table with query: {query}")
job = self._connection.query(query, job_config=job_config)

# Wait for the job to complete
await asyncio.get_event_loop().run_in_executor(None, job.result)

self._logger.info(f"Successfully truncated table `{dataset_id}.{table_id}`.")
return True
except DriverError:
raise
except Exception as e:
raise DriverError(f"BigQuery: Error truncating table: {e}")
self._logger.error(
f"BigQuery: Error truncating table `{dataset_id}.{table_id}`: {e}"
)
raise DriverError(
f"BigQuery: Error truncating table `{dataset_id}.{table_id}`: {e}"
) from e

async def query(self, sentence: str, **kwargs):
if not self._connection:
Expand Down Expand Up @@ -244,8 +279,7 @@ async def fetch_all(self, query, *args):
"""
Fetch all results from a BigQuery query
"""
results = await self.execute(query, *args)
return results
return await self.execute(query, *args)

async def fetch_one(self, query, *args):
"""
Expand Down Expand Up @@ -273,7 +307,7 @@ async def write(
table = f"{self._project_id}.{dataset_id}.{table_id}"
try:
if isinstance(data, pd.DataFrame):
if use_pandas is True:
if use_pandas:
job = await self._thread_func(self._connection.load_table_from_dataframe, data, table, **kwargs)
else:
object_cols = data.select_dtypes(include=["object"]).columns
Expand All @@ -293,7 +327,7 @@ async def write(
dataset_ref = self._connection.dataset(dataset_id)
table_ref = dataset_ref.table(table_id)
table = bq.Table(table_ref)
if use_streams is True:
if use_streams:
errors = await self._thread_func(self._connection.insert_rows_json, table, data, **kwargs)
if errors:
raise RuntimeError(f"Errors occurred while inserting rows: {errors}")
Expand All @@ -314,7 +348,7 @@ async def write(
# return Job object
return job
except Exception as e:
raise DriverError(f"BigQuery: Error writing to table: {e}")
raise DriverError(f"BigQuery: Error writing to table: {e}") from e

async def load_table_from_uri(
self,
Expand Down
2 changes: 1 addition & 1 deletion asyncdb/drivers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class clickhouse(SQLDriver):

_provider: str = "clickhouse"
_syntax: str = "sql"
_dsn: str = ""
_dsn_template: str = ""
_test_query: str = "SELECT now(), version()"

def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None:
Expand Down
2 changes: 1 addition & 1 deletion asyncdb/drivers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def fetch_all(self) -> Iterable[Sequence]:
class duckdb(SQLDriver, DBCursorBackend):
_provider: str = "duckdb"
_syntax: str = "sql"
_dsn: str = "{database}"
_dsn_template: str = "{database}"

def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None:
SQLDriver.__init__(self, dsn, loop, params, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion asyncdb/drivers/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
class dummy(BaseDriver):
_provider = "dummy"
_syntax = "sql"
_dsn_template: str = "test:/{host}:{port}/{db}"


def __init__(self, dsn: Union[str, None] = None, loop=None, params: dict = None, **kwargs):
self._test_query = "SELECT 1"
self._dsn = "test:/{host}:{port}/{db}"
if not params:
params = {"host": "127.0.0.1", "port": "0", "db": 0}
try:
Expand Down
3 changes: 2 additions & 1 deletion asyncdb/drivers/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ def get_dsn(self) -> str:
class elastic(BaseDriver):
_provider = "elasticsearch"
_syntax = "json"
_dsn_template: str = "{protocol}://{host}:{port}/"


def __init__(self, dsn: str = None, loop=None, params: Union[dict, ElasticConfig] = None, **kwargs):
# self._dsn = "{protocol}://{user}:{password}@{host}:{port}/{database}"
if isinstance(params, ElasticConfig):
self._database = params.database
else:
self._database = params.pop("db", "default")
self._dsn = "{protocol}://{host}:{port}/"
super(elastic, self).__init__(dsn=dsn, loop=loop, params=params, **kwargs)

def create_dsn(self, params: Union[dict, dataclass]):
Expand Down
2 changes: 1 addition & 1 deletion asyncdb/drivers/influx.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def retry(self, conf: tuple[str, str, str], data: str, exception: InfluxDBError)
class influx(InitDriver, ConnectionDSNBackend):
_provider = "influxdb"
_syntax = "sql"
_dsn_template: str = "{protocol}://{host}:{port}"

def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params: dict = None, **kwargs) -> None:
self._test_query = "SELECT 1"
self._query_raw = "SELECT {fields} FROM {table} {where_cond}"
self._version: str = None
self._dsn = "{protocol}://{host}:{port}"
self._client = InfluxDBClientAsync
self._enable_gzip = kwargs.get("enable_gzip", True)
self._retries = Retry(connect=5, read=2, redirect=5)
Expand Down
Loading

0 comments on commit 9a0a288

Please sign in to comment.