diff --git a/pypika/dialects.py b/pypika/dialects.py index 8894e562..dab72e07 100644 --- a/pypika/dialects.py +++ b/pypika/dialects.py @@ -1,4 +1,5 @@ import itertools +import warnings from copy import copy from typing import Any, Optional, Union, Tuple as TypedTuple @@ -87,7 +88,7 @@ class MySQLQueryBuilder(QueryBuilder): QUERY_CLS = MySQLQuery def __init__(self, **kwargs: Any) -> None: - super().__init__(dialect=Dialects.MYSQL, wrap_set_operation_queries=False, **kwargs) + super().__init__(dialect=Dialects.MYSQL, **kwargs) self._duplicate_updates = [] self._ignore_duplicates = False self._modifiers = [] @@ -347,6 +348,19 @@ def __str__(self) -> str: return self.get_sql() +class FetchNextAndOffsetRowsQueryBuilder(QueryBuilder): + def _limit_sql(self) -> str: + return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit) + + def _offset_sql(self) -> str: + return " OFFSET {offset} ROWS".format(offset=self._offset or 0) + + @builder + def fetch_next(self, limit: int): + warnings.warn("`fetch_next` is deprecated - please use the `limit` method", DeprecationWarning) + self._limit = limit + + class OracleQuery(Query): """ Defines a query class for use with Oracle. @@ -357,7 +371,7 @@ def _builder(cls, **kwargs: Any) -> "OracleQueryBuilder": return OracleQueryBuilder(**kwargs) -class OracleQueryBuilder(QueryBuilder): +class OracleQueryBuilder(FetchNextAndOffsetRowsQueryBuilder): QUOTE_CHAR = None QUERY_CLS = OracleQuery @@ -370,6 +384,16 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str: kwargs['groupby_alias'] = False return super().get_sql(*args, **kwargs) + def _apply_pagination(self, querystring: str) -> str: + # Note: Overridden as Oracle specifies offset before the fetch next limit + if self._offset: + querystring += self._offset_sql() + + if self._limit is not None: + querystring += self._limit_sql() + + return querystring + class PostgreSQLQuery(Query): """ @@ -670,7 +694,7 @@ def _builder(cls, **kwargs: Any) -> "MSSQLQueryBuilder": return MSSQLQueryBuilder(**kwargs) -class MSSQLQueryBuilder(QueryBuilder): +class MSSQLQueryBuilder(FetchNextAndOffsetRowsQueryBuilder): QUERY_CLS = MSSQLQuery def __init__(self, **kwargs: Any) -> None: @@ -695,17 +719,6 @@ def top(self, value: Union[str, int], percent: bool = False, with_ties: bool = F self._top_percent: bool = percent self._top_with_ties: bool = with_ties - @builder - def fetch_next(self, limit: int) -> "MSSQLQueryBuilder": - # Overridden to provide a more domain-specific API for T-SQL users - self._limit = limit - - def _offset_sql(self) -> str: - return " OFFSET {offset} ROWS".format(offset=self._offset or 0) - - def _limit_sql(self) -> str: - return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit) - def _apply_pagination(self, querystring: str) -> str: # Note: Overridden as MSSQL specifies offset before the fetch next limit if self._limit is not None or self._offset: diff --git a/pypika/terms.py b/pypika/terms.py index b25af5c5..a277e1a5 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -3,7 +3,21 @@ import uuid from datetime import date from enum import Enum -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, +) from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order from pypika.utils import ( @@ -288,57 +302,111 @@ def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() +def idx_placeholder_gen(idx: int) -> str: + return str(idx + 1) + + +def named_placeholder_gen(idx: int) -> str: + return f'param{idx + 1}' + + class Parameter(Term): is_aggregate = None def __init__(self, placeholder: Union[str, int]) -> None: super().__init__() - self.placeholder = placeholder + self._placeholder = placeholder + + @property + def placeholder(self): + return self._placeholder def get_sql(self, **kwargs: Any) -> str: return str(self.placeholder) + def update_parameters(self, param_key: Any, param_value: Any, **kwargs): + pass -class QmarkParameter(Parameter): - """Question mark style, e.g. ...WHERE name=?""" + def get_param_key(self, placeholder: Any, **kwargs): + return placeholder - def __init__(self) -> None: - pass - def get_sql(self, **kwargs: Any) -> str: - return "?" +class ListParameter(Parameter): + def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen) -> None: + super().__init__(placeholder=placeholder) + self._parameters = list() + @property + def placeholder(self) -> str: + if callable(self._placeholder): + return self._placeholder(len(self._parameters)) -class NumericParameter(Parameter): - """Numeric, positional style, e.g. ...WHERE name=:1""" + return str(self._placeholder) - def get_sql(self, **kwargs: Any) -> str: - return ":{placeholder}".format(placeholder=self.placeholder) + def get_parameters(self, **kwargs): + return self._parameters + def update_parameters(self, value: Any, **kwargs): + self._parameters.append(value) -class NamedParameter(Parameter): - """Named style, e.g. ...WHERE name=:name""" + +class DictParameter(Parameter): + def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = named_placeholder_gen) -> None: + super().__init__(placeholder=placeholder) + self._parameters = dict() + + @property + def placeholder(self) -> str: + if callable(self._placeholder): + return self._placeholder(len(self._parameters)) + + return str(self._placeholder) + + def get_parameters(self, **kwargs): + return self._parameters + + def get_param_key(self, placeholder: Any, **kwargs): + return placeholder[1:] + + def update_parameters(self, param_key: Any, value: Any, **kwargs): + self._parameters[param_key] = value + + +class QmarkParameter(ListParameter): + def get_sql(self, **kwargs): + return '?' + + +class NumericParameter(ListParameter): + """Numeric, positional style, e.g. ...WHERE name=:1""" def get_sql(self, **kwargs: Any) -> str: return ":{placeholder}".format(placeholder=self.placeholder) -class FormatParameter(Parameter): +class FormatParameter(ListParameter): """ANSI C printf format codes, e.g. ...WHERE name=%s""" - def __init__(self) -> None: - pass - def get_sql(self, **kwargs: Any) -> str: return "%s" -class PyformatParameter(Parameter): +class NamedParameter(DictParameter): + """Named style, e.g. ...WHERE name=:name""" + + def get_sql(self, **kwargs: Any) -> str: + return ":{placeholder}".format(placeholder=self.placeholder) + + +class PyformatParameter(DictParameter): """Python extended format codes, e.g. ...WHERE name=%(name)s""" def get_sql(self, **kwargs: Any) -> str: return "%({placeholder})s".format(placeholder=self.placeholder) + def get_param_key(self, placeholder: Any, **kwargs): + return placeholder[2:-2] + class Negative(Term): def __init__(self, term: Term) -> None: @@ -385,9 +453,44 @@ def get_formatted_value(cls, value: Any, **kwargs): return "null" return str(value) - def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", **kwargs: Any) -> str: - sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) - return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) + def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: + param_sql = parameter.get_sql(**kwargs) + param_key = parameter.get_param_key(placeholder=param_sql) + + return param_sql, param_key + + def get_sql( + self, + quote_char: Optional[str] = None, + secondary_quote_char: str = "'", + parameter: Parameter = None, + **kwargs: Any, + ) -> str: + if parameter is None: + sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) + return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) + + # Don't stringify numbers when using a parameter + if isinstance(self.value, (int, float)): + value_sql = self.value + else: + value_sql = self.get_value_sql(quote_char=quote_char, **kwargs) + param_sql, param_key = self._get_param_data(parameter, **kwargs) + parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs) + + return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs) + + +class ParameterValueWrapper(ValueWrapper): + def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None) -> None: + super().__init__(value, alias) + self._parameter = parameter + + def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: + param_sql = self._parameter.get_sql(**kwargs) + param_key = self._parameter.get_param_key(placeholder=param_sql) + + return param_sql, param_key class JSON(Term): @@ -548,6 +651,11 @@ def __init__( ) -> None: super().__init__(alias=alias) self.name = name + if isinstance(table, str): + # avoid circular import at load time + from pypika.queries import Table + + table = Table(table) self.table = table def nodes_(self) -> Iterator[NodeT]: diff --git a/pypika/tests/dialects/test_mssql.py b/pypika/tests/dialects/test_mssql.py index f8940b08..7cf53b22 100644 --- a/pypika/tests/dialects/test_mssql.py +++ b/pypika/tests/dialects/test_mssql.py @@ -53,18 +53,13 @@ def test_limit(self): self.assertEqual('SELECT "def" FROM "abc" ORDER BY "def" OFFSET 0 ROWS FETCH NEXT 10 ROWS ONLY', str(q)) - def test_fetch_next(self): - q = MSSQLQuery.from_("abc").select("def").orderby("def").fetch_next(10) - - self.assertEqual('SELECT "def" FROM "abc" ORDER BY "def" OFFSET 0 ROWS FETCH NEXT 10 ROWS ONLY', str(q)) - def test_offset(self): q = MSSQLQuery.from_("abc").select("def").orderby("def").offset(10) self.assertEqual('SELECT "def" FROM "abc" ORDER BY "def" OFFSET 10 ROWS', str(q)) - def test_fetch_next_with_offset(self): - q = MSSQLQuery.from_("abc").select("def").orderby("def").fetch_next(10).offset(10) + def test_limit_with_offset(self): + q = MSSQLQuery.from_("abc").select("def").orderby("def").limit(10).offset(10) self.assertEqual('SELECT "def" FROM "abc" ORDER BY "def" OFFSET 10 ROWS FETCH NEXT 10 ROWS ONLY', str(q)) diff --git a/pypika/tests/dialects/test_oracle.py b/pypika/tests/dialects/test_oracle.py index c3a757a5..21b54f16 100644 --- a/pypika/tests/dialects/test_oracle.py +++ b/pypika/tests/dialects/test_oracle.py @@ -19,3 +19,33 @@ def test_groupby_alias_False_does_not_group_by_alias_when_subqueries_are_present q = OracleQuery.from_(subquery).select(col, Count('*')).groupby(col) self.assertEqual('SELECT sq0.abc a,COUNT(\'*\') FROM (SELECT abc FROM table1) sq0 GROUP BY sq0.abc', str(q)) + + def test_limit_query(self): + t = Table('table1') + limit = 5 + q = OracleQuery.from_(t).select(t.test).limit(limit) + + self.assertEqual(f'SELECT test FROM table1 FETCH NEXT {limit} ROWS ONLY', str(q)) + + def test_offset_query(self): + t = Table('table1') + offset = 5 + q = OracleQuery.from_(t).select(t.test).offset(offset) + + self.assertEqual(f'SELECT test FROM table1 OFFSET {offset} ROWS', str(q)) + + def test_limit_offset_query(self): + t = Table('table1') + limit = 5 + offset = 5 + q = OracleQuery.from_(t).select(t.test).limit(limit).offset(offset) + + self.assertEqual(f'SELECT test FROM table1 OFFSET {offset} ROWS FETCH NEXT {limit} ROWS ONLY', str(q)) + + def test_fetch_next_method_deprecated(self): + with self.assertWarns(DeprecationWarning): + t = Table('table1') + limit = 5 + q = OracleQuery.from_(t).select(t.test).fetch_next(limit) + + self.assertEqual(f'SELECT test FROM table1 FETCH NEXT {limit} ROWS ONLY', str(q)) diff --git a/pypika/tests/test_joins.py b/pypika/tests/test_joins.py index 6e54f883..03142d58 100644 --- a/pypika/tests/test_joins.py +++ b/pypika/tests/test_joins.py @@ -845,12 +845,12 @@ def test_require_equal_number_of_fields(self): with self.assertRaises(SetOperationException): str(query1 + query2) - def test_mysql_query_does_not_wrap_unioned_queries_with_params(self): + def test_mysql_query_wraps_unioned_queries(self): query1 = MySQLQuery.from_(self.table1).select(self.table1.foo) query2 = Query.from_(self.table2).select(self.table2.bar) self.assertEqual( - "SELECT `foo` FROM `abc` UNION SELECT `bar` FROM `efg`", + "(SELECT `foo` FROM `abc`) UNION (SELECT `bar` FROM `efg`)", str(query1 + query2), ) @@ -968,12 +968,12 @@ def test_require_equal_number_of_fields_intersect(self): with self.assertRaises(SetOperationException): str(query1.intersect(query2)) - def test_mysql_query_does_not_wrap_intersected_queries_with_params(self): + def test_mysql_query_wraps_intersected_queries(self): query1 = MySQLQuery.from_(self.table1).select(self.table1.foo) query2 = Query.from_(self.table2).select(self.table2.bar) self.assertEqual( - "SELECT `foo` FROM `abc` INTERSECT SELECT `bar` FROM `efg`", + "(SELECT `foo` FROM `abc`) INTERSECT (SELECT `bar` FROM `efg`)", str(query1.intersect(query2)), ) @@ -1064,12 +1064,12 @@ def test_require_equal_number_of_fields(self): with self.assertRaises(SetOperationException): str(query1.minus(query2)) - def test_mysql_query_does_not_wrap_minus_queries_with_params(self): + def test_mysql_query_wraps_minus_queries(self): query1 = MySQLQuery.from_(self.table1).select(self.table1.foo) query2 = Query.from_(self.table2).select(self.table2.bar) self.assertEqual( - "SELECT `foo` FROM `abc` MINUS SELECT `bar` FROM `efg`", + "(SELECT `foo` FROM `abc`) MINUS (SELECT `bar` FROM `efg`)", str(query1 - query2), ) diff --git a/pypika/tests/test_parameter.py b/pypika/tests/test_parameter.py index e19666a0..c11e9afc 100644 --- a/pypika/tests/test_parameter.py +++ b/pypika/tests/test_parameter.py @@ -1,4 +1,5 @@ import unittest +from datetime import date from pypika import ( FormatParameter, @@ -10,6 +11,7 @@ Query, Tables, ) +from pypika.terms import ListParameter, ParameterValueWrapper class ParametrizedTests(unittest.TestCase): @@ -92,3 +94,113 @@ def test_format_parameter(self): def test_pyformat_parameter(self): self.assertEqual('%(buz)s', PyformatParameter('buz').get_sql()) + + +class ParametrizedTestsWithValues(unittest.TestCase): + table_abc, table_efg = Tables("abc", "efg") + + def test_param_insert(self): + q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, 'foo') + + parameter = QmarkParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (?,?,?)', sql) + self.assertEqual([1, 2.2, 'foo'], parameter.get_parameters()) + + def test_param_select_join(self): + q = ( + Query.from_(self.table_abc) + .select("*") + .where(self.table_abc.category == 'foobar') + .join(self.table_efg) + .on(self.table_abc.id == self.table_efg.abc_id) + .where(self.table_efg.date >= date(2024, 2, 22)) + .limit(10) + ) + + parameter = FormatParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT * FROM "abc" JOIN "efg" ON "abc"."id"="efg"."abc_id" WHERE "abc"."category"=%s AND "efg"."date">=%s LIMIT 10', + sql, + ) + self.assertEqual(['foobar', '2024-02-22'], parameter.get_parameters()) + + def test_param_select_subquery(self): + q = ( + Query.from_(self.table_abc) + .select("*") + .where(self.table_abc.category == 'foobar') + .where( + self.table_abc.id.isin( + Query.from_(self.table_efg) + .select(self.table_efg.abc_id) + .where(self.table_efg.date >= date(2024, 2, 22)) + ) + ) + .limit(10) + ) + + parameter = ListParameter(placeholder=lambda idx: f'&{idx+1}') + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT * FROM "abc" WHERE "category"=&1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=&2) LIMIT 10', + sql, + ) + self.assertEqual(['foobar', '2024-02-22'], parameter.get_parameters()) + + def test_join(self): + subquery = ( + Query.from_(self.table_efg) + .select(self.table_efg.fiz, self.table_efg.buz) + .where(self.table_efg.buz == 'buz') + ) + + q = ( + Query.from_(self.table_abc) + .join(subquery) + .on(self.table_abc.bar == subquery.buz) + .select(self.table_abc.foo, subquery.fiz) + .where(self.table_abc.bar == 'bar') + ) + + parameter = NamedParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:param1)' + ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:param2', + sql, + ) + self.assertEqual({'param1': 'buz', 'param2': 'bar'}, parameter.get_parameters()) + + def test_join_with_parameter_value_wrapper(self): + subquery = ( + Query.from_(self.table_efg) + .select(self.table_efg.fiz, self.table_efg.buz) + .where(self.table_efg.buz == ParameterValueWrapper(Parameter(':buz'), 'buz')) + ) + + q = ( + Query.from_(self.table_abc) + .join(subquery) + .on(self.table_abc.bar == subquery.buz) + .select(self.table_abc.foo, subquery.fiz) + .where(self.table_abc.bar == ParameterValueWrapper(NamedParameter('bar'), 'bar')) + ) + + parameter = NamedParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:buz)' + ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:bar', + sql, + ) + self.assertEqual({':buz': 'buz', 'bar': 'bar'}, parameter.get_parameters()) + + def test_pyformat_parameter(self): + q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, 'foo') + + parameter = PyformatParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (%(param1)s,%(param2)s,%(param3)s)', sql) + self.assertEqual({"param1": 1, "param2": 2.2, "param3": "foo"}, parameter.get_parameters()) diff --git a/pypika/tests/test_terms.py b/pypika/tests/test_terms.py index 562e0d28..4c7590df 100644 --- a/pypika/tests/test_terms.py +++ b/pypika/tests/test_terms.py @@ -15,6 +15,13 @@ def test_when_alias_specified(self): self.assertEqual('bar', str(c1.alias)) +class FieldInitTests(TestCase): + def test_init_with_str_table(self): + test_table_name = "test_table" + field = Field(name="name", table=test_table_name) + self.assertEqual(field.table, Table(name=test_table_name)) + + class FieldHashingTests(TestCase): def test_tabled_eq_fields_equally_hashed(self): client_name1 = Field(name="name", table=Table("clients"))