diff --git a/plaid/security.py b/plaid/security.py index 3c140fce3fb5f..2f46b1f0879ce 100644 --- a/plaid/security.py +++ b/plaid/security.py @@ -15,6 +15,8 @@ from flask_appbuilder import Model from flask_appbuilder.security.manager import AUTH_OID, AUTH_OAUTH from authlib.integrations.flask_client import OAuth +# from authlib.integrations.flask_client.apps import FlaskOAuth2App +from authlib.integrations.flask_client import token_update from requests.exceptions import HTTPError from plaidcloud.rpc.connection.jsonrpc import SimpleRPC @@ -32,6 +34,16 @@ log = logging.getLogger(__name__) +# class PlaidFlaskOAuth2App(FlaskOAuth2App): +def ensure_active_token(self): + with self._get_oauth_client() as client: + return client.ensure_active_token() + +# OAuth.ensure_active_token = ensure_active_token + +# class PlaidOAuth(OAuth): +# oauth2_client_cls = PlaidFlaskOAuth2App + def get_project_role_name(project_id: str) -> str: """Fetch the datasource role name by project ID. @@ -64,9 +76,27 @@ def __init__(self, appbuilder): client_kwargs=self.oidc_params['client_kwargs'], ) self.authoidview = AuthOIDCView + if self.auth_type == AUTH_OAUTH: self.authoauthview = PlaidAuthOAuthView + @token_update.connect_via(appbuilder) + def on_token_update(sender, name, token, refresh_token=None, access_token=None): + # if refresh_token: + # item = OAuth2Token.find(name=name, refresh_token=refresh_token) + # elif access_token: + # item = OAuth2Token.find(name=name, access_token=access_token) + # else: + # return + # + # # update old token + # item.access_token = token['access_token'] + # item.refresh_token = token.get('refresh_token') + # item.expires_at = token['expires_at'] + # item.save() + log.info(f'Updated token for {name} - {repr(token)}') + self.appbuilder.sm.set_oauth_session(name, token) + def oauth_user_info(self, provider, response=None): # logging.debug("Oauth2 provider: {0}.".format(provider)) if provider == 'plaid-keycloak': @@ -370,3 +400,55 @@ def add_user_to_project(self, user, project_id): log.debug( "Appended %s to %s roles list.", role.name, user.username ) + + def has_access(self, permission_name: str, view_name: str) -> bool: + # check token expiry and logout + if self.auth_type == AUTH_OAUTH: + if 'oauth' not in session: + return False + token, secret = session['oauth'] + provider = session["oauth_provider"] + # if not self.oauth.plaidkeycloak.ensure_active_token(token): + # if not self.oauth[provider].ensure_active_token(token): + # if not self.appbuilder.sm.oauth_remotes[provider].ensure_active_token(token): + remote = self.appbuilder.sm.oauth_remotes[provider] + with remote._get_oauth_client() as client: + return client.ensure_active_token() + # return False + # logout and redirect + elif self.auth_type == AUTH_OID: + if 'token' not in session: + return False + token = session['token'] + if not self.oauth.plaid.ensure_active_token(token): + return False + + + # session["oauth"] = ( + # oauth_response[token_key], + # oauth_response.get(token_secret, ""), + # ) + # session["oauth_provider"] = provider + + + return super().has_access(permission_name, view_name) + + + +# @token_update.connect_via(app) +# def on_token_update(sender, name, token, refresh_token=None, access_token=None): +# if refresh_token: +# item = OAuth2Token.find(name=name, refresh_token=refresh_token) +# elif access_token: +# item = OAuth2Token.find(name=name, access_token=access_token) +# else: +# return +# +# # update old token +# item.access_token = token['access_token'] +# item.refresh_token = token.get('refresh_token') +# item.expires_at = token['expires_at'] +# item.save() +# +# set_oauth_session(name, token) +