Skip to content

Commit

Permalink
Update client_manager to use mq_connector for authentication via …
Browse files Browse the repository at this point in the history
…`neon-users-service`

Update tokens to include more data, maintaining backwards-compat and adding `TokenConfig` compat.
Update tokens for Klat token compat
Update permissions handling to respect user configuration values
Update auth request to include token_name for User database integration
Add UserProfile.from_user_config for database compat.
Update MQ connector to integrate with users service
  • Loading branch information
NeonDaniel committed Oct 30, 2024
1 parent da161ee commit dc1abd2
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 28 deletions.
2 changes: 1 addition & 1 deletion neon_hana/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@

config = Configuration().get("hana") or dict()
mq_connector = MQServiceManager(config)
client_manager = ClientManager(config)
client_manager = ClientManager(config, mq_connector)
jwt_bearer = UserTokenAuth(client_manager)
82 changes: 59 additions & 23 deletions neon_hana/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@
from token_throttler.storage import RuntimeStorage

from neon_hana.auth.permissions import ClientPermissions
from neon_hana.mq_service_api import MQServiceManager
from neon_users_service.models import User, AccessRoles, TokenConfig


class ClientManager:
def __init__(self, config: dict):
def __init__(self, config: dict, mq_connector: MQServiceManager):
self.rate_limiter = TokenThrottler(cost=1, storage=RuntimeStorage())

self.authorized_clients: Dict[str, dict] = dict()
Expand All @@ -58,8 +60,9 @@ def __init__(self, config: dict):
self._jwt_algo = "HS256"
self._connected_streams = 0
self._stream_check_lock = Lock()
self._mq_connector = mq_connector

def _create_tokens(self, encode_data: dict) -> dict:
def _create_tokens(self, encode_data: dict) -> TokenConfig:
# Permissions were not included in old tokens, allow refreshing with
# default permissions
encode_data.setdefault("permissions", ClientPermissions().as_dict())
Expand All @@ -69,13 +72,14 @@ def _create_tokens(self, encode_data: dict) -> dict:
encode_data['expire'] = time() + self._refresh_token_lifetime
encode_data['access_token'] = token
refresh = jwt.encode(encode_data, self._refresh_secret, self._jwt_algo)
# TODO: Store refresh token on server to allow invalidating clients
return {"username": encode_data['username'],
"client_id": encode_data['client_id'],
"permissions": encode_data['permissions'],
"access_token": token,
"refresh_token": refresh,
"expiration": token_expiration}
return TokenConfig(**{"username": encode_data['username'],
"client_id": encode_data['client_id'],
"permissions": encode_data['permissions'],
"access_token": token,
"refresh_token": refresh,
"expiration": token_expiration,
"token_name": encode_data['name'],
"refresh_expiration": encode_data['expire']})

def get_permissions(self, client_id: str) -> ClientPermissions:
"""
Expand Down Expand Up @@ -114,13 +118,15 @@ def disconnect_stream(self):

def check_auth_request(self, client_id: str, username: str,
password: Optional[str] = None,
token_name: Optional[str] = None,
origin_ip: str = "127.0.0.1") -> dict:
"""
Authenticate and Authorize a new client connection with the specified
username, password, and origin IP address.
@param client_id: Client ID of the connection to auth
@param username: Supplied username to authenticate
@param password: Supplied password to authenticate
@param token_name: Token name to add to user database
@param origin_ip: Origin IP address of request
@return: response tokens, permissions, and other metadata
"""
Expand All @@ -142,23 +148,40 @@ def check_auth_request(self, client_id: str, username: str,
detail=f"Too many auth requests from: "
f"{origin_ip}. Wait {wait_time}s.")

node_access = False
if username != "guest":
# TODO: Validate password here
pass
if all((self._node_username, username == self._node_username,
password == self._node_password)):
node_access = True
permissions = ClientPermissions(node=node_access)
expiration = time() + self._access_token_lifetime
# TODO: disable "guest" access?
if username == "guest":
user = User(username=username, password=password)
elif all((self._node_username, username == self._node_username,
password == self._node_password)):
user = User(username=username, password=password)
user.permissions.node = AccessRoles.USER
else:
user = self._mq_connector.get_user_profile(username, password)
username = user.username
password = user.password_hash

# Boolean permissions allow access for any role, including `NODE`.
# Specific endpoints may enforce more granular controls/limits based on
# specific user.permissions values.
permissions = ClientPermissions(
node=user.permissions.node != AccessRoles.NONE,
assist=user.permissions.core != AccessRoles.NONE,
backend=user.permissions.diana != AccessRoles.NONE)
create_time = time()
expiration = create_time + self._access_token_lifetime
encode_data = {"client_id": client_id,
"sub": username, # Added for Klat token compat.
"name": token_name,
"username": username,
"password": password,
"permissions": permissions.as_dict(),
"expire": expiration}
"create": create_time,
"expire": expiration,
"last_refresh_timestamp": create_time}
auth = self._create_tokens(encode_data)
self.authorized_clients[client_id] = auth
return auth
self._add_token_to_userdb(user, auth)
self.authorized_clients[client_id] = auth.model_dump()
return auth.model_dump()

def check_refresh_request(self, access_token: str, refresh_token: str,
client_id: str):
Expand All @@ -185,9 +208,22 @@ def check_refresh_request(self, access_token: str, refresh_token: str,
detail="Access token does not match client_id")
encode_data = {k: token_data[k] for k in
("client_id", "username", "password")}
encode_data["expire"] = time() + self._access_token_lifetime
refresh_time = time()
encode_data['last_refresh_timestamp'] = refresh_time
encode_data["expire"] = refresh_time + self._access_token_lifetime
new_auth = self._create_tokens(encode_data)
return new_auth
user = self._mq_connector.get_user_profile(username=token_data['username'],
password=token_data['password'])
self._add_token_to_userdb(user, new_auth)
return new_auth.model_dump()

def _add_token_to_userdb(self, user: User, token_data: TokenConfig):
# Enforce unique `creation_timestamp` values to avoid duplicate entries
for idx, token in enumerate(user.tokens):
if token.creation_timestamp == token_data.creation_timestamp:
user.tokens.remove(token)
user.tokens.append(token_data)
self._mq_connector.update_user(user)

def get_client_id(self, token: str) -> str:
"""
Expand Down
63 changes: 62 additions & 1 deletion neon_hana/mq_service_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
import json

from time import time
from typing import Optional, Dict, Any, List
from typing import Optional, Dict, Any, List, Union
from uuid import uuid4
from fastapi import HTTPException

from neon_hana.schema.node_model import NodeData
from neon_hana.schema.user_profile import UserProfile
from neon_mq_connector.utils.client_utils import send_mq_request
from neon_users_service.models import User


class APIError(HTTPException):
Expand Down Expand Up @@ -77,6 +78,29 @@ def _validate_api_proxy_response(response: dict, query_params: dict):
code = response['status_code'] if response['status_code'] > 200 else 500
raise APIError(status_code=code, detail=response['content'])

@staticmethod
def _query_users_api(operation: str, username: Optional[str] = None,
password: Optional[str] = None,
user: Optional[User] = None) -> (bool, Union[User, int, str]):
"""
Query the users API and return a status code and either a valid User or
a string error message
@param operation: Operation to perform (create, read, update, delete)
@param username: Optional username to include
@param password: Optional password to include
@param user: Optional user object to include
@return: success bool, User object or string error message
"""
response = send_mq_request("/neon_users",
{"operation": operation,
"username": username,
"password": password,
"user": user},
"neon_users_input")
if response.get("success"):
return True, 200, response.get("user")
return False, response.get("code", 500), response.get("error", "")

def get_session(self, node_data: NodeData) -> dict:
"""
Get a serialized Session object for the specified Node.
Expand All @@ -89,6 +113,43 @@ def get_session(self, node_data: NodeData) -> dict:
"site_id": node_data.location.site_id})
return self.sessions_by_id[session_id]

def get_user_profile(self, username: str, password: str) -> User:
"""
Get a User object for a user. This requires that a valid password be
provided to prevent arbitrary users from reading private profile info.
@param username: Valid username to get a User object for
@param password: Valid password for the input username
@returns: User object from the Users service.
"""
stat, code, err_or_user = self._query_users_api("read",
username=username,
password=password)
if not stat:
raise HTTPException(status_code=code, detail=err_or_user)
return err_or_user

def create_user(self, user: User) -> User:
"""
Create a new user.
@param user: User object to add to the users service database
@returns: User object added to the database
"""
stat, code, err_or_user = self._query_users_api("create", user=user)
if not stat:
raise HTTPException(status_code=code, detail=err_or_user)
return err_or_user

def update_user(self, user: User) -> User:
"""
Update an existing user in the database.
@param user: Updated user object to write
@returns: User as read from the database
"""
stat, code, err_or_user = self._query_users_api("update", user=user)
if not stat:
raise HTTPException(status_code=code, detail=err_or_user)
return err_or_user

def query_api_proxy(self, service_name: str, query_params: dict,
timeout: int = 10):
query_params['service'] = service_name
Expand Down
2 changes: 1 addition & 1 deletion neon_hana/mq_websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from neon_iris.client import NeonAIClient
from ovos_bus_client.message import Message
from threading import RLock
from ovos_utils import LOG
from ovos_utils.log import LOG


class ClientNotKnown(RuntimeError):
Expand Down
5 changes: 4 additions & 1 deletion neon_hana/schema/auth_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from datetime import datetime
from typing import Optional
from uuid import uuid4

Expand All @@ -33,13 +34,15 @@
class AuthenticationRequest(BaseModel):
username: str = "guest"
password: Optional[str] = None
token_name: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
client_id: str = Field(default_factory=lambda: str(uuid4()))

model_config = {
"json_schema_extra": {
"examples": [{
"username": "guest",
"password": "password"
"password": "password",
"token_name": "My Client"
}]}}


Expand Down
62 changes: 62 additions & 0 deletions neon_hana/schema/user_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import pytz
import datetime

from typing import Optional, List
from pydantic import BaseModel

from neon_users_service.models import User


class ProfileUser(BaseModel):
first_name: str = ""
Expand Down Expand Up @@ -102,3 +107,60 @@ class UserProfile(BaseModel):
location: ProfileLocation = ProfileLocation()
response_mode: ProfileResponseMode = ProfileResponseMode()
privacy: ProfilePrivacy = ProfilePrivacy()

@classmethod
def from_user_config(cls, user: User):
user_config = user.neon
today = datetime.date.today()
if user_config.user.dob:
dob = user_config.user.dob
age = today.year - dob.year - (
(today.month, today.day) < (dob.month, dob.day))
dob = dob.strftime("%Y/%m/%d")
else:
age = ""
dob = "YYYY/MM/DD"
full_name = " ".join((n for n in (user_config.user.first_name,
user_config.user.middle_name,
user_config.user.last_name) if n))
user = ProfileUser(about=user_config.user.about,
age=age, dob=dob,
email=user_config.user.email,
email_verified=user_config.user.email_verified,
first_name=user_config.user.first_name,
full_name=full_name,
last_name=user_config.user.last_name,
middle_name=user_config.user.middle_name,
password=user.password_hash or "",
phone=user_config.user.phone,
phone_verified=user_config.user.phone_verified,
picture=user_config.user.avatar_url,
preferred_name=user_config.user.preferred_name,
username=user.username
)
alt_stt = [lang.split('-')[0] for lang in
user_config.language.input_languages[1:]]
secondary_tts_lang = user_config.language.output_languages[1] if (
len(user_config.language.output_languages) > 1) else None
speech = ProfileSpeech(
alt_langs=alt_stt,
secondary_tts_gender=user_config.response_mode.tts_gender,
secondary_tts_language=secondary_tts_lang,
speed_multiplier=user_config.response_mode.tts_speed_multiplier,
stt_language=user_config.language.input_languages[0].split('-')[0],
tts_gender=user_config.response_mode.tts_gender,
tts_language=user_config.language.output_languages[0])
units = ProfileUnits(**user_config.units.model_dump())
utc_hours = (pytz.timezone(user_config.location.timezone or "UTC")
.utcoffset(datetime.datetime.now()).total_seconds() / 3600)
# TODO: Get city, state, country from lat/lon
location = ProfileLocation(lat=user_config.location.latitude,
lng=user_config.location.longitude,
tz=user_config.location.timezone,
utc=utc_hours)
response_mode = ProfileResponseMode(**user_config.response_mode.model_dump())
privacy = ProfilePrivacy(**user_config.privacy.model_dump())

return UserProfile(location=location, privacy=privacy,
response_mode=response_mode, speech=speech,
units=units, user=user)
4 changes: 3 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ pydantic~=2.5
pyjwt~=2.8
token-throttler~=1.4
neon-mq-connector~=0.7
ovos-config~=0.0.12
ovos-config~=0.0,>=0.0.12
ovos-utils~=0.0,>=0.0.38
neon-users-service@git+https://github.com/neongeckocom/neon-users-service@FEAT_InitialImplementation
18 changes: 18 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from unittest import TestCase

from neon_hana.schema.user_profile import UserProfile

from neon_users_service.models import User


class TestUserProfile(TestCase):
def test_user_profile(self):
# Test default
profile = UserProfile()
self.assertIsInstance(profile, UserProfile)

# Test from User
default_user = User(username="test_user")
profile = UserProfile.from_user_config(default_user)
self.assertIsInstance(profile, UserProfile)
self.assertEqual(default_user.username, profile.user.username)

0 comments on commit dc1abd2

Please sign in to comment.