From 3ad911a20620a67b6a019e86da815e2a25047de7 Mon Sep 17 00:00:00 2001 From: John Estabrook Date: Mon, 23 Sep 2024 12:53:20 -0500 Subject: [PATCH 1/9] http-api: T6736: update for deprecated/renamed in Pydantic V2 --- src/services/vyos-http-api-server | 37 ++++++++++++++++--------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/services/vyos-http-api-server b/src/services/vyos-http-api-server index 91100410c8..42c4cf387a 100755 --- a/src/services/vyos-http-api-server +++ b/src/services/vyos-http-api-server @@ -33,7 +33,7 @@ from fastapi import BackgroundTasks from fastapi.responses import HTMLResponse from fastapi.exceptions import RequestValidationError from fastapi.routing import APIRoute -from pydantic import BaseModel, StrictStr, validator, model_validator +from pydantic import BaseModel, StrictStr, field_validator, model_validator from starlette.middleware.cors import CORSMiddleware from starlette.datastructures import FormData from starlette.formparsers import FormParser, MultiPartParser @@ -91,9 +91,9 @@ def success(data): return HTMLResponse(resp) # Pydantic models for validation -# Pydantic will cast when possible, so use StrictStr -# validators added as needed for additional constraints -# schema_extra adds anotations to OpenAPI, to add examples +# Pydantic will cast when possible, so use StrictStr validators added as +# needed for additional constraints +# json_schema_extra adds anotations to OpenAPI to add examples class ApiModel(BaseModel): key: StrictStr @@ -102,8 +102,9 @@ class BasePathModel(BaseModel): op: StrictStr path: List[StrictStr] - @validator("path") - def check_non_empty(cls, path): + @field_validator("path") + @classmethod + def check_non_empty(cls, path: str) -> str: if not len(path) > 0: raise ValueError('path must be non-empty') return path @@ -113,7 +114,7 @@ class BaseConfigureModel(BasePathModel): class ConfigureModel(ApiModel, BaseConfigureModel): class Config: - schema_extra = { + json_schema_extra = { "example": { "key": "id_key", "op": "set | delete | comment", @@ -125,7 +126,7 @@ class ConfigureListModel(ApiModel): commands: List[BaseConfigureModel] class Config: - schema_extra = { + json_schema_extra = { "example": { "key": "id_key", "commands": "list of commands", @@ -155,7 +156,7 @@ class RetrieveModel(ApiModel): configFormat: StrictStr = None class Config: - schema_extra = { + json_schema_extra = { "example": { "key": "id_key", "op": "returnValue | returnValues | exists | showConfig", @@ -170,7 +171,7 @@ class ConfigFileModel(ApiModel): file: StrictStr = None class Config: - schema_extra = { + json_schema_extra = { "example": { "key": "id_key", "op": "save | load", @@ -203,7 +204,7 @@ class ImageModel(ApiModel): return self class Config: - schema_extra = { + json_schema_extra = { "example": { "key": "id_key", "op": "add | delete | show | set_default", @@ -218,7 +219,7 @@ class ImportPkiModel(ApiModel): passphrase: StrictStr = None class Config: - schema_extra = { + json_schema_extra = { "example": { "key": "id_key", "op": "import_pki", @@ -233,7 +234,7 @@ class ContainerImageModel(ApiModel): name: StrictStr = None class Config: - schema_extra = { + json_schema_extra = { "example": { "key": "id_key", "op": "add | delete | show", @@ -246,7 +247,7 @@ class GenerateModel(ApiModel): path: List[StrictStr] class Config: - schema_extra = { + json_schema_extra = { "example": { "key": "id_key", "op": "generate", @@ -259,7 +260,7 @@ class ShowModel(ApiModel): path: List[StrictStr] class Config: - schema_extra = { + json_schema_extra = { "example": { "key": "id_key", "op": "show", @@ -272,7 +273,7 @@ class RebootModel(ApiModel): path: List[StrictStr] class Config: - schema_extra = { + json_schema_extra = { "example": { "key": "id_key", "op": "reboot", @@ -285,7 +286,7 @@ class ResetModel(ApiModel): path: List[StrictStr] class Config: - schema_extra = { + json_schema_extra = { "example": { "key": "id_key", "op": "reset", @@ -298,7 +299,7 @@ class PoweroffModel(ApiModel): path: List[StrictStr] class Config: - schema_extra = { + json_schema_extra = { "example": { "key": "id_key", "op": "poweroff", From fc9885f859617bab36c971f4eaa56240741f52c4 Mon Sep 17 00:00:00 2001 From: John Estabrook Date: Tue, 24 Sep 2024 22:48:25 -0500 Subject: [PATCH 2/9] http-api: T6736: separate REST API and GraphQL API activation The GraphQL API was implemented as an addition to the existing REST API. As there is no necessary dependency, separate the initialization of the respective endpoints. Factor out the REST Pydantic models and FastAPI routes for symmetry and clarity. --- src/services/api/__init__.py | 0 src/services/api/graphql/bindings.py | 21 +- .../graphql/graphql/auth_token_mutation.py | 8 +- src/services/api/graphql/graphql/mutations.py | 13 +- src/services/api/graphql/graphql/queries.py | 13 +- src/services/api/graphql/libs/__init__.py | 0 src/services/api/graphql/libs/key_auth.py | 22 +- src/services/api/graphql/libs/token_auth.py | 41 +- src/services/api/graphql/routers.py | 54 + src/services/api/graphql/session/session.py | 6 +- src/services/api/graphql/state.py | 4 - src/services/api/rest/__init__.py | 0 src/services/api/rest/models.py | 276 +++++ src/services/api/rest/routers.py | 685 ++++++++++++ src/services/api/session.py | 40 + src/services/vyos-http-api-server | 972 ++---------------- 16 files changed, 1206 insertions(+), 949 deletions(-) create mode 100644 src/services/api/__init__.py create mode 100644 src/services/api/graphql/libs/__init__.py create mode 100644 src/services/api/graphql/routers.py delete mode 100644 src/services/api/graphql/state.py create mode 100644 src/services/api/rest/__init__.py create mode 100644 src/services/api/rest/models.py create mode 100644 src/services/api/rest/routers.py create mode 100644 src/services/api/session.py diff --git a/src/services/api/__init__.py b/src/services/api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/services/api/graphql/bindings.py b/src/services/api/graphql/bindings.py index ef49664662..93dd0fbfb5 100644 --- a/src/services/api/graphql/bindings.py +++ b/src/services/api/graphql/bindings.py @@ -1,4 +1,4 @@ -# Copyright 2021 VyOS maintainers and contributors +# Copyright 2021-2024 VyOS maintainers and contributors # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -13,24 +13,35 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . + import vyos.defaults + +from ariadne import make_executable_schema +from ariadne import load_schema_from_path +from ariadne import snake_case_fallback_resolvers + from . graphql.queries import query from . graphql.mutations import mutation from . graphql.directives import directives_dict from . graphql.errors import op_mode_error from . graphql.auth_token_mutation import auth_token_mutation from . libs.token_auth import init_secret -from . import state -from ariadne import make_executable_schema, load_schema_from_path, snake_case_fallback_resolvers + +from .. session import SessionState + def generate_schema(): + state = SessionState() api_schema_dir = vyos.defaults.directories['api_schema'] - if state.settings['app'].state.vyos_auth_type == 'token': + if state.auth_type == 'token': init_secret() type_defs = load_schema_from_path(api_schema_dir) - schema = make_executable_schema(type_defs, query, op_mode_error, mutation, auth_token_mutation, snake_case_fallback_resolvers, directives=directives_dict) + schema = make_executable_schema(type_defs, query, op_mode_error, + mutation, auth_token_mutation, + snake_case_fallback_resolvers, + directives=directives_dict) return schema diff --git a/src/services/api/graphql/graphql/auth_token_mutation.py b/src/services/api/graphql/graphql/auth_token_mutation.py index a53fa4d60a..1649602174 100644 --- a/src/services/api/graphql/graphql/auth_token_mutation.py +++ b/src/services/api/graphql/graphql/auth_token_mutation.py @@ -21,7 +21,7 @@ from .. libs.token_auth import generate_token from .. session.session import get_user_info -from .. import state +from ... session import SessionState auth_token_mutation = ObjectType("Mutation") @@ -31,8 +31,10 @@ def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict): user = data['username'] passwd = data['password'] - secret = state.settings['secret'] - exp_interval = int(state.settings['app'].state.vyos_token_exp) + state = SessionState() + + secret = getattr(state, 'secret', '') + exp_interval = int(state.token_exp) expiration = (datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta(seconds=exp_interval)) diff --git a/src/services/api/graphql/graphql/mutations.py b/src/services/api/graphql/graphql/mutations.py index d115a8e94b..62031ada3c 100644 --- a/src/services/api/graphql/graphql/mutations.py +++ b/src/services/api/graphql/graphql/mutations.py @@ -21,10 +21,10 @@ from typing import Any, Dict, Optional # pylint: disable=W0611 from graphql import GraphQLResolveInfo # pylint: disable=W0611 -from .. import state +from ... session import SessionState from .. libs import key_auth -from api.graphql.session.session import Session -from api.graphql.session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code +from .. session.session import Session +from .. session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code from vyos.opmode import Error as OpModeError mutation = ObjectType("Mutation") @@ -45,12 +45,13 @@ def make_mutation_resolver(mutation_name, class_name, session_func): func_base_name = convert_camel_case_to_snake(class_name) resolver_name = f'resolve_{func_base_name}' func_sig = '(obj: Any, info: GraphQLResolveInfo, data: Optional[Dict]=None)' + state = SessionState() @mutation.field(mutation_name) @with_signature(func_sig, func_name=resolver_name) async def func_impl(*args, **kwargs): try: - auth_type = state.settings['app'].state.vyos_auth_type + auth_type = state.auth_type if auth_type == 'key': data = kwargs['data'] @@ -86,11 +87,11 @@ async def func_impl(*args, **kwargs): } else: # AtrributeError will have already been raised if no - # vyos_auth_type; validation and defaultValue ensure it is + # auth_type; validation and defaultValue ensure it is # one of the previous cases, so this is never reached. pass - session = state.settings['app'].state.vyos_session + session = state.session # one may override the session functions with a local subclass try: diff --git a/src/services/api/graphql/graphql/queries.py b/src/services/api/graphql/graphql/queries.py index 717098259a..1e9036574b 100644 --- a/src/services/api/graphql/graphql/queries.py +++ b/src/services/api/graphql/graphql/queries.py @@ -21,10 +21,10 @@ from typing import Any, Dict, Optional # pylint: disable=W0611 from graphql import GraphQLResolveInfo # pylint: disable=W0611 -from .. import state +from ... session import SessionState from .. libs import key_auth -from api.graphql.session.session import Session -from api.graphql.session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code +from .. session.session import Session +from .. session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code from vyos.opmode import Error as OpModeError query = ObjectType("Query") @@ -45,12 +45,13 @@ def make_query_resolver(query_name, class_name, session_func): func_base_name = convert_camel_case_to_snake(class_name) resolver_name = f'resolve_{func_base_name}' func_sig = '(obj: Any, info: GraphQLResolveInfo, data: Optional[Dict]=None)' + state = SessionState() @query.field(query_name) @with_signature(func_sig, func_name=resolver_name) async def func_impl(*args, **kwargs): try: - auth_type = state.settings['app'].state.vyos_auth_type + auth_type = state.auth_type if auth_type == 'key': data = kwargs['data'] @@ -86,11 +87,11 @@ async def func_impl(*args, **kwargs): } else: # AtrributeError will have already been raised if no - # vyos_auth_type; validation and defaultValue ensure it is + # auth_type; validation and defaultValue ensure it is # one of the previous cases, so this is never reached. pass - session = state.settings['app'].state.vyos_session + session = state.session # one may override the session functions with a local subclass try: diff --git a/src/services/api/graphql/libs/__init__.py b/src/services/api/graphql/libs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/services/api/graphql/libs/key_auth.py b/src/services/api/graphql/libs/key_auth.py index 2db0f7d48b..9e49a1203e 100644 --- a/src/services/api/graphql/libs/key_auth.py +++ b/src/services/api/graphql/libs/key_auth.py @@ -1,5 +1,20 @@ +# Copyright 2021-2024 VyOS maintainers and contributors +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see . -from .. import state + +from ... session import SessionState def check_auth(key_list, key): if not key_list: @@ -11,8 +26,9 @@ def check_auth(key_list, key): return key_id def auth_required(key): + state = SessionState() api_keys = None - api_keys = state.settings['app'].state.vyos_keys + api_keys = state.keys key_id = check_auth(api_keys, key) - state.settings['app'].state.vyos_id = key_id + state.id = key_id return key_id diff --git a/src/services/api/graphql/libs/token_auth.py b/src/services/api/graphql/libs/token_auth.py index 8585485c96..2d772e0357 100644 --- a/src/services/api/graphql/libs/token_auth.py +++ b/src/services/api/graphql/libs/token_auth.py @@ -1,33 +1,52 @@ +# Copyright 2021-2024 VyOS maintainers and contributors +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see . + + import jwt import uuid import pam from secrets import token_hex -from .. import state +from ... session import SessionState + def _check_passwd_pam(username: str, passwd: str) -> bool: if pam.authenticate(username, passwd): return True return False + def init_secret(): - length = int(state.settings['app'].state.vyos_secret_len) + state = SessionState() + length = int(state.secret_len) secret = token_hex(length) - state.settings['secret'] = secret + state.secret = secret + def generate_token(user: str, passwd: str, secret: str, exp: int) -> dict: if user is None or passwd is None: return {} + state = SessionState() if _check_passwd_pam(user, passwd): - app = state.settings['app'] try: - users = app.state.vyos_token_users + users = state.token_users except AttributeError: - app.state.vyos_token_users = {} - users = app.state.vyos_token_users + users = state.token_users = {} user_id = uuid.uuid1().hex payload_data = {'iss': user, 'sub': user_id, 'exp': exp} - secret = state.settings.get('secret') + secret = getattr(state, 'secret', None) if secret is None: return {"errors": ['missing secret']} token = jwt.encode(payload=payload_data, key=secret, algorithm="HS256") @@ -37,10 +56,12 @@ def generate_token(user: str, passwd: str, secret: str, exp: int) -> dict: else: return {"errors": ['failed pam authentication']} + def get_user_context(request): context = {} context['request'] = request context['user'] = None + state = SessionState() if 'Authorization' in request.headers: auth = request.headers['Authorization'] scheme, token = auth.split() @@ -48,7 +69,7 @@ def get_user_context(request): return context try: - secret = state.settings.get('secret') + secret = getattr(state, 'secret', None) payload = jwt.decode(token, secret, algorithms=["HS256"]) user_id: str = payload.get('sub') if user_id is None: @@ -59,7 +80,7 @@ def get_user_context(request): except jwt.PyJWTError: return context try: - users = state.settings['app'].state.vyos_token_users + users = state.token_users except AttributeError: return context diff --git a/src/services/api/graphql/routers.py b/src/services/api/graphql/routers.py new file mode 100644 index 0000000000..d04375a49e --- /dev/null +++ b/src/services/api/graphql/routers.py @@ -0,0 +1,54 @@ +# Copyright 2024 VyOS maintainers and contributors +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see . + +# pylint: disable=import-outside-toplevel + + +import typing + +from ariadne.asgi import GraphQL +from starlette.middleware.cors import CORSMiddleware + + +if typing.TYPE_CHECKING: + from fastapi import FastAPI + + +def graphql_init(app: "FastAPI"): + from .. session import SessionState + from .libs.token_auth import get_user_context + + state = SessionState() + + # import after initializaion of state + from .bindings import generate_schema + schema = generate_schema() + + in_spec = state.introspection + + if state.origins: + origins = state.origins + app.add_route('/graphql', CORSMiddleware(GraphQL(schema, + context_value=get_user_context, + debug=True, + introspection=in_spec), + allow_origins=origins, + allow_methods=("GET", "POST", "OPTIONS"), + allow_headers=("Authorization",))) + else: + app.add_route('/graphql', GraphQL(schema, + context_value=get_user_context, + debug=True, + introspection=in_spec)) diff --git a/src/services/api/graphql/session/session.py b/src/services/api/graphql/session/session.py index 6ae44b9bfc..6e2875f3c2 100644 --- a/src/services/api/graphql/session/session.py +++ b/src/services/api/graphql/session/session.py @@ -138,7 +138,6 @@ def delete_system_image(self): return res def show_user_info(self): - session = self._session data = self._data user_info = {} @@ -151,10 +150,9 @@ def show_user_info(self): return user_info def system_status(self): - import api.graphql.session.composite.system_status as system_status + from api.graphql.session.composite import system_status session = self._session - data = self._data status = {} status['host_name'] = session.show(['host', 'name']).strip() @@ -165,7 +163,6 @@ def system_status(self): return status def gen_op_query(self): - session = self._session data = self._data name = self._name op_mode_list = self._op_mode_list @@ -189,7 +186,6 @@ def gen_op_query(self): return res def gen_op_mutation(self): - session = self._session data = self._data name = self._name op_mode_list = self._op_mode_list diff --git a/src/services/api/graphql/state.py b/src/services/api/graphql/state.py deleted file mode 100644 index 63db9f4ef5..0000000000 --- a/src/services/api/graphql/state.py +++ /dev/null @@ -1,4 +0,0 @@ - -def init(): - global settings - settings = {} diff --git a/src/services/api/rest/__init__.py b/src/services/api/rest/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/services/api/rest/models.py b/src/services/api/rest/models.py new file mode 100644 index 0000000000..034e3fcdb3 --- /dev/null +++ b/src/services/api/rest/models.py @@ -0,0 +1,276 @@ +# Copyright 2024 VyOS maintainers and contributors +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see . + + +# pylint: disable=too-few-public-methods + +import json +from enum import Enum +from typing import List +from typing import Union +from typing import Dict +from typing import Self + +from pydantic import BaseModel +from pydantic import StrictStr +from pydantic import field_validator +from pydantic import model_validator +from fastapi.responses import HTMLResponse + + +def error(code, msg): + resp = {"success": False, "error": msg, "data": None} + resp = json.dumps(resp) + return HTMLResponse(resp, status_code=code) + +def success(data): + resp = {"success": True, "data": data, "error": None} + resp = json.dumps(resp) + return HTMLResponse(resp) + +# Pydantic models for validation +# Pydantic will cast when possible, so use StrictStr validators added as +# needed for additional constraints +# json_schema_extra adds anotations to OpenAPI to add examples + +class ApiModel(BaseModel): + key: StrictStr + +class BasePathModel(BaseModel): + op: StrictStr + path: List[StrictStr] + + @field_validator("path") + @classmethod + def check_non_empty(cls, path: str) -> str: + if not len(path) > 0: + raise ValueError('path must be non-empty') + return path + +class BaseConfigureModel(BasePathModel): + value: StrictStr = None + +class ConfigureModel(ApiModel, BaseConfigureModel): + class Config: + json_schema_extra = { + "example": { + "key": "id_key", + "op": "set | delete | comment", + "path": ["config", "mode", "path"], + } + } + +class ConfigureListModel(ApiModel): + commands: List[BaseConfigureModel] + + class Config: + json_schema_extra = { + "example": { + "key": "id_key", + "commands": "list of commands", + } + } + +class BaseConfigSectionModel(BasePathModel): + section: Dict + +class ConfigSectionModel(ApiModel, BaseConfigSectionModel): + pass + +class ConfigSectionListModel(ApiModel): + commands: List[BaseConfigSectionModel] + +class BaseConfigSectionTreeModel(BaseModel): + op: StrictStr + mask: Dict + config: Dict + +class ConfigSectionTreeModel(ApiModel, BaseConfigSectionTreeModel): + pass + +class RetrieveModel(ApiModel): + op: StrictStr + path: List[StrictStr] + configFormat: StrictStr = None + + class Config: + json_schema_extra = { + "example": { + "key": "id_key", + "op": "returnValue | returnValues | exists | showConfig", + "path": ["config", "mode", "path"], + "configFormat": "json (default) | json_ast | raw", + + } + } + +class ConfigFileModel(ApiModel): + op: StrictStr + file: StrictStr = None + + class Config: + json_schema_extra = { + "example": { + "key": "id_key", + "op": "save | load", + "file": "filename", + } + } + + +class ImageOp(str, Enum): + add = "add" + delete = "delete" + show = "show" + set_default = "set_default" + + +class ImageModel(ApiModel): + op: ImageOp + url: StrictStr = None + name: StrictStr = None + + @model_validator(mode='after') + def check_data(self) -> Self: + if self.op == 'add': + if not self.url: + raise ValueError("Missing required field \"url\"") + elif self.op in ['delete', 'set_default']: + if not self.name: + raise ValueError("Missing required field \"name\"") + + return self + + class Config: + json_schema_extra = { + "example": { + "key": "id_key", + "op": "add | delete | show | set_default", + "url": "imagelocation", + "name": "imagename", + } + } + +class ImportPkiModel(ApiModel): + op: StrictStr + path: List[StrictStr] + passphrase: StrictStr = None + + class Config: + json_schema_extra = { + "example": { + "key": "id_key", + "op": "import_pki", + "path": ["op", "mode", "path"], + "passphrase": "passphrase", + } + } + + +class ContainerImageModel(ApiModel): + op: StrictStr + name: StrictStr = None + + class Config: + json_schema_extra = { + "example": { + "key": "id_key", + "op": "add | delete | show", + "name": "imagename", + } + } + +class GenerateModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + "example": { + "key": "id_key", + "op": "generate", + "path": ["op", "mode", "path"], + } + } + +class ShowModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + "example": { + "key": "id_key", + "op": "show", + "path": ["op", "mode", "path"], + } + } + +class RebootModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + "example": { + "key": "id_key", + "op": "reboot", + "path": ["op", "mode", "path"], + } + } + +class ResetModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + "example": { + "key": "id_key", + "op": "reset", + "path": ["op", "mode", "path"], + } + } + +class PoweroffModel(ApiModel): + op: StrictStr + path: List[StrictStr] + + class Config: + json_schema_extra = { + "example": { + "key": "id_key", + "op": "poweroff", + "path": ["op", "mode", "path"], + } + } + + +class Success(BaseModel): + success: bool + data: Union[str, bool, Dict] + error: str + +class Error(BaseModel): + success: bool = False + data: Union[str, bool, Dict] + error: str + +responses = { + 200: {'model': Success}, + 400: {'model': Error}, + 422: {'model': Error, 'description': 'Validation Error'}, + 500: {'model': Error} +} diff --git a/src/services/api/rest/routers.py b/src/services/api/rest/routers.py new file mode 100644 index 0000000000..1568f7d0c3 --- /dev/null +++ b/src/services/api/rest/routers.py @@ -0,0 +1,685 @@ +# Copyright 2024 VyOS maintainers and contributors +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see . + + +# pylint: disable=line-too-long,raise-missing-from,invalid-name +# pylint: disable=wildcard-import,unused-wildcard-import +# pylint: disable=broad-exception-caught + +import copy +import logging +import traceback +from threading import Lock +from typing import Callable +from typing import TYPE_CHECKING + +from fastapi import Depends +from fastapi import Request +from fastapi import Response +from fastapi import HTTPException +from fastapi import APIRouter +from fastapi import BackgroundTasks +from fastapi.routing import APIRoute +from starlette.datastructures import FormData +from starlette.formparsers import FormParser +from starlette.formparsers import MultiPartParser +from starlette.formparsers import MultiPartException +from multipart.multipart import parse_options_header + +from vyos.config import Config +from vyos.configtree import ConfigTree +from vyos.configdiff import get_config_diff +from vyos.configsession import ConfigSessionError + +from .. session import SessionState +from . models import * + +if TYPE_CHECKING: + from fastapi import FastAPI + + +LOG = logging.getLogger('http_api.routers') + +lock = Lock() + + +def check_auth(key_list, key): + key_id = None + for k in key_list: + if k['key'] == key: + key_id = k['id'] + return key_id + + +def auth_required(data: ApiModel): + session = SessionState() + key = data.key + api_keys = session.keys + key_id = check_auth(api_keys, key) + if not key_id: + raise HTTPException(status_code=401, detail="Valid API key is required") + session.id = key_id + + +# override Request and APIRoute classes in order to convert form request to json; +# do all explicit validation here, for backwards compatability of error messages; +# the explicit validation may be dropped, if desired, in favor of native +# validation by FastAPI/Pydantic, as is used for application/json requests +class MultipartRequest(Request): + """Override Request class to convert form request to json""" + # pylint: disable=attribute-defined-outside-init + # pylint: disable=too-many-branches,too-many-statements + + _form_err = () + @property + def form_err(self): + return self._form_err + + @form_err.setter + def form_err(self, val): + if not self._form_err: + self._form_err = val + + @property + def orig_headers(self): + self._orig_headers = super().headers + return self._orig_headers + + @property + def headers(self): + self._headers = super().headers.mutablecopy() + self._headers['content-type'] = 'application/json' + return self._headers + + async def _get_form( + self, *, max_files: int | float = 1000, max_fields: int | float = 1000 + ) -> FormData: + if self._form is None: + assert ( + parse_options_header is not None + ), "The `python-multipart` library must be installed to use form parsing." + content_type_header = self.orig_headers.get("Content-Type") + content_type: bytes + content_type, _ = parse_options_header(content_type_header) + if content_type == b"multipart/form-data": + try: + multipart_parser = MultiPartParser( + self.orig_headers, + self.stream(), + max_files=max_files, + max_fields=max_fields, + ) + self._form = await multipart_parser.parse() + except MultiPartException as exc: + if "app" in self.scope: + raise HTTPException(status_code=400, detail=exc.message) + raise exc + elif content_type == b"application/x-www-form-urlencoded": + form_parser = FormParser(self.orig_headers, self.stream()) + self._form = await form_parser.parse() + else: + self._form = FormData() + return self._form + + async def body(self) -> bytes: + if not hasattr(self, "_body"): + forms = {} + merge = {} + body = await super().body() + self._body = body + + form_data = await self.form() + if form_data: + endpoint = self.url.path + LOG.debug("processing form data") + for k, v in form_data.multi_items(): + forms[k] = v + + if 'data' not in forms: + self.form_err = (422, "Non-empty data field is required") + return self._body + try: + tmp = json.loads(forms['data']) + except json.JSONDecodeError as e: + self.form_err = (400, f'Failed to parse JSON: {e}') + return self._body + if isinstance(tmp, list): + merge['commands'] = tmp + else: + merge = tmp + + if 'commands' in merge: + cmds = merge['commands'] + else: + cmds = copy.deepcopy(merge) + cmds = [cmds] + + for c in cmds: + if not isinstance(c, dict): + self.form_err = (400, + f"Malformed command '{c}': any command must be JSON of dict") + return self._body + if 'op' not in c: + self.form_err = (400, + f"Malformed command '{c}': missing 'op' field") + if endpoint not in ('/config-file', '/container-image', + '/image', '/configure-section'): + if 'path' not in c: + self.form_err = (400, + f"Malformed command '{c}': missing 'path' field") + elif not isinstance(c['path'], list): + self.form_err = (400, + f"Malformed command '{c}': 'path' field must be a list") + elif not all(isinstance(el, str) for el in c['path']): + self.form_err = (400, + f"Malformed command '{0}': 'path' field must be a list of strings") + if endpoint in ('/configure'): + if not c['path']: + self.form_err = (400, + f"Malformed command '{c}': 'path' list must be non-empty") + if 'value' in c and not isinstance(c['value'], str): + self.form_err = (400, + f"Malformed command '{c}': 'value' field must be a string") + if endpoint in ('/configure-section'): + if 'section' not in c and 'config' not in c: + self.form_err = (400, + f"Malformed command '{c}': missing 'section' or 'config' field") + + if 'key' not in forms and 'key' not in merge: + self.form_err = (401, "Valid API key is required") + if 'key' in forms and 'key' not in merge: + merge['key'] = forms['key'] + + new_body = json.dumps(merge) + new_body = new_body.encode() + self._body = new_body + + return self._body + + +class MultipartRoute(APIRoute): + """Override APIRoute class to convert form request to json""" + + def get_route_handler(self) -> Callable: + original_route_handler = super().get_route_handler() + + async def custom_route_handler(request: Request) -> Response: + request = MultipartRequest(request.scope, request.receive) + try: + response: Response = await original_route_handler(request) + except HTTPException as e: + return error(e.status_code, e.detail) + except Exception as e: + form_err = request.form_err + if form_err: + return error(*form_err) + raise e + + return response + + return custom_route_handler + + +router = APIRouter(route_class=MultipartRoute, + responses={**responses}, + dependencies=[Depends(auth_required)]) + + +self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background" + +def call_commit(s: SessionState): + try: + s.session.commit() + except ConfigSessionError as e: + s.session.discard() + if s.debug: + LOG.warning(f"ConfigSessionError:\n {traceback.format_exc()}") + else: + LOG.warning(f"ConfigSessionError: {e}") + + +def _configure_op(data: Union[ConfigureModel, ConfigureListModel, + ConfigSectionModel, ConfigSectionListModel, + ConfigSectionTreeModel], + _request: Request, background_tasks: BackgroundTasks): + # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements + # pylint: disable=consider-using-with + + state = SessionState() + session = state.session + env = session.get_session_env() + + # Allow users to pass just one command + if not isinstance(data, (ConfigureListModel, ConfigSectionListModel)): + data = [data] + else: + data = data.commands + + # We don't want multiple people/apps to be able to commit at once, + # or modify the shared session while someone else is doing the same, + # so the lock is really global + lock.acquire() + + config = Config(session_env=env) + + status = 200 + msg = None + error_msg = None + try: + for c in data: + op = c.op + if not isinstance(c, BaseConfigSectionTreeModel): + path = c.path + + if isinstance(c, BaseConfigureModel): + if c.value: + value = c.value + else: + value = "" + # For vyos.configsession calls that have no separate value arguments, + # and for type checking too + cfg_path = " ".join(path + [value]).strip() + + elif isinstance(c, BaseConfigSectionModel): + section = c.section + + elif isinstance(c, BaseConfigSectionTreeModel): + mask = c.mask + config = c.config + + if isinstance(c, BaseConfigureModel): + if op == 'set': + session.set(path, value=value) + elif op == 'delete': + if state.strict and not config.exists(cfg_path): + raise ConfigSessionError(f"Cannot delete [{cfg_path}]: path/value does not exist") + session.delete(path, value=value) + elif op == 'comment': + session.comment(path, value=value) + else: + raise ConfigSessionError(f"'{op}' is not a valid operation") + + elif isinstance(c, BaseConfigSectionModel): + if op == 'set': + session.set_section(path, section) + elif op == 'load': + session.load_section(path, section) + else: + raise ConfigSessionError(f"'{op}' is not a valid operation") + + elif isinstance(c, BaseConfigSectionTreeModel): + if op == 'set': + session.set_section_tree(config) + elif op == 'load': + session.load_section_tree(mask, config) + else: + raise ConfigSessionError(f"'{op}' is not a valid operation") + # end for + config = Config(session_env=env) + d = get_config_diff(config) + + if d.is_node_changed(['service', 'https']): + background_tasks.add_task(call_commit, state) + msg = self_ref_msg + else: + # capture non-fatal warnings + out = session.commit() + msg = out if out else msg + + LOG.info(f"Configuration modified via HTTP API using key '{state.id}'") + except ConfigSessionError as e: + session.discard() + status = 400 + if state.debug: + LOG.critical(f"ConfigSessionError:\n {traceback.format_exc()}") + error_msg = str(e) + except Exception: + session.discard() + LOG.critical(traceback.format_exc()) + status = 500 + + # Don't give the details away to the outer world + error_msg = "An internal error occured. Check the logs for details." + finally: + lock.release() + + if status != 200: + return error(status, error_msg) + + return success(msg) + + +def create_path_import_pki_no_prompt(path): + correct_paths = ['ca', 'certificate', 'key-pair'] + if path[1] not in correct_paths: + return False + path[1] = '--' + path[1].replace('-', '') + path[3] = '--key-filename' + return path[1:] + + +@router.post('/configure') +def configure_op(data: Union[ConfigureModel, + ConfigureListModel], + request: Request, background_tasks: BackgroundTasks): + return _configure_op(data, request, background_tasks) + + +@router.post('/configure-section') +def configure_section_op(data: Union[ConfigSectionModel, + ConfigSectionListModel, + ConfigSectionTreeModel], + request: Request, background_tasks: BackgroundTasks): + return _configure_op(data, request, background_tasks) + + +@router.post("/retrieve") +async def retrieve_op(data: RetrieveModel): + state = SessionState() + session = state.session + env = session.get_session_env() + config = Config(session_env=env) + + op = data.op + path = " ".join(data.path) + + try: + if op == 'returnValue': + res = config.return_value(path) + elif op == 'returnValues': + res = config.return_values(path) + elif op == 'exists': + res = config.exists(path) + elif op == 'showConfig': + config_format = 'json' + if data.configFormat: + config_format = data.configFormat + + res = session.show_config(path=data.path) + if config_format == 'json': + config_tree = ConfigTree(res) + res = json.loads(config_tree.to_json()) + elif config_format == 'json_ast': + config_tree = ConfigTree(res) + res = json.loads(config_tree.to_json_ast()) + elif config_format == 'raw': + pass + else: + return error(400, f"'{config_format}' is not a valid config format") + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, "An internal error occured. Check the logs for details.") + + return success(res) + + +@router.post('/config-file') +def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): + state = SessionState() + session = state.session + env = session.get_session_env() + op = data.op + msg = None + + try: + if op == 'save': + if data.file: + path = data.file + else: + path = '/config/config.boot' + msg = session.save_config(path) + elif op == 'load': + if data.file: + path = data.file + else: + return error(400, "Missing required field \"file\"") + + session.migrate_and_load_config(path) + + config = Config(session_env=env) + d = get_config_diff(config) + + if d.is_node_changed(['service', 'https']): + background_tasks.add_task(call_commit, state) + msg = self_ref_msg + else: + session.commit() + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, "An internal error occured. Check the logs for details.") + + return success(msg) + + +@router.post('/image') +def image_op(data: ImageModel): + state = SessionState() + session = state.session + + op = data.op + + try: + if op == 'add': + res = session.install_image(data.url) + elif op == 'delete': + res = session.remove_image(data.name) + elif op == 'show': + res = session.show(["system", "image"]) + elif op == 'set_default': + res = session.set_default_image(data.name) + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, "An internal error occured. Check the logs for details.") + + return success(res) + + +@router.post('/container-image') +def container_image_op(data: ContainerImageModel): + state = SessionState() + session = state.session + + op = data.op + + try: + if op == 'add': + if data.name: + name = data.name + else: + return error(400, "Missing required field \"name\"") + res = session.add_container_image(name) + elif op == 'delete': + if data.name: + name = data.name + else: + return error(400, "Missing required field \"name\"") + res = session.delete_container_image(name) + elif op == 'show': + res = session.show_container_image() + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, "An internal error occured. Check the logs for details.") + + return success(res) + + +@router.post('/generate') +def generate_op(data: GenerateModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'generate': + res = session.generate(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, "An internal error occured. Check the logs for details.") + + return success(res) + + +@router.post('/show') +def show_op(data: ShowModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'show': + res = session.show(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, "An internal error occured. Check the logs for details.") + + return success(res) + + +@router.post('/reboot') +def reboot_op(data: RebootModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'reboot': + res = session.reboot(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, "An internal error occured. Check the logs for details.") + + return success(res) + + +@router.post('/reset') +def reset_op(data: ResetModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'reset': + res = session.reset(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, "An internal error occured. Check the logs for details.") + + return success(res) + + +@router.post('/import-pki') +def import_pki(data: ImportPkiModel): + # pylint: disable=consider-using-with + + state = SessionState() + session = state.session + + op = data.op + path = data.path + + lock.acquire() + + try: + if op == 'import-pki': + # need to get rid or interactive mode for private key + if len(path) == 5 and path[3] in ['key-file', 'private-key']: + path_no_prompt = create_path_import_pki_no_prompt(path) + if not path_no_prompt: + return error(400, f"Invalid command: {' '.join(path)}") + if data.passphrase: + path_no_prompt += ['--passphrase', data.passphrase] + res = session.import_pki_no_prompt(path_no_prompt) + else: + res = session.import_pki(path) + if not res[0].isdigit(): + return error(400, res) + # commit changes + session.commit() + res = res.split('. ')[0] + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, "An internal error occured. Check the logs for details.") + finally: + lock.release() + + return success(res) + + +@router.post('/poweroff') +def poweroff_op(data: PoweroffModel): + state = SessionState() + session = state.session + + op = data.op + path = data.path + + try: + if op == 'poweroff': + res = session.poweroff(path) + else: + return error(400, f"'{op}' is not a valid operation") + except ConfigSessionError as e: + return error(400, str(e)) + except Exception: + LOG.critical(traceback.format_exc()) + return error(500, "An internal error occured. Check the logs for details.") + + return success(res) + + +def rest_init(app: "FastAPI"): + app.include_router(router) diff --git a/src/services/api/session.py b/src/services/api/session.py new file mode 100644 index 0000000000..dcdc7246c5 --- /dev/null +++ b/src/services/api/session.py @@ -0,0 +1,40 @@ +# Copyright 2024 VyOS maintainers and contributors +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see . + +class SessionState: + # pylint: disable=attribute-defined-outside-init + # pylint: disable=too-many-instance-attributes,too-few-public-methods + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(SessionState, cls).__new__(cls) + cls._instance._initialize() + return cls._instance + + def _initialize(self): + self.session = None + self.keys = [] + self.id = None + self.rest = False + self.debug = False + self.strict = False + self.graphql = False + self.origins = [] + self.introspection = False + self.auth_type = None + self.token_exp = None + self.secret_len = None diff --git a/src/services/vyos-http-api-server b/src/services/vyos-http-api-server index 42c4cf387a..456ea3d17b 100755 --- a/src/services/vyos-http-api-server +++ b/src/services/vyos-http-api-server @@ -17,918 +17,51 @@ import os import sys import grp -import copy import json import logging import signal -import traceback -import threading -from enum import Enum - from time import sleep -from typing import List, Union, Callable, Dict, Self -from fastapi import FastAPI, Depends, Request, Response, HTTPException -from fastapi import BackgroundTasks -from fastapi.responses import HTMLResponse +from fastapi import FastAPI from fastapi.exceptions import RequestValidationError -from fastapi.routing import APIRoute -from pydantic import BaseModel, StrictStr, field_validator, model_validator -from starlette.middleware.cors import CORSMiddleware -from starlette.datastructures import FormData -from starlette.formparsers import FormParser, MultiPartParser -from multipart.multipart import parse_options_header from uvicorn import Config as UvicornConfig from uvicorn import Server as UvicornServer -from ariadne.asgi import GraphQL - -from vyos.config import Config -from vyos.configtree import ConfigTree -from vyos.configdiff import get_config_diff from vyos.configsession import ConfigSession -from vyos.configsession import ConfigSessionError from vyos.defaults import api_config_state -import api.graphql.state +from api.session import SessionState +from api.rest.models import error CFG_GROUP = 'vyattacfg' debug = True -logger = logging.getLogger(__name__) +LOG = logging.getLogger('http_api') logs_handler = logging.StreamHandler() -logger.addHandler(logs_handler) +LOG.addHandler(logs_handler) if debug: - logger.setLevel(logging.DEBUG) + LOG.setLevel(logging.DEBUG) else: - logger.setLevel(logging.INFO) + LOG.setLevel(logging.INFO) -# Giant lock! -lock = threading.Lock() def load_server_config(): with open(api_config_state) as f: config = json.load(f) return config -def check_auth(key_list, key): - key_id = None - for k in key_list: - if k['key'] == key: - key_id = k['id'] - return key_id - -def error(code, msg): - resp = {"success": False, "error": msg, "data": None} - resp = json.dumps(resp) - return HTMLResponse(resp, status_code=code) - -def success(data): - resp = {"success": True, "data": data, "error": None} - resp = json.dumps(resp) - return HTMLResponse(resp) - -# Pydantic models for validation -# Pydantic will cast when possible, so use StrictStr validators added as -# needed for additional constraints -# json_schema_extra adds anotations to OpenAPI to add examples - -class ApiModel(BaseModel): - key: StrictStr - -class BasePathModel(BaseModel): - op: StrictStr - path: List[StrictStr] - - @field_validator("path") - @classmethod - def check_non_empty(cls, path: str) -> str: - if not len(path) > 0: - raise ValueError('path must be non-empty') - return path - -class BaseConfigureModel(BasePathModel): - value: StrictStr = None - -class ConfigureModel(ApiModel, BaseConfigureModel): - class Config: - json_schema_extra = { - "example": { - "key": "id_key", - "op": "set | delete | comment", - "path": ['config', 'mode', 'path'], - } - } - -class ConfigureListModel(ApiModel): - commands: List[BaseConfigureModel] - - class Config: - json_schema_extra = { - "example": { - "key": "id_key", - "commands": "list of commands", - } - } - -class BaseConfigSectionModel(BasePathModel): - section: Dict - -class ConfigSectionModel(ApiModel, BaseConfigSectionModel): - pass - -class ConfigSectionListModel(ApiModel): - commands: List[BaseConfigSectionModel] - -class BaseConfigSectionTreeModel(BaseModel): - op: StrictStr - mask: Dict - config: Dict - -class ConfigSectionTreeModel(ApiModel, BaseConfigSectionTreeModel): - pass - -class RetrieveModel(ApiModel): - op: StrictStr - path: List[StrictStr] - configFormat: StrictStr = None - - class Config: - json_schema_extra = { - "example": { - "key": "id_key", - "op": "returnValue | returnValues | exists | showConfig", - "path": ['config', 'mode', 'path'], - "configFormat": "json (default) | json_ast | raw", - - } - } - -class ConfigFileModel(ApiModel): - op: StrictStr - file: StrictStr = None - - class Config: - json_schema_extra = { - "example": { - "key": "id_key", - "op": "save | load", - "file": "filename", - } - } - - -class ImageOp(str, Enum): - add = "add" - delete = "delete" - show = "show" - set_default = "set_default" - - -class ImageModel(ApiModel): - op: ImageOp - url: StrictStr = None - name: StrictStr = None - - @model_validator(mode='after') - def check_data(self) -> Self: - if self.op == 'add': - if not self.url: - raise ValueError("Missing required field \"url\"") - elif self.op in ['delete', 'set_default']: - if not self.name: - raise ValueError("Missing required field \"name\"") - - return self - - class Config: - json_schema_extra = { - "example": { - "key": "id_key", - "op": "add | delete | show | set_default", - "url": "imagelocation", - "name": "imagename", - } - } - -class ImportPkiModel(ApiModel): - op: StrictStr - path: List[StrictStr] - passphrase: StrictStr = None - - class Config: - json_schema_extra = { - "example": { - "key": "id_key", - "op": "import_pki", - "path": ["op", "mode", "path"], - "passphrase": "passphrase", - } - } - - -class ContainerImageModel(ApiModel): - op: StrictStr - name: StrictStr = None - - class Config: - json_schema_extra = { - "example": { - "key": "id_key", - "op": "add | delete | show", - "name": "imagename", - } - } - -class GenerateModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - json_schema_extra = { - "example": { - "key": "id_key", - "op": "generate", - "path": ["op", "mode", "path"], - } - } - -class ShowModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - json_schema_extra = { - "example": { - "key": "id_key", - "op": "show", - "path": ["op", "mode", "path"], - } - } - -class RebootModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - json_schema_extra = { - "example": { - "key": "id_key", - "op": "reboot", - "path": ["op", "mode", "path"], - } - } - -class ResetModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - json_schema_extra = { - "example": { - "key": "id_key", - "op": "reset", - "path": ["op", "mode", "path"], - } - } - -class PoweroffModel(ApiModel): - op: StrictStr - path: List[StrictStr] - - class Config: - json_schema_extra = { - "example": { - "key": "id_key", - "op": "poweroff", - "path": ["op", "mode", "path"], - } - } - - -class Success(BaseModel): - success: bool - data: Union[str, bool, Dict] - error: str - -class Error(BaseModel): - success: bool = False - data: Union[str, bool, Dict] - error: str - -responses = { - 200: {'model': Success}, - 400: {'model': Error}, - 422: {'model': Error, 'description': 'Validation Error'}, - 500: {'model': Error} -} - -def auth_required(data: ApiModel): - key = data.key - api_keys = app.state.vyos_keys - key_id = check_auth(api_keys, key) - if not key_id: - raise HTTPException(status_code=401, detail="Valid API key is required") - app.state.vyos_id = key_id - -# override Request and APIRoute classes in order to convert form request to json; -# do all explicit validation here, for backwards compatability of error messages; -# the explicit validation may be dropped, if desired, in favor of native -# validation by FastAPI/Pydantic, as is used for application/json requests -class MultipartRequest(Request): - _form_err = () - @property - def form_err(self): - return self._form_err - - @form_err.setter - def form_err(self, val): - if not self._form_err: - self._form_err = val - - @property - def orig_headers(self): - self._orig_headers = super().headers - return self._orig_headers - - @property - def headers(self): - self._headers = super().headers.mutablecopy() - self._headers['content-type'] = 'application/json' - return self._headers - - async def form(self) -> FormData: - if self._form is None: - assert ( - parse_options_header is not None - ), "The `python-multipart` library must be installed to use form parsing." - content_type_header = self.orig_headers.get("Content-Type") - content_type, options = parse_options_header(content_type_header) - if content_type == b"multipart/form-data": - multipart_parser = MultiPartParser(self.orig_headers, self.stream()) - self._form = await multipart_parser.parse() - elif content_type == b"application/x-www-form-urlencoded": - form_parser = FormParser(self.orig_headers, self.stream()) - self._form = await form_parser.parse() - else: - self._form = FormData() - return self._form - - async def body(self) -> bytes: - if not hasattr(self, "_body"): - forms = {} - merge = {} - body = await super().body() - self._body = body - - form_data = await self.form() - if form_data: - endpoint = self.url.path - logger.debug("processing form data") - for k, v in form_data.multi_items(): - forms[k] = v - - if 'data' not in forms: - self.form_err = (422, "Non-empty data field is required") - return self._body - else: - try: - tmp = json.loads(forms['data']) - except json.JSONDecodeError as e: - self.form_err = (400, f'Failed to parse JSON: {e}') - return self._body - if isinstance(tmp, list): - merge['commands'] = tmp - else: - merge = tmp - - if 'commands' in merge: - cmds = merge['commands'] - else: - cmds = copy.deepcopy(merge) - cmds = [cmds] - - for c in cmds: - if not isinstance(c, dict): - self.form_err = (400, - f"Malformed command '{c}': any command must be JSON of dict") - return self._body - if 'op' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'op' field") - if endpoint not in ('/config-file', '/container-image', - '/image', '/configure-section'): - if 'path' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'path' field") - elif not isinstance(c['path'], list): - self.form_err = (400, - f"Malformed command '{c}': 'path' field must be a list") - elif not all(isinstance(el, str) for el in c['path']): - self.form_err = (400, - f"Malformed command '{0}': 'path' field must be a list of strings") - if endpoint in ('/configure'): - if not c['path']: - self.form_err = (400, - f"Malformed command '{c}': 'path' list must be non-empty") - if 'value' in c and not isinstance(c['value'], str): - self.form_err = (400, - f"Malformed command '{c}': 'value' field must be a string") - if endpoint in ('/configure-section'): - if 'section' not in c and 'config' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'section' or 'config' field") - - if 'key' not in forms and 'key' not in merge: - self.form_err = (401, "Valid API key is required") - if 'key' in forms and 'key' not in merge: - merge['key'] = forms['key'] - - new_body = json.dumps(merge) - new_body = new_body.encode() - self._body = new_body - - return self._body - -class MultipartRoute(APIRoute): - def get_route_handler(self) -> Callable: - original_route_handler = super().get_route_handler() - - async def custom_route_handler(request: Request) -> Response: - request = MultipartRequest(request.scope, request.receive) - try: - response: Response = await original_route_handler(request) - except HTTPException as e: - return error(e.status_code, e.detail) - except Exception as e: - form_err = request.form_err - if form_err: - return error(*form_err) - raise e - - return response - - return custom_route_handler app = FastAPI(debug=True, title="VyOS API", - version="0.1.0", - responses={**responses}, - dependencies=[Depends(auth_required)]) - -app.router.route_class = MultipartRoute + version="0.1.0") @app.exception_handler(RequestValidationError) -async def validation_exception_handler(request, exc): +async def validation_exception_handler(_request, exc): return error(400, str(exc.errors()[0])) -self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background" - -def call_commit(s: ConfigSession): - try: - s.commit() - except ConfigSessionError as e: - s.discard() - if app.state.vyos_debug: - logger.warning(f"ConfigSessionError:\n {traceback.format_exc()}") - else: - logger.warning(f"ConfigSessionError: {e}") - -def _configure_op(data: Union[ConfigureModel, ConfigureListModel, - ConfigSectionModel, ConfigSectionListModel, - ConfigSectionTreeModel], - request: Request, background_tasks: BackgroundTasks): - session = app.state.vyos_session - env = session.get_session_env() - - endpoint = request.url.path - - # Allow users to pass just one command - if not isinstance(data, (ConfigureListModel, ConfigSectionListModel)): - data = [data] - else: - data = data.commands - - # We don't want multiple people/apps to be able to commit at once, - # or modify the shared session while someone else is doing the same, - # so the lock is really global - lock.acquire() - - config = Config(session_env=env) - - status = 200 - msg = None - error_msg = None - try: - for c in data: - op = c.op - if not isinstance(c, BaseConfigSectionTreeModel): - path = c.path - - if isinstance(c, BaseConfigureModel): - if c.value: - value = c.value - else: - value = "" - # For vyos.configsession calls that have no separate value arguments, - # and for type checking too - cfg_path = " ".join(path + [value]).strip() - - elif isinstance(c, BaseConfigSectionModel): - section = c.section - - elif isinstance(c, BaseConfigSectionTreeModel): - mask = c.mask - config = c.config - - if isinstance(c, BaseConfigureModel): - if op == 'set': - session.set(path, value=value) - elif op == 'delete': - if app.state.vyos_strict and not config.exists(cfg_path): - raise ConfigSessionError(f"Cannot delete [{cfg_path}]: path/value does not exist") - session.delete(path, value=value) - elif op == 'comment': - session.comment(path, value=value) - else: - raise ConfigSessionError(f"'{op}' is not a valid operation") - - elif isinstance(c, BaseConfigSectionModel): - if op == 'set': - session.set_section(path, section) - elif op == 'load': - session.load_section(path, section) - else: - raise ConfigSessionError(f"'{op}' is not a valid operation") - - elif isinstance(c, BaseConfigSectionTreeModel): - if op == 'set': - session.set_section_tree(config) - elif op == 'load': - session.load_section_tree(mask, config) - else: - raise ConfigSessionError(f"'{op}' is not a valid operation") - # end for - config = Config(session_env=env) - d = get_config_diff(config) - - if d.is_node_changed(['service', 'https']): - background_tasks.add_task(call_commit, session) - msg = self_ref_msg - else: - # capture non-fatal warnings - out = session.commit() - msg = out if out else msg - - logger.info(f"Configuration modified via HTTP API using key '{app.state.vyos_id}'") - except ConfigSessionError as e: - session.discard() - status = 400 - if app.state.vyos_debug: - logger.critical(f"ConfigSessionError:\n {traceback.format_exc()}") - error_msg = str(e) - except Exception as e: - session.discard() - logger.critical(traceback.format_exc()) - status = 500 - - # Don't give the details away to the outer world - error_msg = "An internal error occured. Check the logs for details." - finally: - lock.release() - - if status != 200: - return error(status, error_msg) - - return success(msg) - -def create_path_import_pki_no_prompt(path): - correct_paths = ['ca', 'certificate', 'key-pair'] - if path[1] not in correct_paths: - return False - path[1] = '--' + path[1].replace('-', '') - path[3] = '--key-filename' - return path[1:] - -@app.post('/configure') -def configure_op(data: Union[ConfigureModel, - ConfigureListModel], - request: Request, background_tasks: BackgroundTasks): - return _configure_op(data, request, background_tasks) - -@app.post('/configure-section') -def configure_section_op(data: Union[ConfigSectionModel, - ConfigSectionListModel, - ConfigSectionTreeModel], - request: Request, background_tasks: BackgroundTasks): - return _configure_op(data, request, background_tasks) - -@app.post("/retrieve") -async def retrieve_op(data: RetrieveModel): - session = app.state.vyos_session - env = session.get_session_env() - config = Config(session_env=env) - - op = data.op - path = " ".join(data.path) - try: - if op == 'returnValue': - res = config.return_value(path) - elif op == 'returnValues': - res = config.return_values(path) - elif op == 'exists': - res = config.exists(path) - elif op == 'showConfig': - config_format = 'json' - if data.configFormat: - config_format = data.configFormat - - res = session.show_config(path=data.path) - if config_format == 'json': - config_tree = ConfigTree(res) - res = json.loads(config_tree.to_json()) - elif config_format == 'json_ast': - config_tree = ConfigTree(res) - res = json.loads(config_tree.to_json_ast()) - elif config_format == 'raw': - pass - else: - return error(400, f"'{config_format}' is not a valid config format") - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/config-file') -def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): - session = app.state.vyos_session - env = session.get_session_env() - op = data.op - msg = None - - try: - if op == 'save': - if data.file: - path = data.file - else: - path = '/config/config.boot' - msg = session.save_config(path) - elif op == 'load': - if data.file: - path = data.file - else: - return error(400, "Missing required field \"file\"") - - session.migrate_and_load_config(path) - - config = Config(session_env=env) - d = get_config_diff(config) - - if d.is_node_changed(['service', 'https']): - background_tasks.add_task(call_commit, session) - msg = self_ref_msg - else: - session.commit() - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(msg) - -@app.post('/image') -def image_op(data: ImageModel): - session = app.state.vyos_session - - op = data.op - - try: - if op == 'add': - res = session.install_image(data.url) - elif op == 'delete': - res = session.remove_image(data.name) - elif op == 'show': - res = session.show(["system", "image"]) - elif op == 'set_default': - res = session.set_default_image(data.name) - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/container-image') -def container_image_op(data: ContainerImageModel): - session = app.state.vyos_session - - op = data.op - - try: - if op == 'add': - if data.name: - name = data.name - else: - return error(400, "Missing required field \"name\"") - res = session.add_container_image(name) - elif op == 'delete': - if data.name: - name = data.name - else: - return error(400, "Missing required field \"name\"") - res = session.delete_container_image(name) - elif op == 'show': - res = session.show_container_image() - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/generate') -def generate_op(data: GenerateModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'generate': - res = session.generate(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/show') -def show_op(data: ShowModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'show': - res = session.show(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/reboot') -def reboot_op(data: RebootModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'reboot': - res = session.reboot(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/reset') -def reset_op(data: ResetModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'reset': - res = session.reset(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - -@app.post('/import-pki') -def import_pki(data: ImportPkiModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - lock.acquire() - - try: - if op == 'import-pki': - # need to get rid or interactive mode for private key - if len(path) == 5 and path[3] in ['key-file', 'private-key']: - path_no_prompt = create_path_import_pki_no_prompt(path) - if not path_no_prompt: - return error(400, f"Invalid command: {' '.join(path)}") - if data.passphrase: - path_no_prompt += ['--passphrase', data.passphrase] - res = session.import_pki_no_prompt(path_no_prompt) - else: - res = session.import_pki(path) - if not res[0].isdigit(): - return error(400, res) - # commit changes - session.commit() - res = res.split('. ')[0] - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - finally: - lock.release() - - return success(res) - -@app.post('/poweroff') -def poweroff_op(data: PoweroffModel): - session = app.state.vyos_session - - op = data.op - path = data.path - - try: - if op == 'poweroff': - res = session.poweroff(path) - else: - return error(400, f"'{op}' is not a valid operation") - except ConfigSessionError as e: - return error(400, str(e)) - except Exception as e: - logger.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") - - return success(res) - - -### -# GraphQL integration -### - -def graphql_init(app: FastAPI = app): - from api.graphql.libs.token_auth import get_user_context - api.graphql.state.init() - api.graphql.state.settings['app'] = app - - # import after initializaion of state - from api.graphql.bindings import generate_schema - schema = generate_schema() - - in_spec = app.state.vyos_introspection - - if app.state.vyos_origins: - origins = app.state.vyos_origins - app.add_route('/graphql', CORSMiddleware(GraphQL(schema, - context_value=get_user_context, - debug=True, - introspection=in_spec), - allow_origins=origins, - allow_methods=("GET", "POST", "OPTIONS"), - allow_headers=("Authorization",))) - else: - app.add_route('/graphql', GraphQL(schema, - context_value=get_user_context, - debug=True, - introspection=in_spec)) ### # Modify uvicorn to allow reloading server within the configsession ### @@ -936,30 +69,41 @@ def graphql_init(app: FastAPI = app): server = None shutdown = False + class ApiServerConfig(UvicornConfig): pass + class ApiServer(UvicornServer): def install_signal_handlers(self): pass + def reload_handler(signum, frame): + # pylint: disable=global-statement + global server - logger.debug('Reload signal received...') + LOG.debug('Reload signal received...') if server is not None: server.handle_exit(signum, frame) server = None - logger.info('Server stopping for reload...') + LOG.info('Server stopping for reload...') else: - logger.warning('Reload called for non-running server...') + LOG.warning('Reload called for non-running server...') + def shutdown_handler(signum, frame): + # pylint: disable=global-statement + global shutdown - logger.debug('Shutdown signal received...') + LOG.debug('Shutdown signal received...') server.handle_exit(signum, frame) - logger.info('Server shutdown...') + LOG.info('Server shutdown...') shutdown = True +# end modify uvicorn + + def flatten_keys(d: dict) -> list[dict]: keys_list = [] for el in list(d['keys'].get('id', {})): @@ -968,49 +112,62 @@ def flatten_keys(d: dict) -> list[dict]: keys_list.append({'id': el, 'key': key}) return keys_list -def initialization(session: ConfigSession, app: FastAPI = app): + +def initialization(session: SessionState, app: FastAPI = app): + # pylint: disable=global-statement,broad-exception-caught,import-outside-toplevel + global server try: server_config = load_server_config() except Exception as e: - logger.critical(f'Failed to load the HTTP API server config: {e}') + LOG.critical(f'Failed to load the HTTP API server config: {e}') sys.exit(1) - app.state.vyos_session = session - app.state.vyos_keys = [] - if 'keys' in server_config: - app.state.vyos_keys = flatten_keys(server_config) + session.keys = flatten_keys(server_config) + + rest_config = server_config.get('rest', {}) + session.debug = bool('debug' in rest_config) + session.strict = bool('strict' in rest_config) + + graphql_config = server_config.get('graphql', {}) + session.origins = graphql_config.get('cors', {}).get('allow_origin', []) + + if 'rest' in server_config: + session.rest = True + else: + session.rest = False - app.state.vyos_debug = bool('debug' in server_config) - app.state.vyos_strict = bool('strict' in server_config) - app.state.vyos_origins = server_config.get('cors', {}).get('allow_origin', []) if 'graphql' in server_config: - app.state.vyos_graphql = True + session.graphql = True if isinstance(server_config['graphql'], dict): if 'introspection' in server_config['graphql']: - app.state.vyos_introspection = True + session.introspection = True else: - app.state.vyos_introspection = False + session.introspection = False # default values if not set explicitly - app.state.vyos_auth_type = server_config['graphql']['authentication']['type'] - app.state.vyos_token_exp = server_config['graphql']['authentication']['expiration'] - app.state.vyos_secret_len = server_config['graphql']['authentication']['secret_length'] + session.auth_type = server_config['graphql']['authentication']['type'] + session.token_exp = server_config['graphql']['authentication']['expiration'] + session.secret_len = server_config['graphql']['authentication']['secret_length'] else: - app.state.vyos_graphql = False + session.graphql = False + + # pass session state + app.state = session - if app.state.vyos_graphql: + # add REST routes + if session.rest: + from api.rest.routers import rest_init + rest_init(app) + + # add GraphQL route + if session.graphql: + from api.graphql.routers import graphql_init graphql_init(app) config = ApiServerConfig(app, uds="/run/api.sock", proxy_headers=True) server = ApiServer(config) -def run_server(): - try: - server.run() - except OSError as e: - logger.critical(e) - sys.exit(1) if __name__ == '__main__': # systemd's user and group options don't work, do it by hand here, @@ -1025,13 +182,14 @@ if __name__ == '__main__': signal.signal(signal.SIGHUP, reload_handler) signal.signal(signal.SIGTERM, shutdown_handler) - config_session = ConfigSession(os.getpid()) + session_state = SessionState() + session_state.session = ConfigSession(os.getpid()) while True: - logger.debug('Enter main loop...') + LOG.debug('Enter main loop...') if shutdown: break if server is None: - initialization(config_session) + initialization(session_state) server.run() sleep(1) From 954be34bc938acc9e14d9fb3b32c8f96cd999959 Mon Sep 17 00:00:00 2001 From: John Estabrook Date: Wed, 25 Sep 2024 14:02:52 -0500 Subject: [PATCH 3/9] http-api: T6736: remove routes on config delete Avoid duplicate entries in the list of routes when adding/deleting endpoints. --- src/services/api/graphql/routers.py | 10 ++++++++++ src/services/api/rest/routers.py | 8 ++++++++ src/services/vyos-http-api-server | 10 ++++++++++ 3 files changed, 28 insertions(+) diff --git a/src/services/api/graphql/routers.py b/src/services/api/graphql/routers.py index d04375a49e..f02380cdc8 100644 --- a/src/services/api/graphql/routers.py +++ b/src/services/api/graphql/routers.py @@ -38,6 +38,10 @@ def graphql_init(app: "FastAPI"): in_spec = state.introspection + # remove route and reinstall below, for any changes; alternatively, test + # for config_diff before proceeding + graphql_clear(app) + if state.origins: origins = state.origins app.add_route('/graphql', CORSMiddleware(GraphQL(schema, @@ -52,3 +56,9 @@ def graphql_init(app: "FastAPI"): context_value=get_user_context, debug=True, introspection=in_spec)) + + +def graphql_clear(app: "FastAPI"): + for r in app.routes: + if r.path == '/graphql': + app.routes.remove(r) diff --git a/src/services/api/rest/routers.py b/src/services/api/rest/routers.py index 1568f7d0c3..38b10ef7d2 100644 --- a/src/services/api/rest/routers.py +++ b/src/services/api/rest/routers.py @@ -682,4 +682,12 @@ def poweroff_op(data: PoweroffModel): def rest_init(app: "FastAPI"): + if all(r in app.routes for r in router.routes): + return app.include_router(router) + + +def rest_clear(app: "FastAPI"): + for r in router.routes: + if r in app.routes: + app.routes.remove(r) diff --git a/src/services/vyos-http-api-server b/src/services/vyos-http-api-server index 456ea3d17b..6bfc2c4352 100755 --- a/src/services/vyos-http-api-server +++ b/src/services/vyos-http-api-server @@ -159,11 +159,21 @@ def initialization(session: SessionState, app: FastAPI = app): if session.rest: from api.rest.routers import rest_init rest_init(app) + else: + from api.rest.routers import rest_clear + rest_clear(app) # add GraphQL route if session.graphql: from api.graphql.routers import graphql_init graphql_init(app) + else: + from api.graphql.routers import graphql_clear + graphql_clear(app) + + LOG.debug('Active routes are:') + for r in app.routes: + LOG.debug(f'{r.path}') config = ApiServerConfig(app, uds="/run/api.sock", proxy_headers=True) server = ApiServer(config) From 9f7767c8022f5e6ad70253ef1c94fc22dbd04f1e Mon Sep 17 00:00:00 2001 From: John Estabrook Date: Wed, 25 Sep 2024 22:10:34 -0500 Subject: [PATCH 4/9] http-api: T6736: regenerate openapi docs --- src/services/vyos-http-api-server | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/services/vyos-http-api-server b/src/services/vyos-http-api-server index 6bfc2c4352..5585611824 100755 --- a/src/services/vyos-http-api-server +++ b/src/services/vyos-http-api-server @@ -113,6 +113,19 @@ def flatten_keys(d: dict) -> list[dict]: return keys_list +def regenerate_docs(app: FastAPI) -> None: + docs = ('/openapi.json', '/docs', '/docs/oauth2-redirect', '/redoc') + remove = [] + for r in app.routes: + if r.path in docs: + remove.append(r) + for r in remove: + app.routes.remove(r) + + app.openapi_schema = None + app.setup() + + def initialization(session: SessionState, app: FastAPI = app): # pylint: disable=global-statement,broad-exception-caught,import-outside-toplevel @@ -171,6 +184,8 @@ def initialization(session: SessionState, app: FastAPI = app): from api.graphql.routers import graphql_clear graphql_clear(app) + regenerate_docs(app) + LOG.debug('Active routes are:') for r in app.routes: LOG.debug(f'{r.path}') From 8183e2cc8178f5da67a918ee37c21797a05c142d Mon Sep 17 00:00:00 2001 From: John Estabrook Date: Wed, 25 Sep 2024 08:43:41 -0500 Subject: [PATCH 5/9] http-api: T6736: add distinct XML path for REST API --- interface-definitions/service_https.xml.in | 53 ++++++++++++---------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/interface-definitions/service_https.xml.in b/interface-definitions/service_https.xml.in index afe430c0c8..7bb63fa5a6 100644 --- a/interface-definitions/service_https.xml.in +++ b/interface-definitions/service_https.xml.in @@ -32,22 +32,29 @@ - + - Enforce strict path checking - + REST API - - - - Debug - - - - + + + + Enforce strict path checking + + + + + + Debug + + + + + + - GraphQL support + GraphQL API @@ -109,19 +116,19 @@ - - - - - Set CORS options - - - + - Allow resource request from origin - + Set CORS options - + + + + Allow resource request from origin + + + + + From 8322cc4d82067d7efe5051250c58d8cfbc895bce Mon Sep 17 00:00:00 2001 From: John Estabrook Date: Sun, 29 Sep 2024 16:44:29 -0500 Subject: [PATCH 6/9] http-api: T6736: update smoketest for syntax change --- smoketest/scripts/cli/test_service_https.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/smoketest/scripts/cli/test_service_https.py b/smoketest/scripts/cli/test_service_https.py index 8a6386e4f9..2f4fcd6abc 100755 --- a/smoketest/scripts/cli/test_service_https.py +++ b/smoketest/scripts/cli/test_service_https.py @@ -154,6 +154,8 @@ def test_api_auth(self): key = 'MySuperSecretVyOS' self.cli_set(base_path + ['api', 'keys', 'id', 'key-01', 'key', key]) + self.cli_set(base_path + ['api', 'rest']) + self.cli_set(base_path + ['listen-address', address]) self.cli_commit() @@ -304,6 +306,7 @@ def test_api_add_delete(self): self.assertEqual(r.status_code, 503) self.cli_set(base_path + ['api', 'keys', 'id', 'key-01', 'key', key]) + self.cli_set(base_path + ['api', 'rest']) self.cli_commit() sleep(2) @@ -326,6 +329,7 @@ def test_api_show(self): headers = {} self.cli_set(base_path + ['api', 'keys', 'id', 'key-01', 'key', key]) + self.cli_set(base_path + ['api', 'rest']) self.cli_commit() payload = { @@ -343,6 +347,7 @@ def test_api_generate(self): headers = {} self.cli_set(base_path + ['api', 'keys', 'id', 'key-01', 'key', key]) + self.cli_set(base_path + ['api', 'rest']) self.cli_commit() payload = { @@ -362,6 +367,7 @@ def test_api_configure(self): conf_address = '192.0.2.44/32' self.cli_set(base_path + ['api', 'keys', 'id', 'key-01', 'key', key]) + self.cli_set(base_path + ['api', 'rest']) self.cli_commit() payload_path = [ @@ -385,6 +391,7 @@ def test_api_config_file(self): headers = {} self.cli_set(base_path + ['api', 'keys', 'id', 'key-01', 'key', key]) + self.cli_set(base_path + ['api', 'rest']) self.cli_commit() payload = { @@ -402,6 +409,7 @@ def test_api_reset(self): headers = {} self.cli_set(base_path + ['api', 'keys', 'id', 'key-01', 'key', key]) + self.cli_set(base_path + ['api', 'rest']) self.cli_commit() payload = { @@ -419,6 +427,7 @@ def test_api_image(self): headers = {} self.cli_set(base_path + ['api', 'keys', 'id', 'key-01', 'key', key]) + self.cli_set(base_path + ['api', 'rest']) self.cli_commit() payload = { @@ -462,6 +471,7 @@ def test_api_config_file_load_http(self): headers = {} self.cli_set(base_path + ['api', 'keys', 'id', 'key-01', 'key', key]) + self.cli_set(base_path + ['api', 'rest']) self.cli_commit() # load config via HTTP requires nginx config From 40d966310cb5d8d758f7fa801facd0a560795783 Mon Sep 17 00:00:00 2001 From: John Estabrook Date: Sun, 29 Sep 2024 20:49:30 -0500 Subject: [PATCH 7/9] http-api: T6736: add migration script and update version --- .../include/version/https-version.xml.i | 2 +- src/migration-scripts/https/6-to-7 | 43 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 src/migration-scripts/https/6-to-7 diff --git a/interface-definitions/include/version/https-version.xml.i b/interface-definitions/include/version/https-version.xml.i index 525314dbd9..a889a7805f 100644 --- a/interface-definitions/include/version/https-version.xml.i +++ b/interface-definitions/include/version/https-version.xml.i @@ -1,3 +1,3 @@ - + diff --git a/src/migration-scripts/https/6-to-7 b/src/migration-scripts/https/6-to-7 new file mode 100644 index 0000000000..571f3b6ae5 --- /dev/null +++ b/src/migration-scripts/https/6-to-7 @@ -0,0 +1,43 @@ +# Copyright 2024 VyOS maintainers and contributors +# +# This library is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 2.1 of the License, or (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see . + +# T6736: move REST API to distinct node + + +from vyos.configtree import ConfigTree + + +base = ['service', 'https', 'api'] + +def migrate(config: ConfigTree) -> None: + if not config.exists(base): + # Nothing to do + return + + # Move REST API configuration to new node + # REST API was previously enabled if base path exists + config.set(['service', 'https', 'api', 'rest']) + for entry in ('debug', 'strict'): + if config.exists(base + [entry]): + config.set(base + ['rest', entry]) + config.delete(base + [entry]) + + # Move CORS settings under GraphQL + # CORS is not implemented for REST API + if config.exists(base + ['cors']): + old_base = base + ['cors'] + new_base = base + ['graphql', 'cors'] + config.copy(old_base, new_base) + config.delete(old_base) From 7e23fd9da028b3c623b69fda8a6bcfd887f1c18c Mon Sep 17 00:00:00 2001 From: John Estabrook Date: Mon, 30 Sep 2024 20:40:02 -0500 Subject: [PATCH 8/9] http-api: T6736: normalize formatting --- smoketest/scripts/cli/test_service_https.py | 75 +++++-- src/services/api/graphql/bindings.py | 29 ++- .../graphql/graphql/auth_token_mutation.py | 31 +-- src/services/api/graphql/graphql/mutations.py | 71 +++--- src/services/api/graphql/graphql/queries.py | 71 +++--- src/services/api/graphql/libs/key_auth.py | 4 +- src/services/api/graphql/libs/token_auth.py | 10 +- src/services/api/graphql/routers.py | 41 ++-- src/services/api/graphql/session/session.py | 33 ++- src/services/api/rest/models.py | 143 +++++++----- src/services/api/rest/routers.py | 209 +++++++++++------- src/services/api/session.py | 1 + 12 files changed, 427 insertions(+), 291 deletions(-) diff --git a/smoketest/scripts/cli/test_service_https.py b/smoketest/scripts/cli/test_service_https.py index 2f4fcd6abc..04c4a2e519 100755 --- a/smoketest/scripts/cli/test_service_https.py +++ b/smoketest/scripts/cli/test_service_https.py @@ -89,6 +89,7 @@ PROCESS_NAME = 'nginx' + class TestHTTPSService(VyOSUnitTestSHIM.TestCase): @classmethod def setUpClass(cls): @@ -120,19 +121,29 @@ def test_certificate(self): # verify() - certificates do not exist (yet) with self.assertRaises(ConfigSessionError): self.cli_commit() - self.cli_set(pki_base + ['certificate', cert_name, 'certificate', cert_data.replace('\n','')]) - self.cli_set(pki_base + ['certificate', cert_name, 'private', 'key', key_data.replace('\n','')]) + self.cli_set( + pki_base + + ['certificate', cert_name, 'certificate', cert_data.replace('\n', '')] + ) + self.cli_set( + pki_base + + ['certificate', cert_name, 'private', 'key', key_data.replace('\n', '')] + ) self.cli_set(base_path + ['certificates', 'dh-params', dh_name]) # verify() - dh-params do not exist (yet) with self.assertRaises(ConfigSessionError): self.cli_commit() - self.cli_set(pki_base + ['dh', dh_name, 'parameters', dh_1024.replace('\n','')]) + self.cli_set( + pki_base + ['dh', dh_name, 'parameters', dh_1024.replace('\n', '')] + ) # verify() - dh-param minimum length is 2048 bit with self.assertRaises(ConfigSessionError): self.cli_commit() - self.cli_set(pki_base + ['dh', dh_name, 'parameters', dh_2048.replace('\n','')]) + self.cli_set( + pki_base + ['dh', dh_name, 'parameters', dh_2048.replace('\n', '')] + ) self.cli_commit() self.assertTrue(process_named_running(PROCESS_NAME)) @@ -162,7 +173,7 @@ def test_api_auth(self): nginx_config = read_file('/etc/nginx/sites-enabled/default') self.assertIn(f'listen {address}:{port} ssl;', nginx_config) - self.assertIn(f'ssl_protocols TLSv1.2 TLSv1.3;', nginx_config) # default + self.assertIn('ssl_protocols TLSv1.2 TLSv1.3;', nginx_config) # default url = f'https://{address}/retrieve' payload = {'data': '{"op": "showConfig", "path": []}', 'key': f'{key}'} @@ -182,11 +193,16 @@ def test_api_auth(self): self.assertEqual(r.status_code, 401) # Check path config - payload = {'data': '{"op": "showConfig", "path": ["system", "login"]}', 'key': f'{key}'} + payload = { + 'data': '{"op": "showConfig", "path": ["system", "login"]}', + 'key': f'{key}', + } r = request('POST', url, verify=False, headers=headers, data=payload) response = r.json() vyos_user_exists = 'vyos' in response.get('data', {}).get('user', {}) - self.assertTrue(vyos_user_exists, "The 'vyos' user does not exist in the response.") + self.assertTrue( + vyos_user_exists, "The 'vyos' user does not exist in the response." + ) # GraphQL auth test: a missing key will return status code 400, as # 'key' is a non-nullable field in the schema; an incorrect key is @@ -210,7 +226,13 @@ def test_api_auth(self): }} """ - r = request('POST', graphql_url, verify=False, headers=headers, json={'query': query_valid_key}) + r = request( + 'POST', + graphql_url, + verify=False, + headers=headers, + json={'query': query_valid_key}, + ) success = r.json()['data']['SystemStatus']['success'] self.assertTrue(success) @@ -226,7 +248,13 @@ def test_api_auth(self): } """ - r = request('POST', graphql_url, verify=False, headers=headers, json={'query': query_invalid_key}) + r = request( + 'POST', + graphql_url, + verify=False, + headers=headers, + json={'query': query_invalid_key}, + ) success = r.json()['data']['SystemStatus']['success'] self.assertFalse(success) @@ -242,7 +270,13 @@ def test_api_auth(self): } """ - r = request('POST', graphql_url, verify=False, headers=headers, json={'query': query_no_key}) + r = request( + 'POST', + graphql_url, + verify=False, + headers=headers, + json={'query': query_no_key}, + ) success = r.json()['data']['SystemStatus']['success'] self.assertFalse(success) @@ -263,7 +297,9 @@ def test_api_auth(self): } } """ - r = request('POST', graphql_url, verify=False, headers=headers, json={'query': mutation}) + r = request( + 'POST', graphql_url, verify=False, headers=headers, json={'query': mutation} + ) token = r.json()['data']['AuthToken']['data']['result']['token'] @@ -286,7 +322,9 @@ def test_api_auth(self): } """ - r = request('POST', graphql_url, verify=False, headers=headers, json={'query': query}) + r = request( + 'POST', graphql_url, verify=False, headers=headers, json={'query': query} + ) success = r.json()['data']['ShowVersion']['success'] self.assertTrue(success) @@ -371,14 +409,14 @@ def test_api_configure(self): self.cli_commit() payload_path = [ - "interfaces", - "dummy", - f"{conf_interface}", - "address", - f"{conf_address}", + 'interfaces', + 'dummy', + f'{conf_interface}', + 'address', + f'{conf_address}', ] - payload = {'data': json.dumps({"op": "set", "path": payload_path}), 'key': key} + payload = {'data': json.dumps({'op': 'set', 'path': payload_path}), 'key': key} r = request('POST', url, verify=False, headers=headers, data=payload) self.assertEqual(r.status_code, 200) @@ -508,5 +546,6 @@ def test_api_config_file_load_http(self): call(f'sudo rm -f {nginx_tmp_site}') call('sudo systemctl reload nginx') + if __name__ == '__main__': unittest.main(verbosity=5) diff --git a/src/services/api/graphql/bindings.py b/src/services/api/graphql/bindings.py index 93dd0fbfb5..ebf745f325 100644 --- a/src/services/api/graphql/bindings.py +++ b/src/services/api/graphql/bindings.py @@ -20,18 +20,18 @@ from ariadne import load_schema_from_path from ariadne import snake_case_fallback_resolvers -from . graphql.queries import query -from . graphql.mutations import mutation -from . graphql.directives import directives_dict -from . graphql.errors import op_mode_error -from . graphql.auth_token_mutation import auth_token_mutation -from . libs.token_auth import init_secret +from .graphql.queries import query +from .graphql.mutations import mutation +from .graphql.directives import directives_dict +from .graphql.errors import op_mode_error +from .graphql.auth_token_mutation import auth_token_mutation +from .libs.token_auth import init_secret -from .. session import SessionState +from ..session import SessionState def generate_schema(): - state = SessionState() + state = SessionState() api_schema_dir = vyos.defaults.directories['api_schema'] if state.auth_type == 'token': @@ -39,9 +39,14 @@ def generate_schema(): type_defs = load_schema_from_path(api_schema_dir) - schema = make_executable_schema(type_defs, query, op_mode_error, - mutation, auth_token_mutation, - snake_case_fallback_resolvers, - directives=directives_dict) + schema = make_executable_schema( + type_defs, + query, + op_mode_error, + mutation, + auth_token_mutation, + snake_case_fallback_resolvers, + directives=directives_dict, + ) return schema diff --git a/src/services/api/graphql/graphql/auth_token_mutation.py b/src/services/api/graphql/graphql/auth_token_mutation.py index 1649602174..c74364603d 100644 --- a/src/services/api/graphql/graphql/auth_token_mutation.py +++ b/src/services/api/graphql/graphql/auth_token_mutation.py @@ -19,11 +19,12 @@ from ariadne import ObjectType from graphql import GraphQLResolveInfo -from .. libs.token_auth import generate_token -from .. session.session import get_user_info -from ... session import SessionState +from ..libs.token_auth import generate_token +from ..session.session import get_user_info +from ...session import SessionState + +auth_token_mutation = ObjectType('Mutation') -auth_token_mutation = ObjectType("Mutation") @auth_token_mutation.field('AuthToken') def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict): @@ -35,8 +36,9 @@ def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict): secret = getattr(state, 'secret', '') exp_interval = int(state.token_exp) - expiration = (datetime.datetime.now(tz=datetime.timezone.utc) + - datetime.timedelta(seconds=exp_interval)) + expiration = datetime.datetime.now(tz=datetime.timezone.utc) + datetime.timedelta( + seconds=exp_interval + ) res = generate_token(user, passwd, secret, expiration) try: @@ -46,18 +48,9 @@ def auth_token_resolver(obj: Any, info: GraphQLResolveInfo, data: Dict): pass if 'token' in res: data['result'] = res - return { - "success": True, - "data": data - } + return {'success': True, 'data': data} if 'errors' in res: - return { - "success": False, - "errors": res['errors'] - } - - return { - "success": False, - "errors": ['token generation failed'] - } + return {'success': False, 'errors': res['errors']} + + return {'success': False, 'errors': ['token generation failed']} diff --git a/src/services/api/graphql/graphql/mutations.py b/src/services/api/graphql/graphql/mutations.py index 62031ada3c..0b391c0708 100644 --- a/src/services/api/graphql/graphql/mutations.py +++ b/src/services/api/graphql/graphql/mutations.py @@ -14,20 +14,23 @@ # along with this library. If not, see . from importlib import import_module -from ariadne import ObjectType, convert_camel_case_to_snake -from makefun import with_signature # used below by func_sig -from typing import Any, Dict, Optional # pylint: disable=W0611 -from graphql import GraphQLResolveInfo # pylint: disable=W0611 +from typing import Any, Dict, Optional # pylint: disable=W0611 # noqa: F401 +from graphql import GraphQLResolveInfo # pylint: disable=W0611 # noqa: F401 + +from ariadne import ObjectType, convert_camel_case_to_snake +from makefun import with_signature -from ... session import SessionState -from .. libs import key_auth -from .. session.session import Session -from .. session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code from vyos.opmode import Error as OpModeError -mutation = ObjectType("Mutation") +from ...session import SessionState +from ..libs import key_auth +from ..session.session import Session +from ..session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code + +mutation = ObjectType('Mutation') + def make_mutation_resolver(mutation_name, class_name, session_func): """Dynamically generate a resolver for the mutation named in the @@ -59,10 +62,7 @@ async def func_impl(*args, **kwargs): auth = key_auth.auth_required(key) if auth is None: - return { - "success": False, - "errors": ['invalid API key'] - } + return {'success': False, 'errors': ['invalid API key']} # We are finished with the 'key' entry, and may remove so as to # pass the rest of data (if any) to function. @@ -77,14 +77,8 @@ async def func_impl(*args, **kwargs): if user is None: error = info.context.get('error') if error is not None: - return { - "success": False, - "errors": [error] - } - return { - "success": False, - "errors": ['not authenticated'] - } + return {'success': False, 'errors': [error]} + return {'success': False, 'errors': ['not authenticated']} else: # AtrributeError will have already been raised if no # auth_type; validation and defaultValue ensure it is @@ -106,35 +100,36 @@ async def func_impl(*args, **kwargs): result = method() data['result'] = result - return { - "success": True, - "data": data - } + return {'success': True, 'data': data} except OpModeError as e: typename = type(e).__name__ msg = str(e) return { - "success": False, - "errore": ['op_mode_error'], - "op_mode_error": {"name": f"{typename}", - "message": msg if msg else op_mode_err_msg.get(typename, "Unknown"), - "vyos_code": op_mode_err_code.get(typename, 9999)} + 'success': False, + 'errore': ['op_mode_error'], + 'op_mode_error': { + 'name': f'{typename}', + 'message': msg if msg else op_mode_err_msg.get(typename, 'Unknown'), + 'vyos_code': op_mode_err_code.get(typename, 9999), + }, } except Exception as error: - return { - "success": False, - "errors": [repr(error)] - } + return {'success': False, 'errors': [repr(error)]} return func_impl + def make_config_session_mutation_resolver(mutation_name): - return make_mutation_resolver(mutation_name, mutation_name, - convert_camel_case_to_snake(mutation_name)) + return make_mutation_resolver( + mutation_name, mutation_name, convert_camel_case_to_snake(mutation_name) + ) + def make_gen_op_mutation_resolver(mutation_name): return make_mutation_resolver(mutation_name, mutation_name, 'gen_op_mutation') + def make_composite_mutation_resolver(mutation_name): - return make_mutation_resolver(mutation_name, mutation_name, - convert_camel_case_to_snake(mutation_name)) + return make_mutation_resolver( + mutation_name, mutation_name, convert_camel_case_to_snake(mutation_name) + ) diff --git a/src/services/api/graphql/graphql/queries.py b/src/services/api/graphql/graphql/queries.py index 1e9036574b..9303fe9099 100644 --- a/src/services/api/graphql/graphql/queries.py +++ b/src/services/api/graphql/graphql/queries.py @@ -14,20 +14,23 @@ # along with this library. If not, see . from importlib import import_module -from ariadne import ObjectType, convert_camel_case_to_snake -from makefun import with_signature # used below by func_sig -from typing import Any, Dict, Optional # pylint: disable=W0611 -from graphql import GraphQLResolveInfo # pylint: disable=W0611 +from typing import Any, Dict, Optional # pylint: disable=W0611 # noqa: F401 +from graphql import GraphQLResolveInfo # pylint: disable=W0611 # noqa: F401 + +from ariadne import ObjectType, convert_camel_case_to_snake +from makefun import with_signature -from ... session import SessionState -from .. libs import key_auth -from .. session.session import Session -from .. session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code from vyos.opmode import Error as OpModeError -query = ObjectType("Query") +from ...session import SessionState +from ..libs import key_auth +from ..session.session import Session +from ..session.errors.op_mode_errors import op_mode_err_msg, op_mode_err_code + +query = ObjectType('Query') + def make_query_resolver(query_name, class_name, session_func): """Dynamically generate a resolver for the query named in the @@ -59,10 +62,7 @@ async def func_impl(*args, **kwargs): auth = key_auth.auth_required(key) if auth is None: - return { - "success": False, - "errors": ['invalid API key'] - } + return {'success': False, 'errors': ['invalid API key']} # We are finished with the 'key' entry, and may remove so as to # pass the rest of data (if any) to function. @@ -77,14 +77,8 @@ async def func_impl(*args, **kwargs): if user is None: error = info.context.get('error') if error is not None: - return { - "success": False, - "errors": [error] - } - return { - "success": False, - "errors": ['not authenticated'] - } + return {'success': False, 'errors': [error]} + return {'success': False, 'errors': ['not authenticated']} else: # AtrributeError will have already been raised if no # auth_type; validation and defaultValue ensure it is @@ -106,35 +100,36 @@ async def func_impl(*args, **kwargs): result = method() data['result'] = result - return { - "success": True, - "data": data - } + return {'success': True, 'data': data} except OpModeError as e: typename = type(e).__name__ msg = str(e) return { - "success": False, - "errors": ['op_mode_error'], - "op_mode_error": {"name": f"{typename}", - "message": msg if msg else op_mode_err_msg.get(typename, "Unknown"), - "vyos_code": op_mode_err_code.get(typename, 9999)} + 'success': False, + 'errors': ['op_mode_error'], + 'op_mode_error': { + 'name': f'{typename}', + 'message': msg if msg else op_mode_err_msg.get(typename, 'Unknown'), + 'vyos_code': op_mode_err_code.get(typename, 9999), + }, } except Exception as error: - return { - "success": False, - "errors": [repr(error)] - } + return {'success': False, 'errors': [repr(error)]} return func_impl + def make_config_session_query_resolver(query_name): - return make_query_resolver(query_name, query_name, - convert_camel_case_to_snake(query_name)) + return make_query_resolver( + query_name, query_name, convert_camel_case_to_snake(query_name) + ) + def make_gen_op_query_resolver(query_name): return make_query_resolver(query_name, query_name, 'gen_op_query') + def make_composite_query_resolver(query_name): - return make_query_resolver(query_name, query_name, - convert_camel_case_to_snake(query_name)) + return make_query_resolver( + query_name, query_name, convert_camel_case_to_snake(query_name) + ) diff --git a/src/services/api/graphql/libs/key_auth.py b/src/services/api/graphql/libs/key_auth.py index 9e49a1203e..ffd7f32b2e 100644 --- a/src/services/api/graphql/libs/key_auth.py +++ b/src/services/api/graphql/libs/key_auth.py @@ -14,7 +14,8 @@ # along with this library. If not, see . -from ... session import SessionState +from ...session import SessionState + def check_auth(key_list, key): if not key_list: @@ -25,6 +26,7 @@ def check_auth(key_list, key): key_id = k['id'] return key_id + def auth_required(key): state = SessionState() api_keys = None diff --git a/src/services/api/graphql/libs/token_auth.py b/src/services/api/graphql/libs/token_auth.py index 2d772e0357..4f743a0964 100644 --- a/src/services/api/graphql/libs/token_auth.py +++ b/src/services/api/graphql/libs/token_auth.py @@ -19,7 +19,7 @@ import pam from secrets import token_hex -from ... session import SessionState +from ...session import SessionState def _check_passwd_pam(username: str, passwd: str) -> bool: @@ -48,13 +48,13 @@ def generate_token(user: str, passwd: str, secret: str, exp: int) -> dict: payload_data = {'iss': user, 'sub': user_id, 'exp': exp} secret = getattr(state, 'secret', None) if secret is None: - return {"errors": ['missing secret']} - token = jwt.encode(payload=payload_data, key=secret, algorithm="HS256") + return {'errors': ['missing secret']} + token = jwt.encode(payload=payload_data, key=secret, algorithm='HS256') users |= {user_id: user} return {'token': token} else: - return {"errors": ['failed pam authentication']} + return {'errors': ['failed pam authentication']} def get_user_context(request): @@ -70,7 +70,7 @@ def get_user_context(request): try: secret = getattr(state, 'secret', None) - payload = jwt.decode(token, secret, algorithms=["HS256"]) + payload = jwt.decode(token, secret, algorithms=['HS256']) user_id: str = payload.get('sub') if user_id is None: return context diff --git a/src/services/api/graphql/routers.py b/src/services/api/graphql/routers.py index f02380cdc8..ed3ee1e8cf 100644 --- a/src/services/api/graphql/routers.py +++ b/src/services/api/graphql/routers.py @@ -26,14 +26,15 @@ from fastapi import FastAPI -def graphql_init(app: "FastAPI"): - from .. session import SessionState +def graphql_init(app: 'FastAPI'): + from ..session import SessionState from .libs.token_auth import get_user_context state = SessionState() # import after initializaion of state from .bindings import generate_schema + schema = generate_schema() in_spec = state.introspection @@ -44,21 +45,33 @@ def graphql_init(app: "FastAPI"): if state.origins: origins = state.origins - app.add_route('/graphql', CORSMiddleware(GraphQL(schema, - context_value=get_user_context, - debug=True, - introspection=in_spec), - allow_origins=origins, - allow_methods=("GET", "POST", "OPTIONS"), - allow_headers=("Authorization",))) + app.add_route( + '/graphql', + CORSMiddleware( + GraphQL( + schema, + context_value=get_user_context, + debug=True, + introspection=in_spec, + ), + allow_origins=origins, + allow_methods=('GET', 'POST', 'OPTIONS'), + allow_headers=('Authorization',), + ), + ) else: - app.add_route('/graphql', GraphQL(schema, - context_value=get_user_context, - debug=True, - introspection=in_spec)) + app.add_route( + '/graphql', + GraphQL( + schema, + context_value=get_user_context, + debug=True, + introspection=in_spec, + ), + ) -def graphql_clear(app: "FastAPI"): +def graphql_clear(app: 'FastAPI'): for r in app.routes: if r.path == '/graphql': app.routes.remove(r) diff --git a/src/services/api/graphql/session/session.py b/src/services/api/graphql/session/session.py index 6e2875f3c2..619534f43f 100644 --- a/src/services/api/graphql/session/session.py +++ b/src/services/api/graphql/session/session.py @@ -28,34 +28,45 @@ op_mode_include_file = os.path.join(directories['data'], 'op-mode-standardized.json') -def get_config_dict(path=[], effective=False, key_mangling=None, - get_first_key=False, no_multi_convert=False, - no_tag_node_value_mangle=False): + +def get_config_dict( + path=[], + effective=False, + key_mangling=None, + get_first_key=False, + no_multi_convert=False, + no_tag_node_value_mangle=False, +): config = Config() - return config.get_config_dict(path=path, effective=effective, - key_mangling=key_mangling, - get_first_key=get_first_key, - no_multi_convert=no_multi_convert, - no_tag_node_value_mangle=no_tag_node_value_mangle) + return config.get_config_dict( + path=path, + effective=effective, + key_mangling=key_mangling, + get_first_key=get_first_key, + no_multi_convert=no_multi_convert, + no_tag_node_value_mangle=no_tag_node_value_mangle, + ) + def get_user_info(user): user_info = {} - info = get_config_dict(['system', 'login', 'user', user], - get_first_key=True) + info = get_config_dict(['system', 'login', 'user', user], get_first_key=True) if not info: - raise ValueError("No such user") + raise ValueError('No such user') user_info['user'] = user user_info['full_name'] = info.get('full-name', '') return user_info + class Session: """ Wrapper for calling configsession functions based on GraphQL requests. Non-nullable fields in the respective schema allow avoiding a key check in 'data'. """ + def __init__(self, session, data): self._session = session self._data = data diff --git a/src/services/api/rest/models.py b/src/services/api/rest/models.py index 034e3fcdb3..d65d6e1ec7 100644 --- a/src/services/api/rest/models.py +++ b/src/services/api/rest/models.py @@ -31,75 +31,88 @@ def error(code, msg): - resp = {"success": False, "error": msg, "data": None} + resp = {'success': False, 'error': msg, 'data': None} resp = json.dumps(resp) return HTMLResponse(resp, status_code=code) + def success(data): - resp = {"success": True, "data": data, "error": None} + resp = {'success': True, 'data': data, 'error': None} resp = json.dumps(resp) return HTMLResponse(resp) + # Pydantic models for validation # Pydantic will cast when possible, so use StrictStr validators added as # needed for additional constraints # json_schema_extra adds anotations to OpenAPI to add examples + class ApiModel(BaseModel): key: StrictStr + class BasePathModel(BaseModel): op: StrictStr path: List[StrictStr] - @field_validator("path") + @field_validator('path') @classmethod def check_non_empty(cls, path: str) -> str: if not len(path) > 0: raise ValueError('path must be non-empty') return path + class BaseConfigureModel(BasePathModel): value: StrictStr = None + class ConfigureModel(ApiModel, BaseConfigureModel): class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "set | delete | comment", - "path": ["config", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'set | delete | comment', + 'path': ['config', 'mode', 'path'], } } + class ConfigureListModel(ApiModel): commands: List[BaseConfigureModel] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "commands": "list of commands", + 'example': { + 'key': 'id_key', + 'commands': 'list of commands', } } + class BaseConfigSectionModel(BasePathModel): section: Dict + class ConfigSectionModel(ApiModel, BaseConfigSectionModel): pass + class ConfigSectionListModel(ApiModel): commands: List[BaseConfigSectionModel] + class BaseConfigSectionTreeModel(BaseModel): op: StrictStr mask: Dict config: Dict + class ConfigSectionTreeModel(ApiModel, BaseConfigSectionTreeModel): pass + class RetrieveModel(ApiModel): op: StrictStr path: List[StrictStr] @@ -107,34 +120,34 @@ class RetrieveModel(ApiModel): class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "returnValue | returnValues | exists | showConfig", - "path": ["config", "mode", "path"], - "configFormat": "json (default) | json_ast | raw", - + 'example': { + 'key': 'id_key', + 'op': 'returnValue | returnValues | exists | showConfig', + 'path': ['config', 'mode', 'path'], + 'configFormat': 'json (default) | json_ast | raw', } } + class ConfigFileModel(ApiModel): op: StrictStr file: StrictStr = None class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "save | load", - "file": "filename", + 'example': { + 'key': 'id_key', + 'op': 'save | load', + 'file': 'filename', } } class ImageOp(str, Enum): - add = "add" - delete = "delete" - show = "show" - set_default = "set_default" + add = 'add' + delete = 'delete' + show = 'show' + set_default = 'set_default' class ImageModel(ApiModel): @@ -146,23 +159,24 @@ class ImageModel(ApiModel): def check_data(self) -> Self: if self.op == 'add': if not self.url: - raise ValueError("Missing required field \"url\"") + raise ValueError('Missing required field "url"') elif self.op in ['delete', 'set_default']: if not self.name: - raise ValueError("Missing required field \"name\"") + raise ValueError('Missing required field "name"') return self class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "add | delete | show | set_default", - "url": "imagelocation", - "name": "imagename", + 'example': { + 'key': 'id_key', + 'op': 'add | delete | show | set_default', + 'url': 'imagelocation', + 'name': 'imagename', } } + class ImportPkiModel(ApiModel): op: StrictStr path: List[StrictStr] @@ -170,11 +184,11 @@ class ImportPkiModel(ApiModel): class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "import_pki", - "path": ["op", "mode", "path"], - "passphrase": "passphrase", + 'example': { + 'key': 'id_key', + 'op': 'import_pki', + 'path': ['op', 'mode', 'path'], + 'passphrase': 'passphrase', } } @@ -185,75 +199,80 @@ class ContainerImageModel(ApiModel): class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "add | delete | show", - "name": "imagename", + 'example': { + 'key': 'id_key', + 'op': 'add | delete | show', + 'name': 'imagename', } } + class GenerateModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "generate", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'generate', + 'path': ['op', 'mode', 'path'], } } + class ShowModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "show", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'show', + 'path': ['op', 'mode', 'path'], } } + class RebootModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "reboot", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'reboot', + 'path': ['op', 'mode', 'path'], } } + class ResetModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "reset", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'reset', + 'path': ['op', 'mode', 'path'], } } + class PoweroffModel(ApiModel): op: StrictStr path: List[StrictStr] class Config: json_schema_extra = { - "example": { - "key": "id_key", - "op": "poweroff", - "path": ["op", "mode", "path"], + 'example': { + 'key': 'id_key', + 'op': 'poweroff', + 'path': ['op', 'mode', 'path'], } } @@ -263,14 +282,16 @@ class Success(BaseModel): data: Union[str, bool, Dict] error: str + class Error(BaseModel): success: bool = False data: Union[str, bool, Dict] error: str + responses = { 200: {'model': Success}, 400: {'model': Error}, 422: {'model': Error, 'description': 'Validation Error'}, - 500: {'model': Error} + 500: {'model': Error}, } diff --git a/src/services/api/rest/routers.py b/src/services/api/rest/routers.py index 38b10ef7d2..da981d5bff 100644 --- a/src/services/api/rest/routers.py +++ b/src/services/api/rest/routers.py @@ -18,10 +18,12 @@ # pylint: disable=wildcard-import,unused-wildcard-import # pylint: disable=broad-exception-caught +import json import copy import logging import traceback from threading import Lock +from typing import Union from typing import Callable from typing import TYPE_CHECKING @@ -43,8 +45,30 @@ from vyos.configdiff import get_config_diff from vyos.configsession import ConfigSessionError -from .. session import SessionState -from . models import * +from ..session import SessionState +from .models import success +from .models import error +from .models import responses +from .models import ApiModel +from .models import ConfigureModel +from .models import ConfigureListModel +from .models import ConfigSectionModel +from .models import ConfigSectionListModel +from .models import ConfigSectionTreeModel +from .models import BaseConfigSectionTreeModel +from .models import BaseConfigureModel +from .models import BaseConfigSectionModel +from .models import RetrieveModel +from .models import ConfigFileModel +from .models import ImageModel +from .models import ContainerImageModel +from .models import GenerateModel +from .models import ShowModel +from .models import RebootModel +from .models import ResetModel +from .models import ImportPkiModel +from .models import PoweroffModel + if TYPE_CHECKING: from fastapi import FastAPI @@ -69,7 +93,7 @@ def auth_required(data: ApiModel): api_keys = session.keys key_id = check_auth(api_keys, key) if not key_id: - raise HTTPException(status_code=401, detail="Valid API key is required") + raise HTTPException(status_code=401, detail='Valid API key is required') session.id = key_id @@ -79,10 +103,12 @@ def auth_required(data: ApiModel): # validation by FastAPI/Pydantic, as is used for application/json requests class MultipartRequest(Request): """Override Request class to convert form request to json""" + # pylint: disable=attribute-defined-outside-init # pylint: disable=too-many-branches,too-many-statements _form_err = () + @property def form_err(self): return self._form_err @@ -104,16 +130,16 @@ def headers(self): return self._headers async def _get_form( - self, *, max_files: int | float = 1000, max_fields: int | float = 1000 - ) -> FormData: + self, *, max_files: int | float = 1000, max_fields: int | float = 1000 + ) -> FormData: if self._form is None: assert ( parse_options_header is not None - ), "The `python-multipart` library must be installed to use form parsing." - content_type_header = self.orig_headers.get("Content-Type") + ), 'The `python-multipart` library must be installed to use form parsing.' + content_type_header = self.orig_headers.get('Content-Type') content_type: bytes content_type, _ = parse_options_header(content_type_header) - if content_type == b"multipart/form-data": + if content_type == b'multipart/form-data': try: multipart_parser = MultiPartParser( self.orig_headers, @@ -123,10 +149,10 @@ async def _get_form( ) self._form = await multipart_parser.parse() except MultiPartException as exc: - if "app" in self.scope: + if 'app' in self.scope: raise HTTPException(status_code=400, detail=exc.message) raise exc - elif content_type == b"application/x-www-form-urlencoded": + elif content_type == b'application/x-www-form-urlencoded': form_parser = FormParser(self.orig_headers, self.stream()) self._form = await form_parser.parse() else: @@ -134,7 +160,7 @@ async def _get_form( return self._form async def body(self) -> bytes: - if not hasattr(self, "_body"): + if not hasattr(self, '_body'): forms = {} merge = {} body = await super().body() @@ -143,12 +169,12 @@ async def body(self) -> bytes: form_data = await self.form() if form_data: endpoint = self.url.path - LOG.debug("processing form data") + LOG.debug('processing form data') for k, v in form_data.multi_items(): forms[k] = v if 'data' not in forms: - self.form_err = (422, "Non-empty data field is required") + self.form_err = (422, 'Non-empty data field is required') return self._body try: tmp = json.loads(forms['data']) @@ -168,37 +194,57 @@ async def body(self) -> bytes: for c in cmds: if not isinstance(c, dict): - self.form_err = (400, - f"Malformed command '{c}': any command must be JSON of dict") + self.form_err = ( + 400, + f"Malformed command '{c}': any command must be JSON of dict", + ) return self._body if 'op' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'op' field") - if endpoint not in ('/config-file', '/container-image', - '/image', '/configure-section'): + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'op' field", + ) + if endpoint not in ( + '/config-file', + '/container-image', + '/image', + '/configure-section', + ): if 'path' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'path' field") + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'path' field", + ) elif not isinstance(c['path'], list): - self.form_err = (400, - f"Malformed command '{c}': 'path' field must be a list") + self.form_err = ( + 400, + f"Malformed command '{c}': 'path' field must be a list", + ) elif not all(isinstance(el, str) for el in c['path']): - self.form_err = (400, - f"Malformed command '{0}': 'path' field must be a list of strings") + self.form_err = ( + 400, + f"Malformed command '{0}': 'path' field must be a list of strings", + ) if endpoint in ('/configure'): if not c['path']: - self.form_err = (400, - f"Malformed command '{c}': 'path' list must be non-empty") + self.form_err = ( + 400, + f"Malformed command '{c}': 'path' list must be non-empty", + ) if 'value' in c and not isinstance(c['value'], str): - self.form_err = (400, - f"Malformed command '{c}': 'value' field must be a string") + self.form_err = ( + 400, + f"Malformed command '{c}': 'value' field must be a string", + ) if endpoint in ('/configure-section'): if 'section' not in c and 'config' not in c: - self.form_err = (400, - f"Malformed command '{c}': missing 'section' or 'config' field") + self.form_err = ( + 400, + f"Malformed command '{c}': missing 'section' or 'config' field", + ) if 'key' not in forms and 'key' not in merge: - self.form_err = (401, "Valid API key is required") + self.form_err = (401, 'Valid API key is required') if 'key' in forms and 'key' not in merge: merge['key'] = forms['key'] @@ -232,12 +278,15 @@ async def custom_route_handler(request: Request) -> Response: return custom_route_handler -router = APIRouter(route_class=MultipartRoute, - responses={**responses}, - dependencies=[Depends(auth_required)]) +router = APIRouter( + route_class=MultipartRoute, + responses={**responses}, + dependencies=[Depends(auth_required)], +) -self_ref_msg = "Requested HTTP API server configuration change; commit will be called in the background" +self_ref_msg = 'Requested HTTP API server configuration change; commit will be called in the background' + def call_commit(s: SessionState): try: @@ -245,15 +294,22 @@ def call_commit(s: SessionState): except ConfigSessionError as e: s.session.discard() if s.debug: - LOG.warning(f"ConfigSessionError:\n {traceback.format_exc()}") + LOG.warning(f'ConfigSessionError:\n {traceback.format_exc()}') else: - LOG.warning(f"ConfigSessionError: {e}") - - -def _configure_op(data: Union[ConfigureModel, ConfigureListModel, - ConfigSectionModel, ConfigSectionListModel, - ConfigSectionTreeModel], - _request: Request, background_tasks: BackgroundTasks): + LOG.warning(f'ConfigSessionError: {e}') + + +def _configure_op( + data: Union[ + ConfigureModel, + ConfigureListModel, + ConfigSectionModel, + ConfigSectionListModel, + ConfigSectionTreeModel, + ], + _request: Request, + background_tasks: BackgroundTasks, +): # pylint: disable=too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements # pylint: disable=consider-using-with @@ -287,10 +343,10 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, if c.value: value = c.value else: - value = "" + value = '' # For vyos.configsession calls that have no separate value arguments, # and for type checking too - cfg_path = " ".join(path + [value]).strip() + cfg_path = ' '.join(path + [value]).strip() elif isinstance(c, BaseConfigSectionModel): section = c.section @@ -304,7 +360,9 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, session.set(path, value=value) elif op == 'delete': if state.strict and not config.exists(cfg_path): - raise ConfigSessionError(f"Cannot delete [{cfg_path}]: path/value does not exist") + raise ConfigSessionError( + f'Cannot delete [{cfg_path}]: path/value does not exist' + ) session.delete(path, value=value) elif op == 'comment': session.comment(path, value=value) @@ -343,7 +401,7 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, session.discard() status = 400 if state.debug: - LOG.critical(f"ConfigSessionError:\n {traceback.format_exc()}") + LOG.critical(f'ConfigSessionError:\n {traceback.format_exc()}') error_msg = str(e) except Exception: session.discard() @@ -351,7 +409,7 @@ def _configure_op(data: Union[ConfigureModel, ConfigureListModel, status = 500 # Don't give the details away to the outer world - error_msg = "An internal error occured. Check the logs for details." + error_msg = 'An internal error occured. Check the logs for details.' finally: lock.release() @@ -371,21 +429,24 @@ def create_path_import_pki_no_prompt(path): @router.post('/configure') -def configure_op(data: Union[ConfigureModel, - ConfigureListModel], - request: Request, background_tasks: BackgroundTasks): +def configure_op( + data: Union[ConfigureModel, ConfigureListModel], + request: Request, + background_tasks: BackgroundTasks, +): return _configure_op(data, request, background_tasks) @router.post('/configure-section') -def configure_section_op(data: Union[ConfigSectionModel, - ConfigSectionListModel, - ConfigSectionTreeModel], - request: Request, background_tasks: BackgroundTasks): +def configure_section_op( + data: Union[ConfigSectionModel, ConfigSectionListModel, ConfigSectionTreeModel], + request: Request, + background_tasks: BackgroundTasks, +): return _configure_op(data, request, background_tasks) -@router.post("/retrieve") +@router.post('/retrieve') async def retrieve_op(data: RetrieveModel): state = SessionState() session = state.session @@ -393,7 +454,7 @@ async def retrieve_op(data: RetrieveModel): config = Config(session_env=env) op = data.op - path = " ".join(data.path) + path = ' '.join(data.path) try: if op == 'returnValue': @@ -424,7 +485,7 @@ async def retrieve_op(data: RetrieveModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -448,7 +509,7 @@ def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): if data.file: path = data.file else: - return error(400, "Missing required field \"file\"") + return error(400, 'Missing required field "file"') session.migrate_and_load_config(path) @@ -466,7 +527,7 @@ def config_file_op(data: ConfigFileModel, background_tasks: BackgroundTasks): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(msg) @@ -484,14 +545,14 @@ def image_op(data: ImageModel): elif op == 'delete': res = session.remove_image(data.name) elif op == 'show': - res = session.show(["system", "image"]) + res = session.show(['system', 'image']) elif op == 'set_default': res = session.set_default_image(data.name) except ConfigSessionError as e: return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -508,13 +569,13 @@ def container_image_op(data: ContainerImageModel): if data.name: name = data.name else: - return error(400, "Missing required field \"name\"") + return error(400, 'Missing required field "name"') res = session.add_container_image(name) elif op == 'delete': if data.name: name = data.name else: - return error(400, "Missing required field \"name\"") + return error(400, 'Missing required field "name"') res = session.delete_container_image(name) elif op == 'show': res = session.show_container_image() @@ -524,7 +585,7 @@ def container_image_op(data: ContainerImageModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -546,7 +607,7 @@ def generate_op(data: GenerateModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -568,7 +629,7 @@ def show_op(data: ShowModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -590,7 +651,7 @@ def reboot_op(data: RebootModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -612,7 +673,7 @@ def reset_op(data: ResetModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) @@ -652,7 +713,7 @@ def import_pki(data: ImportPkiModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') finally: lock.release() @@ -676,18 +737,18 @@ def poweroff_op(data: PoweroffModel): return error(400, str(e)) except Exception: LOG.critical(traceback.format_exc()) - return error(500, "An internal error occured. Check the logs for details.") + return error(500, 'An internal error occured. Check the logs for details.') return success(res) -def rest_init(app: "FastAPI"): +def rest_init(app: 'FastAPI'): if all(r in app.routes for r in router.routes): return app.include_router(router) -def rest_clear(app: "FastAPI"): +def rest_clear(app: 'FastAPI'): for r in router.routes: if r in app.routes: app.routes.remove(r) diff --git a/src/services/api/session.py b/src/services/api/session.py index dcdc7246c5..ad3ef660c3 100644 --- a/src/services/api/session.py +++ b/src/services/api/session.py @@ -13,6 +13,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . + class SessionState: # pylint: disable=attribute-defined-outside-init # pylint: disable=too-many-instance-attributes,too-few-public-methods From c21fa1fb77264c0a92653b064824ac3bce5086ce Mon Sep 17 00:00:00 2001 From: John Estabrook Date: Mon, 30 Sep 2024 21:51:56 -0500 Subject: [PATCH 9/9] http-api: T6736: sanitize error message containing user input --- src/services/api/rest/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/services/api/rest/models.py b/src/services/api/rest/models.py index d65d6e1ec7..23ae9be9d9 100644 --- a/src/services/api/rest/models.py +++ b/src/services/api/rest/models.py @@ -17,6 +17,7 @@ # pylint: disable=too-few-public-methods import json +from html import escape from enum import Enum from typing import List from typing import Union @@ -31,6 +32,7 @@ def error(code, msg): + msg = escape(msg, quote=False) resp = {'success': False, 'error': msg, 'data': None} resp = json.dumps(resp) return HTMLResponse(resp, status_code=code)