Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
edvardm authored Nov 8, 2024
2 parents 98cbf46 + 4072bfb commit 0967cb1
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 49 deletions.
41 changes: 27 additions & 14 deletions pypika/dialects.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import warnings
from copy import copy
from typing import Any, Optional, Union, Tuple as TypedTuple

Expand Down Expand Up @@ -87,7 +88,7 @@ class MySQLQueryBuilder(QueryBuilder):
QUERY_CLS = MySQLQuery

def __init__(self, **kwargs: Any) -> None:
super().__init__(dialect=Dialects.MYSQL, wrap_set_operation_queries=False, **kwargs)
super().__init__(dialect=Dialects.MYSQL, **kwargs)
self._duplicate_updates = []
self._ignore_duplicates = False
self._modifiers = []
Expand Down Expand Up @@ -347,6 +348,19 @@ def __str__(self) -> str:
return self.get_sql()


class FetchNextAndOffsetRowsQueryBuilder(QueryBuilder):
def _limit_sql(self) -> str:
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit)

def _offset_sql(self) -> str:
return " OFFSET {offset} ROWS".format(offset=self._offset or 0)

@builder
def fetch_next(self, limit: int):
warnings.warn("`fetch_next` is deprecated - please use the `limit` method", DeprecationWarning)
self._limit = limit


class OracleQuery(Query):
"""
Defines a query class for use with Oracle.
Expand All @@ -357,7 +371,7 @@ def _builder(cls, **kwargs: Any) -> "OracleQueryBuilder":
return OracleQueryBuilder(**kwargs)


class OracleQueryBuilder(QueryBuilder):
class OracleQueryBuilder(FetchNextAndOffsetRowsQueryBuilder):
QUOTE_CHAR = None
QUERY_CLS = OracleQuery

Expand All @@ -370,6 +384,16 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str:
kwargs['groupby_alias'] = False
return super().get_sql(*args, **kwargs)

def _apply_pagination(self, querystring: str) -> str:
# Note: Overridden as Oracle specifies offset before the fetch next limit
if self._offset:
querystring += self._offset_sql()

if self._limit is not None:
querystring += self._limit_sql()

return querystring


class PostgreSQLQuery(Query):
"""
Expand Down Expand Up @@ -670,7 +694,7 @@ def _builder(cls, **kwargs: Any) -> "MSSQLQueryBuilder":
return MSSQLQueryBuilder(**kwargs)


class MSSQLQueryBuilder(QueryBuilder):
class MSSQLQueryBuilder(FetchNextAndOffsetRowsQueryBuilder):
QUERY_CLS = MSSQLQuery

def __init__(self, **kwargs: Any) -> None:
Expand All @@ -695,17 +719,6 @@ def top(self, value: Union[str, int], percent: bool = False, with_ties: bool = F
self._top_percent: bool = percent
self._top_with_ties: bool = with_ties

@builder
def fetch_next(self, limit: int) -> "MSSQLQueryBuilder":
# Overridden to provide a more domain-specific API for T-SQL users
self._limit = limit

def _offset_sql(self) -> str:
return " OFFSET {offset} ROWS".format(offset=self._offset or 0)

def _limit_sql(self) -> str:
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit)

def _apply_pagination(self, querystring: str) -> str:
# Note: Overridden as MSSQL specifies offset before the fetch next limit
if self._limit is not None or self._offset:
Expand Down
152 changes: 130 additions & 22 deletions pypika/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,21 @@
import uuid
from datetime import date
from enum import Enum
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
)

from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order
from pypika.utils import (
Expand Down Expand Up @@ -288,57 +302,111 @@ def get_sql(self, **kwargs: Any) -> str:
raise NotImplementedError()


def idx_placeholder_gen(idx: int) -> str:
return str(idx + 1)


def named_placeholder_gen(idx: int) -> str:
return f'param{idx + 1}'


class Parameter(Term):
is_aggregate = None

def __init__(self, placeholder: Union[str, int]) -> None:
super().__init__()
self.placeholder = placeholder
self._placeholder = placeholder

@property
def placeholder(self):
return self._placeholder

def get_sql(self, **kwargs: Any) -> str:
return str(self.placeholder)

def update_parameters(self, param_key: Any, param_value: Any, **kwargs):
pass

class QmarkParameter(Parameter):
"""Question mark style, e.g. ...WHERE name=?"""
def get_param_key(self, placeholder: Any, **kwargs):
return placeholder

def __init__(self) -> None:
pass

def get_sql(self, **kwargs: Any) -> str:
return "?"
class ListParameter(Parameter):
def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen) -> None:
super().__init__(placeholder=placeholder)
self._parameters = list()

@property
def placeholder(self) -> str:
if callable(self._placeholder):
return self._placeholder(len(self._parameters))

class NumericParameter(Parameter):
"""Numeric, positional style, e.g. ...WHERE name=:1"""
return str(self._placeholder)

def get_sql(self, **kwargs: Any) -> str:
return ":{placeholder}".format(placeholder=self.placeholder)
def get_parameters(self, **kwargs):
return self._parameters

def update_parameters(self, value: Any, **kwargs):
self._parameters.append(value)

class NamedParameter(Parameter):
"""Named style, e.g. ...WHERE name=:name"""

class DictParameter(Parameter):
def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = named_placeholder_gen) -> None:
super().__init__(placeholder=placeholder)
self._parameters = dict()

@property
def placeholder(self) -> str:
if callable(self._placeholder):
return self._placeholder(len(self._parameters))

return str(self._placeholder)

def get_parameters(self, **kwargs):
return self._parameters

def get_param_key(self, placeholder: Any, **kwargs):
return placeholder[1:]

def update_parameters(self, param_key: Any, value: Any, **kwargs):
self._parameters[param_key] = value


class QmarkParameter(ListParameter):
def get_sql(self, **kwargs):
return '?'


class NumericParameter(ListParameter):
"""Numeric, positional style, e.g. ...WHERE name=:1"""

def get_sql(self, **kwargs: Any) -> str:
return ":{placeholder}".format(placeholder=self.placeholder)


class FormatParameter(Parameter):
class FormatParameter(ListParameter):
"""ANSI C printf format codes, e.g. ...WHERE name=%s"""

def __init__(self) -> None:
pass

def get_sql(self, **kwargs: Any) -> str:
return "%s"


class PyformatParameter(Parameter):
class NamedParameter(DictParameter):
"""Named style, e.g. ...WHERE name=:name"""

def get_sql(self, **kwargs: Any) -> str:
return ":{placeholder}".format(placeholder=self.placeholder)


class PyformatParameter(DictParameter):
"""Python extended format codes, e.g. ...WHERE name=%(name)s"""

def get_sql(self, **kwargs: Any) -> str:
return "%({placeholder})s".format(placeholder=self.placeholder)

def get_param_key(self, placeholder: Any, **kwargs):
return placeholder[2:-2]


class Negative(Term):
def __init__(self, term: Term) -> None:
Expand Down Expand Up @@ -385,9 +453,44 @@ def get_formatted_value(cls, value: Any, **kwargs):
return "null"
return str(value)

def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", **kwargs: Any) -> str:
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
param_sql = parameter.get_sql(**kwargs)
param_key = parameter.get_param_key(placeholder=param_sql)

return param_sql, param_key

def get_sql(
self,
quote_char: Optional[str] = None,
secondary_quote_char: str = "'",
parameter: Parameter = None,
**kwargs: Any,
) -> str:
if parameter is None:
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)

# Don't stringify numbers when using a parameter
if isinstance(self.value, (int, float)):
value_sql = self.value
else:
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs)
param_sql, param_key = self._get_param_data(parameter, **kwargs)
parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs)

return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs)


class ParameterValueWrapper(ValueWrapper):
def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None) -> None:
super().__init__(value, alias)
self._parameter = parameter

def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
param_sql = self._parameter.get_sql(**kwargs)
param_key = self._parameter.get_param_key(placeholder=param_sql)

return param_sql, param_key


class JSON(Term):
Expand Down Expand Up @@ -548,6 +651,11 @@ def __init__(
) -> None:
super().__init__(alias=alias)
self.name = name
if isinstance(table, str):
# avoid circular import at load time
from pypika.queries import Table

table = Table(table)
self.table = table

def nodes_(self) -> Iterator[NodeT]:
Expand Down
9 changes: 2 additions & 7 deletions pypika/tests/dialects/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,13 @@ def test_limit(self):

self.assertEqual('SELECT "def" FROM "abc" ORDER BY "def" OFFSET 0 ROWS FETCH NEXT 10 ROWS ONLY', str(q))

def test_fetch_next(self):
q = MSSQLQuery.from_("abc").select("def").orderby("def").fetch_next(10)

self.assertEqual('SELECT "def" FROM "abc" ORDER BY "def" OFFSET 0 ROWS FETCH NEXT 10 ROWS ONLY', str(q))

def test_offset(self):
q = MSSQLQuery.from_("abc").select("def").orderby("def").offset(10)

self.assertEqual('SELECT "def" FROM "abc" ORDER BY "def" OFFSET 10 ROWS', str(q))

def test_fetch_next_with_offset(self):
q = MSSQLQuery.from_("abc").select("def").orderby("def").fetch_next(10).offset(10)
def test_limit_with_offset(self):
q = MSSQLQuery.from_("abc").select("def").orderby("def").limit(10).offset(10)

self.assertEqual('SELECT "def" FROM "abc" ORDER BY "def" OFFSET 10 ROWS FETCH NEXT 10 ROWS ONLY', str(q))

Expand Down
30 changes: 30 additions & 0 deletions pypika/tests/dialects/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,33 @@ def test_groupby_alias_False_does_not_group_by_alias_when_subqueries_are_present
q = OracleQuery.from_(subquery).select(col, Count('*')).groupby(col)

self.assertEqual('SELECT sq0.abc a,COUNT(\'*\') FROM (SELECT abc FROM table1) sq0 GROUP BY sq0.abc', str(q))

def test_limit_query(self):
t = Table('table1')
limit = 5
q = OracleQuery.from_(t).select(t.test).limit(limit)

self.assertEqual(f'SELECT test FROM table1 FETCH NEXT {limit} ROWS ONLY', str(q))

def test_offset_query(self):
t = Table('table1')
offset = 5
q = OracleQuery.from_(t).select(t.test).offset(offset)

self.assertEqual(f'SELECT test FROM table1 OFFSET {offset} ROWS', str(q))

def test_limit_offset_query(self):
t = Table('table1')
limit = 5
offset = 5
q = OracleQuery.from_(t).select(t.test).limit(limit).offset(offset)

self.assertEqual(f'SELECT test FROM table1 OFFSET {offset} ROWS FETCH NEXT {limit} ROWS ONLY', str(q))

def test_fetch_next_method_deprecated(self):
with self.assertWarns(DeprecationWarning):
t = Table('table1')
limit = 5
q = OracleQuery.from_(t).select(t.test).fetch_next(limit)

self.assertEqual(f'SELECT test FROM table1 FETCH NEXT {limit} ROWS ONLY', str(q))
Loading

0 comments on commit 0967cb1

Please sign in to comment.