Skip to content

Commit

Permalink
Add data type to athena query runner (getredash#7112)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
zachliu and github-actions[bot] authored Aug 7, 2024
1 parent b1fe2d4 commit 285c2b6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
22 changes: 16 additions & 6 deletions redash/query_runner/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,20 @@ def __get_schema_from_glue(self, catalog_id=""):
logger.warning("Glue table doesn't have StorageDescriptor: %s", table_name)
continue
if table_name not in schema:
column = [columns["Name"] for columns in table["StorageDescriptor"]["Columns"]]
schema[table_name] = {"name": table_name, "columns": column}
for partition in table.get("PartitionKeys", []):
schema[table_name]["columns"].append(partition["Name"])
schema[table_name] = {"name": table_name, "columns": []}

for column_data in table["StorageDescriptor"]["Columns"]:
column = {
"name": column_data["Name"],
"type": column_data["Type"] if "Type" in column_data else None,
}
schema[table_name]["columns"].append(column)
for partition in table.get("PartitionKeys", []):
partition_column = {
"name": partition["Name"],
"type": partition["Type"] if "Type" in partition else None,
}
schema[table_name]["columns"].append(partition_column)
return list(schema.values())

def get_schema(self, get_stats=False):
Expand All @@ -212,7 +222,7 @@ def get_schema(self, get_stats=False):

schema = {}
query = """
SELECT table_schema, table_name, column_name
SELECT table_schema, table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema NOT IN ('information_schema')
"""
Expand All @@ -225,7 +235,7 @@ def get_schema(self, get_stats=False):
table_name = "{0}.{1}".format(row["table_schema"], row["table_name"])
if table_name not in schema:
schema[table_name] = {"name": table_name, "columns": []}
schema[table_name]["columns"].append(row["column_name"])
schema[table_name]["columns"].append({"name": row["column_name"], "type": row["data_type"]})

return list(schema.values())

Expand Down
21 changes: 15 additions & 6 deletions tests/query_runner/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def test_external_table(self):
{"DatabaseName": "test1"},
)
with self.stubber:
assert query_runner.get_schema() == [{"columns": ["row_id"], "name": "test1.jdbc_table"}]
assert query_runner.get_schema() == [
{"columns": [{"name": "row_id", "type": "int"}], "name": "test1.jdbc_table"}
]

def test_partitioned_table(self):
"""
Expand Down Expand Up @@ -124,7 +126,12 @@ def test_partitioned_table(self):
{"DatabaseName": "test1"},
)
with self.stubber:
assert query_runner.get_schema() == [{"columns": ["sk", "category"], "name": "test1.partitioned_table"}]
assert query_runner.get_schema() == [
{
"columns": [{"name": "sk", "type": "int"}, {"name": "category", "type": "int"}],
"name": "test1.partitioned_table",
}
]

def test_view(self):
query_runner = Athena({"glue": True, "region": "mars-east-1"})
Expand Down Expand Up @@ -156,7 +163,7 @@ def test_view(self):
{"DatabaseName": "test1"},
)
with self.stubber:
assert query_runner.get_schema() == [{"columns": ["sk"], "name": "test1.view"}]
assert query_runner.get_schema() == [{"columns": [{"name": "sk", "type": "int"}], "name": "test1.view"}]

def test_dodgy_table_does_not_break_schema_listing(self):
"""
Expand Down Expand Up @@ -196,7 +203,9 @@ def test_dodgy_table_does_not_break_schema_listing(self):
{"DatabaseName": "test1"},
)
with self.stubber:
assert query_runner.get_schema() == [{"columns": ["region"], "name": "test1.csv"}]
assert query_runner.get_schema() == [
{"columns": [{"name": "region", "type": "string"}], "name": "test1.csv"}
]

def test_no_storage_descriptor_table(self):
"""
Expand Down Expand Up @@ -312,6 +321,6 @@ def test_multi_catalog_tables(self):
)
with self.stubber:
assert query_runner.get_schema() == [
{"columns": ["row_id"], "name": "test1.jdbc_table"},
{"columns": ["row_id"], "name": "test2.jdbc_table"},
{"columns": [{"name": "row_id", "type": "int"}], "name": "test1.jdbc_table"},
{"columns": [{"name": "row_id", "type": "int"}], "name": "test2.jdbc_table"},
]

0 comments on commit 285c2b6

Please sign in to comment.