Skip to content

Commit

Permalink
Try to check token for all requests
Browse files Browse the repository at this point in the history
  • Loading branch information
rad-pat committed Jun 17, 2024
1 parent ddce60c commit 6c5fb20
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 2 deletions.
2 changes: 1 addition & 1 deletion plaid/auth_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,5 @@ class PlaidAuthOAuthView(AuthOAuthView):
@expose("/login/<provider>")
def login(self, provider=None):
if provider is None:
return super().login(provider='plaid-keycloak')
return super().login(provider='plaidkeycloak')
return super().login(provider=provider)
84 changes: 83 additions & 1 deletion plaid/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from flask_appbuilder import Model
from flask_appbuilder.security.manager import AUTH_OID, AUTH_OAUTH
from authlib.integrations.flask_client import OAuth
# from authlib.integrations.flask_client.apps import FlaskOAuth2App
from authlib.integrations.flask_client import token_update
from requests.exceptions import HTTPError

from plaidcloud.rpc.connection.jsonrpc import SimpleRPC
Expand All @@ -32,6 +34,16 @@

log = logging.getLogger(__name__)

# class PlaidFlaskOAuth2App(FlaskOAuth2App):
def ensure_active_token(self):
with self._get_oauth_client() as client:
return client.ensure_active_token()

# OAuth.ensure_active_token = ensure_active_token

# class PlaidOAuth(OAuth):
# oauth2_client_cls = PlaidFlaskOAuth2App


def get_project_role_name(project_id: str) -> str:
"""Fetch the datasource role name by project ID.
Expand Down Expand Up @@ -64,12 +76,30 @@ def __init__(self, appbuilder):
client_kwargs=self.oidc_params['client_kwargs'],
)
self.authoidview = AuthOIDCView

if self.auth_type == AUTH_OAUTH:
self.authoauthview = PlaidAuthOAuthView

@token_update.connect_via(appbuilder)
def on_token_update(sender, name, token, refresh_token=None, access_token=None):
# if refresh_token:
# item = OAuth2Token.find(name=name, refresh_token=refresh_token)
# elif access_token:
# item = OAuth2Token.find(name=name, access_token=access_token)
# else:
# return
#
# # update old token
# item.access_token = token['access_token']
# item.refresh_token = token.get('refresh_token')
# item.expires_at = token['expires_at']
# item.save()
log.info(f'Updated token for {name} - {repr(token)}')
self.appbuilder.sm.set_oauth_session(name, token)

def oauth_user_info(self, provider, response=None):
# logging.debug("Oauth2 provider: {0}.".format(provider))
if provider == 'plaid-keycloak':
if provider == 'plaidkeycloak':
me = self.appbuilder.sm.oauth_remotes[provider].get("userinfo")
me.raise_for_status()
data = me.json()
Expand Down Expand Up @@ -370,3 +400,55 @@ def add_user_to_project(self, user, project_id):
log.debug(
"Appended %s to %s roles list.", role.name, user.username
)

def has_access(self, permission_name: str, view_name: str) -> bool:
# check token expiry and logout
if self.auth_type == AUTH_OAUTH:
if 'oauth' not in session:
return False
token, secret = session['oauth']
provider = session["oauth_provider"]
# if not self.oauth.plaidkeycloak.ensure_active_token(token):
# if not self.oauth[provider].ensure_active_token(token):
# if not self.appbuilder.sm.oauth_remotes[provider].ensure_active_token(token):
# remote = self.appbuilder.sm.oauth_remotes[provider]
with self.oauth.plaidkeycloak._get_oauth_client() as client:
return client.ensure_active_token(token)
# return False
# logout and redirect
elif self.auth_type == AUTH_OID:
if 'token' not in session:
return False
token = session['token']
if not self.oauth.plaid.ensure_active_token(token):
return False


# session["oauth"] = (
# oauth_response[token_key],
# oauth_response.get(token_secret, ""),
# )
# session["oauth_provider"] = provider


return super().has_access(permission_name, view_name)



# @token_update.connect_via(app)
# def on_token_update(sender, name, token, refresh_token=None, access_token=None):
# if refresh_token:
# item = OAuth2Token.find(name=name, refresh_token=refresh_token)
# elif access_token:
# item = OAuth2Token.find(name=name, access_token=access_token)
# else:
# return
#
# # update old token
# item.access_token = token['access_token']
# item.refresh_token = token.get('refresh_token')
# item.expires_at = token['expires_at']
# item.save()
#
# set_oauth_session(name, token)

0 comments on commit 6c5fb20

Please sign in to comment.