Skip to content

Commit

Permalink
feat!: improve compatibility with databricks sql client (#252)
Browse files Browse the repository at this point in the history
* feat!: add working databricks-cli fallback credentials strategy

* test: add tests for external databricks helpers
  • Loading branch information
dbkegley authored Aug 7, 2024
1 parent 8648f80 commit 37bc1dc
Show file tree
Hide file tree
Showing 13 changed files with 231 additions and 109 deletions.
2 changes: 1 addition & 1 deletion examples/connect/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
.posit
**/rsconnect-python/
22 changes: 12 additions & 10 deletions examples/connect/dash/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import pandas as pd
from dash import Dash, Input, Output, dash_table, html
from databricks import sql
from databricks.sdk.core import ApiClient, Config
from databricks.sdk.core import ApiClient, Config, databricks_cli
from databricks.sdk.service.iam import CurrentUserAPI
from posit.connect.external.databricks import viewer_credentials_provider
from posit.connect.external.databricks import PositCredentialsStrategy

DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}"
Expand Down Expand Up @@ -38,14 +38,16 @@ def update_page(_):
session_token = flask.request.headers.get(
"Posit-Connect-User-Session-Token"
)
credentials_provider = viewer_credentials_provider(
user_session_token=session_token
)
posit_strategy = PositCredentialsStrategy(
local_strategy=databricks_cli,
user_session_token=session_token)
cfg = Config(
host=DATABRICKS_HOST_URL,
# uses Posit's custom credential_strategy if running on Connect,
# otherwise falls back to the strategy defined by local_strategy
credentials_strategy=posit_strategy)

def get_greeting():
cfg = Config(
host=DATABRICKS_HOST_URL, credentials_provider=credentials_provider
)
databricks_user_info = CurrentUserAPI(ApiClient(cfg)).me()
return f"Hello, {databricks_user_info.display_name}!"

Expand All @@ -58,8 +60,8 @@ def get_table():
with sql.connect(
server_hostname=DATABRICKS_HOST,
http_path=SQL_HTTP_PATH,
auth_type="databricks-oauth",
credentials_provider=credentials_provider,
# https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365
credentials_provider=posit_strategy.sql_credentials_provider(cfg)
) as connection:
with connection.cursor() as cursor:
cursor.execute(query)
Expand Down
6 changes: 3 additions & 3 deletions examples/connect/dash/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
databricks-sql-connector==3.0.1
databricks-sdk==0.20.0
databricks-sql-connector==3.3.0
databricks-sdk==0.29.0
dash==2.15.0
git+https://github.com/posit-dev/posit-sdk-py.git
posit-sdk>=0.4.0
18 changes: 12 additions & 6 deletions examples/connect/fastapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from typing import Annotated

from databricks import sql
from databricks.sdk.core import Config, databricks_cli
from fastapi import FastAPI, Header
from fastapi.responses import JSONResponse
from posit.connect.external.databricks import viewer_credentials_provider
from posit.connect.external.databricks import PositCredentialsStrategy

DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}"
Expand All @@ -26,18 +27,23 @@ async def get_fares(
"""
global rows

credentials_provider = viewer_credentials_provider(
user_session_token=posit_connect_user_session_token
)
posit_strategy = PositCredentialsStrategy(
local_strategy=databricks_cli,
user_session_token=posit_connect_user_session_token)
cfg = Config(
host=DATABRICKS_HOST_URL,
# uses Posit's custom credential_strategy if running on Connect,
# otherwise falls back to the strategy defined by local_strategy
credentials_strategy=posit_strategy)

if rows is None:
query = "SELECT * FROM samples.nyctaxi.trips LIMIT 10;"

with sql.connect(
server_hostname=DATABRICKS_HOST,
http_path=SQL_HTTP_PATH,
auth_type="databricks-oauth",
credentials_provider=credentials_provider,
# https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365
credentials_provider=posit_strategy.sql_credentials_provider(cfg)
) as connection:
with connection.cursor() as cursor:
cursor.execute(query)
Expand Down
6 changes: 3 additions & 3 deletions examples/connect/fastapi/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
databricks-sql-connector==3.0.1
databricks-sdk==0.20.0
databricks-sql-connector==3.3.0
databricks-sdk==0.29.0
fastapi==0.110.0
git+https://github.com/posit-dev/posit-sdk-py.git
posit-sdk>=0.4.0
18 changes: 12 additions & 6 deletions examples/connect/flask/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import os

from databricks import sql
from databricks.sdk.core import Config, databricks_cli
from flask import Flask, request
from posit.connect.external.databricks import viewer_credentials_provider
from posit.connect.external.databricks import PositCredentialsStrategy

DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}"
Expand All @@ -28,18 +29,23 @@ def get_fares():
global rows

session_token = request.headers.get("Posit-Connect-User-Session-Token")
credentials_provider = viewer_credentials_provider(
user_session_token=session_token
)
posit_strategy = PositCredentialsStrategy(
local_strategy=databricks_cli,
user_session_token=session_token)
cfg = Config(
host=DATABRICKS_HOST_URL,
# uses Posit's custom credential_strategy if running on Connect,
# otherwise falls back to the strategy defined by local_strategy
credentials_strategy=posit_strategy)

if rows is None:
query = "SELECT * FROM samples.nyctaxi.trips LIMIT 10;"

with sql.connect(
server_hostname=DATABRICKS_HOST,
http_path=SQL_HTTP_PATH,
auth_type="databricks-oauth",
credentials_provider=credentials_provider,
# https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365
credentials_provider=posit_strategy.sql_credentials_provider(cfg)
) as connection:
with connection.cursor() as cursor:
cursor.execute(query)
Expand Down
6 changes: 3 additions & 3 deletions examples/connect/flask/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
databricks-sql-connector==3.0.1
databricks-sdk==0.20.0
databricks-sql-connector==3.3.0
databricks-sdk==0.29.0
flask==3.0.2
git+https://github.com/posit-dev/posit-sdk-py.git
posit-sdk>=0.4.0
22 changes: 12 additions & 10 deletions examples/connect/shiny-python/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import pandas as pd
from databricks import sql
from databricks.sdk.core import ApiClient, Config
from databricks.sdk.core import ApiClient, Config, databricks_cli
from databricks.sdk.service.iam import CurrentUserAPI
from posit.connect.external.databricks import viewer_credentials_provider
from posit.connect.external.databricks import PositCredentialsStrategy
from shiny import App, Inputs, Outputs, Session, render, ui

DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
Expand All @@ -24,9 +24,14 @@ def server(i: Inputs, o: Outputs, session: Session):
session_token = session.http_conn.headers.get(
"Posit-Connect-User-Session-Token"
)
credentials_provider = viewer_credentials_provider(
user_session_token=session_token
)
posit_strategy = PositCredentialsStrategy(
local_strategy=databricks_cli,
user_session_token=session_token)
cfg = Config(
host=DATABRICKS_HOST_URL,
# uses Posit's custom credential_strategy if running on Connect,
# otherwise falls back to the strategy defined by local_strategy
credentials_strategy=posit_strategy)

@render.data_frame
def result():
Expand All @@ -35,8 +40,8 @@ def result():
with sql.connect(
server_hostname=DATABRICKS_HOST,
http_path=SQL_HTTP_PATH,
auth_type="databricks-oauth",
credentials_provider=credentials_provider,
# https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365
credentials_provider=posit_strategy.sql_credentials_provider(cfg),
) as connection:
with connection.cursor() as cursor:
cursor.execute(query)
Expand All @@ -48,9 +53,6 @@ def result():

@render.text
def text():
cfg = Config(
host=DATABRICKS_HOST_URL, credentials_provider=credentials_provider
)
databricks_user_info = CurrentUserAPI(ApiClient(cfg)).me()
return f"Hello, {databricks_user_info.display_name}!"

Expand Down
6 changes: 3 additions & 3 deletions examples/connect/shiny-python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
databricks-sql-connector==3.0.1
databricks-sdk==0.20.0
databricks-sql-connector==3.3.0
databricks-sdk==0.29.0
shiny==0.7.1
git+https://github.com/posit-dev/posit-sdk-py.git
posit-sdk>=0.4.0
28 changes: 13 additions & 15 deletions examples/connect/streamlit/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,32 @@
import pandas as pd
import streamlit as st
from databricks import sql
from databricks.sdk.core import ApiClient, Config
from databricks.sdk.core import ApiClient, Config, databricks_cli
from databricks.sdk.service.iam import CurrentUserAPI
from posit.connect.external.databricks import viewer_credentials_provider
from streamlit.web.server.websocket_headers import _get_websocket_headers
from posit.connect.external.databricks import PositCredentialsStrategy

DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
DATABRICKS_HOST_URL = f"https://{DATABRICKS_HOST}"
SQL_HTTP_PATH = os.getenv("DATABRICKS_PATH")

session_token = _get_websocket_headers().get(
"Posit-Connect-User-Session-Token"
)

credentials_provider = viewer_credentials_provider(
user_session_token=session_token
)

session_token = st.context.headers.get("Posit-Connect-User-Session-Token")
posit_strategy = PositCredentialsStrategy(
local_strategy=databricks_cli,
user_session_token=session_token)
cfg = Config(
host=DATABRICKS_HOST_URL, credentials_provider=credentials_provider
)
host=DATABRICKS_HOST_URL,
# uses Posit's custom credential_strategy if running on Connect,
# otherwise falls back to the strategy defined by local_strategy
credentials_strategy=posit_strategy)

databricks_user = CurrentUserAPI(ApiClient(cfg)).me()
st.write(f"Hello, {databricks_user.display_name}!")

with sql.connect(
server_hostname=DATABRICKS_HOST,
http_path=SQL_HTTP_PATH,
auth_type="databricks-oauth",
credentials_provider=credentials_provider,
# https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365
credentials_provider=posit_strategy.sql_credentials_provider(cfg)
) as connection:
with connection.cursor() as cursor:
cursor.execute("SELECT * FROM samples.nyctaxi.trips LIMIT 10;")
Expand Down
8 changes: 4 additions & 4 deletions examples/connect/streamlit/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
databricks-sql-connector==3.0.1
databricks-sdk==0.20.0
streamlit==1.31.1
git+https://github.com/posit-dev/posit-sdk-py.git
databricks-sql-connector==3.3.0
databricks-sdk==0.29.0
streamlit==1.37.0
posit-sdk>=0.4.0
Loading

0 comments on commit 37bc1dc

Please sign in to comment.