From 38057b1b0ed4398970dc34c78d4d67dec02b84c9 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 9 Sep 2025 17:48:33 +0900 Subject: [PATCH 01/86] add typing to all wraps (#25405) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/wraps.py | 11 +++++--- api/controllers/inner_api/plugin/wraps.py | 23 +++++++++------- api/controllers/inner_api/wraps.py | 4 +-- .../service_api/workspace/models.py | 2 +- api/controllers/service_api/wraps.py | 15 ++++++----- api/controllers/web/wraps.py | 10 +++---- .../vdb/matrixone/matrixone_vector.py | 27 ++++++++++--------- .../enterprise/plugin_manager_service.py | 15 +++++++---- 8 files changed, 61 insertions(+), 46 deletions(-) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index c7e300279a..5a871f896a 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional, Union +from typing import Optional, ParamSpec, TypeVar, Union from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db @@ -8,6 +8,9 @@ from libs.login import current_user from models import App, AppMode from models.account import Account +P = ParamSpec("P") +R = TypeVar("R") + def _load_app_model(app_id: str) -> Optional[App]: assert isinstance(current_user, Account) @@ -19,10 +22,10 @@ def _load_app_model(app_id: str) -> Optional[App]: return app_model -def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): - def decorator(view_func): +def get_app_model(view: Optional[Callable[P, R]] = None, *, mode: Union[AppMode, list[AppMode], None] = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): if not kwargs.get("app_id"): raise ValueError("missing app_id in path parameters") diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index f751e06ddf..68711f7257 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional +from typing import Optional, ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in @@ -14,6 +14,9 @@ from libs.login import _get_user from models.account import Tenant from models.model import EndUser +P = ParamSpec("P") +R = TypeVar("R") + def get_user(tenant_id: str, user_id: str | None) -> EndUser: """ @@ -52,19 +55,19 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: return user_model -def get_user_tenant(view: Optional[Callable] = None): - def decorator(view_func): +def get_user_tenant(view: Optional[Callable[P, R]] = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): # fetch json body parser = reqparse.RequestParser() parser.add_argument("tenant_id", type=str, required=True, location="json") parser.add_argument("user_id", type=str, required=True, location="json") - kwargs = parser.parse_args() + p = parser.parse_args() - user_id = kwargs.get("user_id") - tenant_id = kwargs.get("tenant_id") + user_id: Optional[str] = p.get("user_id") + tenant_id: str = p.get("tenant_id") if not tenant_id: raise ValueError("tenant_id is required") @@ -107,9 +110,9 @@ def get_user_tenant(view: Optional[Callable] = None): return decorator(view) -def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]): - def decorator(view_func): - def decorated_view(*args, **kwargs): +def plugin_data(view: Optional[Callable[P, R]] = None, *, payload_type: type[BaseModel]): + def decorator(view_func: Callable[P, R]): + def decorated_view(*args: P.args, **kwargs: P.kwargs): try: data = request.get_json() except Exception: diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index de4f1da801..4bdcc6832a 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -46,9 +46,9 @@ def enterprise_inner_api_only(view: Callable[P, R]): return decorated -def enterprise_inner_api_user_auth(view): +def enterprise_inner_api_user_auth(view: Callable[P, R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): if not dify_config.INNER_API: return view(*args, **kwargs) diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 536cf81a2f..fffcb47bd4 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -19,7 +19,7 @@ class ModelProviderAvailableModelApi(Resource): } ) @validate_dataset_token - def get(self, _, model_type): + def get(self, _, model_type: str): """Get available models by model type. Returns a list of available models for the specified model type. diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 14291578d5..4394e64ad9 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -3,7 +3,7 @@ from collections.abc import Callable from datetime import timedelta from enum import StrEnum, auto from functools import wraps -from typing import Optional, ParamSpec, TypeVar +from typing import Concatenate, Optional, ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in @@ -25,6 +25,7 @@ from services.feature_service import FeatureService P = ParamSpec("P") R = TypeVar("R") +T = TypeVar("T") class WhereisUserArg(StrEnum): @@ -42,10 +43,10 @@ class FetchUserArg(BaseModel): required: bool = False -def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): - def decorator(view_func): +def validate_app_token(view: Optional[Callable[P, R]] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): + def decorator(view_func: Callable[P, R]): @wraps(view_func) - def decorated_view(*args, **kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs): api_token = validate_and_get_api_token("app") app_model = db.session.query(App).where(App.id == api_token.app_id).first() @@ -189,10 +190,10 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): return interceptor -def validate_dataset_token(view=None): - def decorator(view): +def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None): + def decorator(view: Callable[Concatenate[T, P], R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): api_token = validate_and_get_api_token("dataset") tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 1fbb2c165f..e79456535a 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,6 +1,7 @@ +from collections.abc import Callable from datetime import UTC, datetime from functools import wraps -from typing import ParamSpec, TypeVar +from typing import Concatenate, Optional, ParamSpec, TypeVar from flask import request from flask_restx import Resource @@ -20,12 +21,11 @@ P = ParamSpec("P") R = TypeVar("R") -def validate_jwt_token(view=None): - def decorator(view): +def validate_jwt_token(view: Optional[Callable[Concatenate[App, EndUser, P], R]] = None): + def decorator(view: Callable[Concatenate[App, EndUser, P], R]): @wraps(view) - def decorated(*args, **kwargs): + def decorated(*args: P.args, **kwargs: P.kwargs): app_model, end_user = decode_jwt_token() - return view(app_model, end_user, *args, **kwargs) return decorated diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py index 7da830f643..3dd073ce50 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -1,8 +1,9 @@ import json import logging import uuid +from collections.abc import Callable from functools import wraps -from typing import Any, Optional +from typing import Any, Concatenate, Optional, ParamSpec, TypeVar from mo_vector.client import MoVectorClient # type: ignore from pydantic import BaseModel, model_validator @@ -17,7 +18,6 @@ from extensions.ext_redis import redis_client from models.dataset import Dataset logger = logging.getLogger(__name__) -from typing import ParamSpec, TypeVar P = ParamSpec("P") R = TypeVar("R") @@ -47,16 +47,6 @@ class MatrixoneConfig(BaseModel): return values -def ensure_client(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - if self.client is None: - self.client = self._get_client(None, False) - return func(self, *args, **kwargs) - - return wrapper - - class MatrixoneVector(BaseVector): """ Matrixone vector storage implementation. @@ -216,6 +206,19 @@ class MatrixoneVector(BaseVector): self.client.delete() +T = TypeVar("T", bound=MatrixoneVector) + + +def ensure_client(func: Callable[Concatenate[T, P], R]): + @wraps(func) + def wrapper(self: T, *args: P.args, **kwargs: P.kwargs): + if self.client is None: + self.client = self._get_client(None, False) + return func(self, *args, **kwargs) + + return wrapper + + class MatrixoneVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector: if dataset.index_struct_dict: diff --git a/api/services/enterprise/plugin_manager_service.py b/api/services/enterprise/plugin_manager_service.py index cfcc39416a..ee8a932ded 100644 --- a/api/services/enterprise/plugin_manager_service.py +++ b/api/services/enterprise/plugin_manager_service.py @@ -6,10 +6,12 @@ from pydantic import BaseModel from services.enterprise.base import EnterprisePluginManagerRequest from services.errors.base import BaseServiceError +logger = logging.getLogger(__name__) -class PluginCredentialType(enum.Enum): - MODEL = 0 - TOOL = 1 + +class PluginCredentialType(enum.IntEnum): + MODEL = enum.auto() + TOOL = enum.auto() def to_number(self): return self.value @@ -47,6 +49,9 @@ class PluginManagerService: if not ret.get("result", False): raise CredentialPolicyViolationError("Credentials not available: Please use ENTERPRISE global credentials") - logging.debug( - f"Credential policy compliance checked for {body.provider} with credential {body.dify_credential_id}, result: {ret.get('result', False)}" + logger.debug( + "Credential policy compliance checked for %s with credential %s, result: %s", + body.provider, + body.dify_credential_id, + ret.get("result", False), ) From 22cd97e2e0563ee0269d64e46ed267b47722e556 Mon Sep 17 00:00:00 2001 From: KVOJJJin Date: Tue, 9 Sep 2025 16:49:22 +0800 Subject: [PATCH 02/86] Fix: judgement of open in explore (#25420) --- web/app/components/apps/app-card.tsx | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index d0d42dc32c..e9a64d8867 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -279,12 +279,21 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { )} { - (isGettingUserCanAccessApp || !userCanAccessApp?.result) ? null : <> - - - + (!systemFeatures.webapp_auth.enabled) + ? <> + + + + : !(isGettingUserCanAccessApp || !userCanAccessApp?.result) && ( + <> + + + + ) } { From e5122945fe1fdb4584ac367729182da38530dadb Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 9 Sep 2025 17:00:00 +0800 Subject: [PATCH 03/86] Fix: Use --fix flag instead of --fix-only in autofix workflow (#25425) --- .github/workflows/autofix.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 82ba95444f..be6ce80dfc 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -20,7 +20,7 @@ jobs: cd api uv sync --dev # Fix lint errors - uv run ruff check --fix-only . + uv run ruff check --fix . # Format code uv run ruff format . - name: ast-grep From a1cf48f84e7af7c792f456d1caec8b2be271868a Mon Sep 17 00:00:00 2001 From: GuanMu Date: Tue, 9 Sep 2025 17:11:49 +0800 Subject: [PATCH 04/86] Add lib test (#25410) --- api/tests/unit_tests/libs/test_file_utils.py | 55 ++++++++++++ .../unit_tests/libs/test_json_in_md_parser.py | 88 +++++++++++++++++++ api/tests/unit_tests/libs/test_orjson.py | 25 ++++++ 3 files changed, 168 insertions(+) create mode 100644 api/tests/unit_tests/libs/test_file_utils.py create mode 100644 api/tests/unit_tests/libs/test_json_in_md_parser.py create mode 100644 api/tests/unit_tests/libs/test_orjson.py diff --git a/api/tests/unit_tests/libs/test_file_utils.py b/api/tests/unit_tests/libs/test_file_utils.py new file mode 100644 index 0000000000..8d9b4e803a --- /dev/null +++ b/api/tests/unit_tests/libs/test_file_utils.py @@ -0,0 +1,55 @@ +from pathlib import Path + +import pytest + +from libs.file_utils import search_file_upwards + + +def test_search_file_upwards_found_in_parent(tmp_path: Path): + base = tmp_path / "a" / "b" / "c" + base.mkdir(parents=True) + + target = tmp_path / "a" / "target.txt" + target.write_text("ok", encoding="utf-8") + + found = search_file_upwards(base, "target.txt", max_search_parent_depth=5) + assert found == target + + +def test_search_file_upwards_found_in_current(tmp_path: Path): + base = tmp_path / "x" + base.mkdir() + target = base / "here.txt" + target.write_text("x", encoding="utf-8") + + found = search_file_upwards(base, "here.txt", max_search_parent_depth=1) + assert found == target + + +def test_search_file_upwards_not_found_raises(tmp_path: Path): + base = tmp_path / "m" / "n" + base.mkdir(parents=True) + with pytest.raises(ValueError) as exc: + search_file_upwards(base, "missing.txt", max_search_parent_depth=3) + # error message should contain file name and base path + msg = str(exc.value) + assert "missing.txt" in msg + assert str(base) in msg + + +def test_search_file_upwards_root_breaks_and_raises(): + # Using filesystem root triggers the 'break' branch (parent == current) + with pytest.raises(ValueError): + search_file_upwards(Path("/"), "__definitely_not_exists__.txt", max_search_parent_depth=1) + + +def test_search_file_upwards_depth_limit_raises(tmp_path: Path): + base = tmp_path / "a" / "b" / "c" + base.mkdir(parents=True) + target = tmp_path / "a" / "target.txt" + target.write_text("ok", encoding="utf-8") + # The file is 2 levels up from `c` (in `a`), but search depth is only 2. + # The search path is `c` (depth 1) -> `b` (depth 2). The file is in `a` (would need depth 3). + # So, this should not find the file and should raise an error. + with pytest.raises(ValueError): + search_file_upwards(base, "target.txt", max_search_parent_depth=2) diff --git a/api/tests/unit_tests/libs/test_json_in_md_parser.py b/api/tests/unit_tests/libs/test_json_in_md_parser.py new file mode 100644 index 0000000000..53fd0bea16 --- /dev/null +++ b/api/tests/unit_tests/libs/test_json_in_md_parser.py @@ -0,0 +1,88 @@ +import pytest + +from core.llm_generator.output_parser.errors import OutputParserError +from libs.json_in_md_parser import ( + parse_and_check_json_markdown, + parse_json_markdown, +) + + +def test_parse_json_markdown_triple_backticks_json(): + src = """ + ```json + {"a": 1, "b": "x"} + ``` + """ + assert parse_json_markdown(src) == {"a": 1, "b": "x"} + + +def test_parse_json_markdown_triple_backticks_generic(): + src = """ + ``` + {"k": [1, 2, 3]} + ``` + """ + assert parse_json_markdown(src) == {"k": [1, 2, 3]} + + +def test_parse_json_markdown_single_backticks(): + src = '`{"x": true}`' + assert parse_json_markdown(src) == {"x": True} + + +def test_parse_json_markdown_braces_only(): + src = ' {\n \t"ok": "yes"\n} ' + assert parse_json_markdown(src) == {"ok": "yes"} + + +def test_parse_json_markdown_not_found(): + with pytest.raises(ValueError): + parse_json_markdown("no json here") + + +def test_parse_and_check_json_markdown_missing_key(): + src = """ + ``` + {"present": 1} + ``` + """ + with pytest.raises(OutputParserError) as exc: + parse_and_check_json_markdown(src, ["present", "missing"]) + assert "expected key `missing`" in str(exc.value) + + +def test_parse_and_check_json_markdown_invalid_json(): + src = """ + ```json + {invalid json} + ``` + """ + with pytest.raises(OutputParserError) as exc: + parse_and_check_json_markdown(src, []) + assert "got invalid json object" in str(exc.value) + + +def test_parse_and_check_json_markdown_success(): + src = """ + ```json + {"present": 1, "other": 2} + ``` + """ + obj = parse_and_check_json_markdown(src, ["present"]) + assert obj == {"present": 1, "other": 2} + + +def test_parse_and_check_json_markdown_multiple_blocks_fails(): + src = """ + ```json + {"a": 1} + ``` + Some text + ```json + {"b": 2} + ``` + """ + # The current implementation is greedy and will match from the first + # opening fence to the last closing fence, causing JSON decode failure. + with pytest.raises(OutputParserError): + parse_and_check_json_markdown(src, []) diff --git a/api/tests/unit_tests/libs/test_orjson.py b/api/tests/unit_tests/libs/test_orjson.py new file mode 100644 index 0000000000..6df1d077df --- /dev/null +++ b/api/tests/unit_tests/libs/test_orjson.py @@ -0,0 +1,25 @@ +import orjson +import pytest + +from libs.orjson import orjson_dumps + + +def test_orjson_dumps_round_trip_basic(): + obj = {"a": 1, "b": [1, 2, 3], "c": {"d": True}} + s = orjson_dumps(obj) + assert orjson.loads(s) == obj + + +def test_orjson_dumps_with_unicode_and_indent(): + obj = {"msg": "你好,Dify"} + s = orjson_dumps(obj, option=orjson.OPT_INDENT_2) + # contains indentation newline/spaces + assert "\n" in s + assert orjson.loads(s) == obj + + +def test_orjson_dumps_non_utf8_encoding_fails(): + obj = {"msg": "你好"} + # orjson.dumps() always produces UTF-8 bytes; decoding with non-UTF8 fails. + with pytest.raises(UnicodeDecodeError): + orjson_dumps(obj, encoding="ascii") From 7443c5a6fcb7af3e8d7b723a29d0ceeb00cef242 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 9 Sep 2025 17:12:45 +0800 Subject: [PATCH 05/86] refactor: update pyrightconfig to scan all API files (#25429) --- api/pyrightconfig.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index a3a5f2044e..352161523f 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -1,5 +1,5 @@ { - "include": ["models", "configs"], + "include": ["."], "exclude": [".venv", "tests/", "migrations/"], "ignore": [ "core/", From 240b65b980cbc3d679d348ee852c9e7246b4979e Mon Sep 17 00:00:00 2001 From: Novice Date: Tue, 9 Sep 2025 20:06:35 +0800 Subject: [PATCH 06/86] fix(mcp): properly handle arrays containing both numbers and strings (#25430) Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/core/tools/mcp_tool/tool.py | 50 +++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 6810ac683d..21d256ae03 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -67,22 +67,42 @@ class MCPTool(Tool): for content in result.content: if isinstance(content, TextContent): - try: - content_json = json.loads(content.text) - if isinstance(content_json, dict): - yield self.create_json_message(content_json) - elif isinstance(content_json, list): - for item in content_json: - yield self.create_json_message(item) - else: - yield self.create_text_message(content.text) - except json.JSONDecodeError: - yield self.create_text_message(content.text) - + yield from self._process_text_content(content) elif isinstance(content, ImageContent): - yield self.create_blob_message( - blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType} - ) + yield self._process_image_content(content) + + def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]: + """Process text content and yield appropriate messages.""" + try: + content_json = json.loads(content.text) + yield from self._process_json_content(content_json) + except json.JSONDecodeError: + yield self.create_text_message(content.text) + + def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]: + """Process JSON content based on its type.""" + if isinstance(content_json, dict): + yield self.create_json_message(content_json) + elif isinstance(content_json, list): + yield from self._process_json_list(content_json) + else: + # For primitive types (str, int, bool, etc.), convert to string + yield self.create_text_message(str(content_json)) + + def _process_json_list(self, json_list: list) -> Generator[ToolInvokeMessage, None, None]: + """Process a list of JSON items.""" + if any(not isinstance(item, dict) for item in json_list): + # If the list contains any non-dict item, treat the entire list as a text message. + yield self.create_text_message(str(json_list)) + return + + # Otherwise, process each dictionary as a separate JSON message. + for item in json_list: + yield self.create_json_message(item) + + def _process_image_content(self, content: ImageContent) -> ToolInvokeMessage: + """Process image content and return a blob message.""" + return self.create_blob_message(blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}) def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool": return MCPTool( From 2ac7a9c8fc586c0895ec329cca005a41e6700922 Mon Sep 17 00:00:00 2001 From: Yongtao Huang Date: Tue, 9 Sep 2025 20:07:17 +0800 Subject: [PATCH 07/86] Chore: thanks to bump-pydantic (#25437) --- api/core/app/entities/app_invoke_entities.py | 2 +- api/core/app/entities/queue_entities.py | 4 ++-- api/core/app/entities/task_entities.py | 4 ++-- api/core/entities/provider_entities.py | 2 +- api/core/mcp/types.py | 2 +- api/core/ops/entities/trace_entity.py | 10 +++++----- api/core/plugin/entities/plugin_daemon.py | 2 +- .../datasource/vdb/huawei/huawei_cloud_vector.py | 4 ++-- .../rag/datasource/vdb/tencent/tencent_vector.py | 6 +++--- api/core/variables/segments.py | 2 +- api/core/workflow/nodes/base/entities.py | 2 +- .../nodes/variable_assigner/common/helpers.py | 2 +- api/services/app_dsl_service.py | 14 +++++++------- 13 files changed, 28 insertions(+), 28 deletions(-) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 72b62eb67c..9151137fe8 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -95,7 +95,7 @@ class AppGenerateEntity(BaseModel): task_id: str # app config - app_config: Any + app_config: Any = None file_upload_config: Optional[FileUploadConfig] = None inputs: Mapping[str, Any] diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index db0297c352..fc04e60836 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -432,8 +432,8 @@ class QueueAgentLogEvent(AppQueueEvent): id: str label: str node_execution_id: str - parent_id: str | None - error: str | None + parent_id: str | None = None + error: str | None = None status: str data: Mapping[str, Any] metadata: Optional[Mapping[str, Any]] = None diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index a1c0368354..29f3e3427e 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -828,8 +828,8 @@ class AgentLogStreamResponse(StreamResponse): node_execution_id: str id: str label: str - parent_id: str | None - error: str | None + parent_id: str | None = None + error: str | None = None status: str data: Mapping[str, Any] metadata: Optional[Mapping[str, Any]] = None diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 9b8baf1973..52acbc1eef 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -107,7 +107,7 @@ class CustomModelConfiguration(BaseModel): model: str model_type: ModelType - credentials: dict | None + credentials: dict | None = None current_credential_id: Optional[str] = None current_credential_name: Optional[str] = None available_model_credentials: list[CredentialConfiguration] = [] diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 49aa8e4498..a2c3157b3b 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -809,7 +809,7 @@ class LoggingMessageNotificationParams(NotificationParams): """The severity of this log message.""" logger: str | None = None """An optional name of the logger issuing this message.""" - data: Any + data: Any = None """ The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here. diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 3bad5c92fb..1870da3781 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -35,7 +35,7 @@ class BaseTraceInfo(BaseModel): class WorkflowTraceInfo(BaseTraceInfo): - workflow_data: Any + workflow_data: Any = None conversation_id: Optional[str] = None workflow_app_log_id: Optional[str] = None workflow_id: str @@ -89,7 +89,7 @@ class SuggestedQuestionTraceInfo(BaseTraceInfo): class DatasetRetrievalTraceInfo(BaseTraceInfo): - documents: Any + documents: Any = None class ToolTraceInfo(BaseTraceInfo): @@ -97,12 +97,12 @@ class ToolTraceInfo(BaseTraceInfo): tool_inputs: dict[str, Any] tool_outputs: str metadata: dict[str, Any] - message_file_data: Any + message_file_data: Any = None error: Optional[str] = None tool_config: dict[str, Any] time_cost: Union[int, float] tool_parameters: dict[str, Any] - file_url: Union[str, None, list] + file_url: Union[str, None, list] = None class GenerateNameTraceInfo(BaseTraceInfo): @@ -113,7 +113,7 @@ class GenerateNameTraceInfo(BaseTraceInfo): class TaskData(BaseModel): app_id: str trace_info_type: str - trace_info: Any + trace_info: Any = None trace_info_info_map = { diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 16ab661092..f1d6860bb4 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -24,7 +24,7 @@ class PluginDaemonBasicResponse(BaseModel, Generic[T]): code: int message: str - data: Optional[T] + data: Optional[T] = None class InstallPluginMessage(BaseModel): diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index 107ea75e6a..0eca37a129 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -28,8 +28,8 @@ def create_ssl_context() -> ssl.SSLContext: class HuaweiCloudVectorConfig(BaseModel): hosts: str - username: str | None - password: str | None + username: str | None = None + password: str | None = None @model_validator(mode="before") @classmethod diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 4af34bbb2d..2485857070 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -24,10 +24,10 @@ logger = logging.getLogger(__name__) class TencentConfig(BaseModel): url: str - api_key: Optional[str] + api_key: Optional[str] = None timeout: float = 30 - username: Optional[str] - database: Optional[str] + username: Optional[str] = None + database: Optional[str] = None index_type: str = "HNSW" metric_type: str = "IP" shard: int = 1 diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index cfef193633..7da43a6504 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -19,7 +19,7 @@ class Segment(BaseModel): model_config = ConfigDict(frozen=True) value_type: SegmentType - value: Any + value: Any = None @field_validator("value_type") @classmethod diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 708da21177..90e45e9d25 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -23,7 +23,7 @@ NumberType = Union[int, float] class DefaultValue(BaseModel): - value: Any + value: Any = None type: DefaultValueType key: str diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/core/workflow/nodes/variable_assigner/common/helpers.py index 8caee27363..04a7323739 100644 --- a/api/core/workflow/nodes/variable_assigner/common/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/common/helpers.py @@ -16,7 +16,7 @@ class UpdatedVariable(BaseModel): name: str selector: Sequence[str] value_type: SegmentType - new_value: Any + new_value: Any = None _T = TypeVar("_T", bound=MutableMapping[str, Any]) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 2ed73ffec1..49ff28d191 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -99,17 +99,17 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus: class PendingData(BaseModel): import_mode: str yaml_content: str - name: str | None - description: str | None - icon_type: str | None - icon: str | None - icon_background: str | None - app_id: str | None + name: str | None = None + description: str | None = None + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + app_id: str | None = None class CheckDependenciesPendingData(BaseModel): dependencies: list[PluginDependency] - app_id: str | None + app_id: str | None = None class AppDslService: From 08dd3f7b5079fe9171351ea79054302c915e42d1 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 10 Sep 2025 01:54:26 +0800 Subject: [PATCH 08/86] Fix basedpyright type errors (#25435) Signed-off-by: -LAN- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/commands.py | 18 +++- api/constants/__init__.py | 12 +-- api/contexts/__init__.py | 1 - api/controllers/console/__init__.py | 100 ++++++++++-------- api/controllers/console/apikey.py | 13 +-- api/controllers/console/app/app.py | 30 ++++-- api/controllers/console/app/audio.py | 4 +- api/controllers/console/app/completion.py | 28 ++--- api/controllers/console/app/conversation.py | 6 +- api/controllers/console/app/message.py | 13 ++- api/controllers/console/app/site.py | 6 +- api/controllers/console/app/statistic.py | 12 +-- .../console/app/workflow_statistic.py | 6 +- api/controllers/console/auth/oauth.py | 5 +- api/controllers/console/explore/completion.py | 11 +- .../console/explore/conversation.py | 13 ++- .../console/explore/installed_app.py | 13 ++- api/controllers/console/explore/message.py | 11 +- .../console/explore/recommended_app.py | 8 +- .../console/explore/saved_message.py | 9 +- api/controllers/console/files.py | 3 + api/controllers/console/version.py | 6 +- api/controllers/console/workspace/account.py | 32 ++++++ api/controllers/console/workspace/members.py | 59 +++++++++-- .../console/workspace/model_providers.py | 37 +++++++ .../console/workspace/workspace.py | 24 ++++- api/controllers/files/__init__.py | 2 +- api/controllers/inner_api/__init__.py | 6 +- api/controllers/inner_api/plugin/plugin.py | 30 +++--- api/controllers/inner_api/plugin/wraps.py | 10 +- api/controllers/mcp/__init__.py | 2 +- api/controllers/service_api/__init__.py | 26 ++++- .../service_api/app/conversation.py | 3 +- .../service_api/dataset/document.py | 6 ++ api/controllers/service_api/wraps.py | 4 +- api/controllers/web/__init__.py | 28 ++--- api/core/__init__.py | 1 - api/core/agent/cot_agent_runner.py | 2 + api/core/agent/fc_agent_runner.py | 1 + .../sensitive_word_avoidance/manager.py | 11 +- .../prompt_template/manager.py | 10 +- .../generate_response_converter.py | 12 +-- .../advanced_chat/generate_task_pipeline.py | 24 ++--- .../app/apps/agent_chat/app_config_manager.py | 34 +++--- .../agent_chat/generate_response_converter.py | 11 +- api/core/app/apps/base_app_queue_manager.py | 1 + .../apps/chat/generate_response_converter.py | 11 +- api/core/app/apps/completion/app_generator.py | 2 + .../completion/generate_response_converter.py | 13 ++- .../workflow/generate_response_converter.py | 10 +- .../apps/workflow/generate_task_pipeline.py | 10 +- api/core/app/entities/app_invoke_entities.py | 6 +- api/core/app/entities/task_entities.py | 7 -- .../annotation_reply/annotation_reply.py | 3 + .../app/features/rate_limiting/__init__.py | 2 + .../app/features/rate_limiting/rate_limit.py | 2 +- .../based_generate_task_pipeline.py | 22 ++-- .../easy_ui_based_generate_task_pipeline.py | 22 ++-- .../base/tts/app_generator_tts_publisher.py | 6 +- api/core/entities/provider_configuration.py | 8 +- api/core/file/file_manager.py | 6 +- api/core/file/models.py | 8 ++ api/core/helper/ssrf_proxy.py | 14 +-- api/core/indexing_runner.py | 7 +- api/core/llm_generator/llm_generator.py | 12 ++- .../output_parser/structured_output.py | 14 ++- api/core/mcp/client/sse_client.py | 8 +- api/core/mcp/server/streamable_http.py | 28 ++--- api/core/mcp/session/base_session.py | 12 +-- .../__base/large_language_model.py | 2 +- api/core/plugin/entities/parameters.py | 5 +- api/core/plugin/utils/chunk_merger.py | 4 +- api/core/prompt/simple_prompt_transform.py | 32 ++++-- .../datasource/vdb/qdrant/qdrant_vector.py | 35 ++++-- ...lery_workflow_node_execution_repository.py | 4 +- api/core/variables/segment_group.py | 2 +- api/core/variables/segments.py | 24 ++--- api/core/workflow/errors.py | 4 +- api/core/workflow/nodes/list_operator/node.py | 4 +- api/core/workflow/nodes/llm/node.py | 3 +- api/factories/file_factory.py | 4 +- api/fields/_value_type_serializer.py | 5 +- api/libs/external_api.py | 14 ++- api/libs/helper.py | 7 -- api/pyrightconfig.json | 54 +++++++--- api/services/account_service.py | 4 +- api/services/annotation_service.py | 54 ++++++---- .../clear_free_plan_tenant_expired_logs.py | 1 + api/services/dataset_service.py | 66 ++---------- api/services/external_knowledge_service.py | 2 +- api/services/file_service.py | 4 +- api/services/model_load_balancing_service.py | 17 +-- api/services/plugin/plugin_migration.py | 1 + .../tools/builtin_tools_manage_service.py | 10 +- api/services/workflow/workflow_converter.py | 16 ++- api/services/workflow_service.py | 4 +- api/services/workspace_service.py | 2 +- .../services/test_account_service.py | 4 +- .../workflow/test_workflow_converter.py | 3 +- .../services/test_account_service.py | 16 +-- 100 files changed, 847 insertions(+), 497 deletions(-) diff --git a/api/commands.py b/api/commands.py index 9b13cc2e1a..2bef83b2a7 100644 --- a/api/commands.py +++ b/api/commands.py @@ -511,7 +511,7 @@ def add_qdrant_index(field: str): from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.models import PayloadSchemaType - from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig + from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig for binding in bindings: if dify_config.QDRANT_URL is None: @@ -525,7 +525,21 @@ def add_qdrant_index(field: str): prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, ) try: - client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params()) + params = qdrant_config.to_qdrant_params() + # Check the type before using + if isinstance(params, PathQdrantParams): + # PathQdrantParams case + client = qdrant_client.QdrantClient(path=params.path) + else: + # UrlQdrantParams case - params is UrlQdrantParams + client = qdrant_client.QdrantClient( + url=params.url, + api_key=params.api_key, + timeout=int(params.timeout), + verify=params.verify, + grpc_port=params.grpc_port, + prefer_grpc=params.prefer_grpc, + ) # create payload index client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) create_count += 1 diff --git a/api/constants/__init__.py b/api/constants/__init__.py index c98f4d55c8..fe8f4f8785 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -16,14 +16,14 @@ AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"] AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) +_doc_extensions: list[str] if dify_config.ETL_TYPE == "Unstructured": - DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] - DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) + _doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"] + _doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub")) if dify_config.UNSTRUCTURED_API_URL: - DOCUMENT_EXTENSIONS.append("ppt") - DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) + _doc_extensions.append("ppt") else: - DOCUMENT_EXTENSIONS = [ + _doc_extensions = [ "txt", "markdown", "md", @@ -38,4 +38,4 @@ else: "vtt", "properties", ] - DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) +DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions] diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index ae41a2c03a..a07e6a08a6 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -8,7 +8,6 @@ if TYPE_CHECKING: from core.model_runtime.entities.model_entities import AIModelEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController - from core.workflow.entities.variable_pool import VariablePool """ diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 5ad7645969..9a8e840554 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -43,56 +43,64 @@ api.add_resource(AppImportConfirmApi, "/apps/imports//confirm" api.add_resource(AppImportCheckDependenciesApi, "/apps/imports//check-dependencies") # Import other controllers -from . import admin, apikey, extension, feature, ping, setup, version +from . import admin, apikey, extension, feature, ping, setup, version # pyright: ignore[reportUnusedImport] # Import app controllers from .app import ( - advanced_prompt_template, - agent, - annotation, - app, - audio, - completion, - conversation, - conversation_variables, - generator, - mcp_server, - message, - model_config, - ops_trace, - site, - statistic, - workflow, - workflow_app_log, - workflow_draft_variable, - workflow_run, - workflow_statistic, + advanced_prompt_template, # pyright: ignore[reportUnusedImport] + agent, # pyright: ignore[reportUnusedImport] + annotation, # pyright: ignore[reportUnusedImport] + app, # pyright: ignore[reportUnusedImport] + audio, # pyright: ignore[reportUnusedImport] + completion, # pyright: ignore[reportUnusedImport] + conversation, # pyright: ignore[reportUnusedImport] + conversation_variables, # pyright: ignore[reportUnusedImport] + generator, # pyright: ignore[reportUnusedImport] + mcp_server, # pyright: ignore[reportUnusedImport] + message, # pyright: ignore[reportUnusedImport] + model_config, # pyright: ignore[reportUnusedImport] + ops_trace, # pyright: ignore[reportUnusedImport] + site, # pyright: ignore[reportUnusedImport] + statistic, # pyright: ignore[reportUnusedImport] + workflow, # pyright: ignore[reportUnusedImport] + workflow_app_log, # pyright: ignore[reportUnusedImport] + workflow_draft_variable, # pyright: ignore[reportUnusedImport] + workflow_run, # pyright: ignore[reportUnusedImport] + workflow_statistic, # pyright: ignore[reportUnusedImport] ) # Import auth controllers -from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server +from .auth import ( + activate, # pyright: ignore[reportUnusedImport] + data_source_bearer_auth, # pyright: ignore[reportUnusedImport] + data_source_oauth, # pyright: ignore[reportUnusedImport] + forgot_password, # pyright: ignore[reportUnusedImport] + login, # pyright: ignore[reportUnusedImport] + oauth, # pyright: ignore[reportUnusedImport] + oauth_server, # pyright: ignore[reportUnusedImport] +) # Import billing controllers -from .billing import billing, compliance +from .billing import billing, compliance # pyright: ignore[reportUnusedImport] # Import datasets controllers from .datasets import ( - data_source, - datasets, - datasets_document, - datasets_segments, - external, - hit_testing, - metadata, - website, + data_source, # pyright: ignore[reportUnusedImport] + datasets, # pyright: ignore[reportUnusedImport] + datasets_document, # pyright: ignore[reportUnusedImport] + datasets_segments, # pyright: ignore[reportUnusedImport] + external, # pyright: ignore[reportUnusedImport] + hit_testing, # pyright: ignore[reportUnusedImport] + metadata, # pyright: ignore[reportUnusedImport] + website, # pyright: ignore[reportUnusedImport] ) # Import explore controllers from .explore import ( - installed_app, - parameter, - recommended_app, - saved_message, + installed_app, # pyright: ignore[reportUnusedImport] + parameter, # pyright: ignore[reportUnusedImport] + recommended_app, # pyright: ignore[reportUnusedImport] + saved_message, # pyright: ignore[reportUnusedImport] ) # Explore Audio @@ -167,18 +175,18 @@ api.add_resource( ) # Import tag controllers -from .tag import tags +from .tag import tags # pyright: ignore[reportUnusedImport] # Import workspace controllers from .workspace import ( - account, - agent_providers, - endpoint, - load_balancing_config, - members, - model_providers, - models, - plugin, - tool_providers, - workspace, + account, # pyright: ignore[reportUnusedImport] + agent_providers, # pyright: ignore[reportUnusedImport] + endpoint, # pyright: ignore[reportUnusedImport] + load_balancing_config, # pyright: ignore[reportUnusedImport] + members, # pyright: ignore[reportUnusedImport] + model_providers, # pyright: ignore[reportUnusedImport] + models, # pyright: ignore[reportUnusedImport] + plugin, # pyright: ignore[reportUnusedImport] + tool_providers, # pyright: ignore[reportUnusedImport] + workspace, # pyright: ignore[reportUnusedImport] ) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index cfd5f73ade..58a1d437d1 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,8 +1,9 @@ -from typing import Any, Optional +from typing import Optional import flask_restx from flask_login import current_user from flask_restx import Resource, fields, marshal_with +from flask_restx._http import HTTPStatus from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -40,7 +41,7 @@ def _get_resource(resource_id, tenant_id, resource_model): ).scalar_one_or_none() if resource is None: - flask_restx.abort(404, message=f"{resource_model.__name__} not found.") + flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.") return resource @@ -49,7 +50,7 @@ class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Optional[Any] = None + resource_model: Optional[type] = None resource_id_field: str | None = None token_prefix: str | None = None max_keys = 10 @@ -82,7 +83,7 @@ class BaseApiKeyListResource(Resource): if current_key_count >= self.max_keys: flask_restx.abort( - 400, + HTTPStatus.BAD_REQUEST, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", custom="max_keys_exceeded", ) @@ -102,7 +103,7 @@ class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] resource_type: str | None = None - resource_model: Optional[Any] = None + resource_model: Optional[type] = None resource_id_field: str | None = None def delete(self, resource_id, api_key_id): @@ -126,7 +127,7 @@ class BaseApiKeyResource(Resource): ) if key is None: - flask_restx.abort(404, message="API key not found") + flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found") db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 10753d2f95..1db9d2e764 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -115,6 +115,10 @@ class AppListApi(Resource): raise BadRequest("mode is required") app_service = AppService() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + if current_user.current_tenant_id is None: + raise ValueError("current_user.current_tenant_id cannot be None") app = app_service.create_app(current_user.current_tenant_id, args, current_user) return app, 201 @@ -161,14 +165,26 @@ class AppApi(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app(app_model, args) + # Construct ArgsDict from parsed arguments + from services.app_service import AppService as AppServiceType + + args_dict: AppServiceType.ArgsDict = { + "name": args["name"], + "description": args.get("description", ""), + "icon_type": args.get("icon_type", ""), + "icon": args.get("icon", ""), + "icon_background": args.get("icon_background", ""), + "use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False), + "max_active_requests": args.get("max_active_requests", 0), + } + app_model = app_service.update_app(app_model, args_dict) return app_model + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def delete(self, app_model): """Delete app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -224,10 +240,10 @@ class AppCopyApi(Resource): class AppExportApi(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): """Export app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -263,7 +279,7 @@ class AppNameApi(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_name(app_model, args.get("name")) + app_model = app_service.update_app_name(app_model, args["name"]) return app_model @@ -285,7 +301,7 @@ class AppIconApi(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) + app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "") return app_model @@ -306,7 +322,7 @@ class AppSiteStatus(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) + app_model = app_service.update_app_site_status(app_model, args["enable_site"]) return app_model @@ -327,7 +343,7 @@ class AppApiStatus(Resource): args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) + app_model = app_service.update_app_api_status(app_model, args["enable_api"]) return app_model diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index aaf5c3dfaa..447bcb37c2 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -77,10 +77,10 @@ class ChatMessageAudioApi(Resource): class ChatMessageTextApi(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def post(self, app_model: App): try: parser = reqparse.RequestParser() @@ -125,10 +125,10 @@ class ChatMessageTextApi(Resource): class TextModesApi(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): try: parser = reqparse.RequestParser() diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 701ebb0b4a..2083c15a9b 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,6 +1,5 @@ import logging -import flask_login from flask import request from flask_restx import Resource, reqparse from werkzeug.exceptions import InternalServerError, NotFound @@ -29,7 +28,8 @@ from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value -from libs.login import login_required +from libs.login import current_user, login_required +from models import Account from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -56,11 +56,11 @@ class CompletionMessageApi(Resource): streaming = args["response_mode"] != "blocking" args["auto_generate_name"] = False - account = flask_login.current_user - try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account or EndUser instance") response = AppGenerateService.generate( - app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -92,9 +92,9 @@ class CompletionMessageStopApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model, task_id): - account = flask_login.current_user - - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) return {"result": "success"}, 200 @@ -123,11 +123,11 @@ class ChatMessageApi(Resource): if external_trace_id: args["external_trace_id"] = external_trace_id - account = flask_login.current_user - try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account or EndUser instance") response = AppGenerateService.generate( - app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -161,9 +161,9 @@ class ChatMessageStopApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def post(self, app_model, task_id): - account = flask_login.current_user - - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) return {"result": "success"}, 200 diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index bc825effad..2f2cd66aaa 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -22,7 +22,7 @@ from fields.conversation_fields import ( from libs.datetime_utils import naive_utc_now from libs.helper import DatetimeString from libs.login import login_required -from models import Conversation, EndUser, Message, MessageAnnotation +from models import Account, Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError @@ -124,6 +124,8 @@ class CompletionConversationDetailApi(Resource): conversation_id = str(conversation_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -282,6 +284,8 @@ class ChatConversationDetailApi(Resource): conversation_id = str(conversation_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index f0605a37f9..272f360c06 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,6 +1,5 @@ import logging -from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx.inputs import int_range from sqlalchemy import exists, select @@ -27,7 +26,8 @@ from extensions.ext_database import db from fields.conversation_fields import annotation_fields, message_detail_fields from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination -from libs.login import login_required +from libs.login import current_user, login_required +from models.account import Account from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.annotation_service import AppAnnotationService from services.errors.conversation import ConversationNotExistsError @@ -118,11 +118,14 @@ class ChatMessageListApi(Resource): class MessageFeedbackApi(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def post(self, app_model): + if current_user is None: + raise Forbidden() + parser = reqparse.RequestParser() parser.add_argument("message_id", required=True, type=uuid_value, location="json") parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") @@ -167,6 +170,8 @@ class MessageAnnotationApi(Resource): @get_app_model @marshal_with(annotation_fields) def post(self, app_model): + if not isinstance(current_user, Account): + raise Forbidden() if not current_user.is_editor: raise Forbidden() @@ -182,10 +187,10 @@ class MessageAnnotationApi(Resource): class MessageAnnotationCountApi(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 778ce92da6..871efd989c 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -10,7 +10,7 @@ from extensions.ext_database import db from fields.app_fields import app_site_fields from libs.datetime_utils import naive_utc_now from libs.login import login_required -from models import Site +from models import Account, Site def parse_app_site_args(): @@ -75,6 +75,8 @@ class AppSite(Resource): if value is not None: setattr(site, attr_name, value) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") site.updated_by = current_user.id site.updated_at = naive_utc_now() db.session.commit() @@ -99,6 +101,8 @@ class AppSiteAccessTokenReset(Resource): raise NotFound site.code = Site.generate_code(16) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") site.updated_by = current_user.id site.updated_at = naive_utc_now() db.session.commit() diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 27e405af38..2116732c73 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -18,10 +18,10 @@ from models import AppMode, Message class DailyMessageStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -75,10 +75,10 @@ WHERE class DailyConversationStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -127,10 +127,10 @@ class DailyConversationStatistic(Resource): class DailyTerminalsStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -184,10 +184,10 @@ WHERE class DailyTokenCostStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -320,10 +320,10 @@ ORDER BY class UserSatisfactionRateStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -443,10 +443,10 @@ WHERE class TokensPerSecondStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 7cef175c14..da7216086e 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -18,10 +18,10 @@ from models.model import AppMode class WorkflowDailyRunsStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -80,10 +80,10 @@ WHERE class WorkflowDailyTerminalsStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user @@ -142,10 +142,10 @@ WHERE class WorkflowDailyTokenCostStatistic(Resource): + @get_app_model @setup_required @login_required @account_initialization_required - @get_app_model def get(self, app_model): account = current_user diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 332a98c474..06151ee39b 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -77,6 +77,9 @@ class OAuthCallback(Resource): if state: invite_token = state + if not code: + return {"error": "Authorization code is required"}, 400 + try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) @@ -86,7 +89,7 @@ class OAuthCallback(Resource): return {"error": "OAuth process failed"}, 400 if invite_token and RegisterService.is_valid_invite_token(invite_token): - invitation = RegisterService._get_invitation_by_token(token=invite_token) + invitation = RegisterService.get_invitation_by_token(token=invite_token) if invitation: invitation_email = invitation.get("email", None) if invitation_email != user_info.email: diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index cc46f54ea3..a99708b7cd 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,6 +1,5 @@ import logging -from flask_login import current_user from flask_restx import reqparse from werkzeug.exceptions import InternalServerError, NotFound @@ -28,6 +27,8 @@ from extensions.ext_database import db from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import uuid_value +from libs.login import current_user +from models import Account from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -57,6 +58,8 @@ class CompletionApi(InstalledAppResource): db.session.commit() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming ) @@ -90,6 +93,8 @@ class CompletionStopApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"}, 200 @@ -117,6 +122,8 @@ class ChatApi(InstalledAppResource): db.session.commit() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) @@ -153,6 +160,8 @@ class ChatStopApi(InstalledAppResource): if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"}, 200 diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 43ad3ecfbd..1aef9c544d 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,4 +1,3 @@ -from flask_login import current_user from flask_restx import marshal_with, reqparse from flask_restx.inputs import int_range from sqlalchemy.orm import Session @@ -10,6 +9,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value +from libs.login import current_user +from models import Account from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError @@ -35,6 +36,8 @@ class ConversationListApi(InstalledAppResource): pinned = args["pinned"] == "true" try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") with Session(db.engine) as session: return WebConversationService.pagination_by_last_id( session=session, @@ -58,6 +61,8 @@ class ConversationApi(InstalledAppResource): conversation_id = str(c_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -81,6 +86,8 @@ class ConversationRenameApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") return ConversationService.rename( app_model, conversation_id, current_user, args["name"], args["auto_generate"] ) @@ -98,6 +105,8 @@ class ConversationPinApi(InstalledAppResource): conversation_id = str(c_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") WebConversationService.pin(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -113,6 +122,8 @@ class ConversationUnPinApi(InstalledAppResource): raise NotChatAppError() conversation_id = str(c_id) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") WebConversationService.unpin(app_model, conversation_id, current_user) return {"result": "success"} diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3ccedd654b..22aa753d92 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -2,7 +2,6 @@ import logging from typing import Any from flask import request -from flask_login import current_user from flask_restx import Resource, inputs, marshal_with, reqparse from sqlalchemy import and_ from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -13,8 +12,8 @@ from controllers.console.wraps import account_initialization_required, cloud_edi from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields from libs.datetime_utils import naive_utc_now -from libs.login import login_required -from models import App, InstalledApp, RecommendedApp +from libs.login import current_user, login_required +from models import Account, App, InstalledApp, RecommendedApp from services.account_service import TenantService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService @@ -29,6 +28,8 @@ class InstalledAppsListApi(Resource): @marshal_with(installed_app_list_fields) def get(self): app_id = request.args.get("app_id", default=None, type=str) + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") current_tenant_id = current_user.current_tenant_id if app_id: @@ -40,6 +41,8 @@ class InstalledAppsListApi(Resource): else: installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all() + if current_user.current_tenant is None: + raise ValueError("current_user.current_tenant must not be None") current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) installed_app_list: list[dict[str, Any]] = [ { @@ -115,6 +118,8 @@ class InstalledAppsListApi(Resource): if recommended_app is None: raise NotFound("App not found") + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") current_tenant_id = current_user.current_tenant_id app = db.session.query(App).where(App.id == args["app_id"]).first() @@ -154,6 +159,8 @@ class InstalledAppApi(InstalledAppResource): """ def delete(self, installed_app): + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") if installed_app.app_owner_tenant_id == current_user.current_tenant_id: raise BadRequest("You can't uninstall an app owned by the current tenant") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 608bc6d007..c46c1c1f4f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,6 +1,5 @@ import logging -from flask_login import current_user from flask_restx import marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import InternalServerError, NotFound @@ -24,6 +23,8 @@ from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields from libs import helper from libs.helper import uuid_value +from libs.login import current_user +from models import Account from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError @@ -54,6 +55,8 @@ class MessageListApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") return MessageService.pagination_by_first_id( app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] ) @@ -75,6 +78,8 @@ class MessageFeedbackApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") MessageService.create_feedback( app_model=app_model, message_id=message_id, @@ -105,6 +110,8 @@ class MessageMoreLikeThisApi(InstalledAppResource): streaming = args["response_mode"] == "streaming" try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate_more_like_this( app_model=app_model, user=current_user, @@ -142,6 +149,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource): message_id = str(message_id) try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE ) diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 62f9350b71..974222ddf7 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,11 +1,10 @@ -from flask_login import current_user from flask_restx import Resource, fields, marshal_with, reqparse from constants.languages import languages from controllers.console import api from controllers.console.wraps import account_initialization_required from libs.helper import AppIconUrlField -from libs.login import login_required +from libs.login import current_user, login_required from services.recommended_app_service import RecommendedAppService app_fields = { @@ -46,8 +45,9 @@ class RecommendedAppListApi(Resource): parser.add_argument("language", type=str, location="args") args = parser.parse_args() - if args.get("language") and args.get("language") in languages: - language_prefix = args.get("language") + language = args.get("language") + if language and language in languages: + language_prefix = language elif current_user and current_user.interface_language: language_prefix = current_user.interface_language else: diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 5353dbcad5..6f05f898f9 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,4 +1,3 @@ -from flask_login import current_user from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound @@ -8,6 +7,8 @@ from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import message_file_fields from libs.helper import TimestampField, uuid_value +from libs.login import current_user +from models import Account from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService @@ -42,6 +43,8 @@ class SavedMessageListApi(InstalledAppResource): parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) def post(self, installed_app): @@ -54,6 +57,8 @@ class SavedMessageListApi(InstalledAppResource): args = parser.parse_args() try: + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") SavedMessageService.save(app_model, current_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -70,6 +75,8 @@ class SavedMessageApi(InstalledAppResource): if app_model.mode != "completion": raise NotCompletionAppError() + if not isinstance(current_user, Account): + raise ValueError("current_user must be an Account instance") SavedMessageService.delete(app_model, current_user, message_id) return {"result": "success"}, 204 diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 101a49a32e..5d11dec523 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -22,6 +22,7 @@ from controllers.console.wraps import ( ) from fields.file_fields import file_fields, upload_config_fields from libs.login import login_required +from models import Account from services.file_service import FileService PREVIEW_WORDS_LIMIT = 3000 @@ -68,6 +69,8 @@ class FileApi(Resource): source = None try: + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") upload_file = FileService.upload_file( filename=file.filename, content=file.read(), diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 95515c38f9..8409e7d1ab 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -34,14 +34,14 @@ class VersionApi(Resource): return result try: - response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10)) + response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10)) except Exception as error: logger.warning("Check update version error: %s.", str(error)) - result["version"] = args.get("current_version") + result["version"] = args["current_version"] return result content = json.loads(response.content) - if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"): + if _has_new_version(latest_version=content["version"], current_version=f"{args['current_version']}"): result["version"] = content["version"] result["release_date"] = content["releaseDate"] result["release_notes"] = content["releaseNotes"] diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 5b2828dbab..bd078729c4 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -49,6 +49,8 @@ class AccountInitApi(Resource): @setup_required @login_required def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user if account.status == "active": @@ -102,6 +104,8 @@ class AccountProfileApi(Resource): @marshal_with(account_fields) @enterprise_license_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") return current_user @@ -111,6 +115,8 @@ class AccountNameApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() @@ -130,6 +136,8 @@ class AccountAvatarApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("avatar", type=str, required=True, location="json") args = parser.parse_args() @@ -145,6 +153,8 @@ class AccountInterfaceLanguageApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("interface_language", type=supported_language, required=True, location="json") args = parser.parse_args() @@ -160,6 +170,8 @@ class AccountInterfaceThemeApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") args = parser.parse_args() @@ -175,6 +187,8 @@ class AccountTimezoneApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("timezone", type=str, required=True, location="json") args = parser.parse_args() @@ -194,6 +208,8 @@ class AccountPasswordApi(Resource): @account_initialization_required @marshal_with(account_fields) def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("password", type=str, required=False, location="json") parser.add_argument("new_password", type=str, required=True, location="json") @@ -228,6 +244,8 @@ class AccountIntegrateApi(Resource): @account_initialization_required @marshal_with(integrate_list_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all() @@ -268,6 +286,8 @@ class AccountDeleteVerifyApi(Resource): @login_required @account_initialization_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user token, code = AccountService.generate_account_deletion_verification_code(account) @@ -281,6 +301,8 @@ class AccountDeleteApi(Resource): @login_required @account_initialization_required def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user parser = reqparse.RequestParser() @@ -321,6 +343,8 @@ class EducationVerifyApi(Resource): @cloud_edition_billing_enabled @marshal_with(verify_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user return BillingService.EducationIdentity.verify(account.id, account.email) @@ -340,6 +364,8 @@ class EducationApi(Resource): @only_edition_cloud @cloud_edition_billing_enabled def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user parser = reqparse.RequestParser() @@ -357,6 +383,8 @@ class EducationApi(Resource): @cloud_edition_billing_enabled @marshal_with(status_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") account = current_user res = BillingService.EducationIdentity.status(account.id) @@ -421,6 +449,8 @@ class ChangeEmailSendEmailApi(Resource): raise InvalidTokenError() user_email = reset_data.get("email", "") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if user_email != current_user.email: raise InvalidEmailError() else: @@ -501,6 +531,8 @@ class ChangeEmailResetApi(Resource): AccountService.revoke_change_email_token(args["token"]) old_email = reset_data.get("old_email", "") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if current_user.email != old_email: raise AccountNotFound() diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index cf2a10f453..77f0c9a735 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,8 +1,8 @@ from urllib import parse -from flask import request +from flask import abort, request from flask_login import current_user -from flask_restx import Resource, abort, marshal_with, reqparse +from flask_restx import Resource, marshal_with, reqparse import services from configs import dify_config @@ -41,6 +41,10 @@ class MemberListApi(Resource): @account_initialization_required @marshal_with(account_with_role_list_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") members = TenantService.get_tenant_members(current_user.current_tenant) return {"result": "success", "accounts": members}, 200 @@ -65,7 +69,11 @@ class MemberInviteEmailApi(Resource): if not TenantAccountRole.is_non_owner_role(invitee_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") inviter = current_user + if not inviter.current_tenant: + raise ValueError("No current tenant") invitation_results = [] console_web_url = dify_config.CONSOLE_WEB_URL @@ -76,6 +84,8 @@ class MemberInviteEmailApi(Resource): for invitee_email in invitee_emails: try: + if not inviter.current_tenant: + raise ValueError("No current tenant") token = RegisterService.invite_new_member( inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter ) @@ -97,7 +107,7 @@ class MemberInviteEmailApi(Resource): return { "result": "success", "invitation_results": invitation_results, - "tenant_id": str(current_user.current_tenant.id), + "tenant_id": str(inviter.current_tenant.id) if inviter.current_tenant else "", }, 201 @@ -108,6 +118,10 @@ class MemberCancelInviteApi(Resource): @login_required @account_initialization_required def delete(self, member_id): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") member = db.session.query(Account).where(Account.id == str(member_id)).first() if member is None: abort(404) @@ -123,7 +137,10 @@ class MemberCancelInviteApi(Resource): except Exception as e: raise ValueError(str(e)) - return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200 + return { + "result": "success", + "tenant_id": str(current_user.current_tenant.id) if current_user.current_tenant else "", + }, 200 class MemberUpdateRoleApi(Resource): @@ -141,6 +158,10 @@ class MemberUpdateRoleApi(Resource): if not TenantAccountRole.is_valid_role(new_role): return {"code": "invalid-role", "message": "Invalid role"}, 400 + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") member = db.session.get(Account, str(member_id)) if not member: abort(404) @@ -164,6 +185,10 @@ class DatasetOperatorMemberListApi(Resource): @account_initialization_required @marshal_with(account_with_role_list_fields) def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") members = TenantService.get_dataset_operator_members(current_user.current_tenant) return {"result": "success", "accounts": members}, 200 @@ -184,6 +209,10 @@ class SendOwnerTransferEmailApi(Resource): raise EmailSendIpLimitError() # check if the current user is the owner of the workspace + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): raise NotOwnerError() @@ -198,7 +227,7 @@ class SendOwnerTransferEmailApi(Resource): account=current_user, email=email, language=language, - workspace_name=current_user.current_tenant.name, + workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", ) return {"result": "success", "data": token} @@ -215,6 +244,10 @@ class OwnerTransferCheckApi(Resource): parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() # check if the current user is the owner of the workspace + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): raise NotOwnerError() @@ -256,6 +289,10 @@ class OwnerTransfer(Resource): args = parser.parse_args() # check if the current user is the owner of the workspace + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant: + raise ValueError("No current tenant") if not TenantService.is_owner(current_user, current_user.current_tenant): raise NotOwnerError() @@ -274,9 +311,11 @@ class OwnerTransfer(Resource): member = db.session.get(Account, str(member_id)) if not member: abort(404) - else: - member_account = member - if not TenantService.is_member(member_account, current_user.current_tenant): + return # Never reached, but helps type checker + + if not current_user.current_tenant: + raise ValueError("No current tenant") + if not TenantService.is_member(member, current_user.current_tenant): raise MemberNotInTenantError() try: @@ -286,13 +325,13 @@ class OwnerTransfer(Resource): AccountService.send_new_owner_transfer_notify_email( account=member, email=member.email, - workspace_name=current_user.current_tenant.name, + workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", ) AccountService.send_old_owner_transfer_notify_email( account=current_user, email=current_user.email, - workspace_name=current_user.current_tenant.name, + workspace_name=current_user.current_tenant.name if current_user.current_tenant else "", new_owner_email=member.email, ) diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index bfcc9a7f0a..0c9db660aa 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder from libs.helper import StrLen, uuid_value from libs.login import login_required +from models.account import Account from services.billing_service import BillingService from services.model_provider_service import ModelProviderService @@ -21,6 +22,10 @@ class ModelProviderListApi(Resource): @login_required @account_initialization_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() @@ -45,6 +50,10 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def get(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id # if credential_id is not provided, return current used credential parser = reqparse.RequestParser() @@ -62,6 +71,8 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() @@ -72,6 +83,8 @@ class ModelProviderCredentialApi(Resource): model_provider_service = ModelProviderService() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") try: model_provider_service.create_provider_credential( tenant_id=current_user.current_tenant_id, @@ -88,6 +101,8 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def put(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() @@ -99,6 +114,8 @@ class ModelProviderCredentialApi(Resource): model_provider_service = ModelProviderService() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") try: model_provider_service.update_provider_credential( tenant_id=current_user.current_tenant_id, @@ -116,12 +133,16 @@ class ModelProviderCredentialApi(Resource): @login_required @account_initialization_required def delete(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") model_provider_service = ModelProviderService() model_provider_service.remove_provider_credential( tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] @@ -135,12 +156,16 @@ class ModelProviderCredentialSwitchApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") service = ModelProviderService() service.switch_active_provider_credential( tenant_id=current_user.current_tenant_id, @@ -155,10 +180,14 @@ class ModelProviderValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() @@ -205,9 +234,13 @@ class PreferredProviderTypeUpdateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") if not current_user.is_admin_or_owner: raise Forbidden() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() @@ -236,7 +269,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): def get(self, provider: str): if provider != "anthropic": raise ValueError(f"provider name {provider} is invalid") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") BillingService.is_tenant_owner_or_admin(current_user) + if not current_user.current_tenant_id: + raise ValueError("No current tenant") data = BillingService.get_model_provider_payment_link( provider_name=provider, tenant_id=current_user.current_tenant_id, diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index e7a3aca66c..655afbe73f 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -25,7 +25,7 @@ from controllers.console.wraps import ( from extensions.ext_database import db from libs.helper import TimestampField from libs.login import login_required -from models.account import Tenant, TenantStatus +from models.account import Account, Tenant, TenantStatus from services.account_service import TenantService from services.feature_service import FeatureService from services.file_service import FileService @@ -70,6 +70,8 @@ class TenantListApi(Resource): @login_required @account_initialization_required def get(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") tenants = TenantService.get_join_tenants(current_user) tenant_dicts = [] @@ -83,7 +85,7 @@ class TenantListApi(Resource): "status": tenant.status, "created_at": tenant.created_at, "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", - "current": tenant.id == current_user.current_tenant_id, + "current": tenant.id == current_user.current_tenant_id if current_user.current_tenant_id else False, } tenant_dicts.append(tenant_dict) @@ -125,7 +127,11 @@ class TenantApi(Resource): if request.path == "/info": logger.warning("Deprecated URL /info was used.") + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") tenant = current_user.current_tenant + if not tenant: + raise ValueError("No current tenant") if tenant.status == TenantStatus.ARCHIVE: tenants = TenantService.get_join_tenants(current_user) @@ -137,6 +143,8 @@ class TenantApi(Resource): else: raise Unauthorized("workspace is archived") + if not tenant: + raise ValueError("No tenant available") return WorkspaceService.get_tenant_info(tenant), 200 @@ -145,6 +153,8 @@ class SwitchWorkspaceApi(Resource): @login_required @account_initialization_required def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("tenant_id", type=str, required=True, location="json") args = parser.parse_args() @@ -168,11 +178,15 @@ class CustomConfigWorkspaceApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("remove_webapp_brand", type=bool, location="json") parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant = db.get_or_404(Tenant, current_user.current_tenant_id) custom_config_dict = { @@ -194,6 +208,8 @@ class WebappLogoWorkspaceApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") # check file if "file" not in request.files: raise NoFileUploadedError() @@ -232,10 +248,14 @@ class WorkspaceInfoApi(Resource): @account_initialization_required # Change workspace name def post(self): + if not isinstance(current_user, Account): + raise ValueError("Invalid user account") parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() + if not current_user.current_tenant_id: + raise ValueError("No current tenant") tenant = db.get_or_404(Tenant, current_user.current_tenant_id) tenant.name = args["name"] db.session.commit() diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index 821ad220a2..a1b8bb7cfe 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -15,6 +15,6 @@ api = ExternalApi( files_ns = Namespace("files", description="File operations", path="/") -from . import image_preview, tool_files, upload +from . import image_preview, tool_files, upload # pyright: ignore[reportUnusedImport] api.add_namespace(files_ns) diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index d29a7be139..b09c39309f 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -16,8 +16,8 @@ api = ExternalApi( # Create namespace inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") -from . import mail -from .plugin import plugin -from .workspace import workspace +from . import mail as _mail # pyright: ignore[reportUnusedImport] +from .plugin import plugin as _plugin # pyright: ignore[reportUnusedImport] +from .workspace import workspace as _workspace # pyright: ignore[reportUnusedImport] api.add_namespace(inner_api_ns) diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 170a794d89..c5bb2f2545 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -37,9 +37,9 @@ from models.model import EndUser @inner_api_ns.route("/invoke/llm") class PluginInvokeLLMApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeLLM) @inner_api_ns.doc("plugin_invoke_llm") @inner_api_ns.doc(description="Invoke LLM models through plugin interface") @@ -60,9 +60,9 @@ class PluginInvokeLLMApi(Resource): @inner_api_ns.route("/invoke/llm/structured-output") class PluginInvokeLLMWithStructuredOutputApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput) @inner_api_ns.doc("plugin_invoke_llm_structured") @inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface") @@ -85,9 +85,9 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource): @inner_api_ns.route("/invoke/text-embedding") class PluginInvokeTextEmbeddingApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeTextEmbedding) @inner_api_ns.doc("plugin_invoke_text_embedding") @inner_api_ns.doc(description="Invoke text embedding models through plugin interface") @@ -115,9 +115,9 @@ class PluginInvokeTextEmbeddingApi(Resource): @inner_api_ns.route("/invoke/rerank") class PluginInvokeRerankApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeRerank) @inner_api_ns.doc("plugin_invoke_rerank") @inner_api_ns.doc(description="Invoke rerank models through plugin interface") @@ -141,9 +141,9 @@ class PluginInvokeRerankApi(Resource): @inner_api_ns.route("/invoke/tts") class PluginInvokeTTSApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeTTS) @inner_api_ns.doc("plugin_invoke_tts") @inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface") @@ -168,9 +168,9 @@ class PluginInvokeTTSApi(Resource): @inner_api_ns.route("/invoke/speech2text") class PluginInvokeSpeech2TextApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeSpeech2Text) @inner_api_ns.doc("plugin_invoke_speech2text") @inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface") @@ -194,9 +194,9 @@ class PluginInvokeSpeech2TextApi(Resource): @inner_api_ns.route("/invoke/moderation") class PluginInvokeModerationApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeModeration) @inner_api_ns.doc("plugin_invoke_moderation") @inner_api_ns.doc(description="Invoke moderation models through plugin interface") @@ -220,9 +220,9 @@ class PluginInvokeModerationApi(Resource): @inner_api_ns.route("/invoke/tool") class PluginInvokeToolApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeTool) @inner_api_ns.doc("plugin_invoke_tool") @inner_api_ns.doc(description="Invoke tools through plugin interface") @@ -252,9 +252,9 @@ class PluginInvokeToolApi(Resource): @inner_api_ns.route("/invoke/parameter-extractor") class PluginInvokeParameterExtractorNodeApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeParameterExtractorNode) @inner_api_ns.doc("plugin_invoke_parameter_extractor") @inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface") @@ -285,9 +285,9 @@ class PluginInvokeParameterExtractorNodeApi(Resource): @inner_api_ns.route("/invoke/question-classifier") class PluginInvokeQuestionClassifierNodeApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeQuestionClassifierNode) @inner_api_ns.doc("plugin_invoke_question_classifier") @inner_api_ns.doc(description="Invoke question classifier node through plugin interface") @@ -318,9 +318,9 @@ class PluginInvokeQuestionClassifierNodeApi(Resource): @inner_api_ns.route("/invoke/app") class PluginInvokeAppApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeApp) @inner_api_ns.doc("plugin_invoke_app") @inner_api_ns.doc(description="Invoke application through plugin interface") @@ -348,9 +348,9 @@ class PluginInvokeAppApi(Resource): @inner_api_ns.route("/invoke/encrypt") class PluginInvokeEncryptApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeEncrypt) @inner_api_ns.doc("plugin_invoke_encrypt") @inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface") @@ -375,9 +375,9 @@ class PluginInvokeEncryptApi(Resource): @inner_api_ns.route("/invoke/summary") class PluginInvokeSummaryApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestInvokeSummary) @inner_api_ns.doc("plugin_invoke_summary") @inner_api_ns.doc(description="Invoke summary functionality through plugin interface") @@ -405,9 +405,9 @@ class PluginInvokeSummaryApi(Resource): @inner_api_ns.route("/upload/file/request") class PluginUploadFileRequestApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestRequestUploadFile) @inner_api_ns.doc("plugin_upload_file_request") @inner_api_ns.doc(description="Request signed URL for file upload through plugin interface") @@ -426,9 +426,9 @@ class PluginUploadFileRequestApi(Resource): @inner_api_ns.route("/fetch/app/info") class PluginFetchAppInfoApi(Resource): + @get_user_tenant @setup_required @plugin_inner_api_only - @get_user_tenant @plugin_data(payload_type=RequestFetchAppInfo) @inner_api_ns.doc("plugin_fetch_app_info") @inner_api_ns.doc(description="Fetch application information through plugin interface") diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 68711f7257..18b530f2c4 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Optional, ParamSpec, TypeVar +from typing import Optional, ParamSpec, TypeVar, cast from flask import current_app, request from flask_login import user_logged_in @@ -10,7 +10,7 @@ from sqlalchemy.orm import Session from core.file.constants import DEFAULT_SERVICE_API_USER_ID from extensions.ext_database import db -from libs.login import _get_user +from libs.login import current_user from models.account import Tenant from models.model import EndUser @@ -66,8 +66,8 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None): p = parser.parse_args() - user_id: Optional[str] = p.get("user_id") - tenant_id: str = p.get("tenant_id") + user_id = cast(str, p.get("user_id")) + tenant_id = cast(str, p.get("tenant_id")) if not tenant_id: raise ValueError("tenant_id is required") @@ -98,7 +98,7 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None): kwargs["user_model"] = user current_app.login_manager._update_request_context_with_user(user) # type: ignore - user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore return view_func(*args, **kwargs) diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py index c344ffad08..43b36a70b4 100644 --- a/api/controllers/mcp/__init__.py +++ b/api/controllers/mcp/__init__.py @@ -15,6 +15,6 @@ api = ExternalApi( mcp_ns = Namespace("mcp", description="MCP operations", path="/") -from . import mcp +from . import mcp # pyright: ignore[reportUnusedImport] api.add_namespace(mcp_ns) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 763345d723..d69f49d957 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -15,9 +15,27 @@ api = ExternalApi( service_api_ns = Namespace("service_api", description="Service operations", path="/") -from . import index -from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow -from .dataset import dataset, document, hit_testing, metadata, segment, upload_file -from .workspace import models +from . import index # pyright: ignore[reportUnusedImport] +from .app import ( + annotation, # pyright: ignore[reportUnusedImport] + app, # pyright: ignore[reportUnusedImport] + audio, # pyright: ignore[reportUnusedImport] + completion, # pyright: ignore[reportUnusedImport] + conversation, # pyright: ignore[reportUnusedImport] + file, # pyright: ignore[reportUnusedImport] + file_preview, # pyright: ignore[reportUnusedImport] + message, # pyright: ignore[reportUnusedImport] + site, # pyright: ignore[reportUnusedImport] + workflow, # pyright: ignore[reportUnusedImport] +) +from .dataset import ( + dataset, # pyright: ignore[reportUnusedImport] + document, # pyright: ignore[reportUnusedImport] + hit_testing, # pyright: ignore[reportUnusedImport] + metadata, # pyright: ignore[reportUnusedImport] + segment, # pyright: ignore[reportUnusedImport] + upload_file, # pyright: ignore[reportUnusedImport] +) +from .workspace import models # pyright: ignore[reportUnusedImport] api.add_namespace(service_api_ns) diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 4860bf3a79..711dd5704c 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,4 +1,5 @@ from flask_restx import Resource, reqparse +from flask_restx._http import HTTPStatus from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound @@ -121,7 +122,7 @@ class ConversationDetailApi(Resource): } ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204) + @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT) def delete(self, app_model: App, end_user: EndUser, c_id): """Delete a specific conversation.""" app_mode = AppMode.value_of(app_model.mode) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index de41384270..721cf530c3 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -30,6 +30,7 @@ from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment +from models.model import EndUser from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -298,6 +299,9 @@ class DocumentAddByFileApi(DatasetApiResource): if not file.filename: raise FilenameNotExistsError + if not isinstance(current_user, EndUser): + raise ValueError("Invalid user account") + upload_file = FileService.upload_file( filename=file.filename, content=file.read(), @@ -387,6 +391,8 @@ class DocumentUpdateByFileApi(DatasetApiResource): raise FilenameNotExistsError try: + if not isinstance(current_user, EndUser): + raise ValueError("Invalid user account") upload_file = FileService.upload_file( filename=file.filename, content=file.read(), diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 4394e64ad9..64a2f5445c 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -17,7 +17,7 @@ from core.file.constants import DEFAULT_SERVICE_API_USER_ID from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now -from libs.login import _get_user +from libs.login import current_user from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.dataset import Dataset, RateLimitLog from models.model import ApiToken, App, EndUser @@ -210,7 +210,7 @@ def validate_dataset_token(view: Optional[Callable[Concatenate[T, P], R]] = None if account: account.current_tenant = tenant current_app.login_manager._update_request_context_with_user(account) # type: ignore - user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore else: raise Unauthorized("Tenant owner account does not exist.") else: diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 3b0a9e341a..a825a2a0d8 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -17,20 +17,20 @@ api = ExternalApi( web_ns = Namespace("web", description="Web application API operations", path="/") from . import ( - app, - audio, - completion, - conversation, - feature, - files, - forgot_password, - login, - message, - passport, - remote_files, - saved_message, - site, - workflow, + app, # pyright: ignore[reportUnusedImport] + audio, # pyright: ignore[reportUnusedImport] + completion, # pyright: ignore[reportUnusedImport] + conversation, # pyright: ignore[reportUnusedImport] + feature, # pyright: ignore[reportUnusedImport] + files, # pyright: ignore[reportUnusedImport] + forgot_password, # pyright: ignore[reportUnusedImport] + login, # pyright: ignore[reportUnusedImport] + message, # pyright: ignore[reportUnusedImport] + passport, # pyright: ignore[reportUnusedImport] + remote_files, # pyright: ignore[reportUnusedImport] + saved_message, # pyright: ignore[reportUnusedImport] + site, # pyright: ignore[reportUnusedImport] + workflow, # pyright: ignore[reportUnusedImport] ) api.add_namespace(web_ns) diff --git a/api/core/__init__.py b/api/core/__init__.py index 6eaea7b1c8..e69de29bb2 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -1 +0,0 @@ -import core.moderation.base diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index b94a60c40a..d1d5a011e0 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -72,6 +72,8 @@ class CotAgentRunner(BaseAgentRunner, ABC): function_call_state = True llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" + prompt_messages: list = [] # Initialize prompt_messages + agent_thought_id = "" # Initialize agent_thought_id def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): if not final_llm_usage_dict["usage"]: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 9eb853aa74..5236266908 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -54,6 +54,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): function_call_state = True llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" + prompt_messages: list = [] # Initialize prompt_messages # get tracing instance trace_manager = app_generate_entity.trace_manager diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 037037e6ca..97ede178c7 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -21,7 +21,7 @@ class SensitiveWordAvoidanceConfigManager: @classmethod def validate_and_set_defaults( - cls, tenant_id, config: dict, only_structure_validate: bool = False + cls, tenant_id: str, config: dict, only_structure_validate: bool = False ) -> tuple[dict, list[str]]: if not config.get("sensitive_word_avoidance"): config["sensitive_word_avoidance"] = {"enabled": False} @@ -38,7 +38,14 @@ class SensitiveWordAvoidanceConfigManager: if not only_structure_validate: typ = config["sensitive_word_avoidance"]["type"] - sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] + if not isinstance(typ, str): + raise ValueError("sensitive_word_avoidance.type must be a string") + + sensitive_word_avoidance_config = config["sensitive_word_avoidance"].get("config") + if sensitive_word_avoidance_config is None: + sensitive_word_avoidance_config = {} + if not isinstance(sensitive_word_avoidance_config, dict): + raise ValueError("sensitive_word_avoidance.config must be a dict") ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config) diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index e6ab31e586..cda17c0010 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -25,10 +25,14 @@ class PromptTemplateConfigManager: if chat_prompt_config: chat_prompt_messages = [] for message in chat_prompt_config.get("prompt", []): + text = message.get("text") + if not isinstance(text, str): + raise ValueError("message text must be a string") + role = message.get("role") + if not isinstance(role, str): + raise ValueError("message role must be a string") chat_prompt_messages.append( - AdvancedChatMessageEntity( - **{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} - ) + AdvancedChatMessageEntity(text=text, role=PromptMessageRole.value_of(role)) ) advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 627f6b47ce..02ec96f209 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -71,7 +71,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, Any] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -102,7 +102,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, Any] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) @@ -118,8 +118,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): - response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 8207b70f9e..cec3b83674 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -174,7 +174,7 @@ class AdvancedChatAppGenerateTaskPipeline: generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._base_task_pipeline._stream: + if self._base_task_pipeline.stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -302,13 +302,13 @@ class AdvancedChatAppGenerateTaskPipeline: def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: """Handle ping events.""" - yield self._base_task_pipeline._ping_stream_response() + yield self._base_task_pipeline.ping_stream_response() def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: """Handle error events.""" with self._database_session() as session: - err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id) - yield self._base_task_pipeline._error_to_stream_response(err) + err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id) + yield self._base_task_pipeline.error_to_stream_response(err) def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]: """Handle workflow started events.""" @@ -627,10 +627,10 @@ class AdvancedChatAppGenerateTaskPipeline: workflow_execution=workflow_execution, ) err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) - err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id) + err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id) yield workflow_finish_resp - yield self._base_task_pipeline._error_to_stream_response(err) + yield self._base_task_pipeline.error_to_stream_response(err) def _handle_stop_event( self, @@ -683,7 +683,7 @@ class AdvancedChatAppGenerateTaskPipeline: """Handle advanced chat message end events.""" self._ensure_graph_runtime_initialized(graph_runtime_state) - output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( + output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished( self._task_state.answer ) if output_moderation_answer: @@ -899,7 +899,7 @@ class AdvancedChatAppGenerateTaskPipeline: message.answer = answer_text message.updated_at = naive_utc_now() - message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at + message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at message.message_metadata = self._task_state.metadata.model_dump_json() message_files = [ MessageFile( @@ -955,9 +955,9 @@ class AdvancedChatAppGenerateTaskPipeline: :param text: text :return: True if output moderation should direct output, otherwise False """ - if self._base_task_pipeline._output_moderation_handler: - if self._base_task_pipeline._output_moderation_handler.should_direct_output(): - self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() + if self._base_task_pipeline.output_moderation_handler: + if self._base_task_pipeline.output_moderation_handler.should_direct_output(): + self._task_state.answer = self._base_task_pipeline.output_moderation_handler.get_final_output() self._base_task_pipeline.queue_manager.publish( QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) @@ -967,7 +967,7 @@ class AdvancedChatAppGenerateTaskPipeline: ) return True else: - self._base_task_pipeline._output_moderation_handler.append_new_token(text) + self._base_task_pipeline.output_moderation_handler.append_new_token(text) return False diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 349b583833..54d1a9595f 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Mapping -from typing import Any, Optional +from typing import Any, Optional, cast from core.agent.entities import AgentEntity from core.app.app_config.base_app_config_manager import BaseAppConfigManager @@ -160,7 +160,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager): return filtered_config @classmethod - def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + def validate_agent_mode_and_set_defaults( + cls, tenant_id: str, config: dict[str, Any] + ) -> tuple[dict[str, Any], list[str]]: """ Validate agent_mode and set defaults for agent feature @@ -170,30 +172,32 @@ class AgentChatAppConfigManager(BaseAppConfigManager): if not config.get("agent_mode"): config["agent_mode"] = {"enabled": False, "tools": []} - if not isinstance(config["agent_mode"], dict): + agent_mode = config["agent_mode"] + if not isinstance(agent_mode, dict): raise ValueError("agent_mode must be of object type") - if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: - config["agent_mode"]["enabled"] = False + # FIXME(-LAN-): Cast needed due to basedpyright limitation with dict type narrowing + agent_mode = cast(dict[str, Any], agent_mode) - if not isinstance(config["agent_mode"]["enabled"], bool): + if "enabled" not in agent_mode or not agent_mode["enabled"]: + agent_mode["enabled"] = False + + if not isinstance(agent_mode["enabled"], bool): raise ValueError("enabled in agent_mode must be of boolean type") - if not config["agent_mode"].get("strategy"): - config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value + if not agent_mode.get("strategy"): + agent_mode["strategy"] = PlanningStrategy.ROUTER.value - if config["agent_mode"]["strategy"] not in [ - member.value for member in list(PlanningStrategy.__members__.values()) - ]: + if agent_mode["strategy"] not in [member.value for member in list(PlanningStrategy.__members__.values())]: raise ValueError("strategy in agent_mode must be in the specified strategy list") - if not config["agent_mode"].get("tools"): - config["agent_mode"]["tools"] = [] + if not agent_mode.get("tools"): + agent_mode["tools"] = [] - if not isinstance(config["agent_mode"]["tools"], list): + if not isinstance(agent_mode["tools"], list): raise ValueError("tools in agent_mode must be a list of objects") - for tool in config["agent_mode"]["tools"]: + for tool in agent_mode["tools"]: key = list(tool.keys())[0] if key in OLD_TOOLS: # old style, use tool name as key diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 89a5b8e3b5..e35e9d9408 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -46,7 +46,10 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) + else: + response["metadata"] = {} return response @@ -78,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -106,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) @@ -114,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 795a7befff..2a7fe7902b 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -32,6 +32,7 @@ class AppQueueManager: self._task_id = task_id self._user_id = user_id self._invoke_from = invoke_from + self.invoke_from = invoke_from # Public accessor for invoke_from user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" redis_client.setex( diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 816d6d79a9..3aa1161fd8 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -46,7 +46,10 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) + else: + response["metadata"] = {} return response @@ -78,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -106,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) @@ -114,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 8485ce7519..843328f904 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -271,6 +271,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator): raise MoreLikeThisDisabledError() app_model_config = message.app_model_config + if not app_model_config: + raise ValueError("Message app_model_config is None") override_model_config_dict = app_model_config.to_dict() model_dict = override_model_config_dict["model"] completion_params = model_dict.get("completion_params") diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 4d45c61145..d7e9ebdf24 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -45,7 +45,10 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) + else: + response["metadata"] = {} return response @@ -76,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -103,14 +106,16 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): } if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.to_dict() + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") metadata = sub_stream_response_dict.get("metadata", {}) + if not isinstance(metadata, dict): + metadata = {} sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 210f6110b1..01ecf0298f 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): :param blocking_response: blocking response :return: """ - return dict(blocking_response.to_dict()) + return blocking_response.model_dump() @classmethod def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] @@ -51,7 +51,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, object] = { "event": sub_stream_response.event.value, "workflow_run_id": chunk.workflow_run_id, } @@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk @classmethod @@ -80,7 +80,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, object] = { "event": sub_stream_response.event.value, "workflow_run_id": chunk.workflow_run_id, } @@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 6ab89dbd61..1c950063dd 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -137,7 +137,7 @@ class WorkflowAppGenerateTaskPipeline: self._application_generate_entity = application_generate_entity self._workflow_features_dict = workflow.features_dict self._workflow_run_id = "" - self._invoke_from = queue_manager._invoke_from + self._invoke_from = queue_manager.invoke_from self._draft_var_saver_factory = draft_var_saver_factory def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -146,7 +146,7 @@ class WorkflowAppGenerateTaskPipeline: :return: """ generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._base_task_pipeline._stream: + if self._base_task_pipeline.stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -276,12 +276,12 @@ class WorkflowAppGenerateTaskPipeline: def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: """Handle ping events.""" - yield self._base_task_pipeline._ping_stream_response() + yield self._base_task_pipeline.ping_stream_response() def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: """Handle error events.""" - err = self._base_task_pipeline._handle_error(event=event) - yield self._base_task_pipeline._error_to_stream_response(err) + err = self._base_task_pipeline.handle_error(event=event) + yield self._base_task_pipeline.error_to_stream_response(err) def _handle_workflow_started_event( self, event: QueueWorkflowStartedEvent, **kwargs diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 9151137fe8..1d5ebabaf7 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -123,7 +123,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): """ # app config - app_config: EasyUIBasedAppConfig + app_config: EasyUIBasedAppConfig = None # type: ignore model_conf: ModelConfigWithCredentialsEntity query: Optional[str] = None @@ -186,7 +186,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): """ # app config - app_config: WorkflowUIBasedAppConfig + app_config: WorkflowUIBasedAppConfig = None # type: ignore workflow_run_id: Optional[str] = None query: str @@ -218,7 +218,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): """ # app config - app_config: WorkflowUIBasedAppConfig + app_config: WorkflowUIBasedAppConfig = None # type: ignore workflow_execution_id: str class SingleIterationRunEntity(BaseModel): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 29f3e3427e..31183d19a3 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -5,7 +5,6 @@ from typing import Any, Optional from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.utils.encoders import jsonable_encoder from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -92,9 +91,6 @@ class StreamResponse(BaseModel): event: StreamEvent task_id: str - def to_dict(self): - return jsonable_encoder(self) - class ErrorStreamResponse(StreamResponse): """ @@ -745,9 +741,6 @@ class AppBlockingResponse(BaseModel): task_id: str - def to_dict(self): - return jsonable_encoder(self) - class ChatbotAppBlockingResponse(AppBlockingResponse): """ diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index be183e2086..3853dccdc5 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -35,6 +35,9 @@ class AnnotationReplyFeature: collection_binding_detail = annotation_setting.collection_binding_detail + if not collection_binding_detail: + return None + try: score_threshold = annotation_setting.score_threshold or 1 embedding_provider_name = collection_binding_detail.provider_name diff --git a/api/core/app/features/rate_limiting/__init__.py b/api/core/app/features/rate_limiting/__init__.py index 6624f6ad9d..4ad33acd0f 100644 --- a/api/core/app/features/rate_limiting/__init__.py +++ b/api/core/app/features/rate_limiting/__init__.py @@ -1 +1,3 @@ from .rate_limit import RateLimit + +__all__ = ["RateLimit"] diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index f526d2a16a..6f13f11da0 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,7 +19,7 @@ class RateLimit: _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict: dict[str, "RateLimit"] = {} - def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): + def __new__(cls, client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: instance = super().__new__(cls) cls._instance_dict[client_id] = instance diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 7d98cceb1a..4931300901 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -38,11 +38,11 @@ class BasedGenerateTaskPipeline: ): self._application_generate_entity = application_generate_entity self.queue_manager = queue_manager - self._start_at = time.perf_counter() - self._output_moderation_handler = self._init_output_moderation() - self._stream = stream + self.start_at = time.perf_counter() + self.output_moderation_handler = self._init_output_moderation() + self.stream = stream - def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): + def handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): logger.debug("error: %s", event.error) e = event.error err: Exception @@ -86,7 +86,7 @@ class BasedGenerateTaskPipeline: return message - def _error_to_stream_response(self, e: Exception): + def error_to_stream_response(self, e: Exception): """ Error to stream response. :param e: exception @@ -94,7 +94,7 @@ class BasedGenerateTaskPipeline: """ return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e) - def _ping_stream_response(self) -> PingStreamResponse: + def ping_stream_response(self) -> PingStreamResponse: """ Ping stream response. :return: @@ -118,21 +118,21 @@ class BasedGenerateTaskPipeline: ) return None - def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: + def handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: """ Handle output moderation when task finished. :param completion: completion :return: """ # response moderation - if self._output_moderation_handler: - self._output_moderation_handler.stop_thread() + if self.output_moderation_handler: + self.output_moderation_handler.stop_thread() - completion, flagged = self._output_moderation_handler.moderation_completion( + completion, flagged = self.output_moderation_handler.moderation_completion( completion=completion, public_event=False ) - self._output_moderation_handler = None + self.output_moderation_handler = None if flagged: return completion diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 0dad0a5a9d..71fd5ac653 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -125,7 +125,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._stream: + if self.stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -265,9 +265,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if isinstance(event, QueueErrorEvent): with Session(db.engine) as session: - err = self._handle_error(event=event, session=session, message_id=self._message_id) + err = self.handle_error(event=event, session=session, message_id=self._message_id) session.commit() - yield self._error_to_stream_response(err) + yield self.error_to_stream_response(err) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): @@ -277,7 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): self._handle_stop(event) # handle output moderation - output_moderation_answer = self._handle_output_moderation_when_task_finished( + output_moderation_answer = self.handle_output_moderation_when_task_finished( cast(str, self._task_state.llm_result.message.content) ) if output_moderation_answer: @@ -354,7 +354,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): elif isinstance(event, QueueMessageReplaceEvent): yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + yield self.ping_stream_response() else: continue if publisher: @@ -394,7 +394,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): message.answer_tokens = usage.completion_tokens message.answer_unit_price = usage.completion_unit_price message.answer_price_unit = usage.completion_price_unit - message.provider_response_latency = time.perf_counter() - self._start_at + message.provider_response_latency = time.perf_counter() - self.start_at message.total_price = usage.total_price message.currency = usage.currency self._task_state.llm_result.usage.latency = message.provider_response_latency @@ -438,7 +438,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): # transform usage model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) - self._task_state.llm_result.usage = model_type_instance._calc_response_usage( + self._task_state.llm_result.usage = model_type_instance.calc_response_usage( model, credentials, prompt_tokens, completion_tokens ) @@ -498,10 +498,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): :param text: text :return: True if output moderation should direct output, otherwise False """ - if self._output_moderation_handler: - if self._output_moderation_handler.should_direct_output(): + if self.output_moderation_handler: + if self.output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output - self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() + self._task_state.llm_result.message.content = self.output_moderation_handler.get_final_output() self.queue_manager.publish( QueueLLMChunkEvent( chunk=LLMResultChunk( @@ -521,6 +521,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): ) return True else: - self._output_moderation_handler.append_new_token(text) + self.output_moderation_handler.append_new_token(text) return False diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 4e6422e2df..89190c36cc 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -72,7 +72,7 @@ class AppGeneratorTTSPublisher: self.voice = voice if not voice or voice not in values: self.voice = self.voices[0].get("value") - self.MAX_SENTENCE = 2 + self.max_sentence = 2 self._last_audio_event: Optional[AudioTrunk] = None # FIXME better way to handle this threading.start threading.Thread(target=self._runtime).start() @@ -113,8 +113,8 @@ class AppGeneratorTTSPublisher: self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) - if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): - self.MAX_SENTENCE += 1 + if len(sentence_arr) >= min(self.max_sentence, 7): + self.max_sentence += 1 text_content = "".join(sentence_arr) futures_result = self.executor.submit( _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 9cf35e559d..5309e4e638 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1840,8 +1840,14 @@ class ProviderConfigurations(BaseModel): def __setitem__(self, key, value): self.configurations[key] = value + def __contains__(self, key): + if "/" not in key: + key = str(ModelProviderID(key)) + return key in self.configurations + def __iter__(self): - return iter(self.configurations) + # Return an iterator of (key, value) tuples to match BaseModel's __iter__ + yield from self.configurations.items() def values(self) -> Iterator[ProviderConfiguration]: return iter(self.configurations.values()) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index e3fd175d95..2a5f6c3dc7 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -98,7 +98,7 @@ def to_prompt_message_content( def download(f: File, /): if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE): - return _download_file_content(f._storage_key) + return _download_file_content(f.storage_key) elif f.transfer_method == FileTransferMethod.REMOTE_URL: response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() @@ -134,9 +134,9 @@ def _get_encoded_string(f: File, /): response.raise_for_status() data = response.content case FileTransferMethod.LOCAL_FILE: - data = _download_file_content(f._storage_key) + data = _download_file_content(f.storage_key) case FileTransferMethod.TOOL_FILE: - data = _download_file_content(f._storage_key) + data = _download_file_content(f.storage_key) encoded_string = base64.b64encode(data).decode("utf-8") return encoded_string diff --git a/api/core/file/models.py b/api/core/file/models.py index f61334e7bc..9b74fa387f 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -146,3 +146,11 @@ class File(BaseModel): if not self.related_id: raise ValueError("Missing file related_id") return self + + @property + def storage_key(self) -> str: + return self._storage_key + + @storage_key.setter + def storage_key(self, value: str): + self._storage_key = value diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index efeba9e5ee..cbb78939d2 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -13,18 +13,18 @@ logger = logging.getLogger(__name__) SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES -HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True +http_request_node_ssl_verify = True # Default value for http_request_node_ssl_verify is True try: - HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY - http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower() + config_value = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY + http_request_node_ssl_verify_lower = str(config_value).lower() if http_request_node_ssl_verify_lower == "true": - HTTP_REQUEST_NODE_SSL_VERIFY = True + http_request_node_ssl_verify = True elif http_request_node_ssl_verify_lower == "false": - HTTP_REQUEST_NODE_SSL_VERIFY = False + http_request_node_ssl_verify = False else: raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'") except NameError: - HTTP_REQUEST_NODE_SSL_VERIFY = True + http_request_node_ssl_verify = True BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] @@ -51,7 +51,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): ) if "ssl_verify" not in kwargs: - kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY + kwargs["ssl_verify"] = http_request_node_ssl_verify ssl_verify = kwargs.pop("ssl_verify") diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 89a05e02c8..ed02b70b03 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -529,6 +529,7 @@ class IndexingRunner: # chunk nodes by chunk size indexing_start_at = time.perf_counter() tokens = 0 + create_keyword_thread = None if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": # create keyword index create_keyword_thread = threading.Thread( @@ -567,7 +568,11 @@ class IndexingRunner: for future in futures: tokens += future.result() - if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": + if ( + dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX + and dataset.indexing_technique == "economy" + and create_keyword_thread is not None + ): create_keyword_thread.join() indexing_end_at = time.perf_counter() diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 94b8258e9c..d4c4f10a12 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -20,7 +20,7 @@ from core.llm_generator.prompts import ( ) from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.entities.trace_entity import TraceTaskName @@ -313,14 +313,20 @@ class LLMGenerator: model_type=ModelType.LLM, ) - prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] + prompt_messages: list[PromptMessage] = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] - response: LLMResult = model_instance.invoke_llm( + # Explicitly use the non-streaming overload + result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False, ) + # Runtime type check since pyright has issues with the overload + if not isinstance(result, LLMResult): + raise TypeError("Expected LLMResult when stream=False") + response = result + answer = cast(str, response.message.content) return answer.strip() diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 28833fe8e8..e0b70c132f 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -45,6 +45,7 @@ class SpecialModelType(StrEnum): @overload def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, @@ -53,14 +54,13 @@ def invoke_llm_with_structured_output( model_parameters: Optional[Mapping] = None, tools: Sequence[PromptMessageTool] | None = None, stop: Optional[list[str]] = None, - stream: Literal[True] = True, + stream: Literal[True], user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - @overload def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, @@ -69,14 +69,13 @@ def invoke_llm_with_structured_output( model_parameters: Optional[Mapping] = None, tools: Sequence[PromptMessageTool] | None = None, stop: Optional[list[str]] = None, - stream: Literal[False] = False, + stream: Literal[False], user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, ) -> LLMResultWithStructuredOutput: ... - - @overload def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, @@ -89,9 +88,8 @@ def invoke_llm_with_structured_output( user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - def invoke_llm_with_structured_output( + *, provider: str, model_schema: AIModelEntity, model_instance: ModelInstance, diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index cc4263c0aa..6db22a09e0 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -23,13 +23,13 @@ DEFAULT_QUEUE_READ_TIMEOUT = 3 @final class _StatusReady: def __init__(self, endpoint_url: str): - self._endpoint_url = endpoint_url + self.endpoint_url = endpoint_url @final class _StatusError: def __init__(self, exc: Exception): - self._exc = exc + self.exc = exc # Type aliases for better readability @@ -211,9 +211,9 @@ class SSETransport: raise ValueError("failed to get endpoint URL") if isinstance(status, _StatusReady): - return status._endpoint_url + return status.endpoint_url elif isinstance(status, _StatusError): - raise status._exc + raise status.exc else: raise ValueError("failed to get endpoint URL") diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 3d51ac2333..6f52c65234 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -38,6 +38,7 @@ def handle_mcp_request( """ request_type = type(request.root) + request_root = request.root def create_success_response(result_data: mcp_types.Result) -> mcp_types.JSONRPCResponse: """Create success response with business result data""" @@ -58,21 +59,20 @@ def handle_mcp_request( error=error_data, ) - # Request handler mapping using functional approach - request_handlers = { - mcp_types.InitializeRequest: lambda: handle_initialize(mcp_server.description), - mcp_types.ListToolsRequest: lambda: handle_list_tools( - app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict - ), - mcp_types.CallToolRequest: lambda: handle_call_tool(app, request, user_input_form, end_user), - mcp_types.PingRequest: lambda: handle_ping(), - } - try: - # Dispatch request to appropriate handler - handler = request_handlers.get(request_type) - if handler: - return create_success_response(handler()) + # Dispatch request to appropriate handler based on instance type + if isinstance(request_root, mcp_types.InitializeRequest): + return create_success_response(handle_initialize(mcp_server.description)) + elif isinstance(request_root, mcp_types.ListToolsRequest): + return create_success_response( + handle_list_tools( + app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict + ) + ) + elif isinstance(request_root, mcp_types.CallToolRequest): + return create_success_response(handle_call_tool(app, request, user_input_form, end_user)) + elif isinstance(request_root, mcp_types.PingRequest): + return create_success_response(handle_ping()) else: return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}") diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 96c48034c7..fbad5576aa 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -81,7 +81,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): self.request_meta = request_meta self.request = request self._session = session - self._completed = False + self.completed = False self._on_complete = on_complete self._entered = False # Track if we're in a context manager @@ -98,7 +98,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): ): """Exit the context manager, performing cleanup and notifying completion.""" try: - if self._completed: + if self.completed: self._on_complete(self) finally: self._entered = False @@ -113,9 +113,9 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """ if not self._entered: raise RuntimeError("RequestResponder must be used as a context manager") - assert not self._completed, "Request already responded to" + assert not self.completed, "Request already responded to" - self._completed = True + self.completed = True self._session._send_response(request_id=self.request_id, response=response) @@ -124,7 +124,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): if not self._entered: raise RuntimeError("RequestResponder must be used as a context manager") - self._completed = True # Mark as completed so it's removed from in_flight + self.completed = True # Mark as completed so it's removed from in_flight # Send an error response to indicate cancellation self._session._send_response( request_id=self.request_id, @@ -351,7 +351,7 @@ class BaseSession( self._in_flight[responder.request_id] = responder self._received_request(responder) - if not responder._completed: + if not responder.completed: self._handle_incoming(responder) elif isinstance(message.message.root, JSONRPCNotification): diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 24b206fdbe..1d7fd7d447 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -354,7 +354,7 @@ class LargeLanguageModel(AIModel): ) return 0 - def _calc_response_usage( + def calc_response_usage( self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int ) -> LLMUsage: """ diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 47290ee613..92427a7426 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -1,4 +1,5 @@ import enum +import json from typing import Any, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -162,8 +163,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): # Try to parse JSON string for arrays if isinstance(value, str): try: - import json - parsed_value = json.loads(value) if isinstance(parsed_value, list): return parsed_value @@ -176,8 +175,6 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): # Try to parse JSON string for objects if isinstance(value, str): try: - import json - parsed_value = json.loads(value) if isinstance(parsed_value, dict): return parsed_value diff --git a/api/core/plugin/utils/chunk_merger.py b/api/core/plugin/utils/chunk_merger.py index ec66ba02ee..e30076f9d3 100644 --- a/api/core/plugin/utils/chunk_merger.py +++ b/api/core/plugin/utils/chunk_merger.py @@ -82,7 +82,9 @@ def merge_blob_chunks( message_class = type(resp) merged_message = message_class( type=ToolInvokeMessage.MessageType.BLOB, - message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]), + message=ToolInvokeMessage.BlobMessage( + blob=bytes(files[chunk_id].data[: files[chunk_id].bytes_written]) + ), meta=resp.meta, ) yield cast(MessageType, merged_message) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d75a230d73..d15cb7cbc1 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -101,9 +101,22 @@ class SimplePromptTransform(PromptTransform): with_memory_prompt=histories is not None, ) - variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs} + custom_variable_keys_obj = prompt_template_config["custom_variable_keys"] + special_variable_keys_obj = prompt_template_config["special_variable_keys"] - for v in prompt_template_config["special_variable_keys"]: + # Type check for custom_variable_keys + if not isinstance(custom_variable_keys_obj, list): + raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}") + custom_variable_keys = cast(list[str], custom_variable_keys_obj) + + # Type check for special_variable_keys + if not isinstance(special_variable_keys_obj, list): + raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}") + special_variable_keys = cast(list[str], special_variable_keys_obj) + + variables = {k: inputs[k] for k in custom_variable_keys if k in inputs} + + for v in special_variable_keys: # support #context#, #query# and #histories# if v == "#context#": variables["#context#"] = context or "" @@ -113,9 +126,16 @@ class SimplePromptTransform(PromptTransform): variables["#histories#"] = histories or "" prompt_template = prompt_template_config["prompt_template"] + if not isinstance(prompt_template, PromptTemplateParser): + raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template)}") + prompt = prompt_template.format(variables) - return prompt, prompt_template_config["prompt_rules"] + prompt_rules = prompt_template_config["prompt_rules"] + if not isinstance(prompt_rules, dict): + raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}") + + return prompt, prompt_rules def get_prompt_template( self, @@ -126,11 +146,11 @@ class SimplePromptTransform(PromptTransform): has_context: bool, query_in_prompt: bool, with_memory_prompt: bool = False, - ): + ) -> dict[str, object]: prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) - custom_variable_keys = [] - special_variable_keys = [] + custom_variable_keys: list[str] = [] + special_variable_keys: list[str] = [] prompt = "" for order in prompt_rules["system_prompt_orders"]: diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 12d97c500f..d329220580 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -40,6 +40,19 @@ if TYPE_CHECKING: MetadataFilter = Union[DictFilter, common_types.Filter] +class PathQdrantParams(BaseModel): + path: str + + +class UrlQdrantParams(BaseModel): + url: str + api_key: Optional[str] + timeout: float + verify: bool + grpc_port: int + prefer_grpc: bool + + class QdrantConfig(BaseModel): endpoint: str api_key: Optional[str] = None @@ -50,7 +63,7 @@ class QdrantConfig(BaseModel): replication_factor: int = 1 write_consistency_factor: int = 1 - def to_qdrant_params(self): + def to_qdrant_params(self) -> PathQdrantParams | UrlQdrantParams: if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): @@ -58,23 +71,23 @@ class QdrantConfig(BaseModel): raise ValueError("Root path is not set") path = os.path.join(self.root_path, path) - return {"path": path} + return PathQdrantParams(path=path) else: - return { - "url": self.endpoint, - "api_key": self.api_key, - "timeout": self.timeout, - "verify": self.endpoint.startswith("https"), - "grpc_port": self.grpc_port, - "prefer_grpc": self.prefer_grpc, - } + return UrlQdrantParams( + url=self.endpoint, + api_key=self.api_key, + timeout=self.timeout, + verify=self.endpoint.startswith("https"), + grpc_port=self.grpc_port, + prefer_grpc=self.prefer_grpc, + ) class QdrantVector(BaseVector): def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"): super().__init__(collection_name) self._client_config = config - self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) + self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params().model_dump()) self._distance_func = distance_func.upper() self._group_id = group_id diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index b36252dba2..95ad9f25fe 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -94,10 +94,10 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER # In-memory cache for workflow node executions - self._execution_cache: dict[str, WorkflowNodeExecution] = {} + self._execution_cache = {} # Cache for mapping workflow_execution_ids to execution IDs for efficient retrieval - self._workflow_execution_mapping: dict[str, list[str]] = {} + self._workflow_execution_mapping = {} logger.info( "Initialized CeleryWorkflowNodeExecutionRepository for tenant %s, app %s, triggered_from %s", diff --git a/api/core/variables/segment_group.py b/api/core/variables/segment_group.py index b363255b2c..0a41b64228 100644 --- a/api/core/variables/segment_group.py +++ b/api/core/variables/segment_group.py @@ -4,7 +4,7 @@ from .types import SegmentType class SegmentGroup(Segment): value_type: SegmentType = SegmentType.GROUP - value: list[Segment] + value: list[Segment] = None # type: ignore @property def text(self): diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 7da43a6504..28644b0169 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -74,12 +74,12 @@ class NoneSegment(Segment): class StringSegment(Segment): value_type: SegmentType = SegmentType.STRING - value: str + value: str = None # type: ignore class FloatSegment(Segment): value_type: SegmentType = SegmentType.FLOAT - value: float + value: float = None # type: ignore # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # The following tests cannot pass. # @@ -98,12 +98,12 @@ class FloatSegment(Segment): class IntegerSegment(Segment): value_type: SegmentType = SegmentType.INTEGER - value: int + value: int = None # type: ignore class ObjectSegment(Segment): value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Any] + value: Mapping[str, Any] = None # type: ignore @property def text(self) -> str: @@ -136,7 +136,7 @@ class ArraySegment(Segment): class FileSegment(Segment): value_type: SegmentType = SegmentType.FILE - value: File + value: File = None # type: ignore @property def markdown(self) -> str: @@ -153,17 +153,17 @@ class FileSegment(Segment): class BooleanSegment(Segment): value_type: SegmentType = SegmentType.BOOLEAN - value: bool + value: bool = None # type: ignore class ArrayAnySegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Any] + value: Sequence[Any] = None # type: ignore class ArrayStringSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[str] + value: Sequence[str] = None # type: ignore @property def text(self) -> str: @@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment): class ArrayNumberSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[float | int] + value: Sequence[float | int] = None # type: ignore class ArrayObjectSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[Mapping[str, Any]] + value: Sequence[Mapping[str, Any]] = None # type: ignore class ArrayFileSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[File] + value: Sequence[File] = None # type: ignore @property def markdown(self) -> str: @@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment): class ArrayBooleanSegment(ArraySegment): value_type: SegmentType = SegmentType.ARRAY_BOOLEAN - value: Sequence[bool] + value: Sequence[bool] = None # type: ignore def get_segment_discriminator(v: Any) -> SegmentType | None: diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index 594bb2b32e..63513bdc9f 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -3,6 +3,6 @@ from core.workflow.nodes.base import BaseNode class WorkflowNodeRunFailedError(Exception): def __init__(self, node: BaseNode, err_msg: str): - self._node = node - self._error = err_msg + self.node = node + self.error = err_msg super().__init__(f"Node {node.title} run failed: {err_msg}") diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index eb7b9fc2c6..cf46870254 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -67,8 +67,8 @@ class ListOperatorNode(BaseNode): return "1" def _run(self): - inputs: dict[str, list] = {} - process_data: dict[str, list] = {} + inputs: dict[str, Sequence[object]] = {} + process_data: dict[str, Sequence[object]] = {} outputs: dict[str, Any] = {} variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index c34a06d981..fdcdac1ec2 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1183,7 +1183,8 @@ def _combine_message_content_with_role( return AssistantPromptMessage(content=contents) case PromptMessageRole.SYSTEM: return SystemPromptMessage(content=contents) - raise NotImplementedError(f"Role {role} is not supported") + case _: + raise NotImplementedError(f"Role {role} is not supported") def _render_jinja2_message( diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 9433b312cf..f2c37e1a4b 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -462,9 +462,9 @@ class StorageKeyLoader: upload_file_row = upload_files.get(model_id) if upload_file_row is None: raise ValueError(f"Upload file not found for id: {model_id}") - file._storage_key = upload_file_row.key + file.storage_key = upload_file_row.key elif file.transfer_method == FileTransferMethod.TOOL_FILE: tool_file_row = tool_files.get(model_id) if tool_file_row is None: raise ValueError(f"Tool file not found for id: {model_id}") - file._storage_key = tool_file_row.file_key + file.storage_key = tool_file_row.file_key diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index 8288bd54a3..b2b793d40e 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -12,4 +12,7 @@ def serialize_value_type(v: _VarTypedDict | Segment) -> str: if isinstance(v, Segment): return v.value_type.exposed_type().value else: - return v["value_type"].exposed_type().value + value_type = v.get("value_type") + if value_type is None: + raise ValueError("value_type is required but not provided") + return value_type.exposed_type().value diff --git a/api/libs/external_api.py b/api/libs/external_api.py index cee80f7f24..cf91b0117f 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -69,6 +69,8 @@ def register_external_error_handlers(api: Api): headers["WWW-Authenticate"] = 'Bearer realm="api"' return data, status_code, headers + _ = handle_http_exception + @api.errorhandler(ValueError) def handle_value_error(e: ValueError): got_request_exception.send(current_app, exception=e) @@ -76,6 +78,8 @@ def register_external_error_handlers(api: Api): data = {"code": "invalid_param", "message": str(e), "status": status_code} return data, status_code + _ = handle_value_error + @api.errorhandler(AppInvokeQuotaExceededError) def handle_quota_exceeded(e: AppInvokeQuotaExceededError): got_request_exception.send(current_app, exception=e) @@ -83,15 +87,17 @@ def register_external_error_handlers(api: Api): data = {"code": "too_many_requests", "message": str(e), "status": status_code} return data, status_code + _ = handle_quota_exceeded + @api.errorhandler(Exception) def handle_general_exception(e: Exception): got_request_exception.send(current_app, exception=e) status_code = 500 - data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)}) + data = getattr(e, "data", {"message": http_status_message(status_code)}) # 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response) - if not isinstance(data, Mapping): + if not isinstance(data, dict): data = {"message": str(e)} data.setdefault("code", "unknown") @@ -101,10 +107,12 @@ def register_external_error_handlers(api: Api): exc_info: Any = sys.exc_info() if exc_info[1] is None: exc_info = None - current_app.log_exception(exc_info) # ty: ignore [invalid-argument-type] + current_app.log_exception(exc_info) return data, status_code + _ = handle_general_exception + class ExternalApi(Api): _authorizations = { diff --git a/api/libs/helper.py b/api/libs/helper.py index 139cb329de..f3c46b4843 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -167,13 +167,6 @@ class DatetimeString: return value -def _get_float(value): - try: - return float(value) - except (TypeError, ValueError): - raise ValueError(f"{value} is not a valid float") - - def timezone(timezone_string): if timezone_string and timezone_string in available_timezones(): return timezone_string diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index 352161523f..7c59c2ca28 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -1,24 +1,44 @@ { "include": ["."], - "exclude": [".venv", "tests/", "migrations/"], - "ignore": [ - "core/", - "controllers/", - "tasks/", - "services/", - "schedule/", - "extensions/", - "utils/", - "repositories/", - "libs/", - "fields/", - "factories/", - "events/", - "contexts/", - "constants/", - "commands.py" + "exclude": [ + ".venv", + "tests/", + "migrations/", + "core/rag", + "extensions", + "libs", + "controllers/console/datasets", + "controllers/service_api/dataset", + "core/ops", + "core/tools", + "core/model_runtime", + "core/workflow", + "core/app/app_config/easy_ui_based_app/dataset" ], "typeCheckingMode": "strict", + "allowedUntypedLibraries": [ + "flask_restx", + "flask_login", + "opentelemetry.instrumentation.celery", + "opentelemetry.instrumentation.flask", + "opentelemetry.instrumentation.requests", + "opentelemetry.instrumentation.sqlalchemy", + "opentelemetry.instrumentation.redis" + ], + "reportUnknownMemberType": "hint", + "reportUnknownParameterType": "hint", + "reportUnknownArgumentType": "hint", + "reportUnknownVariableType": "hint", + "reportUnknownLambdaType": "hint", + "reportMissingParameterType": "hint", + "reportMissingTypeArgument": "hint", + "reportUnnecessaryContains": "hint", + "reportUnnecessaryComparison": "hint", + "reportUnnecessaryCast": "hint", + "reportUnnecessaryIsInstance": "hint", + "reportUntypedFunctionDecorator": "hint", + + "reportAttributeAccessIssue": "hint", "pythonVersion": "3.11", "pythonPlatform": "All" } diff --git a/api/services/account_service.py b/api/services/account_service.py index a76792f88e..f66c1aa677 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1318,7 +1318,7 @@ class RegisterService: def get_invitation_if_token_valid( cls, workspace_id: Optional[str], email: str, token: str ) -> Optional[dict[str, Any]]: - invitation_data = cls._get_invitation_by_token(token, workspace_id, email) + invitation_data = cls.get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None @@ -1355,7 +1355,7 @@ class RegisterService: } @classmethod - def _get_invitation_by_token( + def get_invitation_by_token( cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None ) -> Optional[dict[str, str]]: if workspace_id is not None and email is not None: diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index ba86a31240..82b1d21179 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -349,7 +349,7 @@ class AppAnnotationService: try: # Skip the first row - df = pd.read_csv(file, dtype=str) + df = pd.read_csv(file.stream, dtype=str) result = [] for _, row in df.iterrows(): content = {"question": row.iloc[0], "answer": row.iloc[1]} @@ -463,15 +463,23 @@ class AppAnnotationService: annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail - return { - "id": annotation_setting.id, - "enabled": True, - "score_threshold": annotation_setting.score_threshold, - "embedding_model": { - "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name, - }, - } + if collection_binding_detail: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name, + }, + } + else: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": {}, + } return {"enabled": False} @classmethod @@ -506,15 +514,23 @@ class AppAnnotationService: collection_binding_detail = annotation_setting.collection_binding_detail - return { - "id": annotation_setting.id, - "enabled": True, - "score_threshold": annotation_setting.score_threshold, - "embedding_model": { - "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name, - }, - } + if collection_binding_detail: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": { + "embedding_provider_name": collection_binding_detail.provider_name, + "embedding_model_name": collection_binding_detail.model_name, + }, + } + else: + return { + "id": annotation_setting.id, + "enabled": True, + "score_threshold": annotation_setting.score_threshold, + "embedding_model": {}, + } @classmethod def clear_all_annotations(cls, app_id: str): diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 2f1b63664f..3b4cb1900a 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -407,6 +407,7 @@ class ClearFreePlanTenantExpiredLogs: datetime.timedelta(hours=1), ] + tenant_count = 0 for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 65dc673100..20a9c73f08 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -134,11 +134,14 @@ class DatasetService: # Check if tag_ids is not empty to avoid WHERE false condition if tag_ids and len(tag_ids) > 0: - target_ids = TagService.get_target_ids_by_tag_ids( - "knowledge", - tenant_id, # ty: ignore [invalid-argument-type] - tag_ids, - ) + if tenant_id is not None: + target_ids = TagService.get_target_ids_by_tag_ids( + "knowledge", + tenant_id, + tag_ids, + ) + else: + target_ids = [] if target_ids and len(target_ids) > 0: query = query.where(Dataset.id.in_(target_ids)) else: @@ -987,7 +990,8 @@ class DocumentService: for document in documents if document.data_source_type == "upload_file" and document.data_source_info_dict ] - batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) + if dataset.doc_form is not None: + batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) for document in documents: db.session.delete(document) @@ -2688,56 +2692,6 @@ class SegmentService: return paginated_segments.items, paginated_segments.total - @classmethod - def update_segment_by_id( - cls, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, segment_data: dict, user_id: str - ) -> tuple[DocumentSegment, Document]: - """Update a segment by its ID with validation and checks.""" - # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() - if not dataset: - raise NotFound("Dataset not found.") - - # check user's model setting - DatasetService.check_dataset_model_setting(dataset) - - # check document - document = DocumentService.get_document(dataset_id, document_id) - if not document: - raise NotFound("Document not found.") - - # check embedding model setting if high quality - if dataset.indexing_technique == "high_quality": - try: - model_manager = ModelManager() - model_manager.get_model_instance( - tenant_id=user_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) - - # check segment - segment = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) - .first() - ) - if not segment: - raise NotFound("Segment not found.") - - # validate and update segment - cls.segment_create_args_validate(segment_data, document) - updated_segment = cls.update_segment(SegmentUpdateArgs(**segment_data), segment, document, dataset) - - return updated_segment, document - @classmethod def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: """Get a segment by its ID.""" diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 3262a00663..3911b763b6 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -181,7 +181,7 @@ class ExternalDatasetService: do http request depending on api bundle """ - kwargs = { + kwargs: dict[str, Any] = { "url": settings.url, "headers": settings.headers, "follow_redirects": True, diff --git a/api/services/file_service.py b/api/services/file_service.py index 8a4655d25e..364a872a91 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,7 +1,7 @@ import hashlib import os import uuid -from typing import Any, Literal, Union +from typing import Literal, Union from werkzeug.exceptions import NotFound @@ -35,7 +35,7 @@ class FileService: filename: str, content: bytes, mimetype: str, - user: Union[Account, EndUser, Any], + user: Union[Account, EndUser], source: Literal["datasets"] | None = None, source_url: str = "", ) -> UploadFile: diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index c638087f63..d0e2230540 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -165,7 +165,7 @@ class ModelLoadBalancingService: try: if load_balancing_config.encrypted_config: - credentials = json.loads(load_balancing_config.encrypted_config) + credentials: dict[str, object] = json.loads(load_balancing_config.encrypted_config) else: credentials = {} except JSONDecodeError: @@ -180,11 +180,13 @@ class ModelLoadBalancingService: for variable in credential_secret_variables: if variable in credentials: try: - credentials[variable] = encrypter.decrypt_token_with_decoding( - credentials.get(variable), # ty: ignore [invalid-argument-type] - decoding_rsa_key, - decoding_cipher_rsa, - ) + token_value = credentials.get(variable) + if isinstance(token_value, str): + credentials[variable] = encrypter.decrypt_token_with_decoding( + token_value, + decoding_rsa_key, + decoding_cipher_rsa, + ) except ValueError: pass @@ -345,8 +347,9 @@ class ModelLoadBalancingService: credential_id = config.get("credential_id") enabled = config.get("enabled") + credential_record: ProviderCredential | ProviderModelCredential | None = None + if credential_id: - credential_record: ProviderCredential | ProviderModelCredential | None = None if config_from == "predefined-model": credential_record = ( db.session.query(ProviderCredential) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 8dbf117fd3..bae2921a27 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -99,6 +99,7 @@ class PluginMigration: datetime.timedelta(hours=1), ] + tenant_count = 0 for test_interval in test_intervals: tenant_count = ( session.query(Tenant.id) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index bce389b949..cb31111485 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -223,8 +223,8 @@ class BuiltinToolManageService: """ add builtin tool provider """ - try: - with Session(db.engine) as session: + with Session(db.engine) as session: + try: lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" with redis_client.lock(lock, timeout=20): provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) @@ -285,9 +285,9 @@ class BuiltinToolManageService: session.add(db_provider) session.commit() - except Exception as e: - session.rollback() - raise ValueError(str(e)) + except Exception as e: + session.rollback() + raise ValueError(str(e)) return {"result": "success"} @staticmethod diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 2994856b54..8a58289b22 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -18,6 +18,7 @@ from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.nodes import NodeType from events.app_event import app_was_created from extensions.ext_database import db @@ -420,7 +421,11 @@ class WorkflowConverter: query_in_prompt=False, ) - template = prompt_template_config["prompt_template"].template + prompt_template_obj = prompt_template_config["prompt_template"] + if not isinstance(prompt_template_obj, PromptTemplateParser): + raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}") + + template = prompt_template_obj.template if not template: prompts = [] else: @@ -457,7 +462,11 @@ class WorkflowConverter: query_in_prompt=False, ) - template = prompt_template_config["prompt_template"].template + prompt_template_obj = prompt_template_config["prompt_template"] + if not isinstance(prompt_template_obj, PromptTemplateParser): + raise TypeError(f"Expected PromptTemplateParser, got {type(prompt_template_obj)}") + + template = prompt_template_obj.template template = self._replace_template_variables( template=template, variables=start_node["data"]["variables"], @@ -467,6 +476,9 @@ class WorkflowConverter: prompts = {"text": template} prompt_rules = prompt_template_config["prompt_rules"] + if not isinstance(prompt_rules, dict): + raise TypeError(f"Expected dict for prompt_rules, got {type(prompt_rules)}") + role_prefix = { "user": prompt_rules.get("human_prefix", "Human"), "assistant": prompt_rules.get("assistant_prefix", "Assistant"), diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0a14007349..4e0ae15841 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -769,10 +769,10 @@ class WorkflowService: ) error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: - node = e._node + node = e.node run_succeeded = False node_run_result = None - error = e._error + error = e.error # Create a NodeExecution domain model node_execution = WorkflowNodeExecution( diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index d4fc68a084..292ac6e008 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -12,7 +12,7 @@ class WorkspaceService: def get_tenant_info(cls, tenant: Tenant): if not tenant: return None - tenant_info = { + tenant_info: dict[str, object] = { "id": tenant.id, "name": tenant.name, "plan": tenant.plan, diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 415e65ce51..6b5ac713e6 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -3278,7 +3278,7 @@ class TestRegisterService: redis_client.setex(cache_key, 24 * 60 * 60, account_id) # Execute invitation retrieval - result = RegisterService._get_invitation_by_token( + result = RegisterService.get_invitation_by_token( token=token, workspace_id=workspace_id, email=email, @@ -3316,7 +3316,7 @@ class TestRegisterService: redis_client.setex(token_key, 24 * 60 * 60, json.dumps(invitation_data)) # Execute invitation retrieval - result = RegisterService._get_invitation_by_token(token=token) + result = RegisterService.get_invitation_by_token(token=token) # Verify result contains expected data assert result is not None diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 8b3db27525..18ab4bb73c 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -14,6 +14,7 @@ from core.app.app_config.entities import ( VariableEntityType, ) from core.model_runtime.entities.llm_entities import LLMMode +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.account import Account, Tenant from models.api_based_extension import APIBasedExtension from models.model import App, AppMode, AppModelConfig @@ -37,7 +38,7 @@ class TestWorkflowConverter: # Setup default mock returns mock_encrypter.decrypt_token.return_value = "decrypted_api_key" mock_prompt_transform.return_value.get_prompt_template.return_value = { - "prompt_template": type("obj", (object,), {"template": "You are a helpful assistant {{text_input}}"})(), + "prompt_template": PromptTemplateParser(template="You are a helpful assistant {{text_input}}"), "prompt_rules": {"human_prefix": "Human", "assistant_prefix": "Assistant"}, } mock_agent_chat_config_manager.get_app_config.return_value = self._create_mock_app_config() diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 442839e44e..d7404ee90a 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1370,8 +1370,8 @@ class TestRegisterService: account_id="user-123", email="test@example.com" ) - with patch("services.account_service.RegisterService._get_invitation_by_token") as mock_get_invitation_by_token: - # Mock the invitation data returned by _get_invitation_by_token + with patch("services.account_service.RegisterService.get_invitation_by_token") as mock_get_invitation_by_token: + # Mock the invitation data returned by get_invitation_by_token invitation_data = { "account_id": "user-123", "email": "test@example.com", @@ -1503,12 +1503,12 @@ class TestRegisterService: assert result == "member_invite:token:test-token" def test_get_invitation_by_token_with_workspace_and_email(self, mock_redis_dependencies): - """Test _get_invitation_by_token with workspace ID and email.""" + """Test get_invitation_by_token with workspace ID and email.""" # Setup mock mock_redis_dependencies.get.return_value = b"user-123" # Execute test - result = RegisterService._get_invitation_by_token("token-123", "workspace-456", "test@example.com") + result = RegisterService.get_invitation_by_token("token-123", "workspace-456", "test@example.com") # Verify results assert result is not None @@ -1517,7 +1517,7 @@ class TestRegisterService: assert result["workspace_id"] == "workspace-456" def test_get_invitation_by_token_without_workspace_and_email(self, mock_redis_dependencies): - """Test _get_invitation_by_token without workspace ID and email.""" + """Test get_invitation_by_token without workspace ID and email.""" # Setup mock invitation_data = { "account_id": "user-123", @@ -1527,19 +1527,19 @@ class TestRegisterService: mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode() # Execute test - result = RegisterService._get_invitation_by_token("token-123") + result = RegisterService.get_invitation_by_token("token-123") # Verify results assert result is not None assert result == invitation_data def test_get_invitation_by_token_no_data(self, mock_redis_dependencies): - """Test _get_invitation_by_token with no data.""" + """Test get_invitation_by_token with no data.""" # Setup mock mock_redis_dependencies.get.return_value = None # Execute test - result = RegisterService._get_invitation_by_token("token-123") + result = RegisterService.get_invitation_by_token("token-123") # Verify results assert result is None From 928bef9d82c8e12b0ee645ae7177dc348f638a8b Mon Sep 17 00:00:00 2001 From: 17hz <0x149527@gmail.com> Date: Wed, 10 Sep 2025 08:45:00 +0800 Subject: [PATCH 09/86] fix: imporve the condition for stopping the think timer. (#25365) --- web/app/components/base/markdown-blocks/think-block.tsx | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/web/app/components/base/markdown-blocks/think-block.tsx b/web/app/components/base/markdown-blocks/think-block.tsx index 46f992d758..a5813266f1 100644 --- a/web/app/components/base/markdown-blocks/think-block.tsx +++ b/web/app/components/base/markdown-blocks/think-block.tsx @@ -1,5 +1,6 @@ import React, { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import { useChatContext } from '../chat/chat/context' const hasEndThink = (children: any): boolean => { if (typeof children === 'string') @@ -35,6 +36,7 @@ const removeEndThink = (children: any): any => { } const useThinkTimer = (children: any) => { + const { isResponding } = useChatContext() const [startTime] = useState(Date.now()) const [elapsedTime, setElapsedTime] = useState(0) const [isComplete, setIsComplete] = useState(false) @@ -54,9 +56,9 @@ const useThinkTimer = (children: any) => { }, [startTime, isComplete]) useEffect(() => { - if (hasEndThink(children)) + if (hasEndThink(children) || !isResponding) setIsComplete(true) - }, [children]) + }, [children, isResponding]) return { elapsedTime, isComplete } } From cce13750adbc33e0740c5693280f26fb4b9bb698 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 10 Sep 2025 09:51:21 +0900 Subject: [PATCH 10/86] add rule for strenum (#25445) --- api/.ruff.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/api/.ruff.toml b/api/.ruff.toml index 9668dc9f76..9a15754d9a 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -45,6 +45,7 @@ select = [ "G001", # don't use str format to logging messages "G003", # don't use + in logging messages "G004", # don't use f-strings to format logging messages + "UP042", # use StrEnum ] ignore = [ From 6574e9f0b2ac270547b0b5f52b1b44603ad6130e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Newton=20Jos=C3=A9?= Date: Tue, 9 Sep 2025 21:58:39 -0300 Subject: [PATCH 11/86] Fix: Add Password Validation to Account Creation (#25382) --- api/services/account_service.py | 2 ++ .../services/test_account_service.py | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/api/services/account_service.py b/api/services/account_service.py index f66c1aa677..f917959350 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -246,6 +246,8 @@ class AccountService: account.name = name if password: + valid_password(password) + # generate password salt salt = secrets.token_bytes(16) base64_salt = base64.b64encode(salt).decode() diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 6b5ac713e6..dac1fe643a 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -91,6 +91,28 @@ class TestAccountService: assert account.password is None assert account.password_salt is None + def test_create_account_password_invalid_new_password( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test account create with invalid new password format. + """ + fake = Faker() + email = fake.email() + name = fake.name() + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Test with too short password (assuming minimum length validation) + with pytest.raises(ValueError): # Password validation error + AccountService.create_account( + email=email, + name=name, + interface_language="en-US", + password="invalid_new_password", + ) + def test_create_account_registration_disabled(self, db_session_with_containers, mock_external_service_dependencies): """ Test account creation when registration is disabled. From 45ef177809e36afa2529abcf862e57fec6f2159b Mon Sep 17 00:00:00 2001 From: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Date: Wed, 10 Sep 2025 10:02:53 +0800 Subject: [PATCH 12/86] Feature add test containers create segment to index task (#25450) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../test_create_segment_to_index_task.py | 1099 +++++++++++++++++ 1 file changed, 1099 insertions(+) create mode 100644 api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py new file mode 100644 index 0000000000..de81295100 --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -0,0 +1,1099 @@ +""" +Integration tests for create_segment_to_index_task using TestContainers. + +This module provides comprehensive testing for the create_segment_to_index_task +which handles asynchronous document segment indexing operations. +""" + +import time +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from faker import Faker + +from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.create_segment_to_index_task import create_segment_to_index_task + + +class TestCreateSegmentToIndexTask: + """Integration tests for create_segment_to_index_task using testcontainers.""" + + @pytest.fixture(autouse=True) + def cleanup_database(self, db_session_with_containers): + """Clean up database and Redis before each test to ensure isolation.""" + from extensions.ext_database import db + + # Clear all test data + db.session.query(DocumentSegment).delete() + db.session.query(Document).delete() + db.session.query(Dataset).delete() + db.session.query(TenantAccountJoin).delete() + db.session.query(Tenant).delete() + db.session.query(Account).delete() + db.session.commit() + + # Clear Redis cache + redis_client.flushdb() + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("tasks.create_segment_to_index_task.IndexProcessorFactory") as mock_factory, + ): + # Setup default mock returns + mock_processor = MagicMock() + mock_factory.return_value.init_index_processor.return_value = mock_processor + + yield { + "index_processor_factory": mock_factory, + "index_processor": mock_processor, + } + + def _create_test_account_and_tenant(self, db_session_with_containers): + """ + Helper method to create a test account and tenant for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + + Returns: + tuple: (account, tenant) - Created account and tenant instances + """ + fake = Faker() + + # Create account + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + + from extensions.ext_database import db + + db.session.add(account) + db.session.commit() + + # Create tenant + tenant = Tenant( + name=fake.company(), + status="normal", + plan="basic", + ) + db.session.add(tenant) + db.session.commit() + + # Create tenant-account join with owner role + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER.value, + current=True, + ) + db.session.add(join) + db.session.commit() + + # Set current tenant for account + account.current_tenant = tenant + + return account, tenant + + def _create_test_dataset_and_document(self, db_session_with_containers, tenant_id, account_id): + """ + Helper method to create a test dataset and document for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + tenant_id: Tenant ID for the dataset + account_id: Account ID for the document + + Returns: + tuple: (dataset, document) - Created dataset and document instances + """ + fake = Faker() + + # Create dataset + dataset = Dataset( + name=fake.company(), + description=fake.text(max_nb_chars=100), + tenant_id=tenant_id, + data_source_type="upload_file", + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + created_by=account_id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + # Create document + document = Document( + name=fake.file_name(), + dataset_id=dataset.id, + tenant_id=tenant_id, + position=1, + data_source_type="upload_file", + batch="test_batch", + created_from="upload_file", + created_by=account_id, + enabled=True, + archived=False, + indexing_status="completed", + doc_form="qa_model", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + return dataset, document + + def _create_test_segment( + self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status="waiting" + ): + """ + Helper method to create a test document segment for testing. + + Args: + db_session_with_containers: Database session from testcontainers infrastructure + dataset_id: Dataset ID for the segment + document_id: Document ID for the segment + tenant_id: Tenant ID for the segment + account_id: Account ID for the segment + status: Initial status of the segment + + Returns: + DocumentSegment: Created document segment instance + """ + fake = Faker() + + segment = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + position=1, + content=fake.text(max_nb_chars=500), + answer=fake.text(max_nb_chars=200), + word_count=len(fake.text(max_nb_chars=500).split()), + tokens=len(fake.text(max_nb_chars=500).split()) * 2, + keywords=["test", "document", "segment"], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status=status, + created_by=account_id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + return segment + + def test_create_segment_to_index_success(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test successful creation of segment to index. + + This test verifies: + - Segment status transitions from waiting to indexing to completed + - Index processor is called with correct parameters + - Segment metadata is properly updated + - Redis cache key is cleaned up + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify segment status changes + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + assert segment.error is None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + # Verify Redis cache cleanup + cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(cache_key) == 0 + + def test_create_segment_to_index_segment_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of non-existent segment ID. + + This test verifies: + - Task gracefully handles missing segment + - No exceptions are raised + - Database session is properly closed + """ + # Arrange: Use non-existent segment ID + non_existent_segment_id = str(uuid4()) + + # Act & Assert: Task should complete without error + result = create_segment_to_index_task(non_existent_segment_id) + assert result is None + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_invalid_status( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with invalid status. + + This test verifies: + - Task skips segments not in 'waiting' status + - No processing occurs for invalid status + - Database session is properly closed + """ + # Arrange: Create segment with invalid status + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="completed" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status unchanged + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is None + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_no_dataset(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test handling of segment without associated dataset. + + This test verifies: + - Task gracefully handles missing dataset + - Segment status remains unchanged + - No processing occurs + """ + # Arrange: Create segment with invalid dataset_id + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + invalid_dataset_id = str(uuid4()) + + # Create document with invalid dataset_id + document = Document( + name="test_doc", + dataset_id=invalid_dataset_id, + tenant_id=tenant.id, + position=1, + data_source_type="upload_file", + batch="test_batch", + created_from="upload_file", + created_by=account.id, + enabled=True, + archived=False, + indexing_status="completed", + doc_form="text_model", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, invalid_dataset_id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_no_document(self, db_session_with_containers, mock_external_service_dependencies): + """ + Test handling of segment without associated document. + + This test verifies: + - Task gracefully handles missing document + - Segment status remains unchanged + - No processing occurs + """ + # Arrange: Create segment with invalid document_id + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, _ = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + invalid_document_id = str(uuid4()) + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, invalid_document_id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_document_disabled( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with disabled document. + + This test verifies: + - Task skips segments with disabled documents + - No processing occurs for disabled documents + - Segment status remains unchanged + """ + # Arrange: Create disabled document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Disable the document + document.enabled = False + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_document_archived( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with archived document. + + This test verifies: + - Task skips segments with archived documents + - No processing occurs for archived documents + - Segment status remains unchanged + """ + # Arrange: Create archived document + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Archive the document + document.archived = True + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_document_indexing_incomplete( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of segment with document that has incomplete indexing. + + This test verifies: + - Task skips segments with incomplete indexing documents + - No processing occurs for incomplete indexing + - Segment status remains unchanged + """ + # Arrange: Create document with incomplete indexing + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Set incomplete indexing status + document.indexing_status = "indexing" + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + result = create_segment_to_index_task(segment.id) + + # Assert: Task should complete without processing + assert result is None + + # Verify segment status changed to indexing (task updates status before checking document) + db_session_with_containers.refresh(segment) + assert segment.status == "indexing" + + # Verify no index processor calls were made + mock_external_service_dependencies["index_processor_factory"].assert_not_called() + + def test_create_segment_to_index_processor_exception( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test handling of index processor exceptions. + + This test verifies: + - Task properly handles index processor failures + - Segment status is updated to error + - Segment is disabled with error information + - Redis cache is cleaned up despite errors + """ + # Arrange: Create test data and mock processor exception + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Mock processor to raise exception + mock_external_service_dependencies["index_processor"].load.side_effect = Exception("Processor failed") + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify error handling + db_session_with_containers.refresh(segment) + assert segment.status == "error" + assert segment.enabled is False + assert segment.disabled_at is not None + assert segment.error == "Processor failed" + + # Verify Redis cache cleanup still occurs + cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(cache_key) == 0 + + def test_create_segment_to_index_with_keywords( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with custom keywords. + + This test verifies: + - Task accepts and processes keywords parameter + - Keywords are properly passed through the task + - Indexing completes successfully with keywords + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + custom_keywords = ["custom", "keywords", "test"] + + # Act: Execute the task with keywords + create_segment_to_index_task(segment.id, keywords=custom_keywords) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + def test_create_segment_to_index_different_doc_forms( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with different document forms. + + This test verifies: + - Task works with various document forms + - Index processor factory receives correct doc_form + - Processing completes successfully for different forms + """ + # Arrange: Test different doc_forms + doc_forms = ["qa_model", "text_model", "web_model"] + + for doc_form in doc_forms: + # Create fresh test data for each form + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document( + db_session_with_containers, tenant.id, account.id + ) + + # Update document's doc_form for testing + document.doc_form = doc_form + db_session_with_containers.commit() + + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + + # Verify correct doc_form was passed to factory + mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form) + + def test_create_segment_to_index_performance_timing( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing performance and timing. + + This test verifies: + - Task execution time is reasonable + - Performance metrics are properly recorded + - No significant performance degradation + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task and measure time + start_time = time.time() + create_segment_to_index_task(segment.id) + end_time = time.time() + + # Assert: Verify performance + execution_time = end_time - start_time + assert execution_time < 5.0 # Should complete within 5 seconds + + # Verify successful completion + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + + def test_create_segment_to_index_concurrent_execution( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test concurrent execution of segment indexing tasks. + + This test verifies: + - Multiple tasks can run concurrently + - No race conditions occur + - All segments are processed correctly + """ + # Arrange: Create multiple test segments + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + segments = [] + for i in range(3): + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + segments.append(segment) + + # Act: Execute tasks concurrently (simulated) + segment_ids = [segment.id for segment in segments] + for segment_id in segment_ids: + create_segment_to_index_task(segment_id) + + # Assert: Verify all segments processed + for segment in segments: + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called for each segment + assert mock_external_service_dependencies["index_processor_factory"].call_count == 3 + + def test_create_segment_to_index_large_content( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with large content. + + This test verifies: + - Task handles large content segments + - Performance remains acceptable with large content + - No memory or processing issues occur + """ + # Arrange: Create segment with large content + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Generate large content (simulate large document) + large_content = "Large content " * 1000 # ~15KB content + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content=large_content, + answer="Large answer " * 100, + word_count=len(large_content.split()), + tokens=len(large_content.split()) * 2, + keywords=["large", "content", "test"], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status="waiting", + created_by=account.id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Act: Execute the task + start_time = time.time() + create_segment_to_index_task(segment.id) + end_time = time.time() + + # Assert: Verify successful processing + execution_time = end_time - start_time + assert execution_time < 10.0 # Should complete within 10 seconds + + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + def test_create_segment_to_index_redis_failure( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing when Redis operations fail. + + This test verifies: + - Task continues to work even if Redis fails + - Indexing completes successfully + - Redis errors don't affect core functionality + """ + # Arrange: Create test data and mock Redis failure + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Set up Redis cache key to simulate indexing in progress + cache_key = f"segment_{segment.id}_indexing" + redis_client.set(cache_key, "processing", ex=300) + + # Mock Redis to raise exception in finally block + with patch.object(redis_client, "delete", side_effect=Exception("Redis connection failed")): + # Act: Execute the task - Redis failure should not prevent completion + with pytest.raises(Exception) as exc_info: + create_segment_to_index_task(segment.id) + + # Verify the exception contains the expected Redis error message + assert "Redis connection failed" in str(exc_info.value) + + # Assert: Verify indexing still completed successfully despite Redis failure + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify Redis cache key still exists (since delete failed) + assert redis_client.exists(cache_key) == 1 + + def test_create_segment_to_index_database_transaction_rollback( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with database transaction handling. + + This test verifies: + - Database transactions are properly managed + - Rollback occurs on errors + - Data consistency is maintained + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Mock global database session to simulate transaction issues + from extensions.ext_database import db + + original_commit = db.session.commit + commit_called = False + + def mock_commit(): + nonlocal commit_called + if not commit_called: + commit_called = True + raise Exception("Database commit failed") + return original_commit() + + db.session.commit = mock_commit + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify error handling and rollback + db_session_with_containers.refresh(segment) + assert segment.status == "error" + assert segment.enabled is False + assert segment.disabled_at is not None + assert segment.error is not None + + # Restore original commit method + db.session.commit = original_commit + + def test_create_segment_to_index_metadata_validation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with metadata validation. + + This test verifies: + - Document metadata is properly constructed + - All required metadata fields are present + - Metadata is correctly passed to index processor + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + + # Verify index processor was called with correct metadata + mock_processor = mock_external_service_dependencies["index_processor"] + mock_processor.load.assert_called_once() + + # Get the call arguments to verify metadata structure + call_args = mock_processor.load.call_args + assert len(call_args[0]) == 2 # dataset and documents + + # Verify basic structure without deep object inspection + called_dataset = call_args[0][0] # first arg should be dataset + assert called_dataset is not None + + documents = call_args[0][1] # second arg should be list of documents + assert len(documents) == 1 + doc = documents[0] + assert doc is not None + + def test_create_segment_to_index_status_transition_flow( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test complete status transition flow during indexing. + + This test verifies: + - Status transitions: waiting -> indexing -> completed + - Timestamps are properly recorded at each stage + - No intermediate states are skipped + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Verify initial state + assert segment.status == "waiting" + assert segment.indexing_at is None + assert segment.completed_at is None + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify final state + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify timestamp ordering + assert segment.indexing_at <= segment.completed_at + + def test_create_segment_to_index_with_empty_content( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with empty or minimal content. + + This test verifies: + - Task handles empty content gracefully + - Indexing completes successfully with minimal content + - No errors occur with edge case content + """ + # Arrange: Create segment with minimal content + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content="", # Empty content + answer="", + word_count=0, + tokens=0, + keywords=[], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status="waiting", + created_by=account.id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + def test_create_segment_to_index_with_special_characters( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with special characters and unicode content. + + This test verifies: + - Task handles special characters correctly + - Unicode content is processed properly + - No encoding issues occur + """ + # Arrange: Create segment with special characters + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + special_content = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?`~" + unicode_content = "Unicode: 中文测试 🚀 🌟 💻" + mixed_content = special_content + "\n" + unicode_content + + segment = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=1, + content=mixed_content, + answer="Special answer: 🎯", + word_count=len(mixed_content.split()), + tokens=len(mixed_content.split()) * 2, + keywords=["special", "unicode", "test"], + index_node_id=str(uuid4()), + index_node_hash=str(uuid4()), + status="waiting", + created_by=account.id, + ) + db_session_with_containers.add(segment) + db_session_with_containers.commit() + + # Act: Execute the task + create_segment_to_index_task(segment.id) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + def test_create_segment_to_index_with_long_keywords( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with long keyword lists. + + This test verifies: + - Task handles long keyword lists + - Keywords parameter is properly processed + - No performance issues with large keyword sets + """ + # Arrange: Create segment with long keywords + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Create long keyword list + long_keywords = [f"keyword_{i}" for i in range(100)] + + # Act: Execute the task with long keywords + create_segment_to_index_task(segment.id, keywords=long_keywords) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + def test_create_segment_to_index_tenant_isolation( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with proper tenant isolation. + + This test verifies: + - Tasks are properly isolated by tenant + - No cross-tenant data access occurs + - Tenant boundaries are respected + """ + # Arrange: Create multiple tenants with segments + account1, tenant1 = self._create_test_account_and_tenant(db_session_with_containers) + account2, tenant2 = self._create_test_account_and_tenant(db_session_with_containers) + + dataset1, document1 = self._create_test_dataset_and_document( + db_session_with_containers, tenant1.id, account1.id + ) + dataset2, document2 = self._create_test_dataset_and_document( + db_session_with_containers, tenant2.id, account2.id + ) + + segment1 = self._create_test_segment( + db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status="waiting" + ) + segment2 = self._create_test_segment( + db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status="waiting" + ) + + # Act: Execute tasks for both tenants + create_segment_to_index_task(segment1.id) + create_segment_to_index_task(segment2.id) + + # Assert: Verify both segments processed independently + db_session_with_containers.refresh(segment1) + db_session_with_containers.refresh(segment2) + + assert segment1.status == "completed" + assert segment2.status == "completed" + assert segment1.tenant_id == tenant1.id + assert segment2.tenant_id == tenant2.id + assert segment1.tenant_id != segment2.tenant_id + + def test_create_segment_to_index_with_none_keywords( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test segment indexing with None keywords parameter. + + This test verifies: + - Task handles None keywords gracefully + - Default behavior works correctly + - No errors occur with None parameters + """ + # Arrange: Create test data + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + + # Act: Execute the task with None keywords + create_segment_to_index_task(segment.id, keywords=None) + + # Assert: Verify successful indexing + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + + # Verify index processor was called + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(dataset.doc_form) + mock_external_service_dependencies["index_processor"].load.assert_called_once() + + def test_create_segment_to_index_comprehensive_integration( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Comprehensive integration test covering multiple scenarios. + + This test verifies: + - Complete workflow from creation to completion + - All components work together correctly + - End-to-end functionality is maintained + - Performance and reliability under normal conditions + """ + # Arrange: Create comprehensive test scenario + account, tenant = self._create_test_account_and_tenant(db_session_with_containers) + dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) + + # Create multiple segments with different characteristics + segments = [] + for i in range(5): + segment = self._create_test_segment( + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + ) + segments.append(segment) + + # Act: Process all segments + start_time = time.time() + for segment in segments: + create_segment_to_index_task(segment.id) + end_time = time.time() + + # Assert: Verify comprehensive success + total_time = end_time - start_time + assert total_time < 25.0 # Should complete all within 25 seconds + + # Verify all segments processed successfully + for segment in segments: + db_session_with_containers.refresh(segment) + assert segment.status == "completed" + assert segment.indexing_at is not None + assert segment.completed_at is not None + assert segment.error is None + + # Verify index processor was called for each segment + expected_calls = len(segments) + assert mock_external_service_dependencies["index_processor_factory"].call_count == expected_calls + + # Verify Redis cleanup for each segment + for segment in segments: + cache_key = f"segment_{segment.id}_indexing" + assert redis_client.exists(cache_key) == 0 From fecdb9554d2b774072e4b7aab9244350d38dca60 Mon Sep 17 00:00:00 2001 From: Will Date: Wed, 10 Sep 2025 11:31:16 +0800 Subject: [PATCH 13/86] fix: inner_api get_user_tenant (#25462) --- api/controllers/inner_api/plugin/wraps.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 18b530f2c4..bde0150ffd 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -75,9 +75,6 @@ def get_user_tenant(view: Optional[Callable[P, R]] = None): if not user_id: user_id = DEFAULT_SERVICE_API_USER_ID - del kwargs["tenant_id"] - del kwargs["user_id"] - try: tenant_model = ( db.session.query(Tenant) From 26a9abef6480eedf402fb781b0b0f992bfc73150 Mon Sep 17 00:00:00 2001 From: GuanMu Date: Wed, 10 Sep 2025 11:36:22 +0800 Subject: [PATCH 14/86] test: imporve (#25461) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/tests/unit_tests/libs/test_email_i18n.py | 37 ++++++ .../unit_tests/libs/test_external_api.py | 122 ++++++++++++++++++ api/tests/unit_tests/libs/test_oauth_base.py | 19 +++ .../unit_tests/libs/test_sendgrid_client.py | 53 ++++++++ api/tests/unit_tests/libs/test_smtp_client.py | 100 ++++++++++++++ 5 files changed, 331 insertions(+) create mode 100644 api/tests/unit_tests/libs/test_external_api.py create mode 100644 api/tests/unit_tests/libs/test_oauth_base.py create mode 100644 api/tests/unit_tests/libs/test_sendgrid_client.py create mode 100644 api/tests/unit_tests/libs/test_smtp_client.py diff --git a/api/tests/unit_tests/libs/test_email_i18n.py b/api/tests/unit_tests/libs/test_email_i18n.py index b80c711cac..962a36fe03 100644 --- a/api/tests/unit_tests/libs/test_email_i18n.py +++ b/api/tests/unit_tests/libs/test_email_i18n.py @@ -246,6 +246,43 @@ class TestEmailI18nService: sent_email = mock_sender.sent_emails[0] assert sent_email["subject"] == "Reset Your Dify Password" + def test_subject_format_keyerror_fallback_path( + self, + mock_renderer: MockEmailRenderer, + mock_sender: MockEmailSender, + ): + """Trigger subject KeyError and cover except branch.""" + # Config with subject that references an unknown key (no {application_title} to avoid second format) + config = EmailI18nConfig( + templates={ + EmailType.INVITE_MEMBER: { + EmailLanguage.EN_US: EmailTemplate( + subject="Invite: {unknown_placeholder}", + template_path="invite_member_en.html", + branded_template_path="branded/invite_member_en.html", + ), + } + } + ) + branding_service = MockBrandingService(enabled=False) + service = EmailI18nService( + config=config, + renderer=mock_renderer, + branding_service=branding_service, + sender=mock_sender, + ) + + # Will raise KeyError on subject.format(**full_context), then hit except branch and skip fallback + service.send_email( + email_type=EmailType.INVITE_MEMBER, + language_code="en-US", + to="test@example.com", + ) + + assert len(mock_sender.sent_emails) == 1 + # Subject is left unformatted due to KeyError fallback path without application_title + assert mock_sender.sent_emails[0]["subject"] == "Invite: {unknown_placeholder}" + def test_send_change_email_old_phase( self, email_config: EmailI18nConfig, diff --git a/api/tests/unit_tests/libs/test_external_api.py b/api/tests/unit_tests/libs/test_external_api.py new file mode 100644 index 0000000000..a9edb913ea --- /dev/null +++ b/api/tests/unit_tests/libs/test_external_api.py @@ -0,0 +1,122 @@ +from flask import Blueprint, Flask +from flask_restx import Resource +from werkzeug.exceptions import BadRequest, Unauthorized + +from core.errors.error import AppInvokeQuotaExceededError +from libs.external_api import ExternalApi + + +def _create_api_app(): + app = Flask(__name__) + bp = Blueprint("t", __name__) + api = ExternalApi(bp) + + @api.route("/bad-request") + class Bad(Resource): # type: ignore + def get(self): # type: ignore + raise BadRequest("invalid input") + + @api.route("/unauth") + class Unauth(Resource): # type: ignore + def get(self): # type: ignore + raise Unauthorized("auth required") + + @api.route("/value-error") + class ValErr(Resource): # type: ignore + def get(self): # type: ignore + raise ValueError("boom") + + @api.route("/quota") + class Quota(Resource): # type: ignore + def get(self): # type: ignore + raise AppInvokeQuotaExceededError("quota exceeded") + + @api.route("/general") + class Gen(Resource): # type: ignore + def get(self): # type: ignore + raise RuntimeError("oops") + + # Note: We avoid altering default_mediatype to keep normal error paths + + # Special 400 message rewrite + @api.route("/json-empty") + class JsonEmpty(Resource): # type: ignore + def get(self): # type: ignore + e = BadRequest() + # Force the specific message the handler rewrites + e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)" + raise e + + # 400 mapping payload path + @api.route("/param-errors") + class ParamErrors(Resource): # type: ignore + def get(self): # type: ignore + e = BadRequest() + # Coerce a mapping description to trigger param error shaping + e.description = {"field": "is required"} # type: ignore[assignment] + raise e + + app.register_blueprint(bp, url_prefix="/api") + return app + + +def test_external_api_error_handlers_basic_paths(): + app = _create_api_app() + client = app.test_client() + + # 400 + res = client.get("/api/bad-request") + assert res.status_code == 400 + data = res.get_json() + assert data["code"] == "bad_request" + assert data["status"] == 400 + + # 401 + res = client.get("/api/unauth") + assert res.status_code == 401 + assert "WWW-Authenticate" in res.headers + + # 400 ValueError + res = client.get("/api/value-error") + assert res.status_code == 400 + assert res.get_json()["code"] == "invalid_param" + + # 500 general + res = client.get("/api/general") + assert res.status_code == 500 + assert res.get_json()["status"] == 500 + + +def test_external_api_json_message_and_bad_request_rewrite(): + app = _create_api_app() + client = app.test_client() + + # JSON empty special rewrite + res = client.get("/api/json-empty") + assert res.status_code == 400 + assert res.get_json()["message"] == "Invalid JSON payload received or JSON payload is empty." + + +def test_external_api_param_mapping_and_quota_and_exc_info_none(): + # Force exc_info() to return (None,None,None) only during request + import libs.external_api as ext + + orig_exc_info = ext.sys.exc_info + try: + ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment] + + app = _create_api_app() + client = app.test_client() + + # Param errors mapping payload path + res = client.get("/api/param-errors") + assert res.status_code == 400 + data = res.get_json() + assert data["code"] == "invalid_param" + assert data["params"] == "field" + + # Quota path — depending on Flask-RESTX internals it may be handled + res = client.get("/api/quota") + assert res.status_code in (400, 429) + finally: + ext.sys.exc_info = orig_exc_info # type: ignore[assignment] diff --git a/api/tests/unit_tests/libs/test_oauth_base.py b/api/tests/unit_tests/libs/test_oauth_base.py new file mode 100644 index 0000000000..3e0c235fff --- /dev/null +++ b/api/tests/unit_tests/libs/test_oauth_base.py @@ -0,0 +1,19 @@ +import pytest + +from libs.oauth import OAuth + + +def test_oauth_base_methods_raise_not_implemented(): + oauth = OAuth(client_id="id", client_secret="sec", redirect_uri="uri") + + with pytest.raises(NotImplementedError): + oauth.get_authorization_url() + + with pytest.raises(NotImplementedError): + oauth.get_access_token("code") + + with pytest.raises(NotImplementedError): + oauth.get_raw_user_info("token") + + with pytest.raises(NotImplementedError): + oauth._transform_user_info({}) # type: ignore[name-defined] diff --git a/api/tests/unit_tests/libs/test_sendgrid_client.py b/api/tests/unit_tests/libs/test_sendgrid_client.py new file mode 100644 index 0000000000..85744003c7 --- /dev/null +++ b/api/tests/unit_tests/libs/test_sendgrid_client.py @@ -0,0 +1,53 @@ +from unittest.mock import MagicMock, patch + +import pytest +from python_http_client.exceptions import UnauthorizedError + +from libs.sendgrid import SendGridClient + + +def _mail(to: str = "user@example.com") -> dict: + return {"to": to, "subject": "Hi", "html": "Hi"} + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_success(mock_client_cls: MagicMock): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + # nested attribute access: client.mail.send.post + mock_client.client.mail.send.post.return_value = MagicMock(status_code=202, body=b"", headers={}) + + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + sg.send(_mail()) + + mock_client_cls.assert_called_once() + mock_client.client.mail.send.post.assert_called_once() + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_missing_to_raises(mock_client_cls: MagicMock): + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + with pytest.raises(ValueError): + sg.send(_mail(to="")) + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_auth_errors_reraise(mock_client_cls: MagicMock): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + mock_client.client.mail.send.post.side_effect = UnauthorizedError(401, "Unauthorized", b"{}", {}) + + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + with pytest.raises(UnauthorizedError): + sg.send(_mail()) + + +@patch("libs.sendgrid.sendgrid.SendGridAPIClient") +def test_sendgrid_timeout_reraise(mock_client_cls: MagicMock): + mock_client = MagicMock() + mock_client_cls.return_value = mock_client + mock_client.client.mail.send.post.side_effect = TimeoutError("timeout") + + sg = SendGridClient(sendgrid_api_key="key", _from="noreply@example.com") + with pytest.raises(TimeoutError): + sg.send(_mail()) diff --git a/api/tests/unit_tests/libs/test_smtp_client.py b/api/tests/unit_tests/libs/test_smtp_client.py new file mode 100644 index 0000000000..fcee01ca00 --- /dev/null +++ b/api/tests/unit_tests/libs/test_smtp_client.py @@ -0,0 +1,100 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from libs.smtp import SMTPClient + + +def _mail() -> dict: + return {"to": "user@example.com", "subject": "Hi", "html": "Hi"} + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_plain_success(mock_smtp_cls: MagicMock): + mock_smtp = MagicMock() + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com") + client.send(_mail()) + + mock_smtp_cls.assert_called_once_with("smtp.example.com", 25, timeout=10) + mock_smtp.sendmail.assert_called_once() + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock): + mock_smtp = MagicMock() + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient( + server="smtp.example.com", + port=587, + username="user", + password="pass", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=True, + ) + client.send(_mail()) + + mock_smtp_cls.assert_called_once_with("smtp.example.com", 587, timeout=10) + assert mock_smtp.ehlo.call_count == 2 + mock_smtp.starttls.assert_called_once() + mock_smtp.login.assert_called_once_with("user", "pass") + mock_smtp.sendmail.assert_called_once() + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP_SSL") +def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock): + # Cover SMTP_SSL branch and TimeoutError handling + mock_smtp = MagicMock() + mock_smtp.sendmail.side_effect = TimeoutError("timeout") + mock_smtp_ssl_cls.return_value = mock_smtp + + client = SMTPClient( + server="smtp.example.com", + port=465, + username="", + password="", + _from="noreply@example.com", + use_tls=True, + opportunistic_tls=False, + ) + with pytest.raises(TimeoutError): + client.send(_mail()) + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock): + mock_smtp = MagicMock() + mock_smtp.sendmail.side_effect = RuntimeError("oops") + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com") + with pytest.raises(RuntimeError): + client.send(_mail()) + mock_smtp.quit.assert_called_once() + + +@patch("libs.smtp.smtplib.SMTP") +def test_smtp_smtplib_exception_in_login(mock_smtp_cls: MagicMock): + # Ensure we hit the specific SMTPException except branch + import smtplib + + mock_smtp = MagicMock() + mock_smtp.login.side_effect = smtplib.SMTPException("login-fail") + mock_smtp_cls.return_value = mock_smtp + + client = SMTPClient( + server="smtp.example.com", + port=25, + username="user", # non-empty to trigger login + password="pass", + _from="noreply@example.com", + ) + with pytest.raises(smtplib.SMTPException): + client.send(_mail()) + mock_smtp.quit.assert_called_once() From b51c724a9409ad79fec19f318dc69040fb92c9fe Mon Sep 17 00:00:00 2001 From: Guangdong Liu Date: Wed, 10 Sep 2025 12:15:47 +0800 Subject: [PATCH 15/86] refactor: Migrate part of the console basic API module to Flask-RESTX (#24732) Signed-off-by: -LAN- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: -LAN- --- api/controllers/console/__init__.py | 57 ++++++--- api/controllers/console/admin.py | 34 ++++- api/controllers/console/apikey.py | 62 ++++++++-- api/controllers/console/auth/activate.py | 78 ++++++++---- .../console/auth/data_source_oauth.py | 58 +++++++-- .../console/auth/forgot_password.py | 79 ++++++++++-- api/controllers/console/auth/oauth.py | 24 +++- api/controllers/console/extension.py | 64 ++++++++-- api/controllers/console/feature.py | 26 +++- api/controllers/console/init_validate.py | 34 ++++- api/controllers/console/ping.py | 19 +-- api/controllers/console/setup.py | 42 ++++++- api/controllers/console/version.py | 30 ++++- .../console/workspace/agent_providers.py | 25 +++- api/controllers/console/workspace/endpoint.py | 116 ++++++++++++++++-- api/controllers/files/__init__.py | 1 - api/controllers/inner_api/__init__.py | 1 - api/controllers/mcp/__init__.py | 1 - api/controllers/service_api/__init__.py | 1 - api/controllers/web/__init__.py | 1 - api/controllers/web/audio.py | 20 ++- api/controllers/web/completion.py | 44 ++++--- api/controllers/web/conversation.py | 113 +++++++++++++++-- api/controllers/web/message.py | 96 +++++++++++++-- api/controllers/web/saved_message.py | 61 ++++++++- api/controllers/web/site.py | 12 +- api/controllers/web/workflow.py | 24 ++-- 27 files changed, 917 insertions(+), 206 deletions(-) diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 9a8e840554..1400ee7085 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -1,4 +1,5 @@ from flask import Blueprint +from flask_restx import Namespace from libs.external_api import ExternalApi @@ -26,7 +27,16 @@ from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi bp = Blueprint("console", __name__, url_prefix="/console/api") -api = ExternalApi(bp) + +api = ExternalApi( + bp, + version="1.0", + title="Console API", + description="Console management APIs for app configuration, monitoring, and administration", +) + +# Create namespace +console_ns = Namespace("console", description="Console management API operations", path="/") # File api.add_resource(FileApi, "/files/upload") @@ -43,7 +53,16 @@ api.add_resource(AppImportConfirmApi, "/apps/imports//confirm" api.add_resource(AppImportCheckDependenciesApi, "/apps/imports//check-dependencies") # Import other controllers -from . import admin, apikey, extension, feature, ping, setup, version # pyright: ignore[reportUnusedImport] +from . import ( + admin, # pyright: ignore[reportUnusedImport] + apikey, # pyright: ignore[reportUnusedImport] + extension, # pyright: ignore[reportUnusedImport] + feature, # pyright: ignore[reportUnusedImport] + init_validate, # pyright: ignore[reportUnusedImport] + ping, # pyright: ignore[reportUnusedImport] + setup, # pyright: ignore[reportUnusedImport] + version, # pyright: ignore[reportUnusedImport] +) # Import app controllers from .app import ( @@ -103,6 +122,23 @@ from .explore import ( saved_message, # pyright: ignore[reportUnusedImport] ) +# Import tag controllers +from .tag import tags # pyright: ignore[reportUnusedImport] + +# Import workspace controllers +from .workspace import ( + account, # pyright: ignore[reportUnusedImport] + agent_providers, # pyright: ignore[reportUnusedImport] + endpoint, # pyright: ignore[reportUnusedImport] + load_balancing_config, # pyright: ignore[reportUnusedImport] + members, # pyright: ignore[reportUnusedImport] + model_providers, # pyright: ignore[reportUnusedImport] + models, # pyright: ignore[reportUnusedImport] + plugin, # pyright: ignore[reportUnusedImport] + tool_providers, # pyright: ignore[reportUnusedImport] + workspace, # pyright: ignore[reportUnusedImport] +) + # Explore Audio api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") @@ -174,19 +210,4 @@ api.add_resource( InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" ) -# Import tag controllers -from .tag import tags # pyright: ignore[reportUnusedImport] - -# Import workspace controllers -from .workspace import ( - account, # pyright: ignore[reportUnusedImport] - agent_providers, # pyright: ignore[reportUnusedImport] - endpoint, # pyright: ignore[reportUnusedImport] - load_balancing_config, # pyright: ignore[reportUnusedImport] - members, # pyright: ignore[reportUnusedImport] - model_providers, # pyright: ignore[reportUnusedImport] - models, # pyright: ignore[reportUnusedImport] - plugin, # pyright: ignore[reportUnusedImport] - tool_providers, # pyright: ignore[reportUnusedImport] - workspace, # pyright: ignore[reportUnusedImport] -) +api.add_namespace(console_ns) diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 1306efacf4..93f242ad28 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -3,7 +3,7 @@ from functools import wraps from typing import ParamSpec, TypeVar from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized @@ -12,7 +12,7 @@ P = ParamSpec("P") R = TypeVar("R") from configs import dify_config from constants.languages import supported_language -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import only_edition_cloud from extensions.ext_database import db from models.model import App, InstalledApp, RecommendedApp @@ -45,7 +45,28 @@ def admin_required(view: Callable[P, R]): return decorated +@console_ns.route("/admin/insert-explore-apps") class InsertExploreAppListApi(Resource): + @api.doc("insert_explore_app") + @api.doc(description="Insert or update an app in the explore list") + @api.expect( + api.model( + "InsertExploreAppRequest", + { + "app_id": fields.String(required=True, description="Application ID"), + "desc": fields.String(description="App description"), + "copyright": fields.String(description="Copyright information"), + "privacy_policy": fields.String(description="Privacy policy"), + "custom_disclaimer": fields.String(description="Custom disclaimer"), + "language": fields.String(required=True, description="Language code"), + "category": fields.String(required=True, description="App category"), + "position": fields.Integer(required=True, description="Display position"), + }, + ) + ) + @api.response(200, "App updated successfully") + @api.response(201, "App inserted successfully") + @api.response(404, "App not found") @only_edition_cloud @admin_required def post(self): @@ -115,7 +136,12 @@ class InsertExploreAppListApi(Resource): return {"result": "success"}, 200 +@console_ns.route("/admin/insert-explore-apps/") class InsertExploreAppApi(Resource): + @api.doc("delete_explore_app") + @api.doc(description="Remove an app from the explore list") + @api.doc(params={"app_id": "Application ID to remove"}) + @api.response(204, "App removed successfully") @only_edition_cloud @admin_required def delete(self, app_id): @@ -152,7 +178,3 @@ class InsertExploreAppApi(Resource): db.session.commit() return {"result": "success"}, 204 - - -api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps") -api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/") diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 58a1d437d1..06de2fa6b6 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -14,7 +14,7 @@ from libs.login import login_required from models.dataset import Dataset from models.model import ApiToken, App -from . import api +from . import api, console_ns from .wraps import account_initialization_required, setup_required api_key_fields = { @@ -135,7 +135,25 @@ class BaseApiKeyResource(Resource): return {"result": "success"}, 204 +@console_ns.route("/apps//api-keys") class AppApiKeyListResource(BaseApiKeyListResource): + @api.doc("get_app_api_keys") + @api.doc(description="Get all API keys for an app") + @api.doc(params={"resource_id": "App ID"}) + @api.response(200, "Success", api_key_list) + def get(self, resource_id): + """Get all API keys for an app""" + return super().get(resource_id) + + @api.doc("create_app_api_key") + @api.doc(description="Create a new API key for an app") + @api.doc(params={"resource_id": "App ID"}) + @api.response(201, "API key created successfully", api_key_fields) + @api.response(400, "Maximum keys exceeded") + def post(self, resource_id): + """Create a new API key for an app""" + return super().post(resource_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -147,7 +165,16 @@ class AppApiKeyListResource(BaseApiKeyListResource): token_prefix = "app-" +@console_ns.route("/apps//api-keys/") class AppApiKeyResource(BaseApiKeyResource): + @api.doc("delete_app_api_key") + @api.doc(description="Delete an API key for an app") + @api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"}) + @api.response(204, "API key deleted successfully") + def delete(self, resource_id, api_key_id): + """Delete an API key for an app""" + return super().delete(resource_id, api_key_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -158,7 +185,25 @@ class AppApiKeyResource(BaseApiKeyResource): resource_id_field = "app_id" +@console_ns.route("/datasets//api-keys") class DatasetApiKeyListResource(BaseApiKeyListResource): + @api.doc("get_dataset_api_keys") + @api.doc(description="Get all API keys for a dataset") + @api.doc(params={"resource_id": "Dataset ID"}) + @api.response(200, "Success", api_key_list) + def get(self, resource_id): + """Get all API keys for a dataset""" + return super().get(resource_id) + + @api.doc("create_dataset_api_key") + @api.doc(description="Create a new API key for a dataset") + @api.doc(params={"resource_id": "Dataset ID"}) + @api.response(201, "API key created successfully", api_key_fields) + @api.response(400, "Maximum keys exceeded") + def post(self, resource_id): + """Create a new API key for a dataset""" + return super().post(resource_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -170,7 +215,16 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): token_prefix = "ds-" +@console_ns.route("/datasets//api-keys/") class DatasetApiKeyResource(BaseApiKeyResource): + @api.doc("delete_dataset_api_key") + @api.doc(description="Delete an API key for a dataset") + @api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"}) + @api.response(204, "API key deleted successfully") + def delete(self, resource_id, api_key_id): + """Delete an API key for a dataset""" + return super().delete(resource_id, api_key_id) + def after_request(self, resp): resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers["Access-Control-Allow-Credentials"] = "true" @@ -179,9 +233,3 @@ class DatasetApiKeyResource(BaseApiKeyResource): resource_type = "dataset" resource_model = Dataset resource_id_field = "dataset_id" - - -api.add_resource(AppApiKeyListResource, "/apps//api-keys") -api.add_resource(AppApiKeyResource, "/apps//api-keys/") -api.add_resource(DatasetApiKeyListResource, "/datasets//api-keys") -api.add_resource(DatasetApiKeyResource, "/datasets//api-keys/") diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index e82e403ec2..8cdadfb03c 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,8 +1,8 @@ from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from constants.languages import supported_language -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.error import AlreadyActivateError from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -10,14 +10,36 @@ from libs.helper import StrLen, email, extract_remote_ip, timezone from models.account import AccountStatus from services.account_service import AccountService, RegisterService +active_check_parser = reqparse.RequestParser() +active_check_parser.add_argument( + "workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID" +) +active_check_parser.add_argument( + "email", type=email, required=False, nullable=True, location="args", help="Email address" +) +active_check_parser.add_argument( + "token", type=str, required=True, nullable=False, location="args", help="Activation token" +) + +@console_ns.route("/activate/check") class ActivateCheckApi(Resource): + @api.doc("check_activation_token") + @api.doc(description="Check if activation token is valid") + @api.expect(active_check_parser) + @api.response( + 200, + "Success", + api.model( + "ActivationCheckResponse", + { + "is_valid": fields.Boolean(description="Whether token is valid"), + "data": fields.Raw(description="Activation data if valid"), + }, + ), + ) def get(self): - parser = reqparse.RequestParser() - parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args") - parser.add_argument("email", type=email, required=False, nullable=True, location="args") - parser.add_argument("token", type=str, required=True, nullable=False, location="args") - args = parser.parse_args() + args = active_check_parser.parse_args() workspaceId = args["workspace_id"] reg_email = args["email"] @@ -38,18 +60,36 @@ class ActivateCheckApi(Resource): return {"is_valid": False} +active_parser = reqparse.RequestParser() +active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") +active_parser.add_argument("email", type=email, required=False, nullable=True, location="json") +active_parser.add_argument("token", type=str, required=True, nullable=False, location="json") +active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") +active_parser.add_argument( + "interface_language", type=supported_language, required=True, nullable=False, location="json" +) +active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") + + +@console_ns.route("/activate") class ActivateApi(Resource): + @api.doc("activate_account") + @api.doc(description="Activate account with invitation token") + @api.expect(active_parser) + @api.response( + 200, + "Account activated successfully", + api.model( + "ActivationResponse", + { + "result": fields.String(description="Operation result"), + "data": fields.Raw(description="Login token data"), + }, + ), + ) + @api.response(400, "Already activated or invalid token") def post(self): - parser = reqparse.RequestParser() - parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") - parser.add_argument("email", type=email, required=False, nullable=True, location="json") - parser.add_argument("token", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") - parser.add_argument( - "interface_language", type=supported_language, required=True, nullable=False, location="json" - ) - parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") - args = parser.parse_args() + args = active_parser.parse_args() invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) if invitation is None: @@ -70,7 +110,3 @@ class ActivateApi(Resource): token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) return {"result": "success", "data": token_pair.model_dump()} - - -api.add_resource(ActivateCheckApi, "/activate/check") -api.add_resource(ActivateApi, "/activate") diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 8f57b3d03e..fc4ba3a2c7 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -3,11 +3,11 @@ import logging import requests from flask import current_app, redirect, request from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields from werkzeug.exceptions import Forbidden from configs import dify_config -from controllers.console import api +from controllers.console import api, console_ns from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -28,7 +28,21 @@ def get_oauth_providers(): return OAUTH_PROVIDERS +@console_ns.route("/oauth/data-source/") class OAuthDataSource(Resource): + @api.doc("oauth_data_source") + @api.doc(description="Get OAuth authorization URL for data source provider") + @api.doc(params={"provider": "Data source provider name (notion)"}) + @api.response( + 200, + "Authorization URL or internal setup success", + api.model( + "OAuthDataSourceResponse", + {"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")}, + ), + ) + @api.response(400, "Invalid provider") + @api.response(403, "Admin privileges required") def get(self, provider: str): # The role of the current user in the table must be admin or owner if not current_user.is_admin_or_owner: @@ -49,7 +63,19 @@ class OAuthDataSource(Resource): return {"data": auth_url}, 200 +@console_ns.route("/oauth/data-source/callback/") class OAuthDataSourceCallback(Resource): + @api.doc("oauth_data_source_callback") + @api.doc(description="Handle OAuth callback from data source provider") + @api.doc( + params={ + "provider": "Data source provider name (notion)", + "code": "Authorization code from OAuth provider", + "error": "Error message from OAuth provider", + } + ) + @api.response(302, "Redirect to console with result") + @api.response(400, "Invalid provider") def get(self, provider: str): OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() with current_app.app_context(): @@ -68,7 +94,19 @@ class OAuthDataSourceCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") +@console_ns.route("/oauth/data-source/binding/") class OAuthDataSourceBinding(Resource): + @api.doc("oauth_data_source_binding") + @api.doc(description="Bind OAuth data source with authorization code") + @api.doc( + params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"} + ) + @api.response( + 200, + "Data source binding success", + api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Invalid provider or code") def get(self, provider: str): OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() with current_app.app_context(): @@ -90,7 +128,17 @@ class OAuthDataSourceBinding(Resource): return {"result": "success"}, 200 +@console_ns.route("/oauth/data-source///sync") class OAuthDataSourceSync(Resource): + @api.doc("oauth_data_source_sync") + @api.doc(description="Sync data from OAuth data source") + @api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"}) + @api.response( + 200, + "Data source sync success", + api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Invalid provider or sync failed") @setup_required @login_required @account_initialization_required @@ -111,9 +159,3 @@ class OAuthDataSourceSync(Resource): return {"error": "OAuth data source process failed"}, 400 return {"result": "success"}, 200 - - -api.add_resource(OAuthDataSource, "/oauth/data-source/") -api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") -api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") -api.add_resource(OAuthDataSourceSync, "/oauth/data-source///sync") diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index ede0696854..7f34adc0f3 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,12 +2,12 @@ import base64 import secrets from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from sqlalchemy import select from sqlalchemy.orm import Session from constants.languages import languages -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.auth.error import ( EmailCodeError, EmailPasswordResetLimitError, @@ -28,7 +28,32 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces from services.feature_service import FeatureService +@console_ns.route("/forgot-password") class ForgotPasswordSendEmailApi(Resource): + @api.doc("send_forgot_password_email") + @api.doc(description="Send password reset email") + @api.expect( + api.model( + "ForgotPasswordEmailRequest", + { + "email": fields.String(required=True, description="Email address"), + "language": fields.String(description="Language for email (zh-Hans/en-US)"), + }, + ) + ) + @api.response( + 200, + "Email sent successfully", + api.model( + "ForgotPasswordEmailResponse", + { + "result": fields.String(description="Operation result"), + "data": fields.String(description="Reset token"), + "code": fields.String(description="Error code if account not found"), + }, + ), + ) + @api.response(400, "Invalid email or rate limit exceeded") @setup_required @email_password_login_enabled def post(self): @@ -61,7 +86,33 @@ class ForgotPasswordSendEmailApi(Resource): return {"result": "success", "data": token} +@console_ns.route("/forgot-password/validity") class ForgotPasswordCheckApi(Resource): + @api.doc("check_forgot_password_code") + @api.doc(description="Verify password reset code") + @api.expect( + api.model( + "ForgotPasswordCheckRequest", + { + "email": fields.String(required=True, description="Email address"), + "code": fields.String(required=True, description="Verification code"), + "token": fields.String(required=True, description="Reset token"), + }, + ) + ) + @api.response( + 200, + "Code verified successfully", + api.model( + "ForgotPasswordCheckResponse", + { + "is_valid": fields.Boolean(description="Whether code is valid"), + "email": fields.String(description="Email address"), + "token": fields.String(description="New reset token"), + }, + ), + ) + @api.response(400, "Invalid code or token") @setup_required @email_password_login_enabled def post(self): @@ -100,7 +151,26 @@ class ForgotPasswordCheckApi(Resource): return {"is_valid": True, "email": token_data.get("email"), "token": new_token} +@console_ns.route("/forgot-password/resets") class ForgotPasswordResetApi(Resource): + @api.doc("reset_password") + @api.doc(description="Reset password with verification token") + @api.expect( + api.model( + "ForgotPasswordResetRequest", + { + "token": fields.String(required=True, description="Verification token"), + "new_password": fields.String(required=True, description="New password"), + "password_confirm": fields.String(required=True, description="Password confirmation"), + }, + ) + ) + @api.response( + 200, + "Password reset successfully", + api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Invalid token or password mismatch") @setup_required @email_password_login_enabled def post(self): @@ -172,8 +242,3 @@ class ForgotPasswordResetApi(Resource): pass except AccountRegisterError: raise AccountInFreezeError() - - -api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") -api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") -api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 06151ee39b..c3c9de1589 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -22,7 +22,7 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.feature_service import FeatureService -from .. import api +from .. import api, console_ns logger = logging.getLogger(__name__) @@ -50,7 +50,13 @@ def get_oauth_providers(): return OAUTH_PROVIDERS +@console_ns.route("/oauth/login/") class OAuthLogin(Resource): + @api.doc("oauth_login") + @api.doc(description="Initiate OAuth login process") + @api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"}) + @api.response(302, "Redirect to OAuth authorization URL") + @api.response(400, "Invalid provider") def get(self, provider: str): invite_token = request.args.get("invite_token") or None OAUTH_PROVIDERS = get_oauth_providers() @@ -63,7 +69,19 @@ class OAuthLogin(Resource): return redirect(auth_url) +@console_ns.route("/oauth/authorize/") class OAuthCallback(Resource): + @api.doc("oauth_callback") + @api.doc(description="Handle OAuth callback and complete login process") + @api.doc( + params={ + "provider": "OAuth provider name (github/google)", + "code": "Authorization code from OAuth provider", + "state": "Optional state parameter (used for invite token)", + } + ) + @api.response(302, "Redirect to console with access token") + @api.response(400, "OAuth process failed") def get(self, provider: str): OAUTH_PROVIDERS = get_oauth_providers() with current_app.app_context(): @@ -184,7 +202,3 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): AccountService.link_account_integrate(provider, user_info.id, account) return account - - -api.add_resource(OAuthLogin, "/oauth/login/") -api.add_resource(OAuthCallback, "/oauth/authorize/") diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index e157041c35..57f5ab191e 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,8 +1,8 @@ from flask_login import current_user -from flask_restx import Resource, marshal_with, reqparse +from flask_restx import Resource, fields, marshal_with, reqparse from constants import HIDDEN_VALUE -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields from libs.login import login_required @@ -11,7 +11,21 @@ from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService +@console_ns.route("/code-based-extension") class CodeBasedExtensionAPI(Resource): + @api.doc("get_code_based_extension") + @api.doc(description="Get code-based extension data by module name") + @api.expect( + api.parser().add_argument("module", type=str, required=True, location="args", help="Extension module name") + ) + @api.response( + 200, + "Success", + api.model( + "CodeBasedExtensionResponse", + {"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")}, + ), + ) @setup_required @login_required @account_initialization_required @@ -23,7 +37,11 @@ class CodeBasedExtensionAPI(Resource): return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} +@console_ns.route("/api-based-extension") class APIBasedExtensionAPI(Resource): + @api.doc("get_api_based_extensions") + @api.doc(description="Get all API-based extensions for current tenant") + @api.response(200, "Success", fields.List(fields.Nested(api_based_extension_fields))) @setup_required @login_required @account_initialization_required @@ -32,6 +50,19 @@ class APIBasedExtensionAPI(Resource): tenant_id = current_user.current_tenant_id return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + @api.doc("create_api_based_extension") + @api.doc(description="Create a new API-based extension") + @api.expect( + api.model( + "CreateAPIBasedExtensionRequest", + { + "name": fields.String(required=True, description="Extension name"), + "api_endpoint": fields.String(required=True, description="API endpoint URL"), + "api_key": fields.String(required=True, description="API key for authentication"), + }, + ) + ) + @api.response(201, "Extension created successfully", api_based_extension_fields) @setup_required @login_required @account_initialization_required @@ -53,7 +84,12 @@ class APIBasedExtensionAPI(Resource): return APIBasedExtensionService.save(extension_data) +@console_ns.route("/api-based-extension/") class APIBasedExtensionDetailAPI(Resource): + @api.doc("get_api_based_extension") + @api.doc(description="Get API-based extension by ID") + @api.doc(params={"id": "Extension ID"}) + @api.response(200, "Success", api_based_extension_fields) @setup_required @login_required @account_initialization_required @@ -64,6 +100,20 @@ class APIBasedExtensionDetailAPI(Resource): return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + @api.doc("update_api_based_extension") + @api.doc(description="Update API-based extension") + @api.doc(params={"id": "Extension ID"}) + @api.expect( + api.model( + "UpdateAPIBasedExtensionRequest", + { + "name": fields.String(required=True, description="Extension name"), + "api_endpoint": fields.String(required=True, description="API endpoint URL"), + "api_key": fields.String(required=True, description="API key for authentication"), + }, + ) + ) + @api.response(200, "Extension updated successfully", api_based_extension_fields) @setup_required @login_required @account_initialization_required @@ -88,6 +138,10 @@ class APIBasedExtensionDetailAPI(Resource): return APIBasedExtensionService.save(extension_data_from_db) + @api.doc("delete_api_based_extension") + @api.doc(description="Delete API-based extension") + @api.doc(params={"id": "Extension ID"}) + @api.response(204, "Extension deleted successfully") @setup_required @login_required @account_initialization_required @@ -100,9 +154,3 @@ class APIBasedExtensionDetailAPI(Resource): APIBasedExtensionService.delete(extension_data_from_db) return {"result": "success"}, 204 - - -api.add_resource(CodeBasedExtensionAPI, "/code-based-extension") - -api.add_resource(APIBasedExtensionAPI, "/api-based-extension") -api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/") diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 6236832d39..d43b839291 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,26 +1,40 @@ from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields from libs.login import login_required from services.feature_service import FeatureService -from . import api +from . import api, console_ns from .wraps import account_initialization_required, cloud_utm_record, setup_required +@console_ns.route("/features") class FeatureApi(Resource): + @api.doc("get_tenant_features") + @api.doc(description="Get feature configuration for current tenant") + @api.response( + 200, + "Success", + api.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}), + ) @setup_required @login_required @account_initialization_required @cloud_utm_record def get(self): + """Get feature configuration for current tenant""" return FeatureService.get_features(current_user.current_tenant_id).model_dump() +@console_ns.route("/system-features") class SystemFeatureApi(Resource): + @api.doc("get_system_features") + @api.doc(description="Get system-wide feature configuration") + @api.response( + 200, + "Success", + api.model("SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}), + ) def get(self): + """Get system-wide feature configuration""" return FeatureService.get_system_features().model_dump() - - -api.add_resource(FeatureApi, "/features") -api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 2a37b1708a..30b53458b2 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,7 +1,7 @@ import os from flask import session -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from sqlalchemy import select from sqlalchemy.orm import Session @@ -11,20 +11,47 @@ from libs.helper import StrLen from models.model import DifySetup from services.account_service import TenantService -from . import api +from . import api, console_ns from .error import AlreadySetupError, InitValidateFailedError from .wraps import only_edition_self_hosted +@console_ns.route("/init") class InitValidateAPI(Resource): + @api.doc("get_init_status") + @api.doc(description="Get initialization validation status") + @api.response( + 200, + "Success", + model=api.model( + "InitStatusResponse", + {"status": fields.String(description="Initialization status", enum=["finished", "not_started"])}, + ), + ) def get(self): + """Get initialization validation status""" init_status = get_init_validate_status() if init_status: return {"status": "finished"} return {"status": "not_started"} + @api.doc("validate_init_password") + @api.doc(description="Validate initialization password for self-hosted edition") + @api.expect( + api.model( + "InitValidateRequest", + {"password": fields.String(required=True, description="Initialization password", max_length=30)}, + ) + ) + @api.response( + 201, + "Success", + model=api.model("InitValidateResponse", {"result": fields.String(description="Operation result")}), + ) + @api.response(400, "Already setup or validation failed") @only_edition_self_hosted def post(self): + """Validate initialization password""" # is tenant created tenant_count = TenantService.get_tenant_count() if tenant_count > 0: @@ -52,6 +79,3 @@ def get_init_validate_status(): return db_session.execute(select(DifySetup)).scalar_one_or_none() return True - - -api.add_resource(InitValidateAPI, "/init") diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index 1a53a2347e..29f49b99de 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -1,14 +1,17 @@ -from flask_restx import Resource +from flask_restx import Resource, fields -from controllers.console import api +from . import api, console_ns +@console_ns.route("/ping") class PingApi(Resource): + @api.doc("health_check") + @api.doc(description="Health check endpoint for connection testing") + @api.response( + 200, + "Success", + api.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}), + ) def get(self): - """ - For connection health check - """ + """Health check endpoint for connection testing""" return {"result": "pong"} - - -api.add_resource(PingApi, "/ping") diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 8e230496f0..bff5fc1651 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,5 +1,5 @@ from flask import request -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from configs import dify_config from libs.helper import StrLen, email, extract_remote_ip @@ -7,23 +7,56 @@ from libs.password import valid_password from models.model import DifySetup, db from services.account_service import RegisterService, TenantService -from . import api +from . import api, console_ns from .error import AlreadySetupError, NotInitValidateError from .init_validate import get_init_validate_status from .wraps import only_edition_self_hosted +@console_ns.route("/setup") class SetupApi(Resource): + @api.doc("get_setup_status") + @api.doc(description="Get system setup status") + @api.response( + 200, + "Success", + api.model( + "SetupStatusResponse", + { + "step": fields.String(description="Setup step status", enum=["not_started", "finished"]), + "setup_at": fields.String(description="Setup completion time (ISO format)", required=False), + }, + ), + ) def get(self): + """Get system setup status""" if dify_config.EDITION == "SELF_HOSTED": setup_status = get_setup_status() - if setup_status: + # Check if setup_status is a DifySetup object rather than a bool + if setup_status and not isinstance(setup_status, bool): return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()} + elif setup_status: + return {"step": "finished"} return {"step": "not_started"} return {"step": "finished"} + @api.doc("setup_system") + @api.doc(description="Initialize system setup with admin account") + @api.expect( + api.model( + "SetupRequest", + { + "email": fields.String(required=True, description="Admin email address"), + "name": fields.String(required=True, description="Admin name (max 30 characters)"), + "password": fields.String(required=True, description="Admin password"), + }, + ) + ) + @api.response(201, "Success", api.model("SetupResponse", {"result": fields.String(description="Setup result")})) + @api.response(400, "Already setup or validation failed") @only_edition_self_hosted def post(self): + """Initialize system setup with admin account""" # is set up if get_setup_status(): raise AlreadySetupError() @@ -55,6 +88,3 @@ def get_setup_status(): return db.session.query(DifySetup).first() else: return True - - -api.add_resource(SetupApi, "/setup") diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 8409e7d1ab..8d081ad995 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -2,18 +2,41 @@ import json import logging import requests -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from packaging import version from configs import dify_config -from . import api +from . import api, console_ns logger = logging.getLogger(__name__) +@console_ns.route("/version") class VersionApi(Resource): + @api.doc("check_version_update") + @api.doc(description="Check for application version updates") + @api.expect( + api.parser().add_argument( + "current_version", type=str, required=True, location="args", help="Current application version" + ) + ) + @api.response( + 200, + "Success", + api.model( + "VersionResponse", + { + "version": fields.String(description="Latest version number"), + "release_date": fields.String(description="Release date of latest version"), + "release_notes": fields.String(description="Release notes for latest version"), + "can_auto_update": fields.Boolean(description="Whether auto-update is supported"), + "features": fields.Raw(description="Feature flags and capabilities"), + }, + ), + ) def get(self): + """Check for application version updates""" parser = reqparse.RequestParser() parser.add_argument("current_version", type=str, required=True, location="args") args = parser.parse_args() @@ -59,6 +82,3 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool: except version.InvalidVersion: logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version) return False - - -api.add_resource(VersionApi, "/version") diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 08bab6fcb5..0a2c8fcfb4 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,14 +1,22 @@ from flask_login import current_user -from flask_restx import Resource +from flask_restx import Resource, fields -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from libs.login import login_required from services.agent_service import AgentService +@console_ns.route("/workspaces/current/agent-providers") class AgentProviderListApi(Resource): + @api.doc("list_agent_providers") + @api.doc(description="Get list of available agent providers") + @api.response( + 200, + "Success", + fields.List(fields.Raw(description="Agent provider information")), + ) @setup_required @login_required @account_initialization_required @@ -21,7 +29,16 @@ class AgentProviderListApi(Resource): return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id)) +@console_ns.route("/workspaces/current/agent-provider/") class AgentProviderApi(Resource): + @api.doc("get_agent_provider") + @api.doc(description="Get specific agent provider details") + @api.doc(params={"provider_name": "Agent provider name"}) + @api.response( + 200, + "Success", + fields.Raw(description="Agent provider details"), + ) @setup_required @login_required @account_initialization_required @@ -30,7 +47,3 @@ class AgentProviderApi(Resource): user_id = user.id tenant_id = user.current_tenant_id return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name)) - - -api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers") -api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/") diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 96e873d42b..0657b764cc 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,8 +1,8 @@ from flask_login import current_user -from flask_restx import Resource, reqparse +from flask_restx import Resource, fields, reqparse from werkzeug.exceptions import Forbidden -from controllers.console import api +from controllers.console import api, console_ns from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginPermissionDeniedError @@ -10,7 +10,26 @@ from libs.login import login_required from services.plugin.endpoint_service import EndpointService +@console_ns.route("/workspaces/current/endpoints/create") class EndpointCreateApi(Resource): + @api.doc("create_endpoint") + @api.doc(description="Create a new plugin endpoint") + @api.expect( + api.model( + "EndpointCreateRequest", + { + "plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"), + "settings": fields.Raw(required=True, description="Endpoint settings"), + "name": fields.String(required=True, description="Endpoint name"), + }, + ) + ) + @api.response( + 200, + "Endpoint created successfully", + api.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -43,7 +62,20 @@ class EndpointCreateApi(Resource): raise ValueError(e.description) from e +@console_ns.route("/workspaces/current/endpoints/list") class EndpointListApi(Resource): + @api.doc("list_endpoints") + @api.doc(description="List plugin endpoints with pagination") + @api.expect( + api.parser() + .add_argument("page", type=int, required=True, location="args", help="Page number") + .add_argument("page_size", type=int, required=True, location="args", help="Page size") + ) + @api.response( + 200, + "Success", + api.model("EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}), + ) @setup_required @login_required @account_initialization_required @@ -70,7 +102,23 @@ class EndpointListApi(Resource): ) +@console_ns.route("/workspaces/current/endpoints/list/plugin") class EndpointListForSinglePluginApi(Resource): + @api.doc("list_plugin_endpoints") + @api.doc(description="List endpoints for a specific plugin") + @api.expect( + api.parser() + .add_argument("page", type=int, required=True, location="args", help="Page number") + .add_argument("page_size", type=int, required=True, location="args", help="Page size") + .add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID") + ) + @api.response( + 200, + "Success", + api.model( + "PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} + ), + ) @setup_required @login_required @account_initialization_required @@ -100,7 +148,19 @@ class EndpointListForSinglePluginApi(Resource): ) +@console_ns.route("/workspaces/current/endpoints/delete") class EndpointDeleteApi(Resource): + @api.doc("delete_endpoint") + @api.doc(description="Delete a plugin endpoint") + @api.expect( + api.model("EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) + ) + @api.response( + 200, + "Endpoint deleted successfully", + api.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -123,7 +183,26 @@ class EndpointDeleteApi(Resource): } +@console_ns.route("/workspaces/current/endpoints/update") class EndpointUpdateApi(Resource): + @api.doc("update_endpoint") + @api.doc(description="Update a plugin endpoint") + @api.expect( + api.model( + "EndpointUpdateRequest", + { + "endpoint_id": fields.String(required=True, description="Endpoint ID"), + "settings": fields.Raw(required=True, description="Updated settings"), + "name": fields.String(required=True, description="Updated name"), + }, + ) + ) + @api.response( + 200, + "Endpoint updated successfully", + api.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -154,7 +233,19 @@ class EndpointUpdateApi(Resource): } +@console_ns.route("/workspaces/current/endpoints/enable") class EndpointEnableApi(Resource): + @api.doc("enable_endpoint") + @api.doc(description="Enable a plugin endpoint") + @api.expect( + api.model("EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) + ) + @api.response( + 200, + "Endpoint enabled successfully", + api.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -177,7 +268,19 @@ class EndpointEnableApi(Resource): } +@console_ns.route("/workspaces/current/endpoints/disable") class EndpointDisableApi(Resource): + @api.doc("disable_endpoint") + @api.doc(description="Disable a plugin endpoint") + @api.expect( + api.model("EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")}) + ) + @api.response( + 200, + "Endpoint disabled successfully", + api.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}), + ) + @api.response(403, "Admin privileges required") @setup_required @login_required @account_initialization_required @@ -198,12 +301,3 @@ class EndpointDisableApi(Resource): tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id ) } - - -api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create") -api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list") -api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin") -api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete") -api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update") -api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable") -api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable") diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index a1b8bb7cfe..26fbf7097e 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="Files API", description="API for file operations including upload and preview", - doc="/docs", # Enable Swagger UI at /files/docs ) files_ns = Namespace("files", description="File operations", path="/") diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index b09c39309f..f29f624ba5 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="Inner API", description="Internal APIs for enterprise features, billing, and plugin communication", - doc="/docs", # Enable Swagger UI at /inner/api/docs ) # Create namespace diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py index 43b36a70b4..336a7801bb 100644 --- a/api/controllers/mcp/__init__.py +++ b/api/controllers/mcp/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="MCP API", description="API for Model Context Protocol operations", - doc="/docs", # Enable Swagger UI at /mcp/docs ) mcp_ns = Namespace("mcp", description="MCP operations", path="/") diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index d69f49d957..a6008fdb99 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="Service API", description="API for application services", - doc="/docs", # Enable Swagger UI at /v1/docs ) service_api_ns = Namespace("service_api", description="Service operations", path="/") diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index a825a2a0d8..97bcd3d53c 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -10,7 +10,6 @@ api = ExternalApi( version="1.0", title="Web API", description="Public APIs for web applications including file uploads, chat interactions, and app management", - doc="/docs", # Enable Swagger UI at /api/docs ) # Create namespace diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 2c0f6c9759..c1c46891b6 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -5,7 +5,7 @@ from flask_restx import fields, marshal_with, reqparse from werkzeug.exceptions import InternalServerError import services -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, AudioTooLargeError, @@ -32,15 +32,16 @@ from services.errors.audio import ( logger = logging.getLogger(__name__) +@web_ns.route("/audio-to-text") class AudioApi(WebApiResource): audio_to_text_response_fields = { "text": fields.String, } @marshal_with(audio_to_text_response_fields) - @api.doc("Audio to Text") - @api.doc(description="Convert audio file to text using speech-to-text service.") - @api.doc( + @web_ns.doc("Audio to Text") + @web_ns.doc(description="Convert audio file to text using speech-to-text service.") + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -85,6 +86,7 @@ class AudioApi(WebApiResource): raise InternalServerError() +@web_ns.route("/text-to-audio") class TextApi(WebApiResource): text_to_audio_response_fields = { "audio_url": fields.String, @@ -92,9 +94,9 @@ class TextApi(WebApiResource): } @marshal_with(text_to_audio_response_fields) - @api.doc("Text to Audio") - @api.doc(description="Convert text to audio using text-to-speech service.") - @api.doc( + @web_ns.doc("Text to Audio") + @web_ns.doc(description="Convert text to audio using text-to-speech service.") + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -145,7 +147,3 @@ class TextApi(WebApiResource): except Exception as e: logger.exception("Failed to handle post request to TextApi") raise InternalServerError() - - -api.add_resource(AudioApi, "/audio-to-text") -api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index a42bf5fc6e..67ae970388 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -4,7 +4,7 @@ from flask_restx import reqparse from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( AppUnavailableError, CompletionRequestError, @@ -35,10 +35,11 @@ logger = logging.getLogger(__name__) # define completion api for user +@web_ns.route("/completion-messages") class CompletionApi(WebApiResource): - @api.doc("Create Completion Message") - @api.doc(description="Create a completion message for text generation applications.") - @api.doc( + @web_ns.doc("Create Completion Message") + @web_ns.doc(description="Create a completion message for text generation applications.") + @web_ns.doc( params={ "inputs": {"description": "Input variables for the completion", "type": "object", "required": True}, "query": {"description": "Query text for completion", "type": "string", "required": False}, @@ -52,7 +53,7 @@ class CompletionApi(WebApiResource): "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -106,11 +107,12 @@ class CompletionApi(WebApiResource): raise InternalServerError() +@web_ns.route("/completion-messages//stop") class CompletionStopApi(WebApiResource): - @api.doc("Stop Completion Message") - @api.doc(description="Stop a running completion message task.") - @api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) - @api.doc( + @web_ns.doc("Stop Completion Message") + @web_ns.doc(description="Stop a running completion message task.") + @web_ns.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -129,10 +131,11 @@ class CompletionStopApi(WebApiResource): return {"result": "success"}, 200 +@web_ns.route("/chat-messages") class ChatApi(WebApiResource): - @api.doc("Create Chat Message") - @api.doc(description="Create a chat message for conversational applications.") - @api.doc( + @web_ns.doc("Create Chat Message") + @web_ns.doc(description="Create a chat message for conversational applications.") + @web_ns.doc( params={ "inputs": {"description": "Input variables for the chat", "type": "object", "required": True}, "query": {"description": "User query/message", "type": "string", "required": True}, @@ -148,7 +151,7 @@ class ChatApi(WebApiResource): "retriever_from": {"description": "Source of retriever", "type": "string", "required": False}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -207,11 +210,12 @@ class ChatApi(WebApiResource): raise InternalServerError() +@web_ns.route("/chat-messages//stop") class ChatStopApi(WebApiResource): - @api.doc("Stop Chat Message") - @api.doc(description="Stop a running chat message task.") - @api.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) - @api.doc( + @web_ns.doc("Stop Chat Message") + @web_ns.doc(description="Stop a running chat message task.") + @web_ns.doc(params={"task_id": {"description": "Task ID to stop", "type": "string", "required": True}}) + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -229,9 +233,3 @@ class ChatStopApi(WebApiResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {"result": "success"}, 200 - - -api.add_resource(CompletionApi, "/completion-messages") -api.add_resource(CompletionStopApi, "/completion-messages//stop") -api.add_resource(ChatApi, "/chat-messages") -api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 24de4f3f2e..03dd986aed 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -3,7 +3,7 @@ from flask_restx.inputs import int_range from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom @@ -16,7 +16,44 @@ from services.errors.conversation import ConversationNotExistsError, LastConvers from services.web_conversation_service import WebConversationService +@web_ns.route("/conversations") class ConversationListApi(WebApiResource): + @web_ns.doc("Get Conversation List") + @web_ns.doc(description="Retrieve paginated list of conversations for a chat application.") + @web_ns.doc( + params={ + "last_id": {"description": "Last conversation ID for pagination", "type": "string", "required": False}, + "limit": { + "description": "Number of conversations to return (1-100)", + "type": "integer", + "required": False, + "default": 20, + }, + "pinned": { + "description": "Filter by pinned status", + "type": "string", + "enum": ["true", "false"], + "required": False, + }, + "sort_by": { + "description": "Sort order", + "type": "string", + "enum": ["created_at", "-created_at", "updated_at", "-updated_at"], + "required": False, + "default": "-updated_at", + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "App Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) @@ -57,11 +94,25 @@ class ConversationListApi(WebApiResource): raise NotFound("Last Conversation Not Exists.") +@web_ns.route("/conversations/") class ConversationApi(WebApiResource): delete_response_fields = { "result": fields.String, } + @web_ns.doc("Delete Conversation") + @web_ns.doc(description="Delete a specific conversation.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 204: "Conversation deleted successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(delete_response_fields) def delete(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -76,7 +127,32 @@ class ConversationApi(WebApiResource): return {"result": "success"}, 204 +@web_ns.route("/conversations//name") class ConversationRenameApi(WebApiResource): + @web_ns.doc("Rename Conversation") + @web_ns.doc(description="Rename a specific conversation with a custom name or auto-generate one.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + params={ + "name": {"description": "New conversation name", "type": "string", "required": False}, + "auto_generate": { + "description": "Auto-generate conversation name", + "type": "boolean", + "required": False, + "default": False, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Conversation renamed successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -96,11 +172,25 @@ class ConversationRenameApi(WebApiResource): raise NotFound("Conversation Not Exists.") +@web_ns.route("/conversations//pin") class ConversationPinApi(WebApiResource): pin_response_fields = { "result": fields.String, } + @web_ns.doc("Pin Conversation") + @web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 200: "Conversation pinned successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(pin_response_fields) def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -117,11 +207,25 @@ class ConversationPinApi(WebApiResource): return {"result": "success"} +@web_ns.route("/conversations//unpin") class ConversationUnPinApi(WebApiResource): unpin_response_fields = { "result": fields.String, } + @web_ns.doc("Unpin Conversation") + @web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.") + @web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 200: "Conversation unpinned successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(unpin_response_fields) def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -132,10 +236,3 @@ class ConversationUnPinApi(WebApiResource): WebConversationService.unpin(app_model, conversation_id, end_user) return {"result": "success"} - - -api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="web_conversation_name") -api.add_resource(ConversationListApi, "/conversations") -api.add_resource(ConversationApi, "/conversations/") -api.add_resource(ConversationPinApi, "/conversations//pin") -api.add_resource(ConversationUnPinApi, "/conversations//unpin") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 17e06e8856..26c0b133d9 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -4,7 +4,7 @@ from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import InternalServerError, NotFound -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError, @@ -38,6 +38,7 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) +@web_ns.route("/messages") class MessageListApi(WebApiResource): message_fields = { "id": fields.String, @@ -62,6 +63,30 @@ class MessageListApi(WebApiResource): "data": fields.List(fields.Nested(message_fields)), } + @web_ns.doc("Get Message List") + @web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.") + @web_ns.doc( + params={ + "conversation_id": {"description": "Conversation UUID", "type": "string", "required": True}, + "first_id": {"description": "First message ID for pagination", "type": "string", "required": False}, + "limit": { + "description": "Number of messages to return (1-100)", + "type": "integer", + "required": False, + "default": 20, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Conversation Not Found or Not a Chat App", + 500: "Internal Server Error", + } + ) @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) @@ -84,11 +109,36 @@ class MessageListApi(WebApiResource): raise NotFound("First Message Not Exists.") +@web_ns.route("/messages//feedbacks") class MessageFeedbackApi(WebApiResource): feedback_response_fields = { "result": fields.String, } + @web_ns.doc("Create Message Feedback") + @web_ns.doc(description="Submit feedback (like/dislike) for a specific message.") + @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) + @web_ns.doc( + params={ + "rating": { + "description": "Feedback rating", + "type": "string", + "enum": ["like", "dislike"], + "required": False, + }, + "content": {"description": "Feedback content/comment", "type": "string", "required": False}, + } + ) + @web_ns.doc( + responses={ + 200: "Feedback submitted successfully", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(feedback_response_fields) def post(self, app_model, end_user, message_id): message_id = str(message_id) @@ -112,7 +162,31 @@ class MessageFeedbackApi(WebApiResource): return {"result": "success"} +@web_ns.route("/messages//more-like-this") class MessageMoreLikeThisApi(WebApiResource): + @web_ns.doc("Generate More Like This") + @web_ns.doc(description="Generate a new completion similar to an existing message (completion apps only).") + @web_ns.doc( + params={ + "message_id": {"description": "Message UUID", "type": "string", "required": True}, + "response_mode": { + "description": "Response mode", + "type": "string", + "enum": ["blocking", "streaming"], + "required": True, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request - Not a completion app or feature disabled", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) def get(self, app_model, end_user, message_id): if app_model.mode != "completion": raise NotCompletionAppError() @@ -156,11 +230,25 @@ class MessageMoreLikeThisApi(WebApiResource): raise InternalServerError() +@web_ns.route("/messages//suggested-questions") class MessageSuggestedQuestionApi(WebApiResource): suggested_questions_response_fields = { "data": fields.List(fields.String), } + @web_ns.doc("Get Suggested Questions") + @web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).") + @web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request - Not a chat app or feature disabled", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found or Conversation Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(suggested_questions_response_fields) def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) @@ -192,9 +280,3 @@ class MessageSuggestedQuestionApi(WebApiResource): raise InternalServerError() return {"data": questions} - - -api.add_resource(MessageListApi, "/messages") -api.add_resource(MessageFeedbackApi, "/messages//feedbacks") -api.add_resource(MessageMoreLikeThisApi, "/messages//more-like-this") -api.add_resource(MessageSuggestedQuestionApi, "/messages//suggested-questions") diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 7a9d24114e..96f09c8d3c 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with, reqparse from flask_restx.inputs import int_range from werkzeug.exceptions import NotFound -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import message_file_fields @@ -23,6 +23,7 @@ message_fields = { } +@web_ns.route("/saved-messages") class SavedMessageListApi(WebApiResource): saved_message_infinite_scroll_pagination_fields = { "limit": fields.Integer, @@ -34,6 +35,29 @@ class SavedMessageListApi(WebApiResource): "result": fields.String, } + @web_ns.doc("Get Saved Messages") + @web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.") + @web_ns.doc( + params={ + "last_id": {"description": "Last message ID for pagination", "type": "string", "required": False}, + "limit": { + "description": "Number of messages to return (1-100)", + "type": "integer", + "required": False, + "default": 20, + }, + } + ) + @web_ns.doc( + responses={ + 200: "Success", + 400: "Bad Request - Not a completion app", + 401: "Unauthorized", + 403: "Forbidden", + 404: "App Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): if app_model.mode != "completion": @@ -46,6 +70,23 @@ class SavedMessageListApi(WebApiResource): return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) + @web_ns.doc("Save Message") + @web_ns.doc(description="Save a specific message for later reference.") + @web_ns.doc( + params={ + "message_id": {"description": "Message UUID to save", "type": "string", "required": True}, + } + ) + @web_ns.doc( + responses={ + 200: "Message saved successfully", + 400: "Bad Request - Not a completion app", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(post_response_fields) def post(self, app_model, end_user): if app_model.mode != "completion": @@ -63,11 +104,25 @@ class SavedMessageListApi(WebApiResource): return {"result": "success"} +@web_ns.route("/saved-messages/") class SavedMessageApi(WebApiResource): delete_response_fields = { "result": fields.String, } + @web_ns.doc("Delete Saved Message") + @web_ns.doc(description="Remove a message from saved messages.") + @web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}}) + @web_ns.doc( + responses={ + 204: "Message removed successfully", + 400: "Bad Request - Not a completion app", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Message Not Found", + 500: "Internal Server Error", + } + ) @marshal_with(delete_response_fields) def delete(self, app_model, end_user, message_id): message_id = str(message_id) @@ -78,7 +133,3 @@ class SavedMessageApi(WebApiResource): SavedMessageService.delete(app_model, end_user, message_id) return {"result": "success"}, 204 - - -api.add_resource(SavedMessageListApi, "/saved-messages") -api.add_resource(SavedMessageApi, "/saved-messages/") diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 91d67bf9d8..b01aaba357 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -2,7 +2,7 @@ from flask_restx import fields, marshal_with from werkzeug.exceptions import Forbidden from configs import dify_config -from controllers.web import api +from controllers.web import web_ns from controllers.web.wraps import WebApiResource from extensions.ext_database import db from libs.helper import AppIconUrlField @@ -11,6 +11,7 @@ from models.model import Site from services.feature_service import FeatureService +@web_ns.route("/site") class AppSiteApi(WebApiResource): """Resource for app sites.""" @@ -53,9 +54,9 @@ class AppSiteApi(WebApiResource): "custom_config": fields.Raw(attribute="custom_config"), } - @api.doc("Get App Site Info") - @api.doc(description="Retrieve app site information and configuration.") - @api.doc( + @web_ns.doc("Get App Site Info") + @web_ns.doc(description="Retrieve app site information and configuration.") + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -82,9 +83,6 @@ class AppSiteApi(WebApiResource): return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) -api.add_resource(AppSiteApi, "/site") - - class AppSiteInfo: """Class to store site information.""" diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 3566cfae38..490dce8f05 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -3,7 +3,7 @@ import logging from flask_restx import reqparse from werkzeug.exceptions import InternalServerError -from controllers.web import api +from controllers.web import web_ns from controllers.web.error import ( CompletionRequestError, NotWorkflowAppError, @@ -29,16 +29,17 @@ from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) +@web_ns.route("/workflows/run") class WorkflowRunApi(WebApiResource): - @api.doc("Run Workflow") - @api.doc(description="Execute a workflow with provided inputs and files.") - @api.doc( + @web_ns.doc("Run Workflow") + @web_ns.doc(description="Execute a workflow with provided inputs and files.") + @web_ns.doc( params={ "inputs": {"description": "Input variables for the workflow", "type": "object", "required": True}, "files": {"description": "Files to be processed by the workflow", "type": "array", "required": False}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -84,15 +85,16 @@ class WorkflowRunApi(WebApiResource): raise InternalServerError() +@web_ns.route("/workflows/tasks//stop") class WorkflowTaskStopApi(WebApiResource): - @api.doc("Stop Workflow Task") - @api.doc(description="Stop a running workflow task.") - @api.doc( + @web_ns.doc("Stop Workflow Task") + @web_ns.doc(description="Stop a running workflow task.") + @web_ns.doc( params={ "task_id": {"description": "Task ID to stop", "type": "string", "required": True}, } ) - @api.doc( + @web_ns.doc( responses={ 200: "Success", 400: "Bad Request", @@ -113,7 +115,3 @@ class WorkflowTaskStopApi(WebApiResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) return {"result": "success"} - - -api.add_resource(WorkflowRunApi, "/workflows/run") -api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") From cbc0e639e4eb034461ffbd2111e8aae10f6ba239 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 10 Sep 2025 14:00:17 +0900 Subject: [PATCH 16/86] update sql in batch (#24801) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- --- api/commands.py | 20 ++-- api/controllers/console/apikey.py | 10 +- .../console/datasets/data_source.py | 8 +- api/controllers/console/datasets/datasets.py | 31 +++--- .../console/datasets/datasets_document.py | 3 +- .../console/explore/installed_app.py | 16 +-- api/controllers/console/workspace/account.py | 4 +- api/core/memory/token_buffer_memory.py | 25 ++--- .../arize_phoenix_trace.py | 11 +-- .../vdb/tidb_on_qdrant/tidb_service.py | 3 +- api/core/tools/custom_tool/provider.py | 11 ++- api/core/tools/tool_label_manager.py | 4 +- api/core/tools/tool_manager.py | 12 +-- ...aset_join_when_app_model_config_updated.py | 4 +- ...oin_when_app_published_workflow_updated.py | 4 +- api/models/account.py | 10 +- api/models/dataset.py | 14 +-- api/models/model.py | 6 +- api/schedule/clean_unused_datasets_task.py | 18 ++-- .../mail_clean_document_notify_task.py | 7 +- .../update_tidb_serverless_status_task.py | 12 +-- api/services/annotation_service.py | 8 +- api/services/auth/api_key_auth_service.py | 12 ++- .../clear_free_plan_tenant_expired_logs.py | 3 +- api/services/dataset_service.py | 97 ++++++++----------- api/services/model_load_balancing_service.py | 10 +- .../database/database_retrieval.py | 18 ++-- api/services/tag_service.py | 35 +++---- .../tools/api_tools_manage_service.py | 5 +- .../tools/workflow_tools_manage_service.py | 6 +- .../enable_annotation_reply_task.py | 3 +- api/tasks/batch_clean_document_task.py | 7 +- api/tasks/clean_dataset_task.py | 5 +- api/tasks/clean_document_task.py | 3 +- api/tasks/clean_notion_document_task.py | 5 +- api/tasks/deal_dataset_vector_index_task.py | 17 ++-- api/tasks/disable_segments_from_index_task.py | 9 +- api/tasks/document_indexing_sync_task.py | 5 +- api/tasks/document_indexing_update_task.py | 3 +- api/tasks/duplicate_document_indexing_task.py | 5 +- api/tasks/enable_segments_to_index_task.py | 9 +- api/tasks/remove_document_from_index_task.py | 3 +- api/tasks/retry_document_indexing_task.py | 5 +- .../sync_website_document_indexing_task.py | 3 +- .../test_model_load_balancing_service.py | 7 +- .../services/test_tag_service.py | 9 +- .../services/test_web_conversation_service.py | 9 +- .../auth/test_api_key_auth_service.py | 20 ++-- .../services/auth/test_auth_integration.py | 4 +- 49 files changed, 281 insertions(+), 277 deletions(-) diff --git a/api/commands.py b/api/commands.py index 2bef83b2a7..1858cb2734 100644 --- a/api/commands.py +++ b/api/commands.py @@ -212,7 +212,9 @@ def migrate_annotation_vector_database(): if not dataset_collection_binding: click.echo(f"App annotation collection binding not found: {app.id}") continue - annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all() + annotations = db.session.scalars( + select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) + ).all() dataset = Dataset( id=app.id, tenant_id=app.tenant_id, @@ -367,29 +369,25 @@ def migrate_knowledge_vector_database(): ) raise e - dataset_documents = ( - db.session.query(DatasetDocument) - .where( + dataset_documents = db.session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset.id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() documents = [] segments_count = 0 for dataset_document in dataset_documents: - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.status == "completed", DocumentSegment.enabled == True, ) - .all() - ) + ).all() for segment in segments: document = Document( diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 06de2fa6b6..56c61c2886 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -60,11 +60,11 @@ class BaseApiKeyListResource(Resource): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) - keys = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) - .all() - ) + keys = db.session.scalars( + select(ApiToken).where( + ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id + ) + ).all() return {"items": keys} @marshal_with(api_key_fields) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 45c647659b..6e49bfa510 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -29,14 +29,12 @@ class DataSourceApi(Resource): @marshal_with(integrate_list_fields) def get(self): # get workspace data source integrates - data_source_integrates = ( - db.session.query(DataSourceOauthBinding) - .where( + data_source_integrates = db.session.scalars( + select(DataSourceOauthBinding).where( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.disabled == False, ) - .all() - ) + ).all() base_url = request.url_root.rstrip("/") data_source_oauth_base_path = "/console/api/oauth/data-source" diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 11b7b1fec0..9fb092607a 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -2,6 +2,7 @@ import flask_restx from flask import request from flask_login import current_user from flask_restx import Resource, marshal, marshal_with, reqparse +from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound import services @@ -411,11 +412,11 @@ class DatasetIndexingEstimateApi(Resource): extract_settings = [] if args["info_list"]["data_source_type"] == "upload_file": file_ids = args["info_list"]["file_info_list"]["file_ids"] - file_details = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) - .all() - ) + file_details = db.session.scalars( + select(UploadFile).where( + UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids) + ) + ).all() if file_details is None: raise NotFound("File not found.") @@ -518,11 +519,11 @@ class DatasetIndexingStatusApi(Resource): @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) - documents = ( - db.session.query(Document) - .where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) - .all() - ) + documents = db.session.scalars( + select(Document).where( + Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id + ) + ).all() documents_status = [] for document in documents: completed_segments = ( @@ -569,11 +570,11 @@ class DatasetApiKeyApi(Resource): @account_initialization_required @marshal_with(api_key_list) def get(self): - keys = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) - .all() - ) + keys = db.session.scalars( + select(ApiToken).where( + ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id + ) + ).all() return {"items": keys} @setup_required diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index c9c0b6a5ce..f943fb3ccb 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,5 +1,6 @@ import logging from argparse import ArgumentTypeError +from collections.abc import Sequence from typing import Literal, cast from flask import request @@ -79,7 +80,7 @@ class DocumentResource(Resource): return document - def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: + def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 22aa753d92..bdc3fb0dbd 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -3,7 +3,7 @@ from typing import Any from flask import request from flask_restx import Resource, inputs, marshal_with, reqparse -from sqlalchemy import and_ +from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound from controllers.console import api @@ -33,13 +33,15 @@ class InstalledAppsListApi(Resource): current_tenant_id = current_user.current_tenant_id if app_id: - installed_apps = ( - db.session.query(InstalledApp) - .where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) - .all() - ) + installed_apps = db.session.scalars( + select(InstalledApp).where( + and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id) + ) + ).all() else: - installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all() + installed_apps = db.session.scalars( + select(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id) + ).all() if current_user.current_tenant is None: raise ValueError("current_user.current_tenant must not be None") diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index bd078729c4..7a41a8a5cc 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -248,7 +248,9 @@ class AccountIntegrateApi(Resource): raise ValueError("Invalid user account") account = current_user - account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all() + account_integrates = db.session.scalars( + select(AccountIntegrate).where(AccountIntegrate.account_id == account.id) + ).all() base_url = request.url_root.rstrip("/") oauth_base_path = "/console/api/oauth/login" diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 7be695812a..1f2525cfed 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -32,11 +32,16 @@ class TokenBufferMemory: self.model_instance = model_instance def _build_prompt_message_with_files( - self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool + self, + message_files: Sequence[MessageFile], + text_content: str, + message: Message, + app_record, + is_user_message: bool, ) -> PromptMessage: """ Build prompt message with files. - :param message_files: list of MessageFile objects + :param message_files: Sequence of MessageFile objects :param text_content: text content of the message :param message: Message object :param app_record: app record @@ -128,14 +133,12 @@ class TokenBufferMemory: prompt_messages: list[PromptMessage] = [] for message in messages: # Process user message with files - user_files = ( - db.session.query(MessageFile) - .where( + user_files = db.session.scalars( + select(MessageFile).where( MessageFile.message_id == message.id, (MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)), ) - .all() - ) + ).all() if user_files: user_prompt_message = self._build_prompt_message_with_files( @@ -150,11 +153,9 @@ class TokenBufferMemory: prompt_messages.append(UserPromptMessage(content=message.query)) # Process assistant message with files - assistant_files = ( - db.session.query(MessageFile) - .where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant") - .all() - ) + assistant_files = db.session.scalars( + select(MessageFile).where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant") + ).all() if assistant_files: assistant_prompt_message = self._build_prompt_message_with_files( diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index e7c90c1229..c5fbe4d78b 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -15,6 +15,7 @@ from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.id_generator import RandomIdGenerator from opentelemetry.trace import SpanContext, TraceFlags, TraceState +from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig @@ -699,8 +700,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): def _get_workflow_nodes(self, workflow_run_id: str): """Helper method to get workflow nodes""" - workflow_nodes = ( - db.session.query( + workflow_nodes = db.session.scalars( + select( WorkflowNodeExecutionModel.id, WorkflowNodeExecutionModel.tenant_id, WorkflowNodeExecutionModel.app_id, @@ -713,10 +714,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): WorkflowNodeExecutionModel.elapsed_time, WorkflowNodeExecutionModel.process_data, WorkflowNodeExecutionModel.execution_metadata, - ) - .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) - .all() - ) + ).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) + ).all() return workflow_nodes def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 184b5f2142..e1d4422144 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -1,5 +1,6 @@ import time import uuid +from collections.abc import Sequence import requests from requests.auth import HTTPDigestAuth @@ -139,7 +140,7 @@ class TidbService: @staticmethod def batch_update_tidb_serverless_cluster_status( - tidb_serverless_list: list[TidbAuthBinding], + tidb_serverless_list: Sequence[TidbAuthBinding], project_id: str, api_url: str, iam_url: str, diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 5790aea2b0..0cc992155a 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -1,4 +1,5 @@ from pydantic import Field +from sqlalchemy import select from core.entities.provider_entities import ProviderConfig from core.tools.__base.tool_provider import ToolProviderController @@ -176,11 +177,11 @@ class ApiToolProviderController(ToolProviderController): tools: list[ApiTool] = [] # get tenant api providers - db_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider) - .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) - .all() - ) + db_providers = db.session.scalars( + select(ApiToolProvider).where( + ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name + ) + ).all() if db_providers and len(db_providers) != 0: for db_provider in db_providers: diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 84b874975a..39646b7fc8 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -87,9 +87,7 @@ class ToolLabelManager: assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute] - labels: list[ToolLabelBinding] = ( - db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all() - ) + labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index bc1f09a2fc..b29da3f0ba 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -667,9 +667,9 @@ class ToolManager: # get db api providers if "api" in filters: - db_api_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() - ) + db_api_providers = db.session.scalars( + select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id) + ).all() api_provider_controllers: list[dict[str, Any]] = [ {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} @@ -690,9 +690,9 @@ class ToolManager: if "workflow" in filters: # get workflow providers - workflow_providers: list[WorkflowToolProvider] = ( - db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() - ) + workflow_providers = db.session.scalars( + select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) + ).all() workflow_provider_controllers: list[WorkflowToolProviderController] = [] for workflow_provider in workflow_providers: diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index b8b5a89dc5..69959acd19 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from events.app_event import app_model_config_was_updated from extensions.ext_database import db from models.dataset import AppDatasetJoin @@ -13,7 +15,7 @@ def handle(sender, **kwargs): dataset_ids = get_dataset_ids_from_model_config(app_model_config) - app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() + app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all() removed_dataset_ids: set[str] = set() if not app_dataset_joins: diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index fcc3b63fa7..898ec1f153 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,5 +1,7 @@ from typing import cast +from sqlalchemy import select + from core.workflow.nodes import NodeType from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated @@ -15,7 +17,7 @@ def handle(sender, **kwargs): published_workflow = cast(Workflow, published_workflow) dataset_ids = get_dataset_ids_from_workflow(published_workflow) - app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all() + app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all() removed_dataset_ids: set[str] = set() if not app_dataset_joins: diff --git a/api/models/account.py b/api/models/account.py index 019159d2da..4656b47e7a 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -218,10 +218,12 @@ class Tenant(Base): updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: - return ( - db.session.query(Account) - .where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) - .all() + return list( + db.session.scalars( + select(Account).where( + Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id + ) + ).all() ) @property diff --git a/api/models/dataset.py b/api/models/dataset.py index 07f3eb18db..13087bf995 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -208,7 +208,9 @@ class Dataset(Base): @property def doc_metadata(self): - dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all() + dataset_metadatas = db.session.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id) + ).all() doc_metadata = [ { @@ -1055,13 +1057,11 @@ class ExternalKnowledgeApis(Base): @property def dataset_bindings(self) -> list[dict[str, Any]]: - external_knowledge_bindings = ( - db.session.query(ExternalKnowledgeBindings) - .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) - .all() - ) + external_knowledge_bindings = db.session.scalars( + select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) + ).all() dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] - datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() + datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() dataset_bindings: list[dict[str, Any]] = [] for dataset in datasets: dataset_bindings.append({"id": dataset.id, "name": dataset.name}) diff --git a/api/models/model.py b/api/models/model.py index f8ead1f872..5a4c5de6e1 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -812,7 +812,7 @@ class Conversation(Base): @property def status_count(self): - messages = db.session.query(Message).where(Message.conversation_id == self.id).all() + messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all() status_counts = { WorkflowExecutionStatus.RUNNING: 0, WorkflowExecutionStatus.SUCCEEDED: 0, @@ -1090,7 +1090,7 @@ class Message(Base): @property def feedbacks(self): - feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all() + feedbacks = db.session.scalars(select(MessageFeedback).where(MessageFeedback.message_id == self.id)).all() return feedbacks @property @@ -1145,7 +1145,7 @@ class Message(Base): def message_files(self) -> list[dict[str, Any]]: from factories import file_factory - message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() + message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all() current_app = db.session.query(App).where(App.id == self.app_id).first() if not current_app: raise ValueError(f"App {self.app_id} not found") diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 63e6132b6a..2b1e6b47cc 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -96,11 +96,11 @@ def clean_unused_datasets_task(): break for dataset in datasets: - dataset_query = ( - db.session.query(DatasetQuery) - .where(DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id) - .all() - ) + dataset_query = db.session.scalars( + select(DatasetQuery).where( + DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id + ) + ).all() if not dataset_query or len(dataset_query) == 0: try: @@ -121,15 +121,13 @@ def clean_unused_datasets_task(): if should_clean: # Add auto disable log if required if add_logs: - documents = ( - db.session.query(Document) - .where( + documents = db.session.scalars( + select(Document).where( Document.dataset_id == dataset.id, Document.enabled == True, Document.archived == False, ) - .all() - ) + ).all() for document in documents: dataset_auto_disable_log = DatasetAutoDisableLog( tenant_id=dataset.tenant_id, diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 9e32ecc716..ef6edd6709 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -3,6 +3,7 @@ import time from collections import defaultdict import click +from sqlalchemy import select import app from configs import dify_config @@ -31,9 +32,9 @@ def mail_clean_document_notify_task(): # send document clean notify mail try: - dataset_auto_disable_logs = ( - db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False).all() - ) + dataset_auto_disable_logs = db.session.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False) + ).all() # group by tenant_id dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) for dataset_auto_disable_log in dataset_auto_disable_logs: diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 1bfeb869e2..1befa0e8b5 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -1,6 +1,8 @@ import time +from collections.abc import Sequence import click +from sqlalchemy import select import app from configs import dify_config @@ -15,11 +17,9 @@ def update_tidb_serverless_status_task(): start_at = time.perf_counter() try: # check the number of idle tidb serverless - tidb_serverless_list = ( - db.session.query(TidbAuthBinding) - .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") - .all() - ) + tidb_serverless_list = db.session.scalars( + select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + ).all() if len(tidb_serverless_list) == 0: return # update tidb serverless status @@ -32,7 +32,7 @@ def update_tidb_serverless_status_task(): click.echo(click.style(f"Update tidb serverless status task success latency: {end_at - start_at}", fg="green")) -def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): +def update_clusters(tidb_serverless_list: Sequence[TidbAuthBinding]): try: # batch 20 for i in range(0, len(tidb_serverless_list), 20): diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 82b1d21179..34681ba111 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -263,11 +263,9 @@ class AppAnnotationService: db.session.delete(annotation) - annotation_hit_histories = ( - db.session.query(AppAnnotationHitHistory) - .where(AppAnnotationHitHistory.annotation_id == annotation_id) - .all() - ) + annotation_hit_histories = db.session.scalars( + select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id) + ).all() if annotation_hit_histories: for annotation_hit_history in annotation_hit_histories: db.session.delete(annotation_hit_history) diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index f6e960b413..055cf65816 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -1,5 +1,7 @@ import json +from sqlalchemy import select + from core.helper import encrypter from extensions.ext_database import db from models.source import DataSourceApiKeyAuthBinding @@ -9,11 +11,11 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory class ApiKeyAuthService: @staticmethod def get_provider_auth_list(tenant_id: str): - data_source_api_key_bindings = ( - db.session.query(DataSourceApiKeyAuthBinding) - .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) - .all() - ) + data_source_api_key_bindings = db.session.scalars( + select(DataSourceApiKeyAuthBinding).where( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False) + ) + ).all() return data_source_api_key_bindings @staticmethod diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 3b4cb1900a..f8f89d7428 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app +from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -115,7 +116,7 @@ class ClearFreePlanTenantExpiredLogs: @classmethod def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): with flask_app.app_context(): - apps = db.session.query(App).where(App.tenant_id == tenant_id).all() + apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all() app_ids = [app.id for app in apps] while True: with Session(db.engine).no_autoflush as session: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 20a9c73f08..47bd06a7cc 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,6 +6,7 @@ import secrets import time import uuid from collections import Counter +from collections.abc import Sequence from typing import Any, Literal, Optional import sqlalchemy as sa @@ -741,14 +742,12 @@ class DatasetService: } # get recent 30 days auto disable logs start_date = datetime.datetime.now() - datetime.timedelta(days=30) - dataset_auto_disable_logs = ( - db.session.query(DatasetAutoDisableLog) - .where( + dataset_auto_disable_logs = db.session.scalars( + select(DatasetAutoDisableLog).where( DatasetAutoDisableLog.dataset_id == dataset_id, DatasetAutoDisableLog.created_at >= start_date, ) - .all() - ) + ).all() if dataset_auto_disable_logs: return { "document_ids": [log.document_id for log in dataset_auto_disable_logs], @@ -885,69 +884,58 @@ class DocumentService: return document @staticmethod - def get_document_by_ids(document_ids: list[str]) -> list[Document]: - documents = ( - db.session.query(Document) - .where( + def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]: + documents = db.session.scalars( + select(Document).where( Document.id.in_(document_ids), Document.enabled == True, Document.indexing_status == "completed", Document.archived == False, ) - .all() - ) + ).all() return documents @staticmethod - def get_document_by_dataset_id(dataset_id: str) -> list[Document]: - documents = ( - db.session.query(Document) - .where( + def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]: + documents = db.session.scalars( + select(Document).where( Document.dataset_id == dataset_id, Document.enabled == True, ) - .all() - ) + ).all() return documents @staticmethod - def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]: - documents = ( - db.session.query(Document) - .where( + def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: + documents = db.session.scalars( + select(Document).where( Document.dataset_id == dataset_id, Document.enabled == True, Document.indexing_status == "completed", Document.archived == False, ) - .all() - ) + ).all() return documents @staticmethod - def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: - documents = ( - db.session.query(Document) - .where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) - .all() - ) + def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: + documents = db.session.scalars( + select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + ).all() return documents @staticmethod - def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: + def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]: assert isinstance(current_user, Account) - - documents = ( - db.session.query(Document) - .where( + documents = db.session.scalars( + select(Document).where( Document.batch == batch, Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id, ) - .all() - ) + ).all() return documents @@ -984,7 +972,7 @@ class DocumentService: # Check if document_ids is not empty to avoid WHERE false condition if not document_ids or len(document_ids) == 0: return - documents = db.session.query(Document).where(Document.id.in_(document_ids)).all() + documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all() file_ids = [ document.data_source_info_dict["upload_file_id"] for document in documents @@ -2424,16 +2412,14 @@ class SegmentService: if not segment_ids or len(segment_ids) == 0: return if action == "enable": - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, DocumentSegment.enabled == False, ) - .all() - ) + ).all() if not segments: return real_deal_segment_ids = [] @@ -2451,16 +2437,14 @@ class SegmentService: enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) elif action == "disable": - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, DocumentSegment.document_id == document.id, DocumentSegment.enabled == True, ) - .all() - ) + ).all() if not segments: return real_deal_segment_ids = [] @@ -2532,16 +2516,13 @@ class SegmentService: dataset: Dataset, ) -> list[ChildChunk]: assert isinstance(current_user, Account) - - child_chunks = ( - db.session.query(ChildChunk) - .where( + child_chunks = db.session.scalars( + select(ChildChunk).where( ChildChunk.dataset_id == dataset.id, ChildChunk.document_id == document.id, ChildChunk.segment_id == segment.id, ) - .all() - ) + ).all() child_chunks_map = {chunk.id: chunk for chunk in child_chunks} new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], [] @@ -2751,13 +2732,11 @@ class DatasetCollectionBindingService: class DatasetPermissionService: @classmethod def get_dataset_partial_member_list(cls, dataset_id): - user_list_query = ( - db.session.query( + user_list_query = db.session.scalars( + select( DatasetPermission.account_id, - ) - .where(DatasetPermission.dataset_id == dataset_id) - .all() - ) + ).where(DatasetPermission.dataset_id == dataset_id) + ).all() user_list = [] for user in user_list_query: diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index d0e2230540..33d7dacba0 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -3,7 +3,7 @@ import logging from json import JSONDecodeError from typing import Optional, Union -from sqlalchemy import or_ +from sqlalchemy import or_, select from constants import HIDDEN_VALUE from core.entities.provider_configuration import ProviderConfiguration @@ -322,16 +322,14 @@ class ModelLoadBalancingService: if not isinstance(configs, list): raise ValueError("Invalid load balancing configs") - current_load_balancing_configs = ( - db.session.query(LoadBalancingModelConfig) - .where( + current_load_balancing_configs = db.session.scalars( + select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) - .all() - ) + ).all() # id as key, config as value current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index e19f53f120..a9733e0826 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -1,5 +1,7 @@ from typing import Optional +from sqlalchemy import select + from constants.languages import languages from extensions.ext_database import db from models.model import App, RecommendedApp @@ -31,18 +33,14 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): :param language: language :return: """ - recommended_apps = ( - db.session.query(RecommendedApp) - .where(RecommendedApp.is_listed == True, RecommendedApp.language == language) - .all() - ) + recommended_apps = db.session.scalars( + select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == language) + ).all() if len(recommended_apps) == 0: - recommended_apps = ( - db.session.query(RecommendedApp) - .where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) - .all() - ) + recommended_apps = db.session.scalars( + select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + ).all() categories = set() recommended_apps_result = [] diff --git a/api/services/tag_service.py b/api/services/tag_service.py index a16bdb46cd..dd67b19966 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -2,7 +2,7 @@ import uuid from typing import Optional from flask_login import current_user -from sqlalchemy import func +from sqlalchemy import func, select from werkzeug.exceptions import NotFound from extensions.ext_database import db @@ -29,35 +29,30 @@ class TagService: # Check if tag_ids is not empty to avoid WHERE false condition if not tag_ids or len(tag_ids) == 0: return [] - tags = ( - db.session.query(Tag) - .where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) - .all() - ) + tags = db.session.scalars( + select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + ).all() if not tags: return [] tag_ids = [tag.id for tag in tags] # Check if tag_ids is not empty to avoid WHERE false condition if not tag_ids or len(tag_ids) == 0: return [] - tag_bindings = ( - db.session.query(TagBinding.target_id) - .where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) - .all() - ) - if not tag_bindings: - return [] - results = [tag_binding.target_id for tag_binding in tag_bindings] - return results + tag_bindings = db.session.scalars( + select(TagBinding.target_id).where( + TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id + ) + ).all() + return tag_bindings @staticmethod def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str): if not tag_type or not tag_name: return [] - tags = ( - db.session.query(Tag) - .where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) - .all() + tags = list( + db.session.scalars( + select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + ).all() ) if not tags: return [] @@ -117,7 +112,7 @@ class TagService: raise NotFound("Tag not found") db.session.delete(tag) # delete tag binding - tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all() + tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all() if tag_bindings: for tag_binding in tag_bindings: db.session.delete(tag_binding) diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 78e587abee..f86d7e51bf 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -4,6 +4,7 @@ from collections.abc import Mapping from typing import Any, cast from httpx import get +from sqlalchemy import select from core.entities.provider_entities import ProviderConfig from core.model_runtime.utils.encoders import jsonable_encoder @@ -443,9 +444,7 @@ class ApiToolManageService: list api tools """ # get all api providers - db_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or [] - ) + db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() result: list[ToolProviderApiEntity] = [] diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 2f8a91ed82..2449536d5c 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -3,7 +3,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any -from sqlalchemy import or_ +from sqlalchemy import or_, select from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController @@ -186,7 +186,9 @@ class WorkflowToolManageService: :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() + db_tools = db.session.scalars( + select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) + ).all() tools: list[WorkflowToolProviderController] = [] for provider in db_tools: diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 3498e08426..cdc07c77a8 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document @@ -39,7 +40,7 @@ def enable_annotation_reply_task( db.session.close() return - annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all() + annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all() enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 08e2c4a556..212f8c3c6a 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -34,7 +35,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form if not dataset: raise Exception("Document has no dataset") - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) + ).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -59,7 +62,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form db.session.commit() if file_ids: - files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all() + files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() for file in files: try: storage.delete(file.key) diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 9d12b6a589..5f2a355d16 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -55,8 +56,8 @@ def clean_dataset_task( index_struct=index_struct, collection_binding_id=collection_binding_id, ) - documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all() - segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all() + documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace # This ensures all invalid doc_form values are properly handled diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 6549ad04b5..761ac6fc3d 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -4,6 +4,7 @@ from typing import Optional import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -35,7 +36,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if not dataset: raise Exception("Document has no dataset") - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index e7a61e22f2..771b43f9b0 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -34,7 +35,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): document = db.session.query(Document).where(Document.id == document_id).first() db.session.delete(document) - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() index_node_ids = [segment.index_node_id for segment in segments] index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 23e929c57e..dc6ef6fb61 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -4,6 +4,7 @@ from typing import Literal import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -36,16 +37,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a if action == "remove": index_processor.clean(dataset, None, with_keywords=False) elif action == "add": - dataset_documents = ( - db.session.query(DatasetDocument) - .where( + dataset_documents = db.session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] @@ -89,16 +88,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a ) db.session.commit() elif action == "update": - dataset_documents = ( - db.session.query(DatasetDocument) - .where( + dataset_documents = db.session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() # add new index if dataset_documents: # update document status diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index d4899fe0e4..9038dc179b 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -44,15 +45,13 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen # sync index processor index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, ) - .all() - ) + ).all() if not segments: db.session.close() diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 687e3e9551..24d7d16578 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.extractor.notion_extractor import NotionExtractor @@ -85,7 +86,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 48566b6104..161502a228 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -45,7 +46,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index d93f30ba37..2020179cd9 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -79,7 +80,9 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index 647664641d..c5ca7a6171 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -45,15 +46,13 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i # sync index processor index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - segments = ( - db.session.query(DocumentSegment) - .where( + segments = db.session.scalars( + select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset_id, DocumentSegment.document_id == document_id, ) - .all() - ) + ).all() if not segments: logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan")) db.session.close() diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index ec56ab583b..c0ab2d0b41 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -45,7 +46,7 @@ def remove_document_from_index_task(document_id: str): index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all() index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: try: diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index c52218caae..b65eca7e0b 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -69,7 +70,9 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): # clean old data index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document_id) + ).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 3c7c69e3c8..0dc1d841f4 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -63,7 +64,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): # clean old data index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all() + segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index cb20238f0c..66527dd506 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from models.account import TenantAccountJoin, TenantAccountRole from models.model import Account, Tenant @@ -468,7 +469,7 @@ class TestModelLoadBalancingService: assert load_balancing_config.id is not None # Verify inherit config was created in database - inherit_configs = ( - db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all() - ) + inherit_configs = db.session.scalars( + select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__") + ).all() assert len(inherit_configs) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index d09a4a17ab..04cff397b2 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy import select from werkzeug.exceptions import NotFound from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -954,7 +955,9 @@ class TestTagService: from extensions.ext_database import db # Verify only one binding exists - bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() + bindings = db.session.scalars( + select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) + ).all() assert len(bindings) == 1 def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): @@ -1064,7 +1067,9 @@ class TestTagService: # No error should be raised, and database state should remain unchanged from extensions.ext_database import db - bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all() + bindings = db.session.scalars( + select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) + ).all() assert len(bindings) == 0 def test_check_target_exists_knowledge_success( diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 6d6f1dab72..c9ace46c55 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account @@ -354,16 +355,14 @@ class TestWebConversationService: # Verify only one pinned conversation record exists from extensions.ext_database import db - pinned_conversations = ( - db.session.query(PinnedConversation) - .where( + pinned_conversations = db.session.scalars( + select(PinnedConversation).where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, PinnedConversation.created_by_role == "account", PinnedConversation.created_by == account.id, ) - .all() - ) + ).all() assert len(pinned_conversations) == 1 diff --git a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py index dc42a04cf3..d23298f096 100644 --- a/api/tests/unit_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/unit_tests/services/auth/test_api_key_auth_service.py @@ -28,18 +28,20 @@ class TestApiKeyAuthService: mock_binding.provider = self.provider mock_binding.disabled = False - mock_session.query.return_value.where.return_value.all.return_value = [mock_binding] + mock_session.scalars.return_value.all.return_value = [mock_binding] result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) assert len(result) == 1 assert result[0].tenant_id == self.tenant_id - mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding) + assert mock_session.scalars.call_count == 1 + select_arg = mock_session.scalars.call_args[0][0] + assert "data_source_api_key_auth_binding" in str(select_arg).lower() @patch("services.auth.api_key_auth_service.db.session") def test_get_provider_auth_list_empty(self, mock_session): """Test get provider auth list - empty result""" - mock_session.query.return_value.where.return_value.all.return_value = [] + mock_session.scalars.return_value.all.return_value = [] result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) @@ -48,13 +50,15 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.db.session") def test_get_provider_auth_list_filters_disabled(self, mock_session): """Test get provider auth list - filters disabled items""" - mock_session.query.return_value.where.return_value.all.return_value = [] + mock_session.scalars.return_value.all.return_value = [] ApiKeyAuthService.get_provider_auth_list(self.tenant_id) - - # Verify where conditions include disabled.is_(False) - where_call = mock_session.query.return_value.where.call_args[0] - assert len(where_call) == 2 # tenant_id and disabled filter conditions + select_stmt = mock_session.scalars.call_args[0][0] + where_clauses = list(getattr(select_stmt, "_where_criteria", []) or []) + # Ensure both tenant filter and disabled filter exist + where_strs = [str(c).lower() for c in where_clauses] + assert any("tenant_id" in s for s in where_strs) + assert any("disabled" in s for s in where_strs) @patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") diff --git a/api/tests/unit_tests/services/auth/test_auth_integration.py b/api/tests/unit_tests/services/auth/test_auth_integration.py index 4ce5525942..bb39b92c09 100644 --- a/api/tests/unit_tests/services/auth/test_auth_integration.py +++ b/api/tests/unit_tests/services/auth/test_auth_integration.py @@ -63,10 +63,10 @@ class TestAuthIntegration: tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) - mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding] + mock_session.scalars.return_value.all.return_value = [tenant1_binding] result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) - mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding] + mock_session.scalars.return_value.all.return_value = [tenant2_binding] result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) assert len(result1) == 1 From b690ac4e2a5e856aab549df231d81a45c5f6cd56 Mon Sep 17 00:00:00 2001 From: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:17:49 +0800 Subject: [PATCH 17/86] fix: Remove sticky positioning from workflow component fields (#25470) --- web/app/components/workflow/nodes/_base/components/field.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/workflow/nodes/_base/components/field.tsx b/web/app/components/workflow/nodes/_base/components/field.tsx index 44fa4f6f0a..aadcea1065 100644 --- a/web/app/components/workflow/nodes/_base/components/field.tsx +++ b/web/app/components/workflow/nodes/_base/components/field.tsx @@ -38,7 +38,7 @@ const Field: FC = ({
supportFold && toggleFold()} - className={cn('sticky top-0 flex items-center justify-between bg-components-panel-bg', supportFold && 'cursor-pointer')}> + className={cn('flex items-center justify-between', supportFold && 'cursor-pointer')}>
{title} {required && *} From 70e4d6be340731ada69ff399c5eb9a5d7abc4615 Mon Sep 17 00:00:00 2001 From: Eric Guo Date: Wed, 10 Sep 2025 15:57:04 +0800 Subject: [PATCH 18/86] Fix 500 in dataset page. (#25474) --- api/services/dataset_service.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 47bd06a7cc..997b28524a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2738,11 +2738,7 @@ class DatasetPermissionService: ).where(DatasetPermission.dataset_id == dataset_id) ).all() - user_list = [] - for user in user_list_query: - user_list.append(user.account_id) - - return user_list + return user_list_query @classmethod def update_partial_member_list(cls, tenant_id, dataset_id, user_list): From 34e55028aea90e1ef8ab8e1c6df6f3f50ec5ea16 Mon Sep 17 00:00:00 2001 From: Xiyuan Chen <52963600+GareArc@users.noreply.github.com> Date: Wed, 10 Sep 2025 19:01:32 -0700 Subject: [PATCH 19/86] Feat/enteprise cd (#25485) --- .github/workflows/deploy-enterprise.yml | 28 ++++++++++++++++++------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/.github/workflows/deploy-enterprise.yml b/.github/workflows/deploy-enterprise.yml index 98fa7c3b49..7ef94f7622 100644 --- a/.github/workflows/deploy-enterprise.yml +++ b/.github/workflows/deploy-enterprise.yml @@ -19,11 +19,23 @@ jobs: github.event.workflow_run.head_branch == 'deploy/enterprise' steps: - - name: Deploy to server - uses: appleboy/ssh-action@v0.1.8 - with: - host: ${{ secrets.ENTERPRISE_SSH_HOST }} - username: ${{ secrets.ENTERPRISE_SSH_USER }} - password: ${{ secrets.ENTERPRISE_SSH_PASSWORD }} - script: | - ${{ vars.ENTERPRISE_SSH_SCRIPT || secrets.ENTERPRISE_SSH_SCRIPT }} + - name: trigger deployments + env: + DEV_ENV_ADDRS: ${{ vars.DEV_ENV_ADDRS }} + DEPLOY_SECRET: ${{ secrets.DEPLOY_SECRET }} + run: | + IFS=',' read -ra ENDPOINTS <<< "$DEV_ENV_ADDRS" + + for ENDPOINT in "${ENDPOINTS[@]}"; do + ENDPOINT=$(echo "$ENDPOINT" | xargs) + + BODY=$(cat < Date: Wed, 10 Sep 2025 20:53:42 -0700 Subject: [PATCH 20/86] Feat/enteprise cd (#25508) --- .github/workflows/deploy-enterprise.yml | 26 ++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/.github/workflows/deploy-enterprise.yml b/.github/workflows/deploy-enterprise.yml index 7ef94f7622..9cff3a3482 100644 --- a/.github/workflows/deploy-enterprise.yml +++ b/.github/workflows/deploy-enterprise.yml @@ -24,18 +24,18 @@ jobs: DEV_ENV_ADDRS: ${{ vars.DEV_ENV_ADDRS }} DEPLOY_SECRET: ${{ secrets.DEPLOY_SECRET }} run: | - IFS=',' read -ra ENDPOINTS <<< "$DEV_ENV_ADDRS" - + IFS=',' read -ra ENDPOINTS <<< "${DEV_ENV_ADDRS:-}" + BODY='{"project":"dify-api","tag":"deploy-enterprise"}' + for ENDPOINT in "${ENDPOINTS[@]}"; do - ENDPOINT=$(echo "$ENDPOINT" | xargs) - - BODY=$(cat < Date: Thu, 11 Sep 2025 13:17:50 +0800 Subject: [PATCH 21/86] chore: support Zendesk widget (#25517) --- web/app/components/base/ga/index.tsx | 8 +-- web/app/components/base/zendesk/index.tsx | 21 +++++++ web/app/components/base/zendesk/utils.ts | 23 ++++++++ web/app/layout.tsx | 8 +++ web/config/index.ts | 71 +++++++++++++---------- web/context/app-context.tsx | 40 +++++++++++++ web/context/provider-context.tsx | 13 +++++ web/types/feature.ts | 6 ++ 8 files changed, 155 insertions(+), 35 deletions(-) create mode 100644 web/app/components/base/zendesk/index.tsx create mode 100644 web/app/components/base/zendesk/utils.ts diff --git a/web/app/components/base/ga/index.tsx b/web/app/components/base/ga/index.tsx index 7a95561754..81d84a85d3 100644 --- a/web/app/components/base/ga/index.tsx +++ b/web/app/components/base/ga/index.tsx @@ -24,7 +24,7 @@ const GA: FC = ({ if (IS_CE_EDITION) return null - const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') : '' + const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') ?? '' : '' return ( <> @@ -32,7 +32,7 @@ const GA: FC = ({ strategy="beforeInteractive" async src={`https://www.googletagmanager.com/gtag/js?id=${gaIdMaps[gaType]}`} - nonce={nonce!} + nonce={nonce ?? undefined} > {/* Cookie banner */} diff --git a/web/app/components/base/zendesk/index.tsx b/web/app/components/base/zendesk/index.tsx new file mode 100644 index 0000000000..b3d67eb390 --- /dev/null +++ b/web/app/components/base/zendesk/index.tsx @@ -0,0 +1,21 @@ +import { memo } from 'react' +import { type UnsafeUnwrappedHeaders, headers } from 'next/headers' +import Script from 'next/script' +import { IS_CE_EDITION, ZENDESK_WIDGET_KEY } from '@/config' + +const Zendesk = () => { + if (IS_CE_EDITION || !ZENDESK_WIDGET_KEY) + return null + + const nonce = process.env.NODE_ENV === 'production' ? (headers() as unknown as UnsafeUnwrappedHeaders).get('x-nonce') ?? '' : '' + + return ( +