Skip to content

Commit

Permalink
Merge branch 'main' into chore/bump-version-to-0.15.0
Browse files Browse the repository at this point in the history
  • Loading branch information
laipz8200 committed Jan 7, 2025
2 parents 27aace0 + dc650c5 commit 7769fdb
Show file tree
Hide file tree
Showing 147 changed files with 2,748 additions and 1,897 deletions.
3 changes: 3 additions & 0 deletions api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ FILES_ACCESS_TIMEOUT=300
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60

# Refresh token expiration time in days
REFRESH_TOKEN_EXPIRE_DAYS=30

# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1

Expand Down
5 changes: 4 additions & 1 deletion api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ def is_db_command():

app = create_migrations_app()
else:
if os.environ.get("FLASK_DEBUG", "False") != "True":
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey # type: ignore

# gevent
Expand Down
10 changes: 10 additions & 0 deletions api/configs/feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,11 @@ class AuthConfig(BaseSettings):
default=60,
)

REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field(
description="Expiration time for refresh tokens in days",
default=30,
)

LOGIN_LOCKOUT_DURATION: PositiveInt = Field(
description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.",
default=86400,
Expand Down Expand Up @@ -667,6 +672,11 @@ class IndexingConfig(BaseSettings):
default=4000,
)

CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field(
description="Maximum number of child chunks to preview",
default=50,
)


class MultiModalTransferConfig(BaseSettings):
MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field(
Expand Down
3 changes: 2 additions & 1 deletion api/controllers/console/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ def uuid_list(value):
)
parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)

args = parser.parse_args()

# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}

Expand Down
248 changes: 140 additions & 108 deletions api/core/app/apps/advanced_chat/generate_task_pipeline.py

Large diffs are not rendered by default.

162 changes: 89 additions & 73 deletions api/core/app/apps/workflow/generate_task_pipeline.py

Large diffs are not rendered by default.

11 changes: 0 additions & 11 deletions api/core/app/task_pipeline/based_generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from core.app.entities.task_entities import (
ErrorStreamResponse,
PingStreamResponse,
TaskState,
)
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
Expand All @@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""

_task_state: TaskState
_application_generate_entity: AppGenerateEntity

def __init__(
self,
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param user: user
:param stream: stream
"""
self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager
self._start_at = time.perf_counter()
Expand Down
17 changes: 13 additions & 4 deletions api/core/app/task_pipeline/message_cycle_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,19 @@


class MessageCycleManage:
_application_generate_entity: Union[
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
]
_task_state: Union[EasyUITaskState, WorkflowTaskState]
def __init__(
self,
*,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None:
self._application_generate_entity = application_generate_entity
self._task_state = task_state

def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
"""
Expand Down
58 changes: 35 additions & 23 deletions api/core/app/task_pipeline/workflow_cycle_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
Expand All @@ -58,13 +57,20 @@
WorkflowRunStatus,
)

from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
from .exc import WorkflowRunNotFoundError


class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables

def _handle_workflow_run_start(
self,
Expand Down Expand Up @@ -240,24 +246,26 @@ def _handle_workflow_run_failed(
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count

stmt = select(WorkflowNodeExecution).where(
stmt = select(WorkflowNodeExecution.node_execution_id).where(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)

running_workflow_node_executions = session.scalars(stmt).all()
ids = session.scalars(stmt).all()
# Use self._get_workflow_node_execution here to make sure the cache is updated
running_workflow_node_executions = [
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
]

for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (
workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds()
workflow_node_execution.finished_at = now
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()

if trace_manager:
trace_manager.add_trace_task(
Expand Down Expand Up @@ -299,6 +307,8 @@ def _handle_node_execution_start(
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)

session.add(workflow_node_execution)

self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution

def _handle_workflow_node_execution_success(
Expand Down Expand Up @@ -326,6 +336,7 @@ def _handle_workflow_node_execution_success(
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time

workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution

def _handle_workflow_node_execution_failed(
Expand Down Expand Up @@ -365,6 +376,7 @@ def _handle_workflow_node_execution_failed(
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata

workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution

def _handle_workflow_node_execution_retried(
Expand Down Expand Up @@ -416,6 +428,8 @@ def _handle_workflow_node_execution_retried(
workflow_node_execution.index = event.node_run_index

session.add(workflow_node_execution)

self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution

#################################################
Expand Down Expand Up @@ -812,22 +826,20 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any
return None

def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run
:param workflow_run_id: workflow run id
:return:
"""
if self._workflow_run and self._workflow_run.id == workflow_run_id:
cached_workflow_run = self._workflow_run
cached_workflow_run = session.merge(cached_workflow_run)
return cached_workflow_run
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt)
if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id)
self._workflow_run = workflow_run

return workflow_run

def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id)
workflow_node_execution = session.scalar(stmt)
if not workflow_node_execution:
raise WorkflowNodeExecutionNotFoundError(node_execution_id)

return workflow_node_execution
if node_execution_id not in self._workflow_node_executions:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return cached_workflow_node_execution
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from concurrent.futures import ProcessPoolExecutor
from os.path import abspath, dirname, join
from threading import Lock
from typing import Any, cast

import gevent.threadpool # type: ignore
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore

_tokenizer: Any = None
_lock = Lock()
_pool = gevent.threadpool.ThreadPool(1)
_executor = ProcessPoolExecutor(max_workers=1)


class GPT2Tokenizer:
Expand All @@ -22,8 +22,8 @@ def _get_num_tokens_by_gpt2(text: str) -> int:

@staticmethod
def get_num_tokens(text: str) -> int:
future = _pool.spawn(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
result = future.get(block=True)
future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
result = future.result()
return cast(int, result)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,5 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
super().validate_credentials(model, credentials)

@staticmethod
def _add_custom_parameters(credentials: dict, model: Optional[str]) -> None:
if model is None:
model = "bge-m3"

credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model}/v1/"
def _add_custom_parameters(credentials: dict, model: str) -> None:
credentials["endpoint_url"] = "https://ai.gitee.com/v1"
18 changes: 18 additions & 0 deletions api/core/model_runtime/model_providers/gpustack/gpustack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ supported_model_types:
- llm
- text-embedding
- rerank
- speech2text
- tts
configurate_methods:
- customizable-model
model_credential_schema:
Expand Down Expand Up @@ -118,3 +120,19 @@ model_credential_schema:
label:
en_US: Not Support
zh_Hans: 不支持
- variable: voices
show_on:
- variable: __model_type
value: tts
label:
en_US: Available Voices (comma-separated)
zh_Hans: 可用声音(用英文逗号分隔)
type: text-input
required: false
default: "Chinese Female"
placeholder:
en_US: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female"
zh_Hans: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female"
help:
en_US: "List voice names separated by commas. First voice will be used as default."
zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。"
16 changes: 10 additions & 6 deletions api/core/model_runtime/model_providers/gpustack/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from collections.abc import Generator

from yarl import URL

from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
Expand All @@ -24,9 +22,10 @@ def _invoke(
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
compatible_credentials = self._get_compatible_credentials(credentials)
return super()._invoke(
model,
credentials,
compatible_credentials,
prompt_messages,
model_parameters,
tools,
Expand All @@ -36,10 +35,15 @@ def _invoke(
)

def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
compatible_credentials = self._get_compatible_credentials(credentials)
super().validate_credentials(model, compatible_credentials)

def _get_compatible_credentials(self, credentials: dict) -> dict:
credentials = credentials.copy()
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
credentials["endpoint_url"] = f"{base_url}/v1-openai"
return credentials

@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
credentials["mode"] = "chat"
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import IO, Optional

from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel


class GPUStackSpeech2TextModel(OAICompatSpeech2TextModel):
"""
Model class for GPUStack Speech to text model.
"""

def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
compatible_credentials = self._get_compatible_credentials(credentials)
return super()._invoke(model, compatible_credentials, file)

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
"""
compatible_credentials = self._get_compatible_credentials(credentials)
super().validate_credentials(model, compatible_credentials)

def _get_compatible_credentials(self, credentials: dict) -> dict:
"""
Get compatible credentials
:param credentials: model credentials
:return: compatible credentials
"""
compatible_credentials = credentials.copy()
base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai")
compatible_credentials["endpoint_url"] = f"{base_url}/v1-openai"
return compatible_credentials
Loading

0 comments on commit 7769fdb

Please sign in to comment.