diff --git a/plaid/auth_oidc.py b/plaid/auth_oidc.py index 19a424a8f0050..ee921d5a5db7e 100644 --- a/plaid/auth_oidc.py +++ b/plaid/auth_oidc.py @@ -5,11 +5,13 @@ # from uuid import uuid4 from urllib.parse import urljoin, urlparse from flask import request, redirect, url_for, session, make_response, Response -from flask_appbuilder.security.views import AuthOIDView +from flask_appbuilder.security.views import AuthOIDView, AuthOAuthView from flask_appbuilder import expose from flask_login import login_user, logout_user log = logging.getLogger(__name__) + + class AuthOIDCView(AuthOIDView): @expose('/login/', methods=['GET', 'POST']) @@ -25,7 +27,7 @@ def login(self, flag:bool=True) -> Response: def authorize(self) -> Response: oauth = self.appbuilder.sm.oauth token = oauth.plaid.authorize_access_token() - userinfo = oauth.plaid.parse_id_token(token) + userinfo = oauth.plaid.parse_id_token(token, None) log.info(f"Fetched user info from token: {userinfo}") user_email = userinfo['email'].lower() if user_email.endswith('tartansolutions.com') or user_email.endswith('plaidcloud.com'): @@ -60,3 +62,12 @@ def logout(self) -> Response: def throwaway_password() -> str: random_bytes = os.urandom(64) return b64encode(random_bytes).decode('utf-8') + + +class PlaidAuthOAuthView(AuthOAuthView): + @expose("/login/") + @expose("/login/") + def login(self, provider=None): + if provider is None: + return super().login(provider='plaid-keycloak') + return super().login(provider=provider) diff --git a/plaid/security.py b/plaid/security.py index 9c458e8c2d902..8352bff2ebf6d 100644 --- a/plaid/security.py +++ b/plaid/security.py @@ -13,12 +13,12 @@ from flask import session from flask_login import logout_user from flask_appbuilder import Model -from flask_appbuilder.security.manager import AUTH_OID +from flask_appbuilder.security.manager import AUTH_OID, AUTH_OAUTH from authlib.integrations.flask_client import OAuth from requests.exceptions import HTTPError from plaidcloud.rpc.connection.jsonrpc import SimpleRPC -from plaid.auth_oidc import AuthOIDCView +from plaid.auth_oidc import AuthOIDCView, PlaidAuthOAuthView from superset.security import SupersetSecurityManager @@ -44,12 +44,13 @@ class PlaidSecurityManager(SupersetSecurityManager): """ def __init__(self, appbuilder): - super(PlaidSecurityManager, self).__init__(appbuilder) - # engine = self.get_session.get_bind(mapper=None, clause=None) - # metadata = MetaData(bind=engine, reflect=True) - # self.plaiduser_user = metadata.tables['plaiduser_user'] + app = appbuilder.get_app + app.config['AUTH_TYPE'] = AUTH_OAUTH + app.config['AUTH_USER_REGISTRATION'] = True + app.config['AUTH_ROLES_SYNC_AT_LOGIN'] = True + super().__init__(appbuilder) if self.auth_type == AUTH_OID: - self.oidc_params = self.appbuilder.app.config.get("OIDC_PARAMS") + self.oidc_params = app.config.get("OIDC_PARAMS") self.oauth = OAuth(app=appbuilder.get_app) self.oauth.register( 'plaid', @@ -61,11 +62,96 @@ def __init__(self, appbuilder): jwks_uri=self.oidc_params['jwks_uri'], client_kwargs=self.oidc_params['client_kwargs'], ) - self.authoidview = AuthOIDCView + self.authoidview = AuthOIDCView + if self.auth_type == AUTH_OAUTH: + self.authoauthview = PlaidAuthOAuthView + + def oauth_user_info(self, provider, response=None): + # logging.debug("Oauth2 provider: {0}.".format(provider)) + if provider == 'plaid-keycloak': + me = self.appbuilder.sm.oauth_remotes[provider].get("userinfo") + me.raise_for_status() + data = me.json() + log.debug("User info from Keycloak: %s", data) + + user_email = data['email'].lower() + role_keys = ["superset-plaid", "superset-gamma"] + if user_email.endswith('tartansolutions.com') or user_email.endswith('plaidcloud.com'): + role_keys.append("superset-admin") + + # possibly use AUTH_ROLES_MAPPING and AUTH_ROLES_SYNC_AT_LOGIN = True, with roles given in Keycloak + # https://github.com/apache/superset/blob/ff903486a851760e108b2e841e6a17348b3a9523/docs/src/pages/docs/installation/configuring.mdx + + return { + "username": data.get("name", data["preferred_username"]), # this matches OIDC + "first_name": data.get("given_name", ""), + "last_name": data.get("family_name", ""), + "email": data.get("email", ""), + "role_keys": role_keys # data.get("groups", []), + } + + def auth_user_oauth(self, userinfo): + """ + Method for authenticating user with OAuth. + N.B. This is the overridden to use email as the key instead of username + This is as per OIDC registration + :userinfo: dict with user information + (keys are the same as User model columns) + """ + # extract the email from `userinfo` + if "email" in userinfo and userinfo["email"]: + email = userinfo["email"] + else: + log.error("OAUTH userinfo does not have email %s", userinfo) + return None + + if "username" not in userinfo or not userinfo["username"]: + log.error("OAUTH userinfo does not have username %s", userinfo) + return None + + # Search the DB for this user by email + user = self.find_user(email=email) + + # If user is not active, go away + if user and (not user.is_active): + log.debug("User is not active: %s", email) + return None + + # If user is not registered, and not self-registration, go away + if (not user) and (not self.auth_user_registration): + return None + + # Sync the user's roles + if user and self.auth_roles_sync_at_login: + user.roles = self._oauth_calculate_user_roles(userinfo) + log.debug("Calculated new roles for user='%s' as: %s", email, user.roles) + + # If the user is new, register them + if (not user) and self.auth_user_registration: + user = self.add_user( + username=userinfo["username"], + first_name=userinfo.get("first_name", ""), + last_name=userinfo.get("last_name", ""), + email=email, + role=self._oauth_calculate_user_roles(userinfo), + ) + log.debug("New user registered: %s", user) + + # If user registration failed, go away + if not user: + log.error("Error creating a new OAuth user %s", email) + return None + + # LOGIN SUCCESS (only if user is now registered) + if user: + self.update_user_auth_stat(user) + return user + else: + return None def sync_role_definitions(self): - """PlaidSecurityManager contructor. + """PlaidSecurityManager constructor. Establishes a Plaid role (and Public, if configured to do so) after invoking the super constructor. @@ -110,12 +196,12 @@ def get_rpc(self) -> SimpleRPC: base_url = f"http://{self.appbuilder.app.config.get('PLAID_RPC')}" rpc_url = urljoin(base_url, "json-rpc/") - if 'workspace' in session: - temp_token = f"{session['token']['access_token']}_ws{session['workspace']}" + if self.auth_type == AUTH_OAUTH: + rpc_token, secret = session['oauth'] else: - temp_token = session['token']['access_token'] + rpc_token = session['token']['access_token'] - rpc = SimpleRPC(session['token']['access_token'], uri=rpc_url, verify_ssl=False) + rpc = SimpleRPC(rpc_token, uri=rpc_url, verify_ssl=False) try: rpc.identity.me.scopes() # Just checking authentication diff --git a/requirements/base.txt b/requirements/base.txt index d81bb06b088a5..94b8ae89348f9 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -25,7 +25,7 @@ attrs==23.2.0 # referencing # requests-cache # trio -authlib==0.13 +authlib==1.3.0 # via apache-superset babel==2.14.0 # via flask-babel diff --git a/setup.py b/setup.py index 3251ab1fe9cb1..6813368844f01 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ def get_git_sha() -> str: ], }, install_requires=[ - "Authlib==0.13", + "Authlib==1.3.0", "backoff>=1.8.0", "celery>=5.2.2, <6.0.0", "click>=8.0.3",