Skip to content

Commit

Permalink
Merge branch 'main' into fix-readme-typos
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jvasquezrojas authored Nov 21, 2024
2 parents 7995480 + 65754a4 commit 8296a34
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 25 deletions.
3 changes: 3 additions & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ Source code is also available at:

# Release Notes

- (Unreleased)
- Add support for partition by to copy into <location>

- v1.7.0(November 22, 2024)

- Add support for dynamic tables and required options
Expand Down
26 changes: 20 additions & 6 deletions src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from sqlalchemy.schema import Sequence, Table
from sqlalchemy.sql import compiler, expression, functions
from sqlalchemy.sql.base import CompileState
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import BindParameter, quoted_name
from sqlalchemy.sql.expression import Executable
from sqlalchemy.sql.selectable import Lateral, SelectState

from snowflake.sqlalchemy._constants import DIALECT_NAME
Expand Down Expand Up @@ -563,9 +564,8 @@ def visit_copy_into(self, copy_into, **kw):
if isinstance(copy_into.into, Table)
else copy_into.into._compiler_dispatch(self, **kw)
)
from_ = None
if isinstance(copy_into.from_, Table):
from_ = copy_into.from_
from_ = copy_into.from_.name
# this is intended to catch AWSBucket and AzureContainer
elif (
isinstance(copy_into.from_, AWSBucket)
Expand All @@ -576,6 +576,21 @@ def visit_copy_into(self, copy_into, **kw):
# everything else (selects, etc.)
else:
from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})"

partition_by_value = None
if isinstance(copy_into.partition_by, (BindParameter, Executable)):
partition_by_value = copy_into.partition_by.compile(
compile_kwargs={"literal_binds": True}
)
elif copy_into.partition_by is not None:
partition_by_value = copy_into.partition_by

partition_by = (
f"PARTITION BY {partition_by_value}"
if partition_by_value is not None and partition_by_value != ""
else ""
)

credentials, encryption = "", ""
if isinstance(into, tuple):
into, credentials, encryption = into
Expand All @@ -586,8 +601,7 @@ def visit_copy_into(self, copy_into, **kw):
options_list.sort(key=operator.itemgetter(0))
options = (
(
" "
+ " ".join(
" ".join(
[
"{} = {}".format(
n,
Expand All @@ -608,7 +622,7 @@ def visit_copy_into(self, copy_into, **kw):
options += f" {credentials}"
if encryption:
options += f" {encryption}"
return f"COPY INTO {into} FROM {from_} {formatter}{options}"
return f"COPY INTO {into} FROM {' '.join([from_, partition_by, formatter, options])}"

def visit_copy_formatter(self, formatter, **kw):
options_list = list(formatter.options.items())
Expand Down
9 changes: 7 additions & 2 deletions src/snowflake/sqlalchemy/custom_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,23 @@ class CopyInto(UpdateBase):
__visit_name__ = "copy_into"
_bind = None

def __init__(self, from_, into, formatter=None):
def __init__(self, from_, into, partition_by=None, formatter=None):
self.from_ = from_
self.into = into
self.formatter = formatter
self.copy_options = {}
self.partition_by = partition_by

def __repr__(self):
"""
repr for debugging / logging purposes only. For compilation logic, see
the corresponding visitor in base.py
"""
return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})"
val = f"COPY INTO {self.into} FROM {repr(self.from_)}"
if self.partition_by is not None:
val += f" PARTITION BY {self.partition_by}"

return val + f" {repr(self.formatter)} ({self.copy_options})"

def bind(self):
return None
Expand Down
65 changes: 48 additions & 17 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table
from sqlalchemy.sql import select, text
from sqlalchemy.sql import functions, select, text

from snowflake.sqlalchemy import (
AWSBucket,
Expand Down Expand Up @@ -58,8 +58,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
)
assert (
sql_compiler(copy_stmt_1)
== "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv "
"ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION="
== "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv "
"ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION="
"(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')"
)
copy_stmt_2 = CopyIntoStorage(
Expand All @@ -73,8 +73,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
sql_compiler(copy_stmt_2)
== "COPY INTO 's3://backup' FROM (SELECT python_tests_foods.id, "
"python_tests_foods.name, python_tests_foods.quantity FROM python_tests_foods "
"WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' "
"FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') "
"WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' "
"FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') "
"ENCRYPTION=(TYPE='AWS_SSE_S3')"
)
copy_stmt_3 = CopyIntoStorage(
Expand All @@ -87,15 +87,15 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
assert (
sql_compiler(copy_stmt_3)
== "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' "
"FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) "
"FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) "
"CREDENTIALS=(AZURE_SAS_TOKEN='token')"
)

copy_stmt_3.maxfilesize(50000000)
assert (
sql_compiler(copy_stmt_3)
== "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' "
"FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) "
"FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) "
"MAX_FILE_SIZE = 50000000 "
"CREDENTIALS=(AZURE_SAS_TOKEN='token')"
)
Expand All @@ -112,8 +112,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
)
assert (
sql_compiler(copy_stmt_4)
== "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv "
"ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION="
== "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv "
"ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION="
"(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')"
)

Expand All @@ -126,8 +126,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
)
assert (
sql_compiler(copy_stmt_5)
== "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv "
"FIELD_DELIMITER=',') ENCRYPTION="
== "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv "
"FIELD_DELIMITER=',') ENCRYPTION="
"(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')"
)

Expand All @@ -138,7 +138,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
)
assert (
sql_compiler(copy_stmt_6)
== "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv)"
== "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv) "
)

copy_stmt_7 = CopyIntoStorage(
Expand All @@ -148,7 +148,38 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
)
assert (
sql_compiler(copy_stmt_7)
== "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv)"
== "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv) "
)

copy_stmt_8 = CopyIntoStorage(
from_=food_items,
into=ExternalStage(name="stage_name"),
partition_by=text("('YEAR=' || year)"),
)
assert (
sql_compiler(copy_stmt_8)
== "COPY INTO @stage_name FROM python_tests_foods PARTITION BY ('YEAR=' || year) "
)

copy_stmt_9 = CopyIntoStorage(
from_=food_items,
into=ExternalStage(name="stage_name"),
partition_by=functions.concat(
text("'YEAR='"), text(food_items.columns["name"].name)
),
)
assert (
sql_compiler(copy_stmt_9)
== "COPY INTO @stage_name FROM python_tests_foods PARTITION BY concat('YEAR=', name) "
)

copy_stmt_10 = CopyIntoStorage(
from_=food_items,
into=ExternalStage(name="stage_name"),
partition_by="",
)
assert (
sql_compiler(copy_stmt_10) == "COPY INTO @stage_name FROM python_tests_foods "
)

# NOTE Other than expect known compiled text, submit it to RegressionTests environment and expect them to fail, but
Expand Down Expand Up @@ -231,7 +262,7 @@ def test_copy_into_storage_csv_extended(sql_compiler):
result = sql_compiler(copy_into)
expected = (
r"COPY INTO TEST_IMPORT "
r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata "
r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata "
r"FILE_FORMAT=(TYPE=csv COMPRESSION='auto' DATE_FORMAT='AUTO' "
r"ERROR_ON_COLUMN_COUNT_MISMATCH=True ESCAPE=None "
r"ESCAPE_UNENCLOSED_FIELD='\134' FIELD_DELIMITER=',' "
Expand Down Expand Up @@ -288,7 +319,7 @@ def test_copy_into_storage_parquet_named_format(sql_compiler):
expected = (
"COPY INTO TEST_IMPORT "
"FROM (SELECT $1:COL1::number, $1:COL2::varchar "
"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) "
"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) "
"FILE_FORMAT=(format_name = parquet_file_format) force = TRUE"
)
assert result == expected
Expand Down Expand Up @@ -350,7 +381,7 @@ def test_copy_into_storage_parquet_files(sql_compiler):
"COPY INTO TEST_IMPORT "
"FROM (SELECT $1:COL1::number, $1:COL2::varchar "
"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet "
"(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') "
"(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') "
"FORCE = true"
)
assert result == expected
Expand Down Expand Up @@ -412,6 +443,6 @@ def test_copy_into_storage_parquet_pattern(sql_compiler):
"COPY INTO TEST_IMPORT "
"FROM (SELECT $1:COL1::number, $1:COL2::varchar "
"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet "
"(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'"
"(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'"
)
assert result == expected

0 comments on commit 8296a34

Please sign in to comment.