Skip to content

Commit

Permalink
refactor: enhance API token validation with session locking and last …
Browse files Browse the repository at this point in the history
…used timestamp update (#12426)

Signed-off-by: -LAN- <[email protected]>
  • Loading branch information
laipz8200 authored Jan 7, 2025
1 parent 41f39bf commit 0eeacdc
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 23 deletions.
37 changes: 22 additions & 15 deletions api/controllers/service_api/wraps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from datetime import UTC, datetime
from datetime import UTC, datetime, timedelta
from enum import Enum
from functools import wraps
from typing import Optional
Expand All @@ -8,6 +8,8 @@
from flask_login import user_logged_in # type: ignore
from flask_restful import Resource # type: ignore
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized

from extensions.ext_database import db
Expand Down Expand Up @@ -174,7 +176,7 @@ def decorated(*args, **kwargs):
return decorator


def validate_and_get_api_token(scope=None):
def validate_and_get_api_token(scope: str | None = None):
"""
Validate and get API token.
"""
Expand All @@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")

api_token = (
db.session.query(ApiToken)
.filter(
ApiToken.token == auth_token,
ApiToken.type == scope,
current_time = datetime.now(UTC).replace(tzinfo=None)
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
update(ApiToken)
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
.values(last_used_at=current_time)
.returning(ApiToken)
)
.first()
)

if not api_token:
raise Unauthorized("Access token is invalid")

api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
result = session.execute(update_stmt)
api_token = result.scalar_one_or_none()

if not api_token:
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
api_token = session.scalar(stmt)
if not api_token:
raise Unauthorized("Access token is invalid")
else:
session.commit()

return api_token

Expand Down
1 change: 1 addition & 0 deletions api/docker/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ else
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \
app:app
fi
Expand Down
8 changes: 4 additions & 4 deletions api/services/billing_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Literal, Optional

import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
Expand All @@ -17,7 +17,6 @@ def get_info(cls, tenant_id: str):
params = {"tenant_id": tenant_id}

billing_info = cls._send_request("GET", "/subscription/info", params=params)

return billing_info

@classmethod
Expand Down Expand Up @@ -47,12 +46,13 @@ def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""):
retry=retry_if_exception_type(httpx.RequestError),
reraise=True,
)
def _send_request(cls, method, endpoint, json=None, params=None):
def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None):
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}

url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers)

if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
return response.json()

@staticmethod
Expand Down
7 changes: 5 additions & 2 deletions docker/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,13 @@ DIFY_PORT=5001
# The number of API server workers, i.e., the number of workers.
# Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent
# Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers
SERVER_WORKER_AMOUNT=
SERVER_WORKER_AMOUNT=1

# Defaults to gevent. If using windows, it can be switched to sync or solo.
SERVER_WORKER_CLASS=
SERVER_WORKER_CLASS=gevent

# Default number of worker connections, the default is 10.
SERVER_WORKER_CONNECTIONS=10

# Similar to SERVER_WORKER_CLASS.
# If using windows, it can be switched to sync or solo.
Expand Down
5 changes: 3 additions & 2 deletions docker/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ x-shared-env: &shared-api-worker-env
APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200}
DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}
DIFY_PORT: ${DIFY_PORT:-5001}
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-}
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-}
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-1}
SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-gevent}
SERVER_WORKER_CONNECTIONS: ${SERVER_WORKER_CONNECTIONS:-10}
CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-}
GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360}
CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-}
Expand Down

0 comments on commit 0eeacdc

Please sign in to comment.