Skip to content

Commit

Permalink
SNOW-1657037: Pass Azure SAS Token via params
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jrose committed Sep 19, 2024
1 parent 8b61823 commit 22bfa6c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
8 changes: 3 additions & 5 deletions src/snowflake/connector/azure_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from random import choice
from string import hexdigits
from typing import TYPE_CHECKING, Any, NamedTuple
from urllib.parse import parse_qsl

from .compat import quote
from .constants import FileHeader, ResultStatus
Expand Down Expand Up @@ -95,12 +96,9 @@ def generate_authenticated_url_and_rest_args() -> tuple[bytes, dict[str, Any]]:
sas_token = self.credentials.creds["AZURE_SAS_TOKEN"]
if sas_token and sas_token.startswith("?"):
sas_token = sas_token[1:]
if "?" in url:
_url = url + "&" + sas_token
else:
_url = url + "?" + sas_token
params = {k: v for k, v in parse_qsl(sas_token)}
headers["Date"] = timestamp
rest_args = {"headers": headers}
rest_args = {"headers": headers, "params": params}
if data:
rest_args["data"] = data
return _url, rest_args
Expand Down
15 changes: 14 additions & 1 deletion src/snowflake/connector/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def request(
method: str = "post",
client: str = "sfsql",
timeout: int | None = None,
params: dict | None = None,
_no_results: bool = False,
_include_retry_params: bool = False,
_no_retry: bool = False,
Expand Down Expand Up @@ -493,6 +494,7 @@ def request(
headers,
json.dumps(body),
token=self.token,
params=params,
_no_results=_no_results,
timeout=timeout,
_include_retry_params=_include_retry_params,
Expand All @@ -503,6 +505,7 @@ def request(
url,
headers,
token=self.token,
params=params,
timeout=timeout,
)

Expand Down Expand Up @@ -678,6 +681,7 @@ def _get_request(
url: str,
headers: dict[str, str],
token: str = None,
params: dict | None = None,
timeout: int | None = None,
is_fetch_query_status: bool = False,
) -> dict[str, Any]:
Expand All @@ -693,6 +697,7 @@ def _get_request(
headers,
timeout=timeout,
token=token,
params=params,
is_fetch_query_status=is_fetch_query_status,
)
if ret.get("code") == SESSION_EXPIRED_GS_CODE:
Expand All @@ -712,6 +717,7 @@ def _get_request(
url,
headers,
token=self.token,
params=params,
is_fetch_query_status=is_fetch_query_status,
)

Expand All @@ -723,6 +729,7 @@ def _post_request(
headers,
body,
token=None,
params: dict | None = None,
timeout: int | None = None,
socket_timeout: int | None = None,
_no_results: bool = False,
Expand All @@ -743,6 +750,7 @@ def _post_request(
data=body,
timeout=timeout,
token=token,
params=params,
no_retry=no_retry,
_include_retry_params=_include_retry_params,
socket_timeout=socket_timeout,
Expand All @@ -769,7 +777,7 @@ def _post_request(
)
if ret.get("success"):
return self._post_request(
url, headers, body, token=self.token, timeout=timeout
url, headers, body, token=self.token, params=params, timeout=timeout
)

if isinstance(ret.get("data"), dict) and ret["data"].get("queryId"):
Expand All @@ -789,6 +797,7 @@ def _post_request(
result_url,
headers,
token=self.token,
params=params,
timeout=timeout,
is_fetch_query_status=bool(
re.match(r"^/queries/.+/result$", result_url)
Expand Down Expand Up @@ -880,6 +889,7 @@ def _request_exec_wrapper(
retry_ctx,
no_retry: bool = False,
token=NO_TOKEN,
params: dict | None = None,
**kwargs,
):
conn = self._connection
Expand Down Expand Up @@ -908,6 +918,7 @@ def _request_exec_wrapper(
headers=headers,
data=data,
token=token,
params=params,
raise_raw_http_failure=raise_raw_http_failure,
**kwargs,
)
Expand Down Expand Up @@ -1041,6 +1052,7 @@ def _request_exec(
headers,
data,
token,
params,
catch_okta_unauthorized_error: bool = False,
is_raw_text: bool = False,
is_raw_binary: bool = False,
Expand Down Expand Up @@ -1077,6 +1089,7 @@ def _request_exec(
verify=True,
stream=is_raw_binary,
auth=SnowflakeAuth(token),
params=params,
)
download_end_time = get_time_millis()

Expand Down

0 comments on commit 22bfa6c

Please sign in to comment.