diff --git a/neon_hana/app/routers/auth.py b/neon_hana/app/routers/auth.py index 4cf78e2..efee45e 100644 --- a/neon_hana/app/routers/auth.py +++ b/neon_hana/app/routers/auth.py @@ -28,6 +28,7 @@ from neon_hana.app.dependencies import client_manager from neon_hana.schema.auth_requests import * +from neon_users_service.models import User auth_route = APIRouter(prefix="/auth", tags=["authentication"]) @@ -42,3 +43,8 @@ async def check_login(auth_request: AuthenticationRequest, @auth_route.post("/refresh") async def check_refresh(request: RefreshRequest) -> AuthenticationResponse: return client_manager.check_refresh_request(**dict(request)) + + +@auth_route.post("/register") +async def register_user(request: RegistrationRequest) -> User: + return client_manager.check_registration_request(**dict(request)) diff --git a/neon_hana/auth/client_manager.py b/neon_hana/auth/client_manager.py index 1e89529..89904b3 100644 --- a/neon_hana/auth/client_manager.py +++ b/neon_hana/auth/client_manager.py @@ -38,7 +38,14 @@ from neon_hana.auth.permissions import ClientPermissions from neon_hana.mq_service_api import MQServiceManager -from neon_users_service.models import User, AccessRoles, TokenConfig +from neon_users_service.models import User, AccessRoles, TokenConfig, NeonUserConfig, PermissionsConfig + +_DEFAULT_USER_PERMISSIONS = PermissionsConfig(klat=AccessRoles.USER, + core=AccessRoles.USER, + diana=AccessRoles.USER, + node=AccessRoles.USER, + hub=AccessRoles.USER, + llm=AccessRoles.USER) class ClientManager: @@ -69,7 +76,7 @@ def _create_tokens(self, encode_data: dict) -> TokenConfig: token_expiration = encode_data['expire'] token = jwt.encode(encode_data, self._access_secret, self._jwt_algo) - encode_data['expire'] = time() + self._refresh_token_lifetime + encode_data['expire'] = round(time()) + self._refresh_token_lifetime encode_data['access_token'] = token refresh = jwt.encode(encode_data, self._refresh_secret, self._jwt_algo) return TokenConfig(**{"username": encode_data['username'], @@ -78,8 +85,11 @@ def _create_tokens(self, encode_data: dict) -> TokenConfig: "access_token": token, "refresh_token": refresh, "expiration": token_expiration, + "refresh_expiration": encode_data['expire'], "token_name": encode_data['name'], - "refresh_expiration": encode_data['expire']}) + "creation_timestamp": encode_data['create'], + "last_refresh_timestamp": encode_data['last_refresh_timestamp'] + }) def get_permissions(self, client_id: str) -> ClientPermissions: """ @@ -116,6 +126,15 @@ def disconnect_stream(self): with self._stream_check_lock: self._connected_streams -= 1 + def check_registration_request(self, username: str, password: str, + user_config: NeonUserConfig) -> User: + """ + Handle a request to register a new user. + """ + new_user = User(username=username, password_hash=password, + neon=user_config, permissions=_DEFAULT_USER_PERMISSIONS) + return self._mq_connector.create_user(new_user) + def check_auth_request(self, client_id: str, username: str, password: Optional[str] = None, token_name: Optional[str] = None, @@ -158,7 +177,6 @@ def check_auth_request(self, client_id: str, username: str, else: user = self._mq_connector.get_user_profile(username, password) username = user.username - password = user.password_hash # Boolean permissions allow access for any role, including `NODE`. # Specific endpoints may enforce more granular controls/limits based on @@ -167,13 +185,12 @@ def check_auth_request(self, client_id: str, username: str, node=user.permissions.node != AccessRoles.NONE, assist=user.permissions.core != AccessRoles.NONE, backend=user.permissions.diana != AccessRoles.NONE) - create_time = time() + create_time = round(time()) expiration = create_time + self._access_token_lifetime encode_data = {"client_id": client_id, "sub": username, # Added for Klat token compat. "name": token_name, "username": username, - "password": password, "permissions": permissions.as_dict(), "create": create_time, "expire": expiration, @@ -208,12 +225,17 @@ def check_refresh_request(self, access_token: str, refresh_token: str, detail="Access token does not match client_id") encode_data = {k: token_data[k] for k in ("client_id", "username", "password")} - refresh_time = time() + + user = self._mq_connector.get_user_profile(username=token_data['username'], + access_token=refresh_token) + if not user.password_hash: + # This should not be possible, but don't let an error in the + # users service allow for injecting a new valid token to the db + raise HTTPException(status_code=500, detail="Error Fetching User") + refresh_time = round(time()) encode_data['last_refresh_timestamp'] = refresh_time encode_data["expire"] = refresh_time + self._access_token_lifetime new_auth = self._create_tokens(encode_data) - user = self._mq_connector.get_user_profile(username=token_data['username'], - password=token_data['password']) self._add_token_to_userdb(user, new_auth) return new_auth.model_dump() diff --git a/neon_hana/mq_service_api.py b/neon_hana/mq_service_api.py index 3dfcf93..69c1349 100644 --- a/neon_hana/mq_service_api.py +++ b/neon_hana/mq_service_api.py @@ -79,26 +79,30 @@ def _validate_api_proxy_response(response: dict, query_params: dict): raise APIError(status_code=code, detail=response['content']) @staticmethod - def _query_users_api(operation: str, username: Optional[str] = None, + def _query_users_api(operation: str, username: str, password: Optional[str] = None, - user: Optional[User] = None) -> (bool, Union[User, int, str]): + access_token: Optional[str] = None, + user: Optional[User] = None) -> (bool, int, Union[User, str]): """ Query the users API and return a status code and either a valid User or - a string error message + a string error message. Authentication may use EITHER a password or + a token. @param operation: Operation to perform (create, read, update, delete) @param username: Optional username to include @param password: Optional password to include + @param access_token: Optional auth token to include @param user: Optional user object to include - @return: success bool, User object or string error message + @return: success bool, HTTP status code User object or string error message """ response = send_mq_request("/neon_users", {"operation": operation, "username": username, "password": password, - "user": user}, + "access_token": access_token, + "user": user.model_dump() if user else None}, "neon_users_input") if response.get("success"): - return True, 200, response.get("user") + return True, 200, User(**response.get("user")) return False, response.get("code", 500), response.get("error", "") def get_session(self, node_data: NodeData) -> dict: @@ -113,17 +117,21 @@ def get_session(self, node_data: NodeData) -> dict: "site_id": node_data.location.site_id}) return self.sessions_by_id[session_id] - def get_user_profile(self, username: str, password: str) -> User: + def get_user_profile(self, username: str, password: Optional[str] = None, + access_token: Optional[str] = None) -> User: """ - Get a User object for a user. This requires that a valid password be - provided to prevent arbitrary users from reading private profile info. + Get a User object for a user. This requires that a valid password OR + access token be provided to prevent arbitrary users from reading + private profile info. @param username: Valid username to get a User object for - @param password: Valid password for the input username + @param password: Valid password to use for authentication + @param access_token: Valid access token to use for authentication @returns: User object from the Users service. """ stat, code, err_or_user = self._query_users_api("read", username=username, - password=password) + password=password, + access_token=access_token) if not stat: raise HTTPException(status_code=code, detail=err_or_user) return err_or_user @@ -134,7 +142,10 @@ def create_user(self, user: User) -> User: @param user: User object to add to the users service database @returns: User object added to the database """ - stat, code, err_or_user = self._query_users_api("create", user=user) + stat, code, err_or_user = self._query_users_api("create", + username=user.username, + password=user.password_hash, + user=user) if not stat: raise HTTPException(status_code=code, detail=err_or_user) return err_or_user @@ -145,7 +156,10 @@ def update_user(self, user: User) -> User: @param user: Updated user object to write @returns: User as read from the database """ - stat, code, err_or_user = self._query_users_api("update", user=user) + stat, code, err_or_user = self._query_users_api("update", + username=user.username, + password=user.password_hash, + user=user) if not stat: raise HTTPException(status_code=code, detail=err_or_user) return err_or_user diff --git a/neon_hana/schema/auth_requests.py b/neon_hana/schema/auth_requests.py index c889639..46463ee 100644 --- a/neon_hana/schema/auth_requests.py +++ b/neon_hana/schema/auth_requests.py @@ -30,6 +30,8 @@ from pydantic import BaseModel, Field +from neon_users_service.models import NeonUserConfig + class AuthenticationRequest(BaseModel): username: str = "guest" @@ -68,3 +70,19 @@ class RefreshRequest(BaseModel): access_token: str refresh_token: str client_id: str + + +class RegistrationRequest(BaseModel): + username: str + password: str + user_config: NeonUserConfig = NeonUserConfig() + + model_config = { + "json_schema_extra": { + "examples": [{ + "username": "guest", + "password": "password", + "user_config": NeonUserConfig().model_dump() + }, {"username": "guest", + "password": "password"} + ]}}