From 0e9cf02d26438ee5bc7bb950110a6dbd980392ed Mon Sep 17 00:00:00 2001 From: Taylor Steinberg Date: Mon, 7 Oct 2024 13:51:30 -0400 Subject: [PATCH] feat: add version check support (#299) Adds a new decorator, `@context.requires`, which asserts version compatibility when the server version is known. The check is skipped if the server version is unknown (e.g., the Connect configuration disables version information). Also marks the OAuth API with a '2024.08.0' requirement. Closes #272 --- pyproject.toml | 5 +- requirements.txt | 1 + src/posit/connect/client.py | 11 ++- src/posit/connect/context.py | 45 ++++++++++ .../posit/connect/external/test_databricks.py | 2 + .../posit/connect/external/test_snowflake.py | 2 + .../posit/connect/oauth/test_associations.py | 4 + .../posit/connect/oauth/test_integrations.py | 9 +- tests/posit/connect/oauth/test_oauth.py | 5 +- tests/posit/connect/oauth/test_sessions.py | 16 ++-- tests/posit/connect/test_client.py | 10 +++ tests/posit/connect/test_context.py | 90 +++++++++++++++++++ 12 files changed, 185 insertions(+), 15 deletions(-) create mode 100644 src/posit/connect/context.py create mode 100644 tests/posit/connect/test_context.py diff --git a/pyproject.toml b/pyproject.toml index d7f15f16..0f1b19a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,10 @@ classifiers = [ "Typing :: Typed", ] dynamic = ["version"] -dependencies = ["requests>=2.31.0,<3"] +dependencies = [ + "requests>=2.31.0,<3", + "packaging" +] [project.urls] Source = "https://github.com/posit-dev/posit-sdk-py" diff --git a/requirements.txt b/requirements.txt index 6e421681..8f2d86fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ requests==2.32.2 +packaging==24.1 diff --git a/src/posit/connect/client.py b/src/posit/connect/client.py index d33652c9..0f42f475 100644 --- a/src/posit/connect/client.py +++ b/src/posit/connect/client.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import overload +from typing import Optional, overload from requests import Response, Session @@ -10,6 +10,7 @@ from .auth import Auth from .config import Config from .content import Content +from .context import Context, ContextManager, requires from .groups import Groups from .metrics import Metrics from .oauth import OAuth @@ -18,7 +19,7 @@ from .users import User, Users -class Client: +class Client(ContextManager): """ Client connection for Posit Connect. @@ -156,9 +157,10 @@ def __init__(self, *args, **kwargs) -> None: session.hooks["response"].append(hooks.handle_errors) self.session = session self.resource_params = ResourceParameters(session, self.cfg.url) + self.ctx = Context(self.session, self.cfg.url) @property - def version(self) -> str: + def version(self) -> Optional[str]: """ The server version. @@ -167,7 +169,7 @@ def version(self) -> str: str The version of the Posit Connect server. """ - return self.get("server_settings").json()["version"] + return self.ctx.version @property def me(self) -> User: @@ -257,6 +259,7 @@ def metrics(self) -> Metrics: return Metrics(self.resource_params) @property + @requires(version="2024.08.0") def oauth(self) -> OAuth: """ The OAuth API interface. diff --git a/src/posit/connect/context.py b/src/posit/connect/context.py new file mode 100644 index 00000000..14b7e330 --- /dev/null +++ b/src/posit/connect/context.py @@ -0,0 +1,45 @@ +import functools +from typing import Optional, Protocol + +from packaging.version import Version + + +def requires(version: str): + def decorator(func): + @functools.wraps(func) + def wrapper(instance: ContextManager, *args, **kwargs): + ctx = instance.ctx + if ctx.version and Version(ctx.version) < Version(version): + raise RuntimeError( + f"This API is not available in Connect version {ctx.version}. Please upgrade to version {version} or later.", + ) + return func(instance, *args, **kwargs) + + return wrapper + + return decorator + + +class Context(dict): + def __init__(self, session, url): + self.session = session + self.url = url + + @property + def version(self) -> Optional[str]: + try: + value = self["version"] + except KeyError: + endpoint = self.url + "server_settings" + response = self.session.get(endpoint) + result = response.json() + value = self["version"] = result.get("version") + return value + + @version.setter + def version(self, value: str): + self["version"] = value + + +class ContextManager(Protocol): + ctx: Context diff --git a/tests/posit/connect/external/test_databricks.py b/tests/posit/connect/external/test_databricks.py index 27030188..2be75527 100644 --- a/tests/posit/connect/external/test_databricks.py +++ b/tests/posit/connect/external/test_databricks.py @@ -48,6 +48,7 @@ def test_posit_credentials_provider(self): register_mocks() client = Client(api_key="12345", url="https://connect.example/") + client.ctx.version = None cp = PositCredentialsProvider(client=client, user_session_token="cit") assert cp() == {"Authorization": f"Bearer dynamic-viewer-access-token"} @@ -57,6 +58,7 @@ def test_posit_credentials_strategy(self): register_mocks() client = Client(api_key="12345", url="https://connect.example/") + client.ctx.version = None cs = PositCredentialsStrategy( local_strategy=mock_strategy(), user_session_token="cit", diff --git a/tests/posit/connect/external/test_snowflake.py b/tests/posit/connect/external/test_snowflake.py index 59680c53..a544d860 100644 --- a/tests/posit/connect/external/test_snowflake.py +++ b/tests/posit/connect/external/test_snowflake.py @@ -33,6 +33,7 @@ def test_posit_authenticator(self): register_mocks() client = Client(api_key="12345", url="https://connect.example/") + client.ctx.version = None auth = PositAuthenticator( local_authenticator="SNOWFLAKE", user_session_token="cit", @@ -44,6 +45,7 @@ def test_posit_authenticator(self): def test_posit_authenticator_fallback(self): # local_authenticator is used when the content is running locally client = Client(api_key="12345", url="https://connect.example/") + client.ctx.version = None auth = PositAuthenticator( local_authenticator="SNOWFLAKE", user_session_token="cit", diff --git a/tests/posit/connect/oauth/test_associations.py b/tests/posit/connect/oauth/test_associations.py index c5febf82..7ea13a70 100644 --- a/tests/posit/connect/oauth/test_associations.py +++ b/tests/posit/connect/oauth/test_associations.py @@ -55,6 +55,7 @@ def test(self): # setup c = Client("https://connect.example", "12345") + c.ctx.version = None # invoke associations = c.oauth.integrations.get(guid).associations.find() @@ -83,6 +84,7 @@ def test(self): # setup c = Client("https://connect.example", "12345") + c.ctx.version = None # invoke associations = c.content.get(guid).oauth.associations.find() @@ -115,6 +117,7 @@ def test(self): # setup c = Client("https://connect.example", "12345") + c.ctx.version = None # invoke c.content.get(guid).oauth.associations.update(new_integration_guid) @@ -142,6 +145,7 @@ def test(self): # setup c = Client("https://connect.example", "12345") + c.ctx.version = None # invoke c.content.get(guid).oauth.associations.delete() diff --git a/tests/posit/connect/oauth/test_integrations.py b/tests/posit/connect/oauth/test_integrations.py index efd6b051..797e8008 100644 --- a/tests/posit/connect/oauth/test_integrations.py +++ b/tests/posit/connect/oauth/test_integrations.py @@ -73,6 +73,7 @@ def test(self): # setup c = Client("https://connect.example", "12345") + c.ctx.version = None integration = c.oauth.integrations.get(guid) # invoke @@ -93,6 +94,7 @@ def test(self): ) c = Client("https://connect.example", "12345") + c.ctx.version = None integration = c.oauth.integrations.get(guid) assert integration.guid == guid @@ -137,6 +139,7 @@ def test(self): # setup c = Client("https://connect.example", "12345") + c.ctx.version = None # invoke integration = c.oauth.integrations.create( @@ -164,10 +167,11 @@ def test(self): ) # setup - client = Client("https://connect.example", "12345") + c = Client("https://connect.example", "12345") + c.ctx.version = None # invoke - integrations = client.oauth.integrations.find() + integrations = c.oauth.integrations.find() # assert assert mock_get.call_count == 1 @@ -189,6 +193,7 @@ def test(self): # setup c = Client("https://connect.example", "12345") + c.ctx.version = None integration = c.oauth.integrations.get(guid) assert mock_get.call_count == 1 diff --git a/tests/posit/connect/oauth/test_oauth.py b/tests/posit/connect/oauth/test_oauth.py index a851ea9f..8e10ec8a 100644 --- a/tests/posit/connect/oauth/test_oauth.py +++ b/tests/posit/connect/oauth/test_oauth.py @@ -23,5 +23,6 @@ def test_get_credentials(self): "token_type": "Bearer", }, ) - con = Client(api_key="12345", url="https://connect.example/") - assert con.oauth.get_credentials("cit")["access_token"] == "viewer-token" + c = Client(api_key="12345", url="https://connect.example/") + c.ctx.version = None + assert c.oauth.get_credentials("cit")["access_token"] == "viewer-token" diff --git a/tests/posit/connect/oauth/test_sessions.py b/tests/posit/connect/oauth/test_sessions.py index c9e857b3..09cd8300 100644 --- a/tests/posit/connect/oauth/test_sessions.py +++ b/tests/posit/connect/oauth/test_sessions.py @@ -53,6 +53,7 @@ def test(self): # setup c = Client("https://connect.example", "12345") + c.ctx.version = None session = c.oauth.sessions.get(guid) # invoke @@ -72,10 +73,11 @@ def test(self): ) # setup - client = Client("https://connect.example", "12345") + c = Client("https://connect.example", "12345") + c.ctx.version = None # invoke - sessions = client.oauth.sessions.find() + sessions = c.oauth.sessions.find() # assert assert mock_get.call_count == 1 @@ -94,10 +96,11 @@ def test_params_all(self): ) # setup - client = Client("https://connect.example", "12345") + c = Client("https://connect.example", "12345") + c.ctx.version = None # invoke - client.oauth.sessions.find(all=True) + c.oauth.sessions.find(all=True) # assert assert mock_get.call_count == 1 @@ -115,10 +118,11 @@ def test(self): ) # setup - client = Client("https://connect.example", "12345") + c = Client("https://connect.example", "12345") + c.ctx.version = None # invoke - session = client.oauth.sessions.get(guid=guid) + session = c.oauth.sessions.get(guid=guid) # assert assert mock_get.call_count == 1 diff --git a/tests/posit/connect/test_client.py b/tests/posit/connect/test_client.py index 7af64103..217f359c 100644 --- a/tests/posit/connect/test_client.py +++ b/tests/posit/connect/test_client.py @@ -172,3 +172,13 @@ def test_delete(self, MockSession): client = Client(api_key=api_key, url=url) client.delete("/foo") client.session.delete.assert_called_once_with("https://connect.example.com/__api__/foo") + + +class TestClientOAuth: + def test_required_version(self): + api_key = "12345" + url = "https://connect.example.com" + client = Client(api_key=api_key, url=url) + client.ctx.version = "2024.07.0" + with pytest.raises(RuntimeError): + client.oauth diff --git a/tests/posit/connect/test_context.py b/tests/posit/connect/test_context.py new file mode 100644 index 00000000..6ec6275f --- /dev/null +++ b/tests/posit/connect/test_context.py @@ -0,0 +1,90 @@ +from email.contentmanager import ContentManager +from unittest.mock import MagicMock, Mock + +import pytest +import requests +import responses + +from posit.connect.context import Context, requires +from posit.connect.urls import Url + + +class TestRequires: + def test_version_unsupported(self): + class Stub(ContentManager): + def __init__(self, ctx): + self.ctx = ctx + + @requires("1.0.0") + def fail(self): + pass + + ctx = MagicMock() + ctx.version = "0.0.0" + instance = Stub(ctx) + + with pytest.raises(RuntimeError): + instance.fail() + + def test_version_supported(self): + class Stub(ContentManager): + def __init__(self, ctx): + self.ctx = ctx + + @requires("1.0.0") + def success(self): + pass + + ctx = MagicMock() + ctx.version = "1.0.0" + instance = Stub(ctx) + + instance.success() + + def test_version_missing(self): + class Stub(ContentManager): + def __init__(self, ctx): + self.ctx = ctx + + @requires("1.0.0") + def success(self): + pass + + ctx = MagicMock() + ctx.version = None + instance = Stub(ctx) + + instance.success() + + +class TestContextVersion: + @responses.activate + def test_unknown(self): + responses.get( + f"http://connect.example/__api__/server_settings", + json={}, + ) + + session = requests.Session() + url = Url("http://connect.example") + ctx = Context(session, url) + + assert ctx.version is None + + @responses.activate + def test_known(self): + responses.get( + f"http://connect.example/__api__/server_settings", + json={"version": "2024.09.24"}, + ) + + session = requests.Session() + url = Url("http://connect.example") + ctx = Context(session, url) + + assert ctx.version == "2024.09.24" + + def test_setter(self): + ctx = Context(Mock(), Mock()) + ctx.version = "2024.09.24" + assert ctx.version == "2024.09.24"