Skip to content

Commit

Permalink
Remove changes from 3.1 that look wrong
Browse files Browse the repository at this point in the history
Some look like an incorrect merge, some are leftover debugging code
  • Loading branch information
adamsdarlingtower committed Apr 10, 2024
1 parent 6d85ec3 commit a98fb9a
Show file tree
Hide file tree
Showing 19 changed files with 97 additions and 132 deletions.
2 changes: 1 addition & 1 deletion superset/commands/annotation_layer/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class CreateAnnotationLayerCommand(BaseCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()

def run(self) -> None:
def run(self) -> Model:
self.validate()
try:
return AnnotationLayerDAO.create(attributes=self._properties)
Expand Down
4 changes: 2 additions & 2 deletions superset/commands/annotation_layer/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model: Optional[AnnotationLayer] = None

def run(self) -> None:
def run(self) -> Model:
self.validate()
assert self._models
assert self._model

try:
annotation_layer = AnnotationLayerDAO.update(self._model, self._properties)
Expand Down
5 changes: 4 additions & 1 deletion superset/commands/dataset/duplicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def run(self) -> Model:
table.normalize_columns = self._base_model.normalize_columns
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
table.is_sqllab_view = True
table.sql = ParsedQuery(self._base_model.sql).stripped()
table.sql = ParsedQuery(
self._base_model.sql,
engine=database.db_engine_spec.engine,
).stripped()
db.session.add(table)
cols = []
for config_ in self._base_model.columns:
Expand Down
5 changes: 4 additions & 1 deletion superset/commands/sql_lab/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def run(
limit = None
else:
sql = self._query.executed_sql
limit = ParsedQuery(sql).limit
limit = ParsedQuery(
sql,
engine=self._query.database.db_engine_spec.engine,
).limit
if limit is not None and self._query.limiting_factor in {
LimitingFactor.QUERY,
LimitingFactor.DROPDOWN,
Expand Down
5 changes: 4 additions & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,10 @@ def get_from_clause(
return self.get_sqla_table(), None

from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
parsed_query = ParsedQuery(
from_sql,
engine=self.db_engine_spec.engine,
)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
Expand Down
2 changes: 1 addition & 1 deletion superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict
)
parsed_query = ParsedQuery(sql)
parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
if not db_engine_spec.is_readonly_query(parsed_query):
raise SupersetSecurityException(
SupersetError(
Expand Down
3 changes: 1 addition & 2 deletions superset/daos/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def validate_table_exists(
database.get_table(table_name, schema=schema)
return True
except SQLAlchemyError as ex: # pragma: no cover
# logger.warning("Got an error %s validating table: %s", str(ex), table_name)
logger.exception("Got an error %s validating table: %s", str(ex), table_name)
logger.warning("Got an error %s validating table: %s", str(ex), table_name)
return False

@staticmethod
Expand Down
4 changes: 0 additions & 4 deletions superset/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,6 @@ class DatasetRestApi(BaseSupersetModelRestApi):
list_outer_default_load = True
show_outer_default_load = True

def response_400(self, message=None):
logger.error(f"Error from datasets api: {message}")
return super().response_400(message=message)

@expose("/", methods=("POST",))
@protect()
@safe
Expand Down
12 changes: 6 additions & 6 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ def apply_limit_to_sql(
return database.compile_sqla_query(qry)

if cls.limit_method == LimitMethod.FORCE_LIMIT:
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
sql = parsed_query.set_or_update_query_limit(limit, force=force)

return sql
Expand Down Expand Up @@ -981,7 +981,7 @@ def get_limit_from_sql(cls, sql: str) -> int | None:
:param sql: SQL query
:return: Value of limit clause in query
"""
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
return parsed_query.limit

@classmethod
Expand All @@ -993,7 +993,7 @@ def set_or_update_query_limit(cls, sql: str, limit: int) -> str:
:param limit: New limit to insert/replace into query
:return: Query with new limit
"""
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
return parsed_query.set_or_update_query_limit(limit)

@classmethod
Expand Down Expand Up @@ -1490,7 +1490,7 @@ def process_statement(cls, statement: str, database: Database) -> str:
:param database: Database instance
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement)
parsed_query = ParsedQuery(statement, engine=cls.engine)
sql = parsed_query.stripped()
sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"]
Expand Down Expand Up @@ -1525,7 +1525,7 @@ def estimate_query_cost(
"Database does not support cost estimation"
)

parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()

costs = []
Expand Down Expand Up @@ -1586,7 +1586,7 @@ def execute( # pylint: disable=unused-argument
:return:
"""
if not cls.allows_sql_comments:
query = sql_parse.strip_comments_from_sql(query)
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)

if cls.arraysize:
cursor.arraysize = cls.arraysize
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def estimate_query_cost(
if not cls.get_allow_cost_estimate(extra):
raise SupersetException("Database does not support cost estimation")

parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
costs = []
for statement in statements:
Expand Down
5 changes: 0 additions & 5 deletions superset/explore/form_data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,6 @@ def post(self) -> Response:
return self.response(201, key=key)
except ValidationError as ex:
return self.response(400, message=ex.messages)
# except (
# ChartAccessDeniedError,
# DatasetAccessDeniedError,
# TemporaryCacheAccessDeniedError,
# ) as ex:
except TemporaryCacheAccessDeniedError as ex:
return self.response(403, message=str(ex))
except TemporaryCacheResourceNotFoundError as ex:
Expand Down
12 changes: 3 additions & 9 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,7 @@ def get_password_masked_url_from_uri( # pylint: disable=invalid-name
cls, uri: str
) -> URL:
sqlalchemy_url = make_url_safe(uri)
#TODO: turn this back on after I'm done debugging stuff!
# return cls.get_password_masked_url(sqlalchemy_url)
return sqlalchemy_url
return cls.get_password_masked_url(sqlalchemy_url)

@classmethod
def get_password_masked_url(cls, masked_url: URL) -> URL:
Expand All @@ -367,9 +365,7 @@ def set_sqlalchemy_uri(self, uri: str) -> None:
if conn.password != PASSWORD_MASK and not custom_password_store:
# do not over-write the password with the password mask
self.password = conn.password
#TODO: turn this back on after I'm done debugging stuff!
# conn = conn.set(password=PASSWORD_MASK if conn.password else None)
self.sqlalchemy_uri = str(conn) # hides the password
conn = conn.set(password=PASSWORD_MASK if conn.password else None)

def get_effective_user(self, object_url: URL) -> str | None:
"""
Expand Down Expand Up @@ -481,9 +477,7 @@ def _get_sqla_engine(
effective_username,
)

#TODO: switch this back on once I'm done debugging stuff!
# masked_url = self.get_password_masked_url(sqlalchemy_url)
masked_url = sqlalchemy_url
masked_url = self.get_password_masked_url(sqlalchemy_url)
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))

if self.impersonate_user:
Expand Down
2 changes: 1 addition & 1 deletion superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def get_from_clause(
"""

from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
Expand Down
5 changes: 3 additions & 2 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local
database: Database = query.database
db_engine_spec = database.db_engine_spec

parsed_query = ParsedQuery(sql_statement)
parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
if is_feature_enabled("RLS_IN_SQLLAB"):
# There are two ways to insert RLS: either replacing the table with a subquery
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
Expand All @@ -228,7 +228,8 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local
database.id,
query.schema,
)
)
),
engine=db_engine_spec.engine,
)

sql = parsed_query.stripped()
Expand Down
91 changes: 13 additions & 78 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

import logging
import re
from collections.abc import Iterator
import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast, TYPE_CHECKING

Expand Down Expand Up @@ -66,7 +67,7 @@

try:
from sqloxide import parse_sql as sqloxide_parse
except: # pylint: disable=bare-except
except (ImportError, ModuleNotFoundError):
sqloxide_parse = None

if TYPE_CHECKING:
Expand Down Expand Up @@ -227,7 +228,11 @@ def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
:param statement: A string with the SQL statement
:return: SQL statement without comments
"""
return ParsedQuery(statement).strip_comments() if "--" in statement else statement
return (
ParsedQuery(statement, engine=engine).strip_comments()
if "--" in statement
else statement
)


@dataclass(eq=True, frozen=True)
Expand All @@ -246,7 +251,7 @@ def __str__(self) -> str:
"""

return ".".join(
parse.quote(part, safe="").replace(".", "%2E")
urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
Expand All @@ -266,6 +271,7 @@ def __init__(
sql_statement = sqlparse.format(sql_statement, strip_comments=True)

self.sql: str = sql_statement
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._tables: set[Table] = set()
self._alias_names: set[str] = set()
self._limit: int | None = None
Expand All @@ -278,12 +284,7 @@ def __init__(
@property
def tables(self) -> set[Table]:
if not self._tables:
for statement in self._parsed:
self._extract_from_token(statement)

self._tables = {
table for table in self._tables if str(table) not in self._alias_names
}
self._tables = self._extract_tables_from_sql()
return self._tables

def _extract_tables_from_sql(self) -> set[Table]:
Expand Down Expand Up @@ -572,28 +573,6 @@ def get_table(tlist: TokenList) -> Table | None:
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))

def _process_tokenlist(self, token_list: TokenList) -> None:
"""
Add table names to table set
:param token_list: TokenList to be processed
"""
# exclude subselects
if "(" not in str(token_list):
table = self.get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return

# store aliases
if token_list.has_alias():
self._alias_names.add(token_list.get_alias())

# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
self._extract_from_token(token_list)

def as_create_table(
self,
table_name: str,
Expand All @@ -620,50 +599,6 @@ def as_create_table(
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
return exec_sql

def _extract_from_token(self, token: Token) -> None:
"""
<Identifier> store a list of subtokens and <IdentifierList> store lists of
subtoken list.
It extracts <IdentifierList> and <Identifier> from :param token: and loops
through all subtokens recursively. It finds table_name_preceding_token and
passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
self._tables.
:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
if not hasattr(token, "tokens"):
return

table_name_preceding_token = False

for item in token.tokens:
if item.is_group and (
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
):
self._extract_from_token(item)

if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME
or item.normalized.endswith(" JOIN")
):
table_name_preceding_token = True
continue

if item.ttype in Keyword:
table_name_preceding_token = False
continue
if table_name_preceding_token:
if isinstance(item, Identifier):
self._process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token2 in item.get_identifiers():
if isinstance(token2, TokenList):
self._process_tokenlist(token2)
elif isinstance(item, IdentifierList):
if any(not self._is_identifier(token2) for token2 in item.tokens):
self._extract_from_token(item)

def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
"""Returns the query with the specified limit.
Expand Down Expand Up @@ -1060,7 +995,7 @@ def insert_rls_in_predicate(


# mapping between sqloxide and SQLAlchemy dialects
SQLOXITE_DIALECTS = {
SQLOXIDE_DIALECTS = {
"ansi": {"trino", "trinonative", "presto"},
"hive": {"hive", "databricks"},
"ms": {"mssql"},
Expand Down Expand Up @@ -1093,7 +1028,7 @@ def extract_table_references(
tree = None

if sqloxide_parse:
for dialect, sqla_dialects in SQLOXITE_DIALECTS.items():
for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items():
if sqla_dialect in sqla_dialects:
break
sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)
Expand Down
Loading

0 comments on commit a98fb9a

Please sign in to comment.