From 24b280a0ed6d11a5baa38b3eea4bdf28e807f76b Mon Sep 17 00:00:00 2001 From: lif <1835304752@qq.com> Date: Fri, 30 Jan 2026 20:19:35 +0800 Subject: [PATCH 01/32] fix(i18n): improve Chinese translation of Max Tokens (#31771) Signed-off-by: majiayu000 <1835304752@qq.com> --- api/core/model_runtime/entities/defaults.py | 2 +- web/utils/completion-params.spec.ts | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index 76969fea70..51c9c51257 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -88,7 +88,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.MAX_TOKENS: { "label": { "en_US": "Max Tokens", - "zh_Hans": "最大标记", + "zh_Hans": "最大 Token 数", }, "type": "int", "help": { diff --git a/web/utils/completion-params.spec.ts b/web/utils/completion-params.spec.ts index 0b691a0baa..e56957de8f 100644 --- a/web/utils/completion-params.spec.ts +++ b/web/utils/completion-params.spec.ts @@ -21,7 +21,7 @@ describe('completion-params', () => { it('validates int type parameter within range', () => { const rules: ModelParameterRule[] = [ - { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大标记' }, required: false }, + { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大 Token 数' }, required: false }, ] const oldParams: FormValue = { max_tokens: 100 } const result = mergeValidCompletionParams(oldParams, rules) @@ -32,7 +32,7 @@ describe('completion-params', () => { it('removes int parameter below minimum', () => { const rules: ModelParameterRule[] = [ - { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大标记' }, required: false }, + { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大 Token 数' }, required: false }, ] const oldParams: FormValue = { max_tokens: 0 } const result = mergeValidCompletionParams(oldParams, rules) @@ -43,7 +43,7 @@ describe('completion-params', () => { it('removes int parameter above maximum', () => { const rules: ModelParameterRule[] = [ - { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大标记' }, required: false }, + { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大 Token 数' }, required: false }, ] const oldParams: FormValue = { max_tokens: 5000 } const result = mergeValidCompletionParams(oldParams, rules) @@ -54,7 +54,7 @@ describe('completion-params', () => { it('removes int parameter with invalid type', () => { const rules: ModelParameterRule[] = [ - { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大标记' }, required: false }, + { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大 Token 数' }, required: false }, ] const oldParams: FormValue = { max_tokens: 'not a number' as any } const result = mergeValidCompletionParams(oldParams, rules) @@ -184,7 +184,7 @@ describe('completion-params', () => { it('handles multiple parameters with mixed validity', () => { const rules: ModelParameterRule[] = [ { name: 'temperature', type: 'float', min: 0, max: 2, label: { en_US: 'Temperature', zh_Hans: '温度' }, required: false }, - { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大标记' }, required: false }, + { name: 'max_tokens', type: 'int', min: 1, max: 4096, label: { en_US: 'Max Tokens', zh_Hans: '最大 Token 数' }, required: false }, { name: 'model', type: 'string', options: ['gpt-4'], label: { en_US: 'Model', zh_Hans: '模型' }, required: false }, ] const oldParams: FormValue = { From a4db32244027d8c252204e329e22b7493747bfdb Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Fri, 30 Jan 2026 22:24:49 +0900 Subject: [PATCH 02/32] chore: update restx to 1.3.2 (#31229) --- api/pyproject.toml | 2 +- api/uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 482dd4c8ad..41e532047c 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -87,7 +87,7 @@ dependencies = [ "sseclient-py~=1.8.0", "httpx-sse~=0.4.0", "sendgrid~=6.12.3", - "flask-restx~=1.3.0", + "flask-restx~=1.3.2", "packaging~=23.2", "croniter>=6.0.0", "weaviate-client==4.17.0", diff --git a/api/uv.lock b/api/uv.lock index 7bb43fbb12..0677f5ad98 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1589,7 +1589,7 @@ requires-dist = [ { name = "flask-login", specifier = "~=0.6.3" }, { name = "flask-migrate", specifier = "~=4.0.7" }, { name = "flask-orjson", specifier = "~=2.0.0" }, - { name = "flask-restx", specifier = "~=1.3.0" }, + { name = "flask-restx", specifier = "~=1.3.2" }, { name = "flask-sqlalchemy", specifier = "~=3.1.1" }, { name = "gevent", specifier = "~=25.9.1" }, { name = "gmpy2", specifier = "~=2.2.1" }, From b58d9e030a6fc123845073eae6472763f9f6042e Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Fri, 30 Jan 2026 22:39:02 +0900 Subject: [PATCH 03/32] refactor: init_validate.py to v3 (#31457) --- api/controllers/console/init_validate.py | 93 ++++++++----------- api/extensions/ext_fastopenapi.py | 2 + .../console/test_fastopenapi_init_validate.py | 46 +++++++++ 3 files changed, 88 insertions(+), 53 deletions(-) create mode 100644 api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 2bebe79eac..f086bf1862 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,87 +1,74 @@ import os +from typing import Literal from flask import session -from flask_restx import Resource, fields from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config +from controllers.fastopenapi import console_router from extensions.ext_database import db from models.model import DifySetup from services.account_service import TenantService -from . import console_ns from .error import AlreadySetupError, InitValidateFailedError from .wraps import only_edition_self_hosted -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class InitValidatePayload(BaseModel): - password: str = Field(..., max_length=30) + password: str = Field(..., max_length=30, description="Initialization password") -console_ns.schema_model( - InitValidatePayload.__name__, - InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +class InitStatusResponse(BaseModel): + status: Literal["finished", "not_started"] = Field(..., description="Initialization status") + + +class InitValidateResponse(BaseModel): + result: str = Field(description="Operation result", examples=["success"]) + + +@console_router.get( + "/init", + response_model=InitStatusResponse, + tags=["console"], ) +def get_init_status() -> InitStatusResponse: + """Get initialization validation status.""" + init_status = get_init_validate_status() + if init_status: + return InitStatusResponse(status="finished") + return InitStatusResponse(status="not_started") -@console_ns.route("/init") -class InitValidateAPI(Resource): - @console_ns.doc("get_init_status") - @console_ns.doc(description="Get initialization validation status") - @console_ns.response( - 200, - "Success", - model=console_ns.model( - "InitStatusResponse", - {"status": fields.String(description="Initialization status", enum=["finished", "not_started"])}, - ), - ) - def get(self): - """Get initialization validation status""" - init_status = get_init_validate_status() - if init_status: - return {"status": "finished"} - return {"status": "not_started"} +@console_router.post( + "/init", + response_model=InitValidateResponse, + tags=["console"], + status_code=201, +) +@only_edition_self_hosted +def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse: + """Validate initialization password.""" + tenant_count = TenantService.get_tenant_count() + if tenant_count > 0: + raise AlreadySetupError() - @console_ns.doc("validate_init_password") - @console_ns.doc(description="Validate initialization password for self-hosted edition") - @console_ns.expect(console_ns.models[InitValidatePayload.__name__]) - @console_ns.response( - 201, - "Success", - model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}), - ) - @console_ns.response(400, "Already setup or validation failed") - @only_edition_self_hosted - def post(self): - """Validate initialization password""" - # is tenant created - tenant_count = TenantService.get_tenant_count() - if tenant_count > 0: - raise AlreadySetupError() + if payload.password != os.environ.get("INIT_PASSWORD"): + session["is_init_validated"] = False + raise InitValidateFailedError() - payload = InitValidatePayload.model_validate(console_ns.payload) - input_password = payload.password - - if input_password != os.environ.get("INIT_PASSWORD"): - session["is_init_validated"] = False - raise InitValidateFailedError() - - session["is_init_validated"] = True - return {"result": "success"}, 201 + session["is_init_validated"] = True + return InitValidateResponse(result="success") -def get_init_validate_status(): +def get_init_validate_status() -> bool: if dify_config.EDITION == "SELF_HOSTED": if os.environ.get("INIT_PASSWORD"): if session.get("is_init_validated"): return True with Session(db.engine) as db_session: - return db_session.execute(select(DifySetup)).scalar_one_or_none() + return db_session.execute(select(DifySetup)).scalar_one_or_none() is not None return True diff --git a/api/extensions/ext_fastopenapi.py b/api/extensions/ext_fastopenapi.py index 719456803a..ab4d23a072 100644 --- a/api/extensions/ext_fastopenapi.py +++ b/api/extensions/ext_fastopenapi.py @@ -27,9 +27,11 @@ def init_app(app: DifyApp) -> None: ) # Ensure route decorators are evaluated. + import controllers.console.init_validate as init_validate_module import controllers.console.ping as ping_module from controllers.console import remote_files, setup + _ = init_validate_module _ = ping_module _ = remote_files _ = setup diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py new file mode 100644 index 0000000000..b9bc42fb25 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py @@ -0,0 +1,46 @@ +import builtins +from unittest.mock import patch + +import pytest +from flask import Flask +from flask.views import MethodView + +from extensions import ext_fastopenapi + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + app.secret_key = "test-secret-key" + return app + + +def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch): + ext_fastopenapi.init_app(app) + monkeypatch.delenv("INIT_PASSWORD", raising=False) + + with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"): + client = app.test_client() + response = client.get("/console/api/init") + + assert response.status_code == 200 + assert response.get_json() == {"status": "finished"} + + +def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch): + ext_fastopenapi.init_app(app) + monkeypatch.setenv("INIT_PASSWORD", "test-init-password") + + with ( + patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"), + patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0), + ): + client = app.test_client() + response = client.post("/console/api/init", json={"password": "test-init-password"}) + + assert response.status_code == 201 + assert response.get_json() == {"result": "success"} From a433d5ed36d803636a187908905621c88079615f Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Fri, 30 Jan 2026 22:40:14 +0900 Subject: [PATCH 04/32] refactor: port api/controllers/console/tag/tags.py to ov3 (#31767) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/controllers/console/tag/tags.py | 211 +++++++++-------- api/services/tag_service.py | 2 +- .../console/test_fastopenapi_tags.py | 222 ++++++++++++++++++ 3 files changed, 334 insertions(+), 101 deletions(-) create mode 100644 api/tests/unit_tests/controllers/console/test_fastopenapi_tags.py diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 9988524a80..e828d54ff4 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,14 +1,11 @@ from typing import Literal +from uuid import UUID -from flask import request -from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden -from controllers.common.schema import register_schema_models -from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required -from fields.tag_fields import dataset_tag_fields +from controllers.fastopenapi import console_router from libs.login import current_account_with_tenant, login_required from services.tag_service import TagService @@ -35,115 +32,129 @@ class TagListQueryParam(BaseModel): keyword: str | None = Field(None, description="Search keyword") -register_schema_models( - console_ns, - TagBasePayload, - TagBindingPayload, - TagBindingRemovePayload, - TagListQueryParam, +class TagResponse(BaseModel): + id: str = Field(description="Tag ID") + name: str = Field(description="Tag name") + type: str = Field(description="Tag type") + binding_count: int = Field(description="Number of bindings") + + +class TagBindingResult(BaseModel): + result: Literal["success"] = Field(description="Operation result", examples=["success"]) + + +@console_router.get( + "/tags", + response_model=list[TagResponse], + tags=["console"], ) +@setup_required +@login_required +@account_initialization_required +def list_tags(query: TagListQueryParam) -> list[TagResponse]: + _, current_tenant_id = current_account_with_tenant() + tags = TagService.get_tags(query.type, current_tenant_id, query.keyword) + + return [ + TagResponse( + id=tag.id, + name=tag.name, + type=tag.type, + binding_count=int(tag.binding_count), + ) + for tag in tags + ] -@console_ns.route("/tags") -class TagListApi(Resource): - @setup_required - @login_required - @account_initialization_required - @console_ns.doc( - params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."} - ) - @marshal_with(dataset_tag_fields) - def get(self): - _, current_tenant_id = current_account_with_tenant() - raw_args = request.args.to_dict() - param = TagListQueryParam.model_validate(raw_args) - tags = TagService.get_tags(param.type, current_tenant_id, param.keyword) +@console_router.post( + "/tags", + response_model=TagResponse, + tags=["console"], +) +@setup_required +@login_required +@account_initialization_required +def create_tag(payload: TagBasePayload) -> TagResponse: + current_user, _ = current_account_with_tenant() + # The role of the current user in the tag table must be admin, owner, or editor + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() - return tags, 200 + tag = TagService.save_tags(payload.model_dump()) - @console_ns.expect(console_ns.models[TagBasePayload.__name__]) - @setup_required - @login_required - @account_initialization_required - def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, or editor - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() - - payload = TagBasePayload.model_validate(console_ns.payload or {}) - tag = TagService.save_tags(payload.model_dump()) - - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} - - return response, 200 + return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=0) -@console_ns.route("/tags/") -class TagUpdateDeleteApi(Resource): - @console_ns.expect(console_ns.models[TagBasePayload.__name__]) - @setup_required - @login_required - @account_initialization_required - def patch(self, tag_id): - current_user, _ = current_account_with_tenant() - tag_id = str(tag_id) - # The role of the current user in the ta table must be admin, owner, or editor - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() +@console_router.patch( + "/tags/", + response_model=TagResponse, + tags=["console"], +) +@setup_required +@login_required +@account_initialization_required +def update_tag(tag_id: UUID, payload: TagBasePayload) -> TagResponse: + current_user, _ = current_account_with_tenant() + tag_id_str = str(tag_id) + # The role of the current user in the ta table must be admin, owner, or editor + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() - payload = TagBasePayload.model_validate(console_ns.payload or {}) - tag = TagService.update_tags(payload.model_dump(), tag_id) + tag = TagService.update_tags(payload.model_dump(), tag_id_str) - binding_count = TagService.get_tag_binding_count(tag_id) + binding_count = TagService.get_tag_binding_count(tag_id_str) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} - - return response, 200 - - @setup_required - @login_required - @account_initialization_required - @edit_permission_required - def delete(self, tag_id): - tag_id = str(tag_id) - - TagService.delete_tag(tag_id) - - return 204 + return TagResponse(id=tag.id, name=tag.name, type=tag.type, binding_count=binding_count) -@console_ns.route("/tag-bindings/create") -class TagBindingCreateApi(Resource): - @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) - @setup_required - @login_required - @account_initialization_required - def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() +@console_router.delete( + "/tags/", + tags=["console"], + status_code=204, +) +@setup_required +@login_required +@account_initialization_required +@edit_permission_required +def delete_tag(tag_id: UUID) -> None: + tag_id_str = str(tag_id) - payload = TagBindingPayload.model_validate(console_ns.payload or {}) - TagService.save_tag_binding(payload.model_dump()) - - return {"result": "success"}, 200 + TagService.delete_tag(tag_id_str) -@console_ns.route("/tag-bindings/remove") -class TagBindingDeleteApi(Resource): - @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) - @setup_required - @login_required - @account_initialization_required - def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() +@console_router.post( + "/tag-bindings/create", + response_model=TagBindingResult, + tags=["console"], +) +@setup_required +@login_required +@account_initialization_required +def create_tag_binding(payload: TagBindingPayload) -> TagBindingResult: + current_user, _ = current_account_with_tenant() + # The role of the current user in the tag table must be admin, owner, editor, or dataset_operator + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() - payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) - TagService.delete_tag_binding(payload.model_dump()) + TagService.save_tag_binding(payload.model_dump()) - return {"result": "success"}, 200 + return TagBindingResult(result="success") + + +@console_router.post( + "/tag-bindings/remove", + response_model=TagBindingResult, + tags=["console"], +) +@setup_required +@login_required +@account_initialization_required +def delete_tag_binding(payload: TagBindingRemovePayload) -> TagBindingResult: + current_user, _ = current_account_with_tenant() + # The role of the current user in the tag table must be admin, owner, editor, or dataset_operator + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() + + TagService.delete_tag_binding(payload.model_dump()) + + return TagBindingResult(result="success") diff --git a/api/services/tag_service.py b/api/services/tag_service.py index bd3585acf4..56f4ae9494 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -24,7 +24,7 @@ class TagService: escaped_keyword = escape_like_pattern(keyword) query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\"))) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) - results: list = query.order_by(Tag.created_at.desc()).all() + results = query.order_by(Tag.created_at.desc()).all() return results @staticmethod diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_tags.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_tags.py new file mode 100644 index 0000000000..62d143f32d --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_fastopenapi_tags.py @@ -0,0 +1,222 @@ +import builtins +import contextlib +import importlib +import sys +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from flask.views import MethodView + +from extensions import ext_fastopenapi +from extensions.ext_database import db + + +@pytest.fixture +def app(): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["SECRET_KEY"] = "test-secret" + app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" + + db.init_app(app) + + return app + + +@pytest.fixture(autouse=True) +def fix_method_view_issue(monkeypatch): + if not hasattr(builtins, "MethodView"): + monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False) + + +def _create_isolated_router(): + import controllers.fastopenapi + + router_class = type(controllers.fastopenapi.console_router) + return router_class() + + +@contextlib.contextmanager +def _patch_auth_and_router(temp_router): + def noop(func): + return func + + default_user = MagicMock(has_edit_permission=True, is_dataset_editor=False) + + with ( + patch("controllers.fastopenapi.console_router", temp_router), + patch("extensions.ext_fastopenapi.console_router", temp_router), + patch("controllers.console.wraps.setup_required", side_effect=noop), + patch("libs.login.login_required", side_effect=noop), + patch("controllers.console.wraps.account_initialization_required", side_effect=noop), + patch("controllers.console.wraps.edit_permission_required", side_effect=noop), + patch("libs.login.current_account_with_tenant", return_value=(default_user, "tenant-id")), + patch("configs.dify_config.EDITION", "CLOUD"), + ): + import extensions.ext_fastopenapi + + importlib.reload(extensions.ext_fastopenapi) + + yield + + +def _force_reload_module(target_module: str, alias_module: str): + if target_module in sys.modules: + del sys.modules[target_module] + if alias_module in sys.modules: + del sys.modules[alias_module] + + module = importlib.import_module(target_module) + sys.modules[alias_module] = sys.modules[target_module] + + return module + + +def _dedupe_routes(router): + seen = set() + unique_routes = [] + for path, method, endpoint in reversed(router.get_routes()): + key = (path, method, endpoint.__name__) + if key in seen: + continue + seen.add(key) + unique_routes.append((path, method, endpoint)) + router._routes = list(reversed(unique_routes)) + + +def _cleanup_modules(target_module: str, alias_module: str): + if target_module in sys.modules: + del sys.modules[target_module] + if alias_module in sys.modules: + del sys.modules[alias_module] + + +@pytest.fixture +def mock_tags_module_env(): + target_module = "controllers.console.tag.tags" + alias_module = "api.controllers.console.tag.tags" + temp_router = _create_isolated_router() + + try: + with _patch_auth_and_router(temp_router): + tags_module = _force_reload_module(target_module, alias_module) + _dedupe_routes(temp_router) + yield tags_module + finally: + _cleanup_modules(target_module, alias_module) + + +def test_list_tags_success(app: Flask, mock_tags_module_env): + # Arrange + tag = SimpleNamespace(id="tag-1", name="Alpha", type="app", binding_count=2) + with patch("controllers.console.tag.tags.TagService.get_tags", return_value=[tag]): + ext_fastopenapi.init_app(app) + client = app.test_client() + + # Act + response = client.get("/console/api/tags?type=app&keyword=Alpha") + + # Assert + assert response.status_code == 200 + assert response.get_json() == [ + {"id": "tag-1", "name": "Alpha", "type": "app", "binding_count": 2}, + ] + + +def test_create_tag_success(app: Flask, mock_tags_module_env): + # Arrange + tag = SimpleNamespace(id="tag-2", name="Beta", type="app") + with patch("controllers.console.tag.tags.TagService.save_tags", return_value=tag) as mock_save: + ext_fastopenapi.init_app(app) + client = app.test_client() + + # Act + response = client.post("/console/api/tags", json={"name": "Beta", "type": "app"}) + + # Assert + assert response.status_code == 200 + assert response.get_json() == { + "id": "tag-2", + "name": "Beta", + "type": "app", + "binding_count": 0, + } + mock_save.assert_called_once_with({"name": "Beta", "type": "app"}) + + +def test_update_tag_success(app: Flask, mock_tags_module_env): + # Arrange + tag = SimpleNamespace(id="tag-3", name="Gamma", type="app") + with ( + patch("controllers.console.tag.tags.TagService.update_tags", return_value=tag) as mock_update, + patch("controllers.console.tag.tags.TagService.get_tag_binding_count", return_value=4), + ): + ext_fastopenapi.init_app(app) + client = app.test_client() + + # Act + response = client.patch( + "/console/api/tags/11111111-1111-1111-1111-111111111111", + json={"name": "Gamma", "type": "app"}, + ) + + # Assert + assert response.status_code == 200 + assert response.get_json() == { + "id": "tag-3", + "name": "Gamma", + "type": "app", + "binding_count": 4, + } + mock_update.assert_called_once_with( + {"name": "Gamma", "type": "app"}, + "11111111-1111-1111-1111-111111111111", + ) + + +def test_delete_tag_success(app: Flask, mock_tags_module_env): + # Arrange + with patch("controllers.console.tag.tags.TagService.delete_tag") as mock_delete: + ext_fastopenapi.init_app(app) + client = app.test_client() + + # Act + response = client.delete("/console/api/tags/11111111-1111-1111-1111-111111111111") + + # Assert + assert response.status_code == 204 + mock_delete.assert_called_once_with("11111111-1111-1111-1111-111111111111") + + +def test_create_tag_binding_success(app: Flask, mock_tags_module_env): + # Arrange + payload = {"tag_ids": ["tag-1", "tag-2"], "target_id": "target-1", "type": "app"} + with patch("controllers.console.tag.tags.TagService.save_tag_binding") as mock_bind: + ext_fastopenapi.init_app(app) + client = app.test_client() + + # Act + response = client.post("/console/api/tag-bindings/create", json=payload) + + # Assert + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + mock_bind.assert_called_once_with(payload) + + +def test_delete_tag_binding_success(app: Flask, mock_tags_module_env): + # Arrange + payload = {"tag_id": "tag-1", "target_id": "target-1", "type": "app"} + with patch("controllers.console.tag.tags.TagService.delete_tag_binding") as mock_unbind: + ext_fastopenapi.init_app(app) + client = app.test_client() + + # Act + response = client.post("/console/api/tag-bindings/remove", json=payload) + + # Assert + assert response.status_code == 200 + assert response.get_json() == {"result": "success"} + mock_unbind.assert_called_once_with(payload) From 5bc99995fcd8123382c3ee04d4bf07cc3bba1cc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Sat, 31 Jan 2026 00:57:36 +0800 Subject: [PATCH 05/32] fix(api): align graph protocols for response streaming (#31777) --- .../response_coordinator/session.py | 17 +++++++---- .../workflow/runtime/graph_runtime_state.py | 30 +++++++++++++++---- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/api/core/workflow/graph_engine/response_coordinator/session.py b/api/core/workflow/graph_engine/response_coordinator/session.py index 8ceaa428c3..5e4fada7d9 100644 --- a/api/core/workflow/graph_engine/response_coordinator/session.py +++ b/api/core/workflow/graph_engine/response_coordinator/session.py @@ -10,10 +10,10 @@ from __future__ import annotations from dataclasses import dataclass from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.template import Template from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.knowledge_index import KnowledgeIndexNode +from core.workflow.runtime.graph_runtime_state import NodeProtocol @dataclass @@ -29,21 +29,26 @@ class ResponseSession: index: int = 0 # Current position in the template segments @classmethod - def from_node(cls, node: Node) -> ResponseSession: + def from_node(cls, node: NodeProtocol) -> ResponseSession: """ - Create a ResponseSession from an AnswerNode or EndNode. + Create a ResponseSession from a response-capable node. + + The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer, + but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides: + - `id: str` + - `get_streaming_template() -> Template` Args: - node: Must be either an AnswerNode or EndNode instance + node: Node from the materialized workflow graph. Returns: ResponseSession configured with the node's streaming template Raises: - TypeError: If node is not an AnswerNode or EndNode + TypeError: If node is not a supported response node type. """ if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode): - raise TypeError + raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode") return cls( node_id=node.id, template=node.get_streaming_template(), diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/core/workflow/runtime/graph_runtime_state.py index 401cecc162..acf0ee6839 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/core/workflow/runtime/graph_runtime_state.py @@ -6,12 +6,13 @@ import threading from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from typing import Any, Protocol +from typing import Any, ClassVar, Protocol from pydantic.json import pydantic_encoder from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.pause_reason import PauseReason +from core.workflow.enums import NodeExecutionType, NodeState, NodeType from core.workflow.runtime.variable_pool import VariablePool @@ -103,14 +104,33 @@ class ResponseStreamCoordinatorProtocol(Protocol): ... +class NodeProtocol(Protocol): + """Structural interface for graph nodes.""" + + id: str + state: NodeState + execution_type: NodeExecutionType + node_type: ClassVar[NodeType] + + def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ... + + +class EdgeProtocol(Protocol): + id: str + state: NodeState + tail: str + head: str + source_handle: str + + class GraphProtocol(Protocol): """Structural interface required from graph instances attached to the runtime state.""" - nodes: Mapping[str, object] - edges: Mapping[str, object] - root_node: object + nodes: Mapping[str, NodeProtocol] + edges: Mapping[str, EdgeProtocol] + root_node: NodeProtocol - def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ... + def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... @dataclass(slots=True) From b8cb5f5ea250ca6d14c46b80d497d6aa30bacab2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Sat, 31 Jan 2026 17:00:56 +0800 Subject: [PATCH 06/32] refactor(typing): Fixup typing A2 - workflow engine & nodes (#31723) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato --- api/core/app/workflow/node_factory.py | 35 +++++++---------- api/core/file/file_manager.py | 15 +++++++ .../code_executor/code_node_provider.py | 23 +++++------ api/core/helper/ssrf_proxy.py | 38 ++++++++++++++++++ api/core/schemas/registry.py | 1 + api/core/tools/tool_manager.py | 38 ++++++++---------- api/core/workflow/entities/graph_config.py | 24 ++++++++++++ api/core/workflow/graph/graph.py | 31 +++++++-------- .../workflow/graph_engine/graph_engine.py | 3 +- .../response_coordinator/coordinator.py | 4 +- api/core/workflow/nodes/base/entities.py | 2 +- api/core/workflow/nodes/code/entities.py | 4 +- .../nodes/datasource/datasource_node.py | 4 +- .../workflow/nodes/http_request/executor.py | 21 +++++----- api/core/workflow/nodes/http_request/node.py | 13 ++++--- .../nodes/iteration/iteration_node.py | 2 +- api/core/workflow/nodes/list_operator/node.py | 9 ++--- api/core/workflow/nodes/llm/node.py | 23 +++++------ api/core/workflow/nodes/protocols.py | 14 +++---- api/core/workflow/workflow_entry.py | 6 +-- api/models/workflow.py | 6 +-- api/pyproject.toml | 2 +- .../services/test_webhook_service.py | 1 + api/ty.toml | 12 ++++-- api/uv.lock | 39 +++++++++---------- 25 files changed, 217 insertions(+), 153 deletions(-) create mode 100644 api/core/workflow/entities/graph_config.py diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py index e0a0059a38..a5773bbef8 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -4,13 +4,14 @@ from typing import TYPE_CHECKING, final from typing_extensions import override from configs import dify_config -from core.file import file_manager -from core.helper import ssrf_proxy +from core.file.file_manager import file_manager from core.helper.code_executor.code_executor import CodeExecutor from core.helper.code_executor.code_node_provider import CodeNodeProvider +from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager +from core.workflow.entities.graph_config import NodeConfigDict from core.workflow.enums import NodeType -from core.workflow.graph import NodeFactory +from core.workflow.graph.graph import NodeFactory from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.limits import CodeNodeLimits @@ -22,7 +23,6 @@ from core.workflow.nodes.template_transform.template_renderer import ( Jinja2TemplateRenderer, ) from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from libs.typing import is_str, is_str_dict if TYPE_CHECKING: from core.workflow.entities import GraphInitParams @@ -47,9 +47,9 @@ class DifyNodeFactory(NodeFactory): code_providers: Sequence[type[CodeNodeProvider]] | None = None, code_limits: CodeNodeLimits | None = None, template_renderer: Jinja2TemplateRenderer | None = None, - http_request_http_client: HttpClientProtocol = ssrf_proxy, + http_request_http_client: HttpClientProtocol | None = None, http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, - http_request_file_manager: FileManagerProtocol = file_manager, + http_request_file_manager: FileManagerProtocol | None = None, ) -> None: self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state @@ -68,12 +68,12 @@ class DifyNodeFactory(NodeFactory): max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, ) self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer() - self._http_request_http_client = http_request_http_client + self._http_request_http_client = http_request_http_client or ssrf_proxy self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory - self._http_request_file_manager = http_request_file_manager + self._http_request_file_manager = http_request_file_manager or file_manager @override - def create_node(self, node_config: dict[str, object]) -> Node: + def create_node(self, node_config: NodeConfigDict) -> Node: """ Create a Node instance from node configuration data using the traditional mapping. @@ -82,23 +82,14 @@ class DifyNodeFactory(NodeFactory): :raises ValueError: if node type is unknown or configuration is invalid """ # Get node_id from config - node_id = node_config.get("id") - if not is_str(node_id): - raise ValueError("Node config missing id") + node_id = node_config["id"] # Get node type from config - node_data = node_config.get("data", {}) - if not is_str_dict(node_data): - raise ValueError(f"Node {node_id} missing data information") - - node_type_str = node_data.get("type") - if not is_str(node_type_str): - raise ValueError(f"Node {node_id} missing or invalid type information") - + node_data = node_config["data"] try: - node_type = NodeType(node_type_str) + node_type = NodeType(node_data["type"]) except ValueError: - raise ValueError(f"Unknown node type: {node_type_str}") + raise ValueError(f"Unknown node type: {node_data['type']}") # Get node class node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index c0fefef3d0..9945d7c1ab 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -168,3 +168,18 @@ def _to_url(f: File, /): return sign_tool_file(tool_file_id=f.related_id, extension=f.extension) else: raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + + +class FileManager: + """ + Adapter exposing file manager helpers behind FileManagerProtocol. + + This is intentionally a thin wrapper over the existing module-level functions so callers can inject it + where a protocol-typed file manager is expected. + """ + + def download(self, f: File, /) -> bytes: + return download(f) + + +file_manager = FileManager() diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py index e93e1e4414..f4cce0b332 100644 --- a/api/core/helper/code_executor/code_node_provider.py +++ b/api/core/helper/code_executor/code_node_provider.py @@ -47,15 +47,16 @@ class CodeNodeProvider(BaseModel, ABC): @classmethod def get_default_config(cls) -> DefaultConfig: - return { - "type": "code", - "config": { - "variables": [ - {"variable": "arg1", "value_selector": []}, - {"variable": "arg2", "value_selector": []}, - ], - "code_language": cls.get_language(), - "code": cls.get_default_code(), - "outputs": {"result": {"type": "string", "children": None}}, - }, + variables: list[VariableConfig] = [ + {"variable": "arg1", "value_selector": []}, + {"variable": "arg2", "value_selector": []}, + ] + outputs: dict[str, OutputConfig] = {"result": {"type": "string", "children": None}} + + config: CodeConfig = { + "variables": variables, + "code_language": cls.get_language(), + "code": cls.get_default_code(), + "outputs": outputs, } + return {"type": "code", "config": config} diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index ddccfbaf45..54068fc28d 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -230,3 +230,41 @@ def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: return make_request("HEAD", url, max_retries=max_retries, **kwargs) + + +class SSRFProxy: + """ + Adapter exposing SSRF-protected HTTP helpers behind HttpClientProtocol. + + This is intentionally a thin wrapper over the existing module-level functions so callers can inject it + where a protocol-typed HTTP client is expected. + """ + + @property + def max_retries_exceeded_error(self) -> type[Exception]: + return max_retries_exceeded_error + + @property + def request_error(self) -> type[Exception]: + return request_error + + def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return get(url=url, max_retries=max_retries, **kwargs) + + def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return head(url=url, max_retries=max_retries, **kwargs) + + def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return post(url=url, max_retries=max_retries, **kwargs) + + def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return put(url=url, max_retries=max_retries, **kwargs) + + def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return delete(url=url, max_retries=max_retries, **kwargs) + + def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response: + return patch(url=url, max_retries=max_retries, **kwargs) + + +ssrf_proxy = SSRFProxy() diff --git a/api/core/schemas/registry.py b/api/core/schemas/registry.py index b4ecfe47ff..b87fba4eaa 100644 --- a/api/core/schemas/registry.py +++ b/api/core/schemas/registry.py @@ -35,6 +35,7 @@ class SchemaRegistry: registry.load_all_versions() cls._default_instance = registry + return cls._default_instance return cls._default_instance diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index f8213d9fd7..d561d39923 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -189,16 +189,13 @@ class ToolManager: raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") if not provider_controller.need_credentials: - return cast( - BuiltinTool, - builtin_tool.fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), + return builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) ) builtin_provider = None if isinstance(provider_controller, PluginToolProviderController): @@ -300,18 +297,15 @@ class ToolManager: decrypted_credentials = refreshed_credentials.credentials cache.delete() - return cast( - BuiltinTool, - builtin_tool.fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials=dict(decrypted_credentials), - credential_type=CredentialType.of(builtin_provider.credential_type), - runtime_parameters={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), + return builtin_tool.fork_tool_runtime( + runtime=ToolRuntime( + tenant_id=tenant_id, + credentials=dict(decrypted_credentials), + credential_type=CredentialType.of(builtin_provider.credential_type), + runtime_parameters={}, + invoke_from=invoke_from, + tool_invoke_from=tool_invoke_from, + ) ) elif provider_type == ToolProviderType.API: diff --git a/api/core/workflow/entities/graph_config.py b/api/core/workflow/entities/graph_config.py new file mode 100644 index 0000000000..209dcfe6bc --- /dev/null +++ b/api/core/workflow/entities/graph_config.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import sys + +from pydantic import TypeAdapter, with_config + +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict + + +@with_config(extra="allow") +class NodeConfigData(TypedDict): + type: str + + +@with_config(extra="allow") +class NodeConfigDict(TypedDict): + id: str + data: NodeConfigData + + +NodeConfigDictAdapter = TypeAdapter(NodeConfigDict) diff --git a/api/core/workflow/graph/graph.py b/api/core/workflow/graph/graph.py index 31bf6f3b27..52bbbb20cc 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/core/workflow/graph/graph.py @@ -5,15 +5,20 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Protocol, cast, final +from pydantic import TypeAdapter + +from core.workflow.entities.graph_config import NodeConfigDict from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType from core.workflow.nodes.base.node import Node -from libs.typing import is_str, is_str_dict +from libs.typing import is_str from .edge import Edge from .validation import get_graph_validator logger = logging.getLogger(__name__) +_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict]) + class NodeFactory(Protocol): """ @@ -23,7 +28,7 @@ class NodeFactory(Protocol): allowing for different node creation strategies while maintaining type safety. """ - def create_node(self, node_config: dict[str, object]) -> Node: + def create_node(self, node_config: NodeConfigDict) -> Node: """ Create a Node instance from node configuration data. @@ -63,28 +68,24 @@ class Graph: self.root_node = root_node @classmethod - def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]: + def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]: """ Parse node configurations and build a mapping of node IDs to configs. :param node_configs: list of node configuration dictionaries :return: mapping of node ID to node config """ - node_configs_map: dict[str, dict[str, object]] = {} + node_configs_map: dict[str, NodeConfigDict] = {} for node_config in node_configs: - node_id = node_config.get("id") - if not node_id or not isinstance(node_id, str): - continue - - node_configs_map[node_id] = node_config + node_configs_map[node_config["id"]] = node_config return node_configs_map @classmethod def _find_root_node_id( cls, - node_configs_map: Mapping[str, Mapping[str, object]], + node_configs_map: Mapping[str, NodeConfigDict], edge_configs: Sequence[Mapping[str, object]], root_node_id: str | None = None, ) -> str: @@ -113,10 +114,8 @@ class Graph: # Prefer START node if available start_node_id = None for nid in root_candidates: - node_data = node_configs_map[nid].get("data") - if not is_str_dict(node_data): - continue - node_type = node_data.get("type") + node_data = node_configs_map[nid]["data"] + node_type = node_data["type"] if not isinstance(node_type, str): continue if NodeType(node_type).is_start_node: @@ -176,7 +175,7 @@ class Graph: @classmethod def _create_node_instances( cls, - node_configs_map: dict[str, dict[str, object]], + node_configs_map: dict[str, NodeConfigDict], node_factory: NodeFactory, ) -> dict[str, Node]: """ @@ -303,7 +302,7 @@ class Graph: node_configs = graph_config.get("nodes", []) edge_configs = cast(list[dict[str, object]], edge_configs) - node_configs = cast(list[dict[str, object]], node_configs) + node_configs = _ListNodeConfigDict.validate_python(node_configs) if not node_configs: raise ValueError("Graph must have at least one node") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 0b359a2392..2b76b563ff 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -46,7 +46,6 @@ from .graph_traversal import EdgeProcessor, SkipPropagator from .layers.base import GraphEngineLayer from .orchestration import Dispatcher, ExecutionCoordinator from .protocols.command_channel import CommandChannel -from .ready_queue import ReadyQueue from .worker_management import WorkerPool if TYPE_CHECKING: @@ -90,7 +89,7 @@ class GraphEngine: self._graph_execution.workflow_id = workflow_id # === Execution Queues === - self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue) + self._ready_queue = self._graph_runtime_state.ready_queue # Queue for events generated during execution self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/core/workflow/graph_engine/response_coordinator/coordinator.py index 98e0ea91ef..e82ba29438 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/core/workflow/graph_engine/response_coordinator/coordinator.py @@ -15,10 +15,10 @@ from uuid import uuid4 from pydantic import BaseModel, Field from core.workflow.enums import NodeExecutionType, NodeState -from core.workflow.graph import Graph from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent from core.workflow.nodes.base.template import TextSegment, VariableSegment from core.workflow.runtime import VariablePool +from core.workflow.runtime.graph_runtime_state import GraphProtocol from .path import Path from .session import ResponseSession @@ -75,7 +75,7 @@ class ResponseStreamCoordinator: Ensures ordered streaming of responses based on upstream node outputs and constants. """ - def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None: + def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None: """ Initialize coordinator with variable pool. diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index e5a20c8e91..c5426e3fb7 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -115,7 +115,7 @@ class DefaultValue(BaseModel): @model_validator(mode="after") def validate_value_type(self) -> DefaultValue: # Type validation configuration - type_validators = { + type_validators: dict[DefaultValueType, dict[str, Any]] = { DefaultValueType.STRING: { "type": str, "converter": lambda x: x, diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 10a1c897e9..8026011196 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -1,4 +1,4 @@ -from typing import Annotated, Literal, Self +from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel @@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData): class Output(BaseModel): type: Annotated[SegmentType, AfterValidator(_validate_type)] - children: dict[str, Self] | None = None + children: dict[str, "CodeNodeData.Output"] | None = None class Dependency(BaseModel): name: str diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 925561cf7c..fd71d610b4 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -69,11 +69,13 @@ class DatasourceNode(Node[DatasourceNodeData]): if datasource_type is None: raise DatasourceNodeError("Datasource type is not set") + datasource_type = DatasourceProviderType.value_of(datasource_type) + datasource_runtime = DatasourceManager.get_datasource_runtime( provider_id=f"{node_data.plugin_id}/{node_data.provider_name}", datasource_name=node_data.datasource_name or "", tenant_id=self.tenant_id, - datasource_type=DatasourceProviderType.value_of(datasource_type), + datasource_type=datasource_type, ) datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id) diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 429f8411a6..7de8216562 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -2,7 +2,7 @@ import base64 import json import secrets import string -from collections.abc import Mapping +from collections.abc import Callable, Mapping from copy import deepcopy from typing import Any, Literal from urllib.parse import urlencode, urlparse @@ -11,9 +11,9 @@ import httpx from json_repair import repair_json from configs import dify_config -from core.file import file_manager from core.file.enums import FileTransferMethod -from core.helper import ssrf_proxy +from core.file.file_manager import file_manager as default_file_manager +from core.helper.ssrf_proxy import ssrf_proxy from core.variables.segments import ArrayFileSegment, FileSegment from core.workflow.runtime import VariablePool @@ -79,8 +79,8 @@ class Executor: timeout: HttpRequestNodeTimeout, variable_pool: VariablePool, max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES, - http_client: HttpClientProtocol = ssrf_proxy, - file_manager: FileManagerProtocol = file_manager, + http_client: HttpClientProtocol | None = None, + file_manager: FileManagerProtocol | None = None, ): # If authorization API key is present, convert the API key using the variable pool if node_data.authorization.type == "api-key": @@ -107,8 +107,8 @@ class Executor: self.data = None self.json = None self.max_retries = max_retries - self._http_client = http_client - self._file_manager = file_manager + self._http_client = http_client or ssrf_proxy + self._file_manager = file_manager or default_file_manager # init template self.variable_pool = variable_pool @@ -336,7 +336,7 @@ class Executor: """ do http request depending on api bundle """ - _METHOD_MAP = { + _METHOD_MAP: dict[str, Callable[..., httpx.Response]] = { "get": self._http_client.get, "head": self._http_client.head, "post": self._http_client.post, @@ -348,7 +348,7 @@ class Executor: if method_lc not in _METHOD_MAP: raise InvalidHttpMethodError(f"Invalid http method {self.method}") - request_args = { + request_args: dict[str, Any] = { "data": self.data, "files": self.files, "json": self.json, @@ -361,14 +361,13 @@ class Executor: } # request_args = {k: v for k, v in request_args.items() if v is not None} try: - response: httpx.Response = _METHOD_MAP[method_lc]( + response = _METHOD_MAP[method_lc]( url=self.url, **request_args, max_retries=self.max_retries, ) except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e: raise HttpRequestNodeError(str(e)) from e - # FIXME: fix type ignore, this maybe httpx type issue return response def invoke(self) -> Response: diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 964e53e03c..480482375f 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -4,8 +4,9 @@ from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any from configs import dify_config -from core.file import File, FileTransferMethod, file_manager -from core.helper import ssrf_proxy +from core.file import File, FileTransferMethod +from core.file.file_manager import file_manager as default_file_manager +from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.variables.segments import ArrayFileSegment from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus @@ -47,9 +48,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - http_client: HttpClientProtocol = ssrf_proxy, + http_client: HttpClientProtocol | None = None, tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, - file_manager: FileManagerProtocol = file_manager, + file_manager: FileManagerProtocol | None = None, ) -> None: super().__init__( id=id, @@ -57,9 +58,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._http_client = http_client + self._http_client = http_client or ssrf_proxy self._tool_file_manager_factory = tool_file_manager_factory - self._file_manager = file_manager + self._file_manager = file_manager or default_file_manager @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index c19182549f..25a881ea7d 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -397,7 +397,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return outputs # Check if all non-None outputs are lists - non_none_outputs = [output for output in outputs if output is not None] + non_none_outputs: list[object] = [output for output in outputs if output is not None] if not non_none_outputs: return outputs diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 813d898b9a..235f5b9c52 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -196,13 +196,13 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: case "name": return lambda x: x.filename or "" case "type": - return lambda x: x.type + return lambda x: str(x.type) case "extension": return lambda x: x.extension or "" case "mime_type": return lambda x: x.mime_type or "" case "transfer_method": - return lambda x: x.transfer_method + return lambda x: str(x.transfer_method) case "url": return lambda x: x.remote_url or "" case "related_id": @@ -276,7 +276,6 @@ def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Calla def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: - extract_func: Callable[[File], Any] if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str): extract_func = _get_file_extract_string_func(key=key) return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) @@ -284,8 +283,8 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str extract_func = _get_file_extract_string_func(key=key) return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x)) elif key == "size" and isinstance(value, str): - extract_func = _get_file_extract_number_func(key=key) - return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x)) + extract_number = _get_file_extract_number_func(key=key) + return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x)) else: raise InvalidKeyError(f"Invalid key: {key}") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 17d82c2118..beccf79344 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -852,18 +852,16 @@ class LLMNode(Node[LLMNodeData]): # Insert histories into the prompt prompt_content = prompt_messages[0].content # For issue #11247 - Check if prompt content is a string or a list - prompt_content_type = type(prompt_content) - if prompt_content_type == str: + if isinstance(prompt_content, str): prompt_content = str(prompt_content) if "#histories#" in prompt_content: prompt_content = prompt_content.replace("#histories#", memory_text) else: prompt_content = memory_text + "\n" + prompt_content prompt_messages[0].content = prompt_content - elif prompt_content_type == list: - prompt_content = prompt_content if isinstance(prompt_content, list) else [] + elif isinstance(prompt_content, list): for content_item in prompt_content: - if content_item.type == PromptMessageContentType.TEXT: + if isinstance(content_item, TextPromptMessageContent): if "#histories#" in content_item.data: content_item.data = content_item.data.replace("#histories#", memory_text) else: @@ -873,13 +871,12 @@ class LLMNode(Node[LLMNodeData]): # Add current query to the prompt message if sys_query: - if prompt_content_type == str: + if isinstance(prompt_content, str): prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) prompt_messages[0].content = prompt_content - elif prompt_content_type == list: - prompt_content = prompt_content if isinstance(prompt_content, list) else [] + elif isinstance(prompt_content, list): for content_item in prompt_content: - if content_item.type == PromptMessageContentType.TEXT: + if isinstance(content_item, TextPromptMessageContent): content_item.data = sys_query + "\n" + content_item.data else: raise ValueError("Invalid prompt content type") @@ -1033,14 +1030,14 @@ class LLMNode(Node[LLMNodeData]): if typed_node_data.prompt_config: enable_jinja = False - if isinstance(prompt_template, list): + if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + if prompt_template.edition_type == "jinja2": + enable_jinja = True + else: for prompt in prompt_template: if prompt.edition_type == "jinja2": enable_jinja = True break - else: - if prompt_template.edition_type == "jinja2": - enable_jinja = True if enable_jinja: for variable_selector in typed_node_data.prompt_config.jinja2_variables or []: diff --git a/api/core/workflow/nodes/protocols.py b/api/core/workflow/nodes/protocols.py index e7dcf62fcf..2ad39e0ab5 100644 --- a/api/core/workflow/nodes/protocols.py +++ b/api/core/workflow/nodes/protocols.py @@ -1,4 +1,4 @@ -from typing import Protocol +from typing import Any, Protocol import httpx @@ -12,17 +12,17 @@ class HttpClientProtocol(Protocol): @property def request_error(self) -> type[Exception]: ... - def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ... + def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... class FileManagerProtocol(Protocol): diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 43f15f6fd0..4b1845cda2 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -144,11 +144,11 @@ class WorkflowEntry: :param user_inputs: user inputs :return: """ - node_config = dict(workflow.get_node_config_by_id(node_id)) - node_config_data = node_config.get("data", {}) + node_config = workflow.get_node_config_by_id(node_id) + node_config_data = node_config["data"] # Get node type - node_type = NodeType(node_config_data.get("type")) + node_type = NodeType(node_config_data["type"]) # init graph init params and runtime state graph_init_params = GraphInitParams( diff --git a/api/models/workflow.py b/api/models/workflow.py index df83228c2a..83956b1114 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -29,6 +29,7 @@ from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) +from core.workflow.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause from core.workflow.enums import NodeType from extensions.ext_storage import Storage @@ -229,7 +230,7 @@ class Workflow(Base): # bug # - `_get_graph_and_variable_pool_for_single_node_run`. return json.loads(self.graph) if self.graph else {} - def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]: + def get_node_config_by_id(self, node_id: str) -> NodeConfigDict: """Extract a node configuration from the workflow graph by node ID. A node configuration is a dictionary containing the node's properties, including the node's id, title, and its data as a dict. @@ -247,8 +248,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - assert isinstance(node_config, dict) - return node_config + return NodeConfigDictAdapter.validate_python(node_config) @staticmethod def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType: diff --git a/api/pyproject.toml b/api/pyproject.toml index 41e532047c..97e6c83ed6 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -116,7 +116,7 @@ dev = [ "dotenv-linter~=0.5.0", "faker~=38.2.0", "lxml-stubs~=0.5.1", - "ty~=0.0.1a19", + "ty>=0.0.14", "basedpyright~=1.31.0", "ruff~=0.14.0", "pytest~=8.3.2", diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index e3431fd382..934d1bdd34 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -90,6 +90,7 @@ class TestWebhookService: "id": "webhook_node", "type": "webhook", "data": { + "type": "trigger-webhook", "title": "Test Webhook", "method": "post", "content_type": "application/json", diff --git a/api/ty.toml b/api/ty.toml index afdd37897e..380e14dbef 100644 --- a/api/ty.toml +++ b/api/ty.toml @@ -1,16 +1,15 @@ [src] exclude = [ # deps groups (A1/A2/B/C/D/E) - # A2: workflow engine/nodes - "core/workflow", - "core/app/workflow", - "core/helper/code_executor", # B: app runner + prompt "core/prompt", "core/app/apps/base_app_runner.py", "core/app/apps/workflow_app_runner.py", + "core/agent", + "core/plugin", # C: services/controllers/fields/libs "services", + "controllers/inner_api", "controllers/console/app", "controllers/console/explore", "controllers/console/datasets", @@ -28,3 +27,8 @@ exclude = [ "tests", ] + +[rules] +deprecated = "ignore" +unused-ignore-comment = "ignore" +# possibly-missing-attribute = "ignore" \ No newline at end of file diff --git a/api/uv.lock b/api/uv.lock index 0677f5ad98..04d9a7c021 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1684,7 +1684,7 @@ dev = [ { name = "scipy-stubs", specifier = ">=1.15.3.0" }, { name = "sseclient-py", specifier = ">=1.8.0" }, { name = "testcontainers", specifier = "~=4.13.2" }, - { name = "ty", specifier = "~=0.0.1a19" }, + { name = "ty", specifier = ">=0.0.14" }, { name = "types-aiofiles", specifier = "~=24.1.0" }, { name = "types-beautifulsoup4", specifier = "~=4.12.0" }, { name = "types-cachetools", specifier = "~=5.5.0" }, @@ -6239,27 +6239,26 @@ wheels = [ [[package]] name = "ty" -version = "0.0.1a27" +version = "0.0.14" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8f/65/3592d7c73d80664378fc90d0a00c33449a99cbf13b984433c883815245f3/ty-0.0.1a27.tar.gz", hash = "sha256:d34fe04979f2c912700cbf0919e8f9b4eeaa10c4a2aff7450e5e4c90f998bc28", size = 4516059, upload-time = "2025-11-18T21:55:18.381Z" } +sdist = { url = "https://files.pythonhosted.org/packages/af/57/22c3d6bf95c2229120c49ffc2f0da8d9e8823755a1c3194da56e51f1cc31/ty-0.0.14.tar.gz", hash = "sha256:a691010565f59dd7f15cf324cdcd1d9065e010c77a04f887e1ea070ba34a7de2", size = 5036573, upload-time = "2026-01-27T00:57:31.427Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/05/7945aa97356446fd53ed3ddc7ee02a88d8ad394217acd9428f472d6b109d/ty-0.0.1a27-py3-none-linux_armv6l.whl", hash = "sha256:3cbb735f5ecb3a7a5f5b82fb24da17912788c109086df4e97d454c8fb236fbc5", size = 9375047, upload-time = "2025-11-18T21:54:31.577Z" }, - { url = "https://files.pythonhosted.org/packages/69/4e/89b167a03de0e9ec329dc89bc02e8694768e4576337ef6c0699987681342/ty-0.0.1a27-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:4a6367236dc456ba2416563301d498aef8c6f8959be88777ef7ba5ac1bf15f0b", size = 9169540, upload-time = "2025-11-18T21:54:34.036Z" }, - { url = "https://files.pythonhosted.org/packages/38/07/e62009ab9cc242e1becb2bd992097c80a133fce0d4f055fba6576150d08a/ty-0.0.1a27-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8e93e231a1bcde964cdb062d2d5e549c24493fb1638eecae8fcc42b81e9463a4", size = 8711942, upload-time = "2025-11-18T21:54:36.3Z" }, - { url = "https://files.pythonhosted.org/packages/b5/43/f35716ec15406f13085db52e762a3cc663c651531a8124481d0ba602eca0/ty-0.0.1a27-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5b6a8166b60117da1179851a3d719cc798bf7e61f91b35d76242f0059e9ae1d", size = 8984208, upload-time = "2025-11-18T21:54:39.453Z" }, - { url = "https://files.pythonhosted.org/packages/2d/79/486a3374809523172379768de882c7a369861165802990177fe81489b85f/ty-0.0.1a27-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bfbe8b0e831c072b79a078d6c126d7f4d48ca17f64a103de1b93aeda32265dc5", size = 9157209, upload-time = "2025-11-18T21:54:42.664Z" }, - { url = "https://files.pythonhosted.org/packages/ff/08/9a7c8efcb327197d7d347c548850ef4b54de1c254981b65e8cd0672dc327/ty-0.0.1a27-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:90e09678331552e7c25d7eb47868b0910dc5b9b212ae22c8ce71a52d6576ddbb", size = 9519207, upload-time = "2025-11-18T21:54:45.311Z" }, - { url = "https://files.pythonhosted.org/packages/e0/9d/7b4680683e83204b9edec551bb91c21c789ebc586b949c5218157ee474b7/ty-0.0.1a27-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:88c03e4beeca79d85a5618921e44b3a6ea957e0453e08b1cdd418b51da645939", size = 10148794, upload-time = "2025-11-18T21:54:48.329Z" }, - { url = "https://files.pythonhosted.org/packages/89/21/8b961b0ab00c28223f06b33222427a8e31aa04f39d1b236acc93021c626c/ty-0.0.1a27-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ece5811322789fefe22fc088ed36c5879489cd39e913f9c1ff2a7678f089c61", size = 9900563, upload-time = "2025-11-18T21:54:51.214Z" }, - { url = "https://files.pythonhosted.org/packages/85/eb/95e1f0b426c2ea8d443aa923fcab509059c467bbe64a15baaf573fea1203/ty-0.0.1a27-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f2ccb4f0fddcd6e2017c268dfce2489e9a36cb82a5900afe6425835248b1086", size = 9926355, upload-time = "2025-11-18T21:54:53.927Z" }, - { url = "https://files.pythonhosted.org/packages/f5/78/40e7f072049e63c414f2845df780be3a494d92198c87c2ffa65e63aecf3f/ty-0.0.1a27-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33450528312e41d003e96a1647780b2783ab7569bbc29c04fc76f2d1908061e3", size = 9480580, upload-time = "2025-11-18T21:54:56.617Z" }, - { url = "https://files.pythonhosted.org/packages/18/da/f4a2dfedab39096808ddf7475f35ceb750d9a9da840bee4afd47b871742f/ty-0.0.1a27-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a0a9ac635deaa2b15947701197ede40cdecd13f89f19351872d16f9ccd773fa1", size = 8957524, upload-time = "2025-11-18T21:54:59.085Z" }, - { url = "https://files.pythonhosted.org/packages/21/ea/26fee9a20cf77a157316fd3ab9c6db8ad5a0b20b2d38a43f3452622587ac/ty-0.0.1a27-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:797fb2cd49b6b9b3ac9f2f0e401fb02d3aa155badc05a8591d048d38d28f1e0c", size = 9201098, upload-time = "2025-11-18T21:55:01.845Z" }, - { url = "https://files.pythonhosted.org/packages/b0/53/e14591d1275108c9ae28f97ac5d4b93adcc2c8a4b1b9a880dfa9d07c15f8/ty-0.0.1a27-py3-none-musllinux_1_2_i686.whl", hash = "sha256:7fe81679a0941f85e98187d444604e24b15bde0a85874957c945751756314d03", size = 9275470, upload-time = "2025-11-18T21:55:04.23Z" }, - { url = "https://files.pythonhosted.org/packages/37/44/e2c9acecac70bf06fb41de285e7be2433c2c9828f71e3bf0e886fc85c4fd/ty-0.0.1a27-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:355f651d0cdb85535a82bd9f0583f77b28e3fd7bba7b7da33dcee5a576eff28b", size = 9592394, upload-time = "2025-11-18T21:55:06.542Z" }, - { url = "https://files.pythonhosted.org/packages/ee/a7/4636369731b24ed07c2b4c7805b8d990283d677180662c532d82e4ef1a36/ty-0.0.1a27-py3-none-win32.whl", hash = "sha256:61782e5f40e6df622093847b34c366634b75d53f839986f1bf4481672ad6cb55", size = 8783816, upload-time = "2025-11-18T21:55:09.648Z" }, - { url = "https://files.pythonhosted.org/packages/a7/1d/b76487725628d9e81d9047dc0033a5e167e0d10f27893d04de67fe1a9763/ty-0.0.1a27-py3-none-win_amd64.whl", hash = "sha256:c682b238085d3191acddcf66ef22641562946b1bba2a7f316012d5b2a2f4de11", size = 9616833, upload-time = "2025-11-18T21:55:12.457Z" }, - { url = "https://files.pythonhosted.org/packages/3a/db/c7cd5276c8f336a3cf87992b75ba9d486a7cf54e753fcd42495b3bc56fb7/ty-0.0.1a27-py3-none-win_arm64.whl", hash = "sha256:e146dfa32cbb0ac6afb0cb65659e87e4e313715e68d76fe5ae0a4b3d5b912ce8", size = 9137796, upload-time = "2025-11-18T21:55:15.897Z" }, + { url = "https://files.pythonhosted.org/packages/99/cb/cc6d1d8de59beb17a41f9a614585f884ec2d95450306c173b3b7cc090d2e/ty-0.0.14-py3-none-linux_armv6l.whl", hash = "sha256:32cf2a7596e693094621d3ae568d7ee16707dce28c34d1762947874060fdddaa", size = 10034228, upload-time = "2026-01-27T00:57:53.133Z" }, + { url = "https://files.pythonhosted.org/packages/f3/96/dd42816a2075a8f31542296ae687483a8d047f86a6538dfba573223eaf9a/ty-0.0.14-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:f971bf9805f49ce8c0968ad53e29624d80b970b9eb597b7cbaba25d8a18ce9a2", size = 9939162, upload-time = "2026-01-27T00:57:43.857Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b4/73c4859004e0f0a9eead9ecb67021438b2e8e5fdd8d03e7f5aca77623992/ty-0.0.14-py3-none-macosx_11_0_arm64.whl", hash = "sha256:45448b9e4806423523268bc15e9208c4f3f2ead7c344f615549d2e2354d6e924", size = 9418661, upload-time = "2026-01-27T00:58:03.411Z" }, + { url = "https://files.pythonhosted.org/packages/58/35/839c4551b94613db4afa20ee555dd4f33bfa7352d5da74c5fa416ffa0fd2/ty-0.0.14-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee94a9b747ff40114085206bdb3205a631ef19a4d3fb89e302a88754cbbae54c", size = 9837872, upload-time = "2026-01-27T00:57:23.718Z" }, + { url = "https://files.pythonhosted.org/packages/41/2b/bbecf7e2faa20c04bebd35fc478668953ca50ee5847ce23e08acf20ea119/ty-0.0.14-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6756715a3c33182e9ab8ffca2bb314d3c99b9c410b171736e145773ee0ae41c3", size = 9848819, upload-time = "2026-01-27T00:57:58.501Z" }, + { url = "https://files.pythonhosted.org/packages/be/60/3c0ba0f19c0f647ad9d2b5b5ac68c0f0b4dc899001bd53b3a7537fb247a2/ty-0.0.14-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89d0038a2f698ba8b6fec5cf216a4e44e2f95e4a5095a8c0f57fe549f87087c2", size = 10324371, upload-time = "2026-01-27T00:57:29.291Z" }, + { url = "https://files.pythonhosted.org/packages/24/32/99d0a0b37d0397b0a989ffc2682493286aa3bc252b24004a6714368c2c3d/ty-0.0.14-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c64a83a2d669b77f50a4957039ca1450626fb474619f18f6f8a3eb885bf7544", size = 10865898, upload-time = "2026-01-27T00:57:33.542Z" }, + { url = "https://files.pythonhosted.org/packages/1a/88/30b583a9e0311bb474269cfa91db53350557ebec09002bfc3fb3fc364e8c/ty-0.0.14-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:242488bfb547ef080199f6fd81369ab9cb638a778bb161511d091ffd49c12129", size = 10555777, upload-time = "2026-01-27T00:58:05.853Z" }, + { url = "https://files.pythonhosted.org/packages/cd/a2/cb53fb6325dcf3d40f2b1d0457a25d55bfbae633c8e337bde8ec01a190eb/ty-0.0.14-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4790c3866f6c83a4f424fc7d09ebdb225c1f1131647ba8bdc6fcdc28f09ed0ff", size = 10412913, upload-time = "2026-01-27T00:57:38.834Z" }, + { url = "https://files.pythonhosted.org/packages/42/8f/f2f5202d725ed1e6a4e5ffaa32b190a1fe70c0b1a2503d38515da4130b4c/ty-0.0.14-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:950f320437f96d4ea9a2332bbfb5b68f1c1acd269ebfa4c09b6970cc1565bd9d", size = 9837608, upload-time = "2026-01-27T00:57:55.898Z" }, + { url = "https://files.pythonhosted.org/packages/f7/ba/59a2a0521640c489dafa2c546ae1f8465f92956fede18660653cce73b4c5/ty-0.0.14-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4a0ec3ee70d83887f86925bbc1c56f4628bd58a0f47f6f32ddfe04e1f05466df", size = 9884324, upload-time = "2026-01-27T00:57:46.786Z" }, + { url = "https://files.pythonhosted.org/packages/03/95/8d2a49880f47b638743212f011088552ecc454dd7a665ddcbdabea25772a/ty-0.0.14-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a1a4e6b6da0c58b34415955279eff754d6206b35af56a18bb70eb519d8d139ef", size = 10033537, upload-time = "2026-01-27T00:58:01.149Z" }, + { url = "https://files.pythonhosted.org/packages/e9/40/4523b36f2ce69f92ccf783855a9e0ebbbd0f0bb5cdce6211ee1737159ed3/ty-0.0.14-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:dc04384e874c5de4c5d743369c277c8aa73d1edea3c7fc646b2064b637db4db3", size = 10495910, upload-time = "2026-01-27T00:57:26.691Z" }, + { url = "https://files.pythonhosted.org/packages/08/d5/655beb51224d1bfd4f9ddc0bb209659bfe71ff141bcf05c418ab670698f0/ty-0.0.14-py3-none-win32.whl", hash = "sha256:b20e22cf54c66b3e37e87377635da412d9a552c9bf4ad9fc449fed8b2e19dad2", size = 9507626, upload-time = "2026-01-27T00:57:41.43Z" }, + { url = "https://files.pythonhosted.org/packages/b6/d9/c569c9961760e20e0a4bc008eeb1415754564304fd53997a371b7cf3f864/ty-0.0.14-py3-none-win_amd64.whl", hash = "sha256:e312ff9475522d1a33186657fe74d1ec98e4a13e016d66f5758a452c90ff6409", size = 10437980, upload-time = "2026-01-27T00:57:36.422Z" }, + { url = "https://files.pythonhosted.org/packages/ad/0c/186829654f5bfd9a028f6648e9caeb11271960a61de97484627d24443f91/ty-0.0.14-py3-none-win_arm64.whl", hash = "sha256:b6facdbe9b740cb2c15293a1d178e22ffc600653646452632541d01c36d5e378", size = 9885831, upload-time = "2026-01-27T00:57:49.747Z" }, ] [[package]] From 7828508b3075bd4adcb6b272d43bd799bf96c8f9 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sun, 1 Feb 2026 13:43:14 +0900 Subject: [PATCH 07/32] refactor: remove all reqparser (#29289) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com> --- api/.ruff.toml | 14 +- .../rag_pipeline/rag_pipeline_workflow.py | 19 +- .../console/workspace/tool_providers.py | 676 +++++++++--------- api/controllers/service_api/app/completion.py | 3 +- .../service_api/app/conversation.py | 6 +- api/controllers/service_api/app/message.py | 6 +- .../service_api/dataset/hit_testing.py | 6 +- .../utils/workflow_configuration_sync.py | 5 - .../tools/workflow_tools_manage_service.py | 16 +- .../test_workflow_tools_manage_service.py | 95 +-- 10 files changed, 434 insertions(+), 412 deletions(-) diff --git a/api/.ruff.toml b/api/.ruff.toml index 8db0cbcb21..3301452ad9 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -53,6 +53,7 @@ select = [ "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. "S302", # suspicious-marshal-usage, disallow use of `marshal` module "S311", # suspicious-non-cryptographic-random-usage, + "TID", # flake8-tidy-imports ] @@ -88,6 +89,7 @@ ignore = [ "SIM113", # enumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false + "TID252", # allow relative imports from parent modules ] [lint.per-file-ignores] @@ -109,10 +111,20 @@ ignore = [ "S110", # allow ignoring exceptions in tests code (currently) ] +"controllers/console/explore/trial.py" = ["TID251"] +"controllers/console/human_input_form.py" = ["TID251"] +"controllers/web/human_input_form.py" = ["TID251"] [lint.pyflakes] allowed-unused-imports = [ - "_pytest.monkeypatch", "tests.integration_tests", "tests.unit_tests", ] + +[lint.flake8-tidy-imports] + +[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"] +msg = "Use Pydantic payload/query models instead of reqparse." + +[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"] +msg = "Use Pydantic payload/query models instead of reqparse." diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index d34fd5088d..29b6b64b94 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -1,10 +1,9 @@ import json import logging from typing import Any, Literal, cast -from uuid import UUID from flask import abort, request -from flask_restx import Resource, marshal_with, reqparse # type: ignore +from flask_restx import Resource, marshal_with # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -38,7 +37,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory from libs import helper -from libs.helper import TimestampField +from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required from models import Account from models.dataset import Pipeline @@ -110,7 +109,7 @@ class NodeIdQuery(BaseModel): class WorkflowRunQuery(BaseModel): - last_id: UUID | None = None + last_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100) @@ -121,6 +120,10 @@ class DatasourceVariablesPayload(BaseModel): start_node_title: str +class RagPipelineRecommendedPluginQuery(BaseModel): + type: str = "all" + + register_schema_models( console_ns, DraftWorkflowSyncPayload, @@ -135,6 +138,7 @@ register_schema_models( NodeIdQuery, WorkflowRunQuery, DatasourceVariablesPayload, + RagPipelineRecommendedPluginQuery, ) @@ -975,11 +979,8 @@ class RagPipelineRecommendedPluginApi(Resource): @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument("type", type=str, location="args", required=False, default="all") - args = parser.parse_args() - type = args["type"] + query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict()) rag_pipeline_service = RagPipelineService() - recommended_plugins = rag_pipeline_service.get_recommended_plugins(type) + recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type) return recommended_plugins diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index e9e7b72718..5bfa895849 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,16 +1,16 @@ import io import logging +from typing import Any, Literal from urllib.parse import urlparse from flask import make_response, redirect, request, send_file -from flask_restx import ( - Resource, - reqparse, -) +from flask_restx import Resource +from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, @@ -26,8 +26,9 @@ from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler +from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db -from libs.helper import StrLen, alphanumeric, uuid_value +from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID @@ -52,24 +53,209 @@ def is_valid_url(url: str) -> bool: parsed = urlparse(url) return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"] except (ValueError, TypeError): - # ValueError: Invalid URL format - # TypeError: url is not a string return False -parser_tool = reqparse.RequestParser().add_argument( - "type", - type=str, - choices=["builtin", "model", "api", "workflow", "mcp"], - required=False, - nullable=True, - location="args", +class ToolProviderListQuery(BaseModel): + type: Literal["builtin", "model", "api", "workflow", "mcp"] | None = None + + +class BuiltinToolCredentialDeletePayload(BaseModel): + credential_id: str + + +class BuiltinToolAddPayload(BaseModel): + credentials: dict[str, Any] + name: str | None = Field(default=None, max_length=30) + type: CredentialType + + +class BuiltinToolUpdatePayload(BaseModel): + credential_id: str + credentials: dict[str, Any] | None = None + name: str | None = Field(default=None, max_length=30) + + +class ApiToolProviderBasePayload(BaseModel): + credentials: dict[str, Any] + schema_type: ApiProviderSchemaType + schema_: str = Field(alias="schema") + provider: str + icon: dict[str, Any] + privacy_policy: str | None = None + labels: list[str] | None = None + custom_disclaimer: str = "" + + +class ApiToolProviderAddPayload(ApiToolProviderBasePayload): + pass + + +class ApiToolProviderUpdatePayload(ApiToolProviderBasePayload): + original_provider: str + + +class UrlQuery(BaseModel): + url: HttpUrl + + +class ProviderQuery(BaseModel): + provider: str + + +class ApiToolProviderDeletePayload(BaseModel): + provider: str + + +class ApiToolSchemaPayload(BaseModel): + schema_: str = Field(alias="schema") + + +class ApiToolTestPayload(BaseModel): + tool_name: str + provider_name: str | None = None + credentials: dict[str, Any] + parameters: dict[str, Any] + schema_type: ApiProviderSchemaType + schema_: str = Field(alias="schema") + + +class WorkflowToolBasePayload(BaseModel): + name: str + label: str + description: str + icon: dict[str, Any] + parameters: list[WorkflowToolParameterConfiguration] = Field(default_factory=list) + privacy_policy: str | None = "" + labels: list[str] | None = None + + @field_validator("name") + @classmethod + def validate_name(cls, value: str) -> str: + return alphanumeric(value) + + +class WorkflowToolCreatePayload(WorkflowToolBasePayload): + workflow_app_id: str + + @field_validator("workflow_app_id") + @classmethod + def validate_workflow_app_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolUpdatePayload(WorkflowToolBasePayload): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolDeletePayload(BaseModel): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class WorkflowToolGetQuery(BaseModel): + workflow_tool_id: str | None = None + workflow_app_id: str | None = None + + @field_validator("workflow_tool_id", "workflow_app_id") + @classmethod + def validate_ids(cls, value: str | None) -> str | None: + if value is None: + return value + return uuid_value(value) + + @model_validator(mode="after") + def ensure_one(self) -> "WorkflowToolGetQuery": + if not self.workflow_tool_id and not self.workflow_app_id: + raise ValueError("workflow_tool_id or workflow_app_id is required") + return self + + +class WorkflowToolListQuery(BaseModel): + workflow_tool_id: str + + @field_validator("workflow_tool_id") + @classmethod + def validate_workflow_tool_id(cls, value: str) -> str: + return uuid_value(value) + + +class BuiltinProviderDefaultCredentialPayload(BaseModel): + id: str + + +class ToolOAuthCustomClientPayload(BaseModel): + client_params: dict[str, Any] | None = None + enable_oauth_custom_client: bool | None = True + + +class MCPProviderBasePayload(BaseModel): + server_url: str + name: str + icon: str + icon_type: str + icon_background: str = "" + server_identifier: str + configuration: dict[str, Any] | None = Field(default_factory=dict) + headers: dict[str, Any] | None = Field(default_factory=dict) + authentication: dict[str, Any] | None = Field(default_factory=dict) + + +class MCPProviderCreatePayload(MCPProviderBasePayload): + pass + + +class MCPProviderUpdatePayload(MCPProviderBasePayload): + provider_id: str + + +class MCPProviderDeletePayload(BaseModel): + provider_id: str + + +class MCPAuthPayload(BaseModel): + provider_id: str + authorization_code: str | None = None + + +class MCPCallbackQuery(BaseModel): + code: str + state: str + + +register_schema_models( + console_ns, + BuiltinToolCredentialDeletePayload, + BuiltinToolAddPayload, + BuiltinToolUpdatePayload, + ApiToolProviderAddPayload, + ApiToolProviderUpdatePayload, + ApiToolProviderDeletePayload, + ApiToolSchemaPayload, + ApiToolTestPayload, + WorkflowToolCreatePayload, + WorkflowToolUpdatePayload, + WorkflowToolDeletePayload, + BuiltinProviderDefaultCredentialPayload, + ToolOAuthCustomClientPayload, + MCPProviderCreatePayload, + MCPProviderUpdatePayload, + MCPProviderDeletePayload, + MCPAuthPayload, ) @console_ns.route("/workspaces/current/tool-providers") class ToolProviderListApi(Resource): - @console_ns.expect(parser_tool) @setup_required @login_required @account_initialization_required @@ -78,9 +264,10 @@ class ToolProviderListApi(Resource): user_id = user.id - args = parser_tool.parse_args() + raw_args = request.args.to_dict() + query = ToolProviderListQuery.model_validate(raw_args) - return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) + return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) # type: ignore @console_ns.route("/workspaces/current/tool-provider/builtin//tools") @@ -110,14 +297,9 @@ class ToolBuiltinProviderInfoApi(Resource): return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) -parser_delete = reqparse.RequestParser().add_argument( - "credential_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//delete") class ToolBuiltinProviderDeleteApi(Resource): - @console_ns.expect(parser_delete) + @console_ns.expect(console_ns.models[BuiltinToolCredentialDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -125,26 +307,18 @@ class ToolBuiltinProviderDeleteApi(Resource): def post(self, provider): _, tenant_id = current_account_with_tenant() - args = parser_delete.parse_args() + payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.delete_builtin_tool_provider( tenant_id, provider, - args["credential_id"], + payload.credential_id, ) -parser_add = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=False, location="json") - .add_argument("type", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//add") class ToolBuiltinProviderAddApi(Resource): - @console_ns.expect(parser_add) + @console_ns.expect(console_ns.models[BuiltinToolAddPayload.__name__]) @setup_required @login_required @account_initialization_required @@ -153,32 +327,21 @@ class ToolBuiltinProviderAddApi(Resource): user_id = user.id - args = parser_add.parse_args() - - if args["type"] not in CredentialType.values(): - raise ValueError(f"Invalid credential type: {args['type']}") + payload = BuiltinToolAddPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.add_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, provider=provider, - credentials=args["credentials"], - name=args["name"], - api_type=CredentialType.of(args["type"]), + credentials=payload.credentials, + name=payload.name, + api_type=CredentialType.of(payload.type), ) -parser_update = ( - reqparse.RequestParser() - .add_argument("credential_id", type=str, required=True, nullable=False, location="json") - .add_argument("credentials", type=dict, required=False, nullable=True, location="json") - .add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//update") class ToolBuiltinProviderUpdateApi(Resource): - @console_ns.expect(parser_update) + @console_ns.expect(console_ns.models[BuiltinToolUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -187,15 +350,15 @@ class ToolBuiltinProviderUpdateApi(Resource): user, tenant_id = current_account_with_tenant() user_id = user.id - args = parser_update.parse_args() + payload = BuiltinToolUpdatePayload.model_validate(console_ns.payload or {}) result = BuiltinToolManageService.update_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, provider=provider, - credential_id=args["credential_id"], - credentials=args.get("credentials", None), - name=args.get("name", ""), + credential_id=payload.credential_id, + credentials=payload.credentials, + name=payload.name or "", ) return result @@ -225,22 +388,9 @@ class ToolBuiltinProviderIconApi(Resource): return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) -parser_api_add = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) - .add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/add") class ToolApiProviderAddApi(Resource): - @console_ns.expect(parser_api_add) + @console_ns.expect(console_ns.models[ApiToolProviderAddPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -250,28 +400,24 @@ class ToolApiProviderAddApi(Resource): user_id = user.id - args = parser_api_add.parse_args() + payload = ApiToolProviderAddPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, - args["provider"], - args["icon"], - args["credentials"], - args["schema_type"], - args["schema"], - args.get("privacy_policy", ""), - args.get("custom_disclaimer", ""), - args.get("labels", []), + payload.provider, + payload.icon, + payload.credentials, + payload.schema_type, + payload.schema_, + payload.privacy_policy or "", + payload.custom_disclaimer or "", + payload.labels or [], ) -parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=True, nullable=False, location="args") - - @console_ns.route("/workspaces/current/tool-provider/api/remote") class ToolApiProviderGetRemoteSchemaApi(Resource): - @console_ns.expect(parser_remote) @setup_required @login_required @account_initialization_required @@ -280,23 +426,18 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): user_id = user.id - args = parser_remote.parse_args() + raw_args = request.args.to_dict() + query = UrlQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider_remote_schema( user_id, tenant_id, - args["url"], + str(query.url), ) -parser_tools = reqparse.RequestParser().add_argument( - "provider", type=str, required=True, nullable=False, location="args" -) - - @console_ns.route("/workspaces/current/tool-provider/api/tools") class ToolApiProviderListToolsApi(Resource): - @console_ns.expect(parser_tools) @setup_required @login_required @account_initialization_required @@ -305,34 +446,21 @@ class ToolApiProviderListToolsApi(Resource): user_id = user.id - args = parser_tools.parse_args() + raw_args = request.args.to_dict() + query = ProviderQuery.model_validate(raw_args) return jsonable_encoder( ApiToolManageService.list_api_tool_provider_tools( user_id, tenant_id, - args["provider"], + query.provider, ) ) -parser_api_update = ( - reqparse.RequestParser() - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") - .add_argument("provider", type=str, required=True, nullable=False, location="json") - .add_argument("original_provider", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") - .add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/update") class ToolApiProviderUpdateApi(Resource): - @console_ns.expect(parser_api_update) + @console_ns.expect(console_ns.models[ApiToolProviderUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -342,31 +470,26 @@ class ToolApiProviderUpdateApi(Resource): user_id = user.id - args = parser_api_update.parse_args() + payload = ApiToolProviderUpdatePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, - args["provider"], - args["original_provider"], - args["icon"], - args["credentials"], - args["schema_type"], - args["schema"], - args["privacy_policy"], - args["custom_disclaimer"], - args.get("labels", []), + payload.provider, + payload.original_provider, + payload.icon, + payload.credentials, + payload.schema_type, + payload.schema_, + payload.privacy_policy, + payload.custom_disclaimer, + payload.labels or [], ) -parser_api_delete = reqparse.RequestParser().add_argument( - "provider", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/api/delete") class ToolApiProviderDeleteApi(Resource): - @console_ns.expect(parser_api_delete) + @console_ns.expect(console_ns.models[ApiToolProviderDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -376,21 +499,17 @@ class ToolApiProviderDeleteApi(Resource): user_id = user.id - args = parser_api_delete.parse_args() + payload = ApiToolProviderDeletePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, - args["provider"], + payload.provider, ) -parser_get = reqparse.RequestParser().add_argument("provider", type=str, required=True, nullable=False, location="args") - - @console_ns.route("/workspaces/current/tool-provider/api/get") class ToolApiProviderGetApi(Resource): - @console_ns.expect(parser_get) @setup_required @login_required @account_initialization_required @@ -399,12 +518,13 @@ class ToolApiProviderGetApi(Resource): user_id = user.id - args = parser_get.parse_args() + raw_args = request.args.to_dict() + query = ProviderQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, - args["provider"], + query.provider, ) @@ -423,72 +543,43 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): ) -parser_schema = reqparse.RequestParser().add_argument( - "schema", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/api/schema") class ToolApiProviderSchemaApi(Resource): - @console_ns.expect(parser_schema) + @console_ns.expect(console_ns.models[ApiToolSchemaPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_schema.parse_args() + payload = ApiToolSchemaPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.parser_api_schema( - schema=args["schema"], + schema=payload.schema_, ) -parser_pre = ( - reqparse.RequestParser() - .add_argument("tool_name", type=str, required=True, nullable=False, location="json") - .add_argument("provider_name", type=str, required=False, nullable=False, location="json") - .add_argument("credentials", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=dict, required=True, nullable=False, location="json") - .add_argument("schema_type", type=str, required=True, nullable=False, location="json") - .add_argument("schema", type=str, required=True, nullable=False, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/api/test/pre") class ToolApiProviderPreviousTestApi(Resource): - @console_ns.expect(parser_pre) + @console_ns.expect(console_ns.models[ApiToolTestPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_pre.parse_args() + payload = ApiToolTestPayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() return ApiToolManageService.test_api_tool_preview( current_tenant_id, - args["provider_name"] or "", - args["tool_name"], - args["credentials"], - args["parameters"], - args["schema_type"], - args["schema"], + payload.provider_name or "", + payload.tool_name, + payload.credentials, + payload.parameters, + payload.schema_type, + payload.schema_, ) -parser_create = ( - reqparse.RequestParser() - .add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - .add_argument("label", type=str, required=True, nullable=False, location="json") - .add_argument("description", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/create") class ToolWorkflowProviderCreateApi(Resource): - @console_ns.expect(parser_create) + @console_ns.expect(console_ns.models[WorkflowToolCreatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -498,38 +589,25 @@ class ToolWorkflowProviderCreateApi(Resource): user_id = user.id - args = parser_create.parse_args() + payload = WorkflowToolCreatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.create_workflow_tool( user_id=user_id, tenant_id=tenant_id, - workflow_app_id=args["workflow_app_id"], - name=args["name"], - label=args["label"], - icon=args["icon"], - description=args["description"], - parameters=args["parameters"], - privacy_policy=args["privacy_policy"], - labels=args["labels"], + workflow_app_id=payload.workflow_app_id, + name=payload.name, + label=payload.label, + icon=payload.icon, + description=payload.description, + parameters=payload.parameters, + privacy_policy=payload.privacy_policy or "", + labels=payload.labels or [], ) -parser_workflow_update = ( - reqparse.RequestParser() - .add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") - .add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") - .add_argument("label", type=str, required=True, nullable=False, location="json") - .add_argument("description", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=dict, required=True, nullable=False, location="json") - .add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") - .add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") - .add_argument("labels", type=list[str], required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/update") class ToolWorkflowProviderUpdateApi(Resource): - @console_ns.expect(parser_workflow_update) + @console_ns.expect(console_ns.models[WorkflowToolUpdatePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -538,33 +616,25 @@ class ToolWorkflowProviderUpdateApi(Resource): user, tenant_id = current_account_with_tenant() user_id = user.id - args = parser_workflow_update.parse_args() - - if not args["workflow_tool_id"]: - raise ValueError("incorrect workflow_tool_id") + payload = WorkflowToolUpdatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.update_workflow_tool( user_id, tenant_id, - args["workflow_tool_id"], - args["name"], - args["label"], - args["icon"], - args["description"], - args["parameters"], - args["privacy_policy"], - args.get("labels", []), + payload.workflow_tool_id, + payload.name, + payload.label, + payload.icon, + payload.description, + payload.parameters, + payload.privacy_policy or "", + payload.labels or [], ) -parser_workflow_delete = reqparse.RequestParser().add_argument( - "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/delete") class ToolWorkflowProviderDeleteApi(Resource): - @console_ns.expect(parser_workflow_delete) + @console_ns.expect(console_ns.models[WorkflowToolDeletePayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @@ -574,25 +644,17 @@ class ToolWorkflowProviderDeleteApi(Resource): user_id = user.id - args = parser_workflow_delete.parse_args() + payload = WorkflowToolDeletePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.delete_workflow_tool( user_id, tenant_id, - args["workflow_tool_id"], + payload.workflow_tool_id, ) -parser_wf_get = ( - reqparse.RequestParser() - .add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") - .add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/get") class ToolWorkflowProviderGetApi(Resource): - @console_ns.expect(parser_wf_get) @setup_required @login_required @account_initialization_required @@ -601,19 +663,20 @@ class ToolWorkflowProviderGetApi(Resource): user_id = user.id - args = parser_wf_get.parse_args() + raw_args = request.args.to_dict() + query = WorkflowToolGetQuery.model_validate(raw_args) - if args.get("workflow_tool_id"): + if query.workflow_tool_id: tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( user_id, tenant_id, - args["workflow_tool_id"], + query.workflow_tool_id, ) - elif args.get("workflow_app_id"): + elif query.workflow_app_id: tool = WorkflowToolManageService.get_workflow_tool_by_app_id( user_id, tenant_id, - args["workflow_app_id"], + query.workflow_app_id, ) else: raise ValueError("incorrect workflow_tool_id or workflow_app_id") @@ -621,14 +684,8 @@ class ToolWorkflowProviderGetApi(Resource): return jsonable_encoder(tool) -parser_wf_tools = reqparse.RequestParser().add_argument( - "workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args" -) - - @console_ns.route("/workspaces/current/tool-provider/workflow/tools") class ToolWorkflowProviderListToolApi(Resource): - @console_ns.expect(parser_wf_tools) @setup_required @login_required @account_initialization_required @@ -637,13 +694,14 @@ class ToolWorkflowProviderListToolApi(Resource): user_id = user.id - args = parser_wf_tools.parse_args() + raw_args = request.args.to_dict() + query = WorkflowToolListQuery.model_validate(raw_args) return jsonable_encoder( WorkflowToolManageService.list_single_workflow_tools( user_id, tenant_id, - args["workflow_tool_id"], + query.workflow_tool_id, ) ) @@ -810,49 +868,39 @@ class ToolOAuthCallback(Resource): return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") -parser_default_cred = reqparse.RequestParser().add_argument( - "id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//default-credential") class ToolBuiltinProviderSetDefaultApi(Resource): - @console_ns.expect(parser_default_cred) + @console_ns.expect(console_ns.models[BuiltinProviderDefaultCredentialPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self, provider): current_user, current_tenant_id = current_account_with_tenant() - args = parser_default_cred.parse_args() + payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.set_default_provider( - tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] + tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id ) -parser_custom = ( - reqparse.RequestParser() - .add_argument("client_params", type=dict, required=False, nullable=True, location="json") - .add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/builtin//oauth/custom-client") class ToolOAuthCustomClient(Resource): - @console_ns.expect(parser_custom) + @console_ns.expect(console_ns.models[ToolOAuthCustomClientPayload.__name__]) @setup_required @login_required @is_admin_or_owner_required @account_initialization_required def post(self, provider: str): - args = parser_custom.parse_args() + payload = ToolOAuthCustomClientPayload.model_validate(console_ns.payload or {}) _, tenant_id = current_account_with_tenant() return BuiltinToolManageService.save_custom_oauth_client_params( tenant_id=tenant_id, provider=provider, - client_params=args.get("client_params", {}), - enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), + client_params=payload.client_params or {}, + enable_oauth_custom_client=payload.enable_oauth_custom_client + if payload.enable_oauth_custom_client is not None + else True, ) @setup_required @@ -904,49 +952,19 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource): ) -parser_mcp = ( - reqparse.RequestParser() - .add_argument("server_url", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=str, required=True, nullable=False, location="json") - .add_argument("icon_type", type=str, required=True, nullable=False, location="json") - .add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") - .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) -) -parser_mcp_put = ( - reqparse.RequestParser() - .add_argument("server_url", type=str, required=True, nullable=False, location="json") - .add_argument("name", type=str, required=True, nullable=False, location="json") - .add_argument("icon", type=str, required=True, nullable=False, location="json") - .add_argument("icon_type", type=str, required=True, nullable=False, location="json") - .add_argument("icon_background", type=str, required=False, nullable=True, location="json") - .add_argument("provider_id", type=str, required=True, nullable=False, location="json") - .add_argument("server_identifier", type=str, required=True, nullable=False, location="json") - .add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("headers", type=dict, required=False, nullable=True, location="json", default={}) - .add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={}) -) -parser_mcp_delete = reqparse.RequestParser().add_argument( - "provider_id", type=str, required=True, nullable=False, location="json" -) - - @console_ns.route("/workspaces/current/tool-provider/mcp") class ToolProviderMCPApi(Resource): - @console_ns.expect(parser_mcp) + @console_ns.expect(console_ns.models[MCPProviderCreatePayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_mcp.parse_args() + payload = MCPProviderCreatePayload.model_validate(console_ns.payload or {}) user, tenant_id = current_account_with_tenant() # Parse and validate models - configuration = MCPConfiguration.model_validate(args["configuration"]) - authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + configuration = MCPConfiguration.model_validate(payload.configuration or {}) + authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None # 1) Create provider in a short transaction (no network I/O inside) with session_factory.create_session() as session, session.begin(): @@ -954,13 +972,13 @@ class ToolProviderMCPApi(Resource): result = service.create_provider( tenant_id=tenant_id, user_id=user.id, - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - headers=args["headers"], + server_url=payload.server_url, + name=payload.name, + icon=payload.icon, + icon_type=payload.icon_type, + icon_background=payload.icon_background, + server_identifier=payload.server_identifier, + headers=payload.headers or {}, configuration=configuration, authentication=authentication, ) @@ -969,8 +987,8 @@ class ToolProviderMCPApi(Resource): # Perform network I/O outside any DB session to avoid holding locks. try: reconnect = MCPToolManageService.reconnect_with_url( - server_url=args["server_url"], - headers=args.get("headers") or {}, + server_url=payload.server_url, + headers=payload.headers or {}, timeout=configuration.timeout, sse_read_timeout=configuration.sse_read_timeout, ) @@ -988,14 +1006,14 @@ class ToolProviderMCPApi(Resource): return jsonable_encoder(result) - @console_ns.expect(parser_mcp_put) + @console_ns.expect(console_ns.models[MCPProviderUpdatePayload.__name__]) @setup_required @login_required @account_initialization_required def put(self): - args = parser_mcp_put.parse_args() - configuration = MCPConfiguration.model_validate(args["configuration"]) - authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None + payload = MCPProviderUpdatePayload.model_validate(console_ns.payload or {}) + configuration = MCPConfiguration.model_validate(payload.configuration or {}) + authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None _, current_tenant_id = current_account_with_tenant() # Step 1: Get provider data for URL validation (short-lived session, no network I/O) @@ -1003,14 +1021,14 @@ class ToolProviderMCPApi(Resource): with Session(db.engine) as session: service = MCPToolManageService(session=session) validation_data = service.get_provider_for_url_validation( - tenant_id=current_tenant_id, provider_id=args["provider_id"] + tenant_id=current_tenant_id, provider_id=payload.provider_id ) # Step 2: Perform URL validation with network I/O OUTSIDE of any database session # This prevents holding database locks during potentially slow network operations validation_result = MCPToolManageService.validate_server_url_standalone( tenant_id=current_tenant_id, - new_server_url=args["server_url"], + new_server_url=payload.server_url, validation_data=validation_data, ) @@ -1019,14 +1037,14 @@ class ToolProviderMCPApi(Resource): service = MCPToolManageService(session=session) service.update_provider( tenant_id=current_tenant_id, - provider_id=args["provider_id"], - server_url=args["server_url"], - name=args["name"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - server_identifier=args["server_identifier"], - headers=args["headers"], + provider_id=payload.provider_id, + server_url=payload.server_url, + name=payload.name, + icon=payload.icon, + icon_type=payload.icon_type, + icon_background=payload.icon_background, + server_identifier=payload.server_identifier, + headers=payload.headers or {}, configuration=configuration, authentication=authentication, validation_result=validation_result, @@ -1034,37 +1052,30 @@ class ToolProviderMCPApi(Resource): return {"result": "success"} - @console_ns.expect(parser_mcp_delete) + @console_ns.expect(console_ns.models[MCPProviderDeletePayload.__name__]) @setup_required @login_required @account_initialization_required def delete(self): - args = parser_mcp_delete.parse_args() + payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() with Session(db.engine) as session, session.begin(): service = MCPToolManageService(session=session) - service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"]) + service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id) return {"result": "success"} -parser_auth = ( - reqparse.RequestParser() - .add_argument("provider_id", type=str, required=True, nullable=False, location="json") - .add_argument("authorization_code", type=str, required=False, nullable=True, location="json") -) - - @console_ns.route("/workspaces/current/tool-provider/mcp/auth") class ToolMCPAuthApi(Resource): - @console_ns.expect(parser_auth) + @console_ns.expect(console_ns.models[MCPAuthPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - args = parser_auth.parse_args() - provider_id = args["provider_id"] + payload = MCPAuthPayload.model_validate(console_ns.payload or {}) + provider_id = payload.provider_id _, tenant_id = current_account_with_tenant() with Session(db.engine) as session, session.begin(): @@ -1102,7 +1113,7 @@ class ToolMCPAuthApi(Resource): # Pass the extracted OAuth metadata hints to auth() auth_result = auth( provider_entity, - args.get("authorization_code"), + payload.authorization_code, resource_metadata_url=e.resource_metadata_url, scope_hint=e.scope_hint, ) @@ -1167,20 +1178,13 @@ class ToolMCPUpdateApi(Resource): return jsonable_encoder(tools) -parser_cb = ( - reqparse.RequestParser() - .add_argument("code", type=str, required=True, nullable=False, location="args") - .add_argument("state", type=str, required=True, nullable=False, location="args") -) - - @console_ns.route("/mcp/oauth/callback") class ToolMCPCallbackApi(Resource): - @console_ns.expect(parser_cb) def get(self): - args = parser_cb.parse_args() - state_key = args["state"] - authorization_code = args["code"] + raw_args = request.args.to_dict() + query = MCPCallbackQuery.model_validate(raw_args) + state_key = query.state + authorization_code = query.code # Create service instance for handle_callback with Session(db.engine) as session, session.begin(): diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index b3836f3a47..9d8431f066 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -30,6 +30,7 @@ from core.errors.error import ( from core.helper.trace_id_helper import get_external_trace_id from core.model_runtime.errors.invoke import InvokeError from libs import helper +from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService @@ -52,7 +53,7 @@ class ChatRequestPayload(BaseModel): query: str files: list[dict[str, Any]] | None = None response_mode: Literal["blocking", "streaming"] | None = None - conversation_id: str | None = Field(default=None, description="Conversation UUID") + conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID") retriever_from: str = Field(default="dev") auto_generate_name: bool = Field(default=True, description="Auto generate conversation name") workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 62e8258e25..8e29c9ff0f 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,4 @@ from typing import Any, Literal -from uuid import UUID from flask import request from flask_restx import Resource @@ -23,12 +22,13 @@ from fields.conversation_variable_fields import ( build_conversation_variable_infinite_scroll_pagination_model, build_conversation_variable_model, ) +from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService class ConversationListQuery(BaseModel): - last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination") + last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination") limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return") sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field( default="-updated_at", description="Sort order for conversations" @@ -48,7 +48,7 @@ class ConversationRenamePayload(BaseModel): class ConversationVariablesQuery(BaseModel): - last_id: UUID | None = Field(default=None, description="Last variable ID for pagination") + last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination") limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return") variable_name: str | None = Field( default=None, description="Filter variables by name", min_length=1, max_length=255 diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 8981bbd7d5..2aaf920efb 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,6 +1,5 @@ import logging from typing import Literal -from uuid import UUID from flask import request from flask_restx import Resource @@ -15,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem +from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.errors.message import ( FirstMessageNotExistsError, @@ -27,8 +27,8 @@ logger = logging.getLogger(__name__) class MessageListQuery(BaseModel): - conversation_id: UUID - first_id: UUID | None = None + conversation_id: UUIDStrOrEmpty + first_id: UUIDStrOrEmpty | None = None limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return") diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index 8dbb690901..97a70f5d0e 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -1,7 +1,10 @@ -from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase +from controllers.common.schema import register_schema_model +from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check +register_schema_model(service_api_ns, HitTestingPayload) + @service_api_ns.route("/datasets//hit-testing", "/datasets//retrieve") class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): @@ -15,6 +18,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): 404: "Dataset not found", } ) + @service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__]) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Perform hit testing on a dataset. diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 188da0c32d..6d75df3603 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -7,11 +7,6 @@ from core.workflow.nodes.base.entities import OutputVariableEntity class WorkflowToolConfigurationUtils: - @classmethod - def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): - for configuration in configurations: - WorkflowToolParameterConfiguration.model_validate(configuration) - @classmethod def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: """ diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index ab5d5480df..6d84d4e250 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,8 +1,6 @@ import json import logging -from collections.abc import Mapping from datetime import datetime -from typing import Any from sqlalchemy import or_, select from sqlalchemy.orm import Session @@ -10,8 +8,8 @@ from sqlalchemy.orm import Session from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db @@ -38,12 +36,10 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[Mapping[str, Any]], + parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): - WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -75,7 +71,7 @@ class WorkflowToolManageService: label=label, icon=json.dumps(icon), description=description, - parameter_configuration=json.dumps(parameters), + parameter_configuration=json.dumps([p.model_dump() for p in parameters]), privacy_policy=privacy_policy, version=workflow.version, ) @@ -104,7 +100,7 @@ class WorkflowToolManageService: label: str, icon: dict, description: str, - parameters: list[Mapping[str, Any]], + parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): @@ -122,8 +118,6 @@ class WorkflowToolManageService: :param labels: labels :return: the updated tool """ - WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) - # check if the name is unique existing_workflow_tool_provider = ( db.session.query(WorkflowToolProvider) @@ -162,7 +156,7 @@ class WorkflowToolManageService: workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) workflow_tool_provider.description = description - workflow_tool_provider.parameter_configuration = json.dumps(parameters) + workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters]) workflow_tool_provider.privacy_policy = privacy_policy workflow_tool_provider.version = workflow.version workflow_tool_provider.updated_at = datetime.now() diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 3d46735a1a..3c0a660e7c 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -3,7 +3,9 @@ from unittest.mock import patch import pytest from faker import Faker +from pydantic import ValidationError +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from models.tools import WorkflowToolProvider from models.workflow import Workflow as WorkflowModel from services.account_service import AccountService, TenantService @@ -130,20 +132,24 @@ class TestWorkflowToolManageService: def _create_test_workflow_tool_parameters(self): """Helper method to create valid workflow tool parameters.""" return [ - { - "name": "input_text", - "description": "Input text for processing", - "form": "form", - "type": "string", - "required": True, - }, - { - "name": "output_format", - "description": "Output format specification", - "form": "form", - "type": "select", - "required": False, - }, + WorkflowToolParameterConfiguration.model_validate( + { + "name": "input_text", + "description": "Input text for processing", + "form": "form", + "type": "string", + "required": True, + } + ), + WorkflowToolParameterConfiguration.model_validate( + { + "name": "output_format", + "description": "Output format specification", + "form": "form", + "type": "select", + "required": False, + } + ), ] def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): @@ -208,7 +214,7 @@ class TestWorkflowToolManageService: assert created_tool_provider.label == tool_label assert created_tool_provider.icon == json.dumps(tool_icon) assert created_tool_provider.description == tool_description - assert created_tool_provider.parameter_configuration == json.dumps(tool_parameters) + assert created_tool_provider.parameter_configuration == json.dumps([p.model_dump() for p in tool_parameters]) assert created_tool_provider.privacy_policy == tool_privacy_policy assert created_tool_provider.version == workflow.version assert created_tool_provider.user_id == account.id @@ -353,18 +359,9 @@ class TestWorkflowToolManageService: app, account, workflow = self._create_test_app_and_account( db_session_with_containers, mock_external_service_dependencies ) - - # Setup invalid workflow tool parameters (missing required fields) - invalid_parameters = [ - { - "name": "input_text", - # Missing description and form fields - "type": "string", - "required": True, - } - ] # Attempt to create workflow tool with invalid parameters - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValidationError) as exc_info: + # Setup invalid workflow tool parameters (missing required fields) WorkflowToolManageService.create_workflow_tool( user_id=account.id, tenant_id=account.current_tenant.id, @@ -373,7 +370,16 @@ class TestWorkflowToolManageService: label=fake.word(), icon={"type": "emoji", "emoji": "🔧"}, description=fake.text(max_nb_chars=200), - parameters=invalid_parameters, + parameters=[ + WorkflowToolParameterConfiguration.model_validate( + { + "name": "input_text", + # Missing description and form fields + "type": "string", + "required": True, + } + ) + ], ) # Verify error message contains validation error @@ -579,11 +585,12 @@ class TestWorkflowToolManageService: # Verify database state was updated db.session.refresh(created_tool) + assert created_tool is not None assert created_tool.name == updated_tool_name assert created_tool.label == updated_tool_label assert created_tool.icon == json.dumps(updated_tool_icon) assert created_tool.description == updated_tool_description - assert created_tool.parameter_configuration == json.dumps(updated_tool_parameters) + assert created_tool.parameter_configuration == json.dumps([p.model_dump() for p in updated_tool_parameters]) assert created_tool.privacy_policy == updated_tool_privacy_policy assert created_tool.version == workflow.version assert created_tool.updated_at is not None @@ -750,13 +757,15 @@ class TestWorkflowToolManageService: # Setup workflow tool parameters with FILE type file_parameters = [ - { - "name": "document", - "description": "Upload a document", - "form": "form", - "type": "file", - "required": False, - } + WorkflowToolParameterConfiguration.model_validate( + { + "name": "document", + "description": "Upload a document", + "form": "form", + "type": "file", + "required": False, + } + ) ] # Execute the method under test @@ -823,13 +832,15 @@ class TestWorkflowToolManageService: # Setup workflow tool parameters with FILES type files_parameters = [ - { - "name": "documents", - "description": "Upload multiple documents", - "form": "form", - "type": "files", - "required": False, - } + WorkflowToolParameterConfiguration.model_validate( + { + "name": "documents", + "description": "Upload multiple documents", + "form": "form", + "type": "files", + "required": False, + } + ) ] # Execute the method under test From 3216b67bfa93417814dbdf05fd8d0e174df3c627 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sun, 1 Feb 2026 19:25:54 +0900 Subject: [PATCH 08/32] refactor: examples of use match case (#31312) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/annotation.py | 9 +-- api/controllers/console/auth/oauth_server.py | 66 ++++++++++---------- 2 files changed, 38 insertions(+), 37 deletions(-) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 6a4c1528b0..a07145ce9f 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -107,10 +107,11 @@ class AnnotationReplyActionApi(Resource): def post(self, app_id, action: Literal["enable", "disable"]): app_id = str(app_id) args = AnnotationReplyPayload.model_validate(console_ns.payload) - if action == "enable": - result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) - elif action == "disable": - result = AppAnnotationService.disable_app_annotation(app_id) + match action: + case "enable": + result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) + case "disable": + result = AppAnnotationService.disable_app_annotation(app_id) return result, 200 diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 6162d88a0b..38ea5d2dae 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource): grant_type = OAuthGrantType(payload.grant_type) except ValueError: raise BadRequest("invalid grant_type") + match grant_type: + case OAuthGrantType.AUTHORIZATION_CODE: + if not payload.code: + raise BadRequest("code is required") - if grant_type == OAuthGrantType.AUTHORIZATION_CODE: - if not payload.code: - raise BadRequest("code is required") + if payload.client_secret != oauth_provider_app.client_secret: + raise BadRequest("client_secret is invalid") - if payload.client_secret != oauth_provider_app.client_secret: - raise BadRequest("client_secret is invalid") + if payload.redirect_uri not in oauth_provider_app.redirect_uris: + raise BadRequest("redirect_uri is invalid") - if payload.redirect_uri not in oauth_provider_app.redirect_uris: - raise BadRequest("redirect_uri is invalid") + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, code=payload.code, client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) + case OAuthGrantType.REFRESH_TOKEN: + if not payload.refresh_token: + raise BadRequest("refresh_token is required") - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, code=payload.code, client_id=oauth_provider_app.client_id - ) - return jsonable_encoder( - { - "access_token": access_token, - "token_type": "Bearer", - "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, - "refresh_token": refresh_token, - } - ) - elif grant_type == OAuthGrantType.REFRESH_TOKEN: - if not payload.refresh_token: - raise BadRequest("refresh_token is required") - - access_token, refresh_token = OAuthServerService.sign_oauth_access_token( - grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id - ) - return jsonable_encoder( - { - "access_token": access_token, - "token_type": "Bearer", - "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, - "refresh_token": refresh_token, - } - ) + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id + ) + return jsonable_encoder( + { + "access_token": access_token, + "token_type": "Bearer", + "expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN, + "refresh_token": refresh_token, + } + ) @console_ns.route("/oauth/provider/account") From 4f826b4641f44a8e5c1185ee09455dcd5bff4042 Mon Sep 17 00:00:00 2001 From: yyh <92089059+lyzno1@users.noreply.github.com> Date: Mon, 2 Feb 2026 09:41:34 +0800 Subject: [PATCH 09/32] refactor(typing): use enum types for workflow status fields (#31792) --- .../app/apps/common/workflow_response_converter.py | 10 +++++----- api/core/app/entities/task_entities.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 38ecec5d30..cefff7be92 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -250,7 +250,7 @@ class WorkflowResponseConverter: data=WorkflowFinishStreamResponse.Data( id=run_id, workflow_id=workflow_id, - status=status.value, + status=status, outputs=encoded_outputs, error=error, elapsed_time=elapsed_time, @@ -340,13 +340,13 @@ class WorkflowResponseConverter: metadata = self._merge_metadata(event.execution_metadata, snapshot) if isinstance(event, QueueNodeSucceededEvent): - status = WorkflowNodeExecutionStatus.SUCCEEDED.value + status = WorkflowNodeExecutionStatus.SUCCEEDED error_message = event.error elif isinstance(event, QueueNodeFailedEvent): - status = WorkflowNodeExecutionStatus.FAILED.value + status = WorkflowNodeExecutionStatus.FAILED error_message = event.error else: - status = WorkflowNodeExecutionStatus.EXCEPTION.value + status = WorkflowNodeExecutionStatus.EXCEPTION error_message = event.error return NodeFinishStreamResponse( @@ -413,7 +413,7 @@ class WorkflowResponseConverter: process_data_truncated=process_data_truncated, outputs=outputs, outputs_truncated=outputs_truncated, - status=WorkflowNodeExecutionStatus.RETRY.value, + status=WorkflowNodeExecutionStatus.RETRY, error=event.error, elapsed_time=elapsed_time, execution_metadata=metadata, diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 79a5e657b3..26fb17ccef 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities import AgentNodeStrategyInit -from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class AnnotationReplyAccount(BaseModel): @@ -223,7 +223,7 @@ class WorkflowFinishStreamResponse(StreamResponse): id: str workflow_id: str - status: str + status: WorkflowExecutionStatus outputs: Mapping[str, Any] | None = None error: str | None = None elapsed_time: float @@ -311,7 +311,7 @@ class NodeFinishStreamResponse(StreamResponse): process_data_truncated: bool = False outputs: Mapping[str, Any] | None = None outputs_truncated: bool = True - status: str + status: WorkflowNodeExecutionStatus error: str | None = None elapsed_time: float execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None @@ -375,7 +375,7 @@ class NodeRetryStreamResponse(StreamResponse): process_data_truncated: bool = False outputs: Mapping[str, Any] | None = None outputs_truncated: bool = False - status: str + status: WorkflowNodeExecutionStatus error: str | None = None elapsed_time: float execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None @@ -719,7 +719,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): id: str workflow_id: str - status: str + status: WorkflowExecutionStatus outputs: Mapping[str, Any] | None = None error: str | None = None elapsed_time: float From 41177757e64a174abcdf577079e45936294e19d4 Mon Sep 17 00:00:00 2001 From: FFXN <31929997+FFXN@users.noreply.github.com> Date: Mon, 2 Feb 2026 09:45:17 +0800 Subject: [PATCH 10/32] fix: summary index bug (#31810) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: zxhlyh Co-authored-by: Yansong Zhang <916125788@qq.com> Co-authored-by: hj24 Co-authored-by: CodingOnStar Co-authored-by: CodingOnStar Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../console/datasets/datasets_document.py | 12 ++++++ api/core/indexing_runner.py | 4 +- api/core/llm_generator/prompts.py | 4 +- .../index_processor/index_processor_base.py | 12 +++++- .../processor/paragraph_index_processor.py | 31 ++++++++++++-- .../processor/parent_child_index_processor.py | 8 +++- .../processor/qa_index_processor.py | 6 ++- .../knowledge_index/knowledge_index_node.py | 13 ++++++ api/services/dataset_service.py | 41 +++++++++++++++++++ .../rag_pipeline_transform_service.py | 4 ++ api/services/summary_index_service.py | 11 ++++- 11 files changed, 137 insertions(+), 9 deletions(-) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 6e3c0db8a3..6a0c9e5f77 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1339,6 +1339,18 @@ class DocumentGenerateSummaryApi(Resource): missing_ids = set(document_list) - found_ids raise NotFound(f"Some documents not found: {list(missing_ids)}") + # Update need_summary to True for documents that don't have it set + # This handles the case where documents were created when summary_index_setting was disabled + documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"] + + if documents_to_update: + document_ids_to_update = [str(doc.id) for doc in documents_to_update] + DocumentService.update_documents_need_summary( + dataset_id=dataset_id, + document_ids=document_ids_to_update, + need_summary=True, + ) + # Dispatch async tasks for each document for document in documents: # Skip qa_model documents as they don't generate summaries diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index e172e88298..61f168a26f 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -369,7 +369,9 @@ class IndexingRunner: # Generate summary preview summary_index_setting = tmp_processing_rule.get("summary_index_setting") if summary_index_setting and summary_index_setting.get("enable") and preview_texts: - preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting) + preview_texts = index_processor.generate_summary_preview( + tenant_id, preview_texts, summary_index_setting, doc_language + ) return IndexingEstimate(total_segments=total_segments, preview=preview_texts) diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index d46cf049dd..ee9a016c95 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -441,11 +441,13 @@ DEFAULT_GENERATOR_SUMMARY_PROMPT = ( Requirements: 1. Write a concise summary in plain text -2. Use the same language as the input content +2. You must write in {language}. No language other than {language} should be used. 3. Focus on important facts, concepts, and details 4. If images are included, describe their key information 5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions" 6. Write directly without extra words +7. If there is not enough content to generate a meaningful summary, + return an empty string without any explanation or prompt Output only the summary text. Start summarizing now: diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 151a3de7d9..6e76321ea0 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -48,12 +48,22 @@ class BaseIndexProcessor(ABC): @abstractmethod def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ For each segment in preview_texts, generate a summary using LLM and attach it to the segment. The summary can be stored in a new attribute, e.g., summary. This method should be implemented by subclasses. + + Args: + tenant_id: Tenant ID + preview_texts: List of preview details to generate summaries for + summary_index_setting: Summary index configuration + doc_language: Optional document language to ensure summary is generated in the correct language """ raise NotImplementedError diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index ab91e29145..41d7656f8a 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -275,7 +275,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): raise ValueError("Chunks is not a list") def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ For each segment, concurrently call generate_summary to generate a summary @@ -298,11 +302,15 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if flask_app: # Ensure Flask app context in worker thread with flask_app.app_context(): - summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + summary, _ = self.generate_summary( + tenant_id, preview.content, summary_index_setting, document_language=doc_language + ) preview.summary = summary else: # Fallback: try without app context (may fail) - summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting) + summary, _ = self.generate_summary( + tenant_id, preview.content, summary_index_setting, document_language=doc_language + ) preview.summary = summary # Generate summaries concurrently using ThreadPoolExecutor @@ -356,6 +364,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): text: str, summary_index_setting: dict | None = None, segment_id: str | None = None, + document_language: str | None = None, ) -> tuple[str, LLMUsage]: """ Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt, @@ -366,6 +375,8 @@ class ParagraphIndexProcessor(BaseIndexProcessor): text: Text content to summarize summary_index_setting: Summary index configuration segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table + document_language: Optional document language (e.g., "Chinese", "English") + to ensure summary is generated in the correct language Returns: Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object @@ -381,8 +392,22 @@ class ParagraphIndexProcessor(BaseIndexProcessor): raise ValueError("model_name and model_provider_name are required in summary_index_setting") # Import default summary prompt + is_default_prompt = False if not summary_prompt: summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT + is_default_prompt = True + + # Format prompt with document language only for default prompt + # Custom prompts are used as-is to avoid interfering with user-defined templates + # If document_language is provided, use it; otherwise, use "the same language as the input content" + # This is especially important for image-only chunks where text is empty or minimal + if is_default_prompt: + language_for_prompt = document_language or "the same language as the input content" + try: + summary_prompt = summary_prompt.format(language=language_for_prompt) + except KeyError: + # If default prompt doesn't have {language} placeholder, use it as-is + pass provider_manager = ProviderManager() provider_model_bundle = provider_manager.get_provider_model_bundle( diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 961df2e50c..0ea77405ed 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -358,7 +358,11 @@ class ParentChildIndexProcessor(BaseIndexProcessor): } def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary @@ -389,6 +393,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): tenant_id=tenant_id, text=preview.content, summary_index_setting=summary_index_setting, + document_language=doc_language, ) preview.summary = summary else: @@ -397,6 +402,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): tenant_id=tenant_id, text=preview.content, summary_index_setting=summary_index_setting, + document_language=doc_language, ) preview.summary = summary diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 272d2ed351..40d9caaa69 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -241,7 +241,11 @@ class QAIndexProcessor(BaseIndexProcessor): } def generate_summary_preview( - self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict + self, + tenant_id: str, + preview_texts: list[PreviewDetail], + summary_index_setting: dict, + doc_language: str | None = None, ) -> list[PreviewDetail]: """ QA model doesn't generate summaries, so this method returns preview_texts unchanged. diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index b88c2d510f..2aff953bc6 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -78,12 +78,21 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): indexing_technique = node_data.indexing_technique or dataset.indexing_technique summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting + # Try to get document language if document_id is available + doc_language = None + document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + if document_id: + document = db.session.query(Document).filter_by(id=document_id.value).first() + if document and document.doc_language: + doc_language = document.doc_language + outputs = self._get_preview_output_with_summaries( node_data.chunk_structure, chunks, dataset=dataset, indexing_technique=indexing_technique, summary_index_setting=summary_index_setting, + doc_language=doc_language, ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -315,6 +324,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): dataset: Dataset, indexing_technique: str | None = None, summary_index_setting: dict | None = None, + doc_language: str | None = None, ) -> Mapping[str, Any]: """ Generate preview output with summaries for chunks in preview mode. @@ -326,6 +336,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): dataset: Dataset object (for tenant_id) indexing_technique: Indexing technique from node config or dataset summary_index_setting: Summary index setting from node config or dataset + doc_language: Optional document language to ensure summary is generated in the correct language """ index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() preview_output = index_processor.format_preview(chunks) @@ -365,6 +376,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): tenant_id=dataset.tenant_id, text=preview_item["content"], summary_index_setting=summary_index_setting, + document_language=doc_language, ) if summary: preview_item["summary"] = summary @@ -374,6 +386,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): tenant_id=dataset.tenant_id, text=preview_item["content"], summary_index_setting=summary_index_setting, + document_language=doc_language, ) if summary: preview_item["summary"] = summary diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 0b3fcbe4ae..16945fca6a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -16,6 +16,7 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config +from core.db.session_factory import session_factory from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.file import helpers as file_helpers from core.helper.name_generator import generate_incremental_name @@ -1388,6 +1389,46 @@ class DocumentService: ).all() return documents + @staticmethod + def update_documents_need_summary(dataset_id: str, document_ids: Sequence[str], need_summary: bool = True) -> int: + """ + Update need_summary field for multiple documents. + + This method handles the case where documents were created when summary_index_setting was disabled, + and need to be updated when summary_index_setting is later enabled. + + Args: + dataset_id: Dataset ID + document_ids: List of document IDs to update + need_summary: Value to set for need_summary field (default: True) + + Returns: + Number of documents updated + """ + if not document_ids: + return 0 + + document_id_list: list[str] = [str(document_id) for document_id in document_ids] + + with session_factory.create_session() as session: + updated_count = ( + session.query(Document) + .filter( + Document.id.in_(document_id_list), + Document.dataset_id == dataset_id, + Document.doc_form != "qa_model", # Skip qa_model documents + ) + .update({Document.need_summary: need_summary}, synchronize_session=False) + ) + session.commit() + logger.info( + "Updated need_summary to %s for %d documents in dataset %s", + need_summary, + updated_count, + dataset_id, + ) + return updated_count + @staticmethod def get_document_download_url(document: Document) -> str: """ diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 8ea365e907..d0dfbc1070 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -174,6 +174,10 @@ class RagPipelineTransformService: else: dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() + # Copy summary_index_setting from dataset to knowledge_index node configuration + if dataset.summary_index_setting: + knowledge_configuration.summary_index_setting = dataset.summary_index_setting + knowledge_configuration_dict.update(knowledge_configuration.model_dump()) node["data"] = knowledge_configuration_dict return node diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index b8e1f8bc3f..7c03ceed5b 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -49,11 +49,18 @@ class SummaryIndexService: # Use lazy import to avoid circular import from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor + # Get document language to ensure summary is generated in the correct language + # This is especially important for image-only chunks where text is empty or minimal + document_language = None + if segment.document and segment.document.doc_language: + document_language = segment.document.doc_language + summary_content, usage = ParagraphIndexProcessor.generate_summary( tenant_id=dataset.tenant_id, text=segment.content, summary_index_setting=summary_index_setting, segment_id=segment.id, + document_language=document_language, ) if not summary_content: @@ -558,6 +565,9 @@ class SummaryIndexService: ) session.add(summary_record) + # Commit the batch created records + session.commit() + @staticmethod def update_summary_record_error( segment: DocumentSegment, @@ -762,7 +772,6 @@ class SummaryIndexService: dataset=dataset, status="not_started", ) - session.commit() # Commit initial records summary_records = [] From 603a896c496295042235dad1e94fbad74b21d927 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 2 Feb 2026 11:12:04 +0800 Subject: [PATCH 11/32] chore(CODEOWNERS): assign `.agents/skills` to @hyoban (#31816) Signed-off-by: -LAN- --- .github/CODEOWNERS | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 106c26bbed..36fa39b5d7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,6 +9,9 @@ # CODEOWNERS file /.github/CODEOWNERS @laipz8200 @crazywoola +# Agents +/.agents/skills/ @hyoban + # Docs /docs/ @crazywoola From 9fb72c151cadf84d5c7353baf7076b43f2f5a952 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 2 Feb 2026 11:18:18 +0800 Subject: [PATCH 12/32] refactor: "chore: update version to 1.12.0" (#31817) --- api/pyproject.toml | 2 +- api/uv.lock | 2 +- docker/docker-compose-template.yaml | 8 ++++---- docker/docker-compose.yaml | 8 ++++---- web/package.json | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 97e6c83ed6..02d1aea21d 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.12.0" +version = "1.11.4" requires-python = ">=3.11,<3.13" dependencies = [ diff --git a/api/uv.lock b/api/uv.lock index 04d9a7c021..ad84b35212 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1368,7 +1368,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.12.0" +version = "1.11.4" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index e27b51bcc0..eb8c2b53c5 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.12.0 + image: langgenius/dify-web:1.11.4 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index a0a755f570..02b8146aa9 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -707,7 +707,7 @@ services: # API service api: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -749,7 +749,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -788,7 +788,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.11.4 restart: always environment: # Use the shared environment variables. @@ -818,7 +818,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.12.0 + image: langgenius/dify-web:1.11.4 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/web/package.json b/web/package.json index 954366fc89..83a4f98dee 100644 --- a/web/package.json +++ b/web/package.json @@ -1,7 +1,7 @@ { "name": "dify-web", "type": "module", - "version": "1.12.0", + "version": "1.11.4", "private": true, "packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a", "imports": { From 840a975fef42b965700699928e9a02fe3e2383b4 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 2 Feb 2026 14:54:16 +0900 Subject: [PATCH 13/32] =?UTF-8?q?refactor:=20add=20test=20for=20api/contro?= =?UTF-8?q?llers/console/workspace/tool=5Fpr=E2=80=A6=20(#29886)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../console/workspace/test_tool_providers.py | 364 ++++++++++++++++++ 1 file changed, 364 insertions(+) create mode 100644 api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py new file mode 100644 index 0000000000..94c3019d5e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -0,0 +1,364 @@ +"""Endpoint tests for controllers.console.workspace.tool_providers.""" + +from __future__ import annotations + +import builtins +import importlib +from contextlib import contextmanager +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from flask.views import MethodView + +if not hasattr(builtins, "MethodView"): + builtins.MethodView = MethodView # type: ignore[attr-defined] + + +_CONTROLLER_MODULE: ModuleType | None = None +_WRAPS_MODULE: ModuleType | None = None +_CONTROLLER_PATCHERS: list[patch] = [] + + +@contextmanager +def _mock_db(): + mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True)) + with patch("extensions.ext_database.db.session", mock_session): + yield + + +@pytest.fixture +def app() -> Flask: + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture +def controller_module(monkeypatch: pytest.MonkeyPatch): + module_name = "controllers.console.workspace.tool_providers" + global _CONTROLLER_MODULE + if _CONTROLLER_MODULE is None: + + def _noop(func): + return func + + patch_targets = [ + ("libs.login.login_required", _noop), + ("controllers.console.wraps.setup_required", _noop), + ("controllers.console.wraps.account_initialization_required", _noop), + ("controllers.console.wraps.is_admin_or_owner_required", _noop), + ("controllers.console.wraps.enterprise_license_required", _noop), + ] + for target, value in patch_targets: + patcher = patch(target, value) + patcher.start() + _CONTROLLER_PATCHERS.append(patcher) + monkeypatch.setenv("DIFY_SETUP_READY", "true") + with _mock_db(): + _CONTROLLER_MODULE = importlib.import_module(module_name) + + module = _CONTROLLER_MODULE + monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload) + + # Ensure decorators that consult deployment edition do not reach the database. + global _WRAPS_MODULE + wraps_module = importlib.import_module("controllers.console.wraps") + _WRAPS_MODULE = wraps_module + monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD") + monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD") + + login_module = importlib.import_module("libs.login") + monkeypatch.setattr(login_module, "check_csrf_token", lambda *args, **kwargs: None) + return module + + +def _mock_account(user_id: str = "user-123") -> SimpleNamespace: + return SimpleNamespace(id=user_id, status="active", is_authenticated=True, current_tenant_id=None) + + +def _set_current_account( + monkeypatch: pytest.MonkeyPatch, + controller_module: ModuleType, + user: SimpleNamespace, + tenant_id: str, +) -> None: + def _getter(): + return user, tenant_id + + user.current_tenant_id = tenant_id + + monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter) + if _WRAPS_MODULE is not None: + monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter) + + login_module = importlib.import_module("libs.login") + monkeypatch.setattr(login_module, "_get_user", lambda: user) + + +def test_tool_provider_list_calls_service_with_query( + app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch +): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-456") + + service_mock = MagicMock(return_value=[{"provider": "builtin"}]) + monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock) + + with app.test_request_context("/workspaces/current/tool-providers?type=builtin"): + response = controller_module.ToolProviderListApi().get() + + assert response == [{"provider": "builtin"}] + service_mock.assert_called_once_with(user.id, "tenant-456", "builtin") + + +def test_builtin_provider_add_passes_payload( + app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch +): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-456") + + service_mock = MagicMock(return_value={"status": "ok"}) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock) + + payload = { + "credentials": {"api_key": "sk-test"}, + "name": "MyTool", + "type": controller_module.CredentialType.API_KEY, + } + + with app.test_request_context( + "/workspaces/current/tool-provider/builtin/openai/add", + method="POST", + json=payload, + ): + response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai") + + assert response == {"status": "ok"} + service_mock.assert_called_once_with( + user_id="user-123", + tenant_id="tenant-456", + provider="openai", + credentials={"api_key": "sk-test"}, + name="MyTool", + api_type=controller_module.CredentialType.API_KEY, + ) + + +def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-789") + _set_current_account(monkeypatch, controller_module, user, "tenant-789") + + service_mock = MagicMock(return_value=[{"name": "tool-a"}]) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock) + monkeypatch.setattr(controller_module, "jsonable_encoder", lambda payload: payload) + + with app.test_request_context( + "/workspaces/current/tool-provider/builtin/my-provider/tools", + method="GET", + ): + response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider") + + assert response == [{"name": "tool-a"}] + service_mock.assert_called_once_with("tenant-789", "my-provider") + + +def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-9") + _set_current_account(monkeypatch, controller_module, user, "tenant-9") + service_mock = MagicMock(return_value={"info": True}) + monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock) + + with app.test_request_context("/info", method="GET"): + resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo") + + assert resp == {"info": True} + service_mock.assert_called_once_with("tenant-9", "demo") + + +def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-cred") + _set_current_account(monkeypatch, controller_module, user, "tenant-cred") + service_mock = MagicMock(return_value=[{"cred": 1}]) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "get_builtin_tool_provider_credentials", + service_mock, + ) + + with app.test_request_context("/creds", method="GET"): + resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo") + + assert resp == [{"cred": 1}] + service_mock.assert_called_once_with(tenant_id="tenant-cred", provider_name="demo") + + +def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-10") + service_mock = MagicMock(return_value={"schema": "ok"}) + monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock) + + with app.test_request_context("/remote?url=https://example.com/"): + resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get() + + assert resp == {"schema": "ok"} + service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/") + + +def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-11") + service_mock = MagicMock(return_value=[{"tool": "t"}]) + monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock) + + with app.test_request_context("/tools?provider=foo"): + resp = controller_module.ToolApiProviderListToolsApi().get() + + assert resp == [{"tool": "t"}] + service_mock.assert_called_once_with(user.id, "tenant-11", "foo") + + +def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-12") + service_mock = MagicMock(return_value={"provider": "foo"}) + monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock) + + with app.test_request_context("/get?provider=foo"): + resp = controller_module.ToolApiProviderGetApi().get() + + assert resp == {"provider": "foo"} + service_mock.assert_called_once_with(user.id, "tenant-12", "foo") + + +def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-13") + _set_current_account(monkeypatch, controller_module, user, "tenant-13") + service_mock = MagicMock(return_value={"schema": True}) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "list_builtin_provider_credentials_schema", + service_mock, + ) + + with app.test_request_context("/schema", method="GET"): + resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get( + provider="demo", credential_type="api-key" + ) + + assert resp == {"schema": True} + service_mock.assert_called_once() + + +def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf") + tool_service = MagicMock(return_value={"wf": 1}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "get_workflow_tool_by_tool_id", + tool_service, + ) + + tool_id = "00000000-0000-0000-0000-000000000001" + with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"): + resp = controller_module.ToolWorkflowProviderGetApi().get() + + assert resp == {"wf": 1} + tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id) + + +def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf2") + service_mock = MagicMock(return_value={"app": 1}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "get_workflow_tool_by_app_id", + service_mock, + ) + + app_id = "00000000-0000-0000-0000-000000000002" + with app.test_request_context(f"/workflow?workflow_app_id={app_id}"): + resp = controller_module.ToolWorkflowProviderGetApi().get() + + assert resp == {"app": 1} + service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id) + + +def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf3") + service_mock = MagicMock(return_value=[{"id": 1}]) + monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock) + + tool_id = "00000000-0000-0000-0000-000000000003" + with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"): + resp = controller_module.ToolWorkflowProviderListToolApi().get() + + assert resp == [{"id": 1}] + service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id) + + +def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-bt") + + provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"}) + monkeypatch.setattr( + controller_module.BuiltinToolManageService, + "list_builtin_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/builtin"): + resp = controller_module.ToolBuiltinListApi().get() + + assert resp == [{"name": "builtin"}] + + +def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-tenant-api") + _set_current_account(monkeypatch, controller_module, user, "tenant-api") + + provider = SimpleNamespace(to_dict=lambda: {"name": "api"}) + monkeypatch.setattr( + controller_module.ApiToolManageService, + "list_api_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/api"): + resp = controller_module.ToolApiListApi().get() + + assert resp == [{"name": "api"}] + + +def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account() + _set_current_account(monkeypatch, controller_module, user, "tenant-wf4") + + provider = SimpleNamespace(to_dict=lambda: {"name": "wf"}) + monkeypatch.setattr( + controller_module.WorkflowToolManageService, + "list_tenant_workflow_tools", + MagicMock(return_value=[provider]), + ) + + with app.test_request_context("/tools/workflow"): + resp = controller_module.ToolWorkflowListApi().get() + + assert resp == [{"name": "wf"}] + + +def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): + user = _mock_account("user-label") + _set_current_account(monkeypatch, controller_module, user, "tenant-labels") + monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"]) + + with app.test_request_context("/tool-labels"): + resp = controller_module.ToolLabelsApi().get() + + assert resp == ["a", "b"] From ac222a4dd4f030e06a0e0b47daa7c11d0514f0d1 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 2 Feb 2026 18:03:07 +0900 Subject: [PATCH 14/32] refactor: port api/controllers/console/app/audio.py api/controllers/console/app/message.py api/controllers/console/auth/data_source_oauth.py api/controllers/console/auth/forgot_password.py api/controllers/console/workspace/endpoint.py (#30680) --- api/controllers/console/app/audio.py | 16 ++--- api/controllers/console/app/message.py | 31 +++++---- .../console/auth/data_source_oauth.py | 33 +++++++-- .../console/auth/forgot_password.py | 50 ++++++++------ api/controllers/console/workspace/endpoint.py | 69 ++++++++++++++----- .../clickzetta_volume_storage.py | 3 +- 6 files changed, 135 insertions(+), 67 deletions(-) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index d344ede466..941db325bf 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( AppUnavailableError, @@ -33,7 +34,6 @@ from services.errors.audio import ( ) logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class TextToSpeechPayload(BaseModel): @@ -47,13 +47,11 @@ class TextToSpeechVoiceQuery(BaseModel): language: str = Field(..., description="Language code") -console_ns.schema_model( - TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - TextToSpeechVoiceQuery.__name__, - TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +class AudioTranscriptResponse(BaseModel): + text: str = Field(description="Transcribed text from audio") + + +register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery) @console_ns.route("/apps//audio-to-text") @@ -64,7 +62,7 @@ class ChatMessageAudioApi(Resource): @console_ns.response( 200, "Audio transcription successful", - console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}), + console_ns.models[AudioTranscriptResponse.__name__], ) @console_ns.response(400, "Bad request - No audio uploaded or unsupported type") @console_ns.response(413, "Audio file too large") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 12ada8b798..0be3e0ec49 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, select from werkzeug.exceptions import InternalServerError, NotFound +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( CompletionRequestError, @@ -35,7 +36,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft from services.message_service import MessageService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class ChatMessagesQuery(BaseModel): @@ -90,13 +90,22 @@ class FeedbackExportQuery(BaseModel): raise ValueError("has_comment must be a boolean value") -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class AnnotationCountResponse(BaseModel): + count: int = Field(description="Number of annotations") -reg(ChatMessagesQuery) -reg(MessageFeedbackPayload) -reg(FeedbackExportQuery) +class SuggestedQuestionsResponse(BaseModel): + data: list[str] = Field(description="Suggested question") + + +register_schema_models( + console_ns, + ChatMessagesQuery, + MessageFeedbackPayload, + FeedbackExportQuery, + AnnotationCountResponse, + SuggestedQuestionsResponse, +) # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -231,7 +240,7 @@ class ChatMessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_model) @edit_permission_required def get(self, app_model): - args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ChatMessagesQuery.model_validate(request.args.to_dict()) conversation = ( db.session.query(Conversation) @@ -356,7 +365,7 @@ class MessageAnnotationCountApi(Resource): @console_ns.response( 200, "Annotation count retrieved successfully", - console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}), + console_ns.models[AnnotationCountResponse.__name__], ) @get_app_model @setup_required @@ -376,9 +385,7 @@ class MessageSuggestedQuestionApi(Resource): @console_ns.response( 200, "Suggested questions retrieved successfully", - console_ns.model( - "SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))} - ), + console_ns.models[SuggestedQuestionsResponse.__name__], ) @console_ns.response(404, "Message or conversation not found") @setup_required @@ -428,7 +435,7 @@ class MessageFeedbackExportApi(Resource): @login_required @account_initialization_required def get(self, app_model): - args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = FeedbackExportQuery.model_validate(request.args.to_dict()) # Import the service function from services.feedback_service import FeedbackService diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 0dd7d33ae9..3a3278ec9d 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -2,9 +2,11 @@ import logging import httpx from flask import current_app, redirect, request -from flask_restx import Resource, fields +from flask_restx import Resource +from pydantic import BaseModel, Field from configs import dify_config +from controllers.common.schema import register_schema_models from libs.login import login_required from libs.oauth_data_source import NotionOAuth @@ -14,6 +16,26 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required, logger = logging.getLogger(__name__) +class OAuthDataSourceResponse(BaseModel): + data: str = Field(description="Authorization URL or 'internal' for internal setup") + + +class OAuthDataSourceBindingResponse(BaseModel): + result: str = Field(description="Operation result") + + +class OAuthDataSourceSyncResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models( + console_ns, + OAuthDataSourceResponse, + OAuthDataSourceBindingResponse, + OAuthDataSourceSyncResponse, +) + + def get_oauth_providers(): with current_app.app_context(): notion_oauth = NotionOAuth( @@ -34,10 +56,7 @@ class OAuthDataSource(Resource): @console_ns.response( 200, "Authorization URL or internal setup success", - console_ns.model( - "OAuthDataSourceResponse", - {"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")}, - ), + console_ns.models[OAuthDataSourceResponse.__name__], ) @console_ns.response(400, "Invalid provider") @console_ns.response(403, "Admin privileges required") @@ -101,7 +120,7 @@ class OAuthDataSourceBinding(Resource): @console_ns.response( 200, "Data source binding success", - console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[OAuthDataSourceBindingResponse.__name__], ) @console_ns.response(400, "Invalid provider or code") def get(self, provider: str): @@ -133,7 +152,7 @@ class OAuthDataSourceSync(Resource): @console_ns.response( 200, "Data source sync success", - console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[OAuthDataSourceSyncResponse.__name__], ) @console_ns.response(400, "Invalid provider or sync failed") @setup_required diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 394f205d93..1ed931b0d7 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,10 +2,11 @@ import base64 import secrets from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailCodeError, @@ -48,8 +49,31 @@ class ForgotPasswordResetPayload(BaseModel): return valid_password(value) -for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class ForgotPasswordEmailResponse(BaseModel): + result: str = Field(description="Operation result") + data: str | None = Field(default=None, description="Reset token") + code: str | None = Field(default=None, description="Error code if account not found") + + +class ForgotPasswordCheckResponse(BaseModel): + is_valid: bool = Field(description="Whether code is valid") + email: EmailStr = Field(description="Email address") + token: str = Field(description="New reset token") + + +class ForgotPasswordResetResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models( + console_ns, + ForgotPasswordSendPayload, + ForgotPasswordCheckPayload, + ForgotPasswordResetPayload, + ForgotPasswordEmailResponse, + ForgotPasswordCheckResponse, + ForgotPasswordResetResponse, +) @console_ns.route("/forgot-password") @@ -60,14 +84,7 @@ class ForgotPasswordSendEmailApi(Resource): @console_ns.response( 200, "Email sent successfully", - console_ns.model( - "ForgotPasswordEmailResponse", - { - "result": fields.String(description="Operation result"), - "data": fields.String(description="Reset token"), - "code": fields.String(description="Error code if account not found"), - }, - ), + console_ns.models[ForgotPasswordEmailResponse.__name__], ) @console_ns.response(400, "Invalid email or rate limit exceeded") @setup_required @@ -106,14 +123,7 @@ class ForgotPasswordCheckApi(Resource): @console_ns.response( 200, "Code verified successfully", - console_ns.model( - "ForgotPasswordCheckResponse", - { - "is_valid": fields.Boolean(description="Whether code is valid"), - "email": fields.String(description="Email address"), - "token": fields.String(description="New reset token"), - }, - ), + console_ns.models[ForgotPasswordCheckResponse.__name__], ) @console_ns.response(400, "Invalid code or token") @setup_required @@ -163,7 +173,7 @@ class ForgotPasswordResetApi(Resource): @console_ns.response( 200, "Password reset successfully", - console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}), + console_ns.models[ForgotPasswordResetResponse.__name__], ) @console_ns.response(400, "Invalid token or password mismatch") @setup_required diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index bfd9fc6c29..1897cbdca7 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,9 +1,10 @@ from typing import Any from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder @@ -38,15 +39,53 @@ class EndpointListForPluginQuery(EndpointListQuery): plugin_id: str +class EndpointCreateResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointListResponse(BaseModel): + endpoints: list[dict[str, Any]] = Field(description="Endpoint information") + + +class PluginEndpointListResponse(BaseModel): + endpoints: list[dict[str, Any]] = Field(description="Endpoint information") + + +class EndpointDeleteResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointUpdateResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointEnableResponse(BaseModel): + success: bool = Field(description="Operation success") + + +class EndpointDisableResponse(BaseModel): + success: bool = Field(description="Operation success") + + def reg(cls: type[BaseModel]): console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) -reg(EndpointCreatePayload) -reg(EndpointIdPayload) -reg(EndpointUpdatePayload) -reg(EndpointListQuery) -reg(EndpointListForPluginQuery) +register_schema_models( + console_ns, + EndpointCreatePayload, + EndpointIdPayload, + EndpointUpdatePayload, + EndpointListQuery, + EndpointListForPluginQuery, + EndpointCreateResponse, + EndpointListResponse, + PluginEndpointListResponse, + EndpointDeleteResponse, + EndpointUpdateResponse, + EndpointEnableResponse, + EndpointDisableResponse, +) @console_ns.route("/workspaces/current/endpoints/create") @@ -57,7 +96,7 @@ class EndpointCreateApi(Resource): @console_ns.response( 200, "Endpoint created successfully", - console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointCreateResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -91,9 +130,7 @@ class EndpointListApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} - ), + console_ns.models[EndpointListResponse.__name__], ) @setup_required @login_required @@ -126,9 +163,7 @@ class EndpointListForSinglePluginApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} - ), + console_ns.models[PluginEndpointListResponse.__name__], ) @setup_required @login_required @@ -163,7 +198,7 @@ class EndpointDeleteApi(Resource): @console_ns.response( 200, "Endpoint deleted successfully", - console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointDeleteResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -190,7 +225,7 @@ class EndpointUpdateApi(Resource): @console_ns.response( 200, "Endpoint updated successfully", - console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointUpdateResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -221,7 +256,7 @@ class EndpointEnableApi(Resource): @console_ns.response( 200, "Endpoint enabled successfully", - console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointEnableResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required @@ -248,7 +283,7 @@ class EndpointDisableApi(Resource): @console_ns.response( 200, "Endpoint disabled successfully", - console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}), + console_ns.models[EndpointDisableResponse.__name__], ) @console_ns.response(403, "Admin privileges required") @setup_required diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index c1608f58a5..18eed4e481 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -390,8 +390,7 @@ class ClickZettaVolumeStorage(BaseStorage): """ content = self.load_once(filename) - with Path(target_filepath).open("wb") as f: - f.write(content) + Path(target_filepath).write_bytes(content) logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath) From 920db69ef2d52034df50a4f1821a7c63d003a544 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 2 Feb 2026 18:12:03 +0900 Subject: [PATCH 15/32] refactor: if to match (#31799) --- api/commands.py | 203 +++++++++--------- api/controllers/console/app/conversation.py | 23 +- .../console/datasets/datasets_document.py | 107 +++++---- api/controllers/service_api/wraps.py | 16 +- 4 files changed, 179 insertions(+), 170 deletions(-) diff --git a/api/commands.py b/api/commands.py index 4b811fb1e6..c4f2c9edbb 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1450,54 +1450,58 @@ def clear_orphaned_file_records(force: bool): all_ids_in_tables = [] for ids_table in ids_tables: query = "" - if ids_table["type"] == "uuid": - click.echo( - click.style( - f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white" + match ids_table["type"]: + case "uuid": + click.echo( + click.style( + f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", + fg="white", + ) ) - ) - query = ( - f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) - elif ids_table["type"] == "text": - click.echo( - click.style( - f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}", - fg="white", + c = ids_table["column"] + query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) + case "text": + t = ids_table["table"] + click.echo( + click.style( + f"- Listing file-id-like strings in column {ids_table['column']} in table {t}", + fg="white", + ) ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) - elif ids_table["type"] == "json": - click.echo( - click.style( - ( - f"- Listing file-id-like JSON string in column {ids_table['column']} " - f"in table {ids_table['table']}" - ), - fg="white", + query = ( + f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case "json": + click.echo( + click.style( + ( + f"- Listing file-id-like JSON string in column {ids_table['column']} " + f"in table {ids_table['table']}" + ), + fg="white", + ) + ) + query = ( + f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case _: + pass click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white")) except Exception as e: @@ -1737,59 +1741,18 @@ def file_usage( if src_filter != src: continue - if ids_table["type"] == "uuid": - # Direct UUID match - query = ( - f"SELECT {ids_table['pk_column']}, {ids_table['column']} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - ref_file_id = str(row[1]) - if ref_file_id not in file_key_map: - continue - storage_key = file_key_map[ref_file_id] - - # Apply filters - if file_id and ref_file_id != file_id: - continue - if key and not storage_key.endswith(key): - continue - - # Only collect items within the requested page range - if offset <= total_count < offset + limit: - paginated_usages.append( - { - "src": f"{ids_table['table']}.{ids_table['column']}", - "record_id": record_id, - "file_id": ref_file_id, - "key": storage_key, - } - ) - total_count += 1 - - elif ids_table["type"] in ("text", "json"): - # Extract UUIDs from text/json content - column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] - query = ( - f"SELECT {ids_table['pk_column']}, {column_cast} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - content = str(row[1]) - - # Find all UUIDs in the content - import re - - uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) - matches = uuid_pattern.findall(content) - - for ref_file_id in matches: + match ids_table["type"]: + case "uuid": + # Direct UUID match + query = ( + f"SELECT {ids_table['pk_column']}, {ids_table['column']} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + ref_file_id = str(row[1]) if ref_file_id not in file_key_map: continue storage_key = file_key_map[ref_file_id] @@ -1812,6 +1775,50 @@ def file_usage( ) total_count += 1 + case "text" | "json": + # Extract UUIDs from text/json content + column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] + query = ( + f"SELECT {ids_table['pk_column']}, {column_cast} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + content = str(row[1]) + + # Find all UUIDs in the content + import re + + uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) + matches = uuid_pattern.findall(content) + + for ref_file_id in matches: + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + case _: + pass + # Output results if output_json: result = { diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 55fdcb51e4..82cc957d04 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -508,16 +508,19 @@ class ChatConversationApi(Resource): case "created_at" | "-created_at" | _: query = query.where(Conversation.created_at <= end_datetime_utc) - if args.annotation_status == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ) - elif args.annotation_status == "not_annotated": - query = ( - query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) - .group_by(Conversation.id) - .having(func.count(MessageAnnotation.id) == 0) - ) + match args.annotation_status: + case "annotated": + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ) + case "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) + case "all": + pass if app_model.mode == AppMode.ADVANCED_CHAT: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 6a0c9e5f77..e8b8f2ec6d 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -576,63 +576,62 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict + match document.data_source_type: + case "upload_file": + if not data_source_info: + continue + file_id = data_source_info["upload_file_id"] + file_detail = ( + db.session.query(UploadFile) + .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) + .first() + ) - if document.data_source_type == "upload_file": - if not data_source_info: - continue - file_id = data_source_info["upload_file_id"] - file_detail = ( - db.session.query(UploadFile) - .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) - .first() - ) + if file_detail is None: + raise NotFound("File not found.") - if file_detail is None: - raise NotFound("File not found.") + extract_setting = ExtractSetting( + datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form + ) + extract_settings.append(extract_setting) + case "notion_import": + if not data_source_info: + continue + extract_setting = ExtractSetting( + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info.get("credential_id"), + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "tenant_id": current_tenant_id, + } + ), + document_model=document.doc_form, + ) + extract_settings.append(extract_setting) + case "website_crawl": + if not data_source_info: + continue + extract_setting = ExtractSetting( + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "url": data_source_info["url"], + "tenant_id": current_tenant_id, + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), + document_model=document.doc_form, + ) + extract_settings.append(extract_setting) - extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form - ) - extract_settings.append(extract_setting) - - elif document.data_source_type == "notion_import": - if not data_source_info: - continue - extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION, - notion_info=NotionInfo.model_validate( - { - "credential_id": data_source_info.get("credential_id"), - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "tenant_id": current_tenant_id, - } - ), - document_model=document.doc_form, - ) - extract_settings.append(extract_setting) - elif document.data_source_type == "website_crawl": - if not data_source_info: - continue - extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE, - website_info=WebsiteInfo.model_validate( - { - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "url": data_source_info["url"], - "tenant_id": current_tenant_id, - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - } - ), - document_model=document.doc_form, - ) - extract_settings.append(extract_setting) - - else: - raise ValueError("Data source type not support") + case _: + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: response = indexing_runner.indexing_estimate( diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 24acced0d1..e597a72fc0 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -73,14 +73,14 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe # If caller needs end-user context, attach EndUser to current_user if fetch_user_arg: - if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: - user_id = request.args.get("user") - elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: - user_id = request.get_json().get("user") - elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: - user_id = request.form.get("user") - else: - user_id = None + user_id = None + match fetch_user_arg.fetch_from: + case WhereisUserArg.QUERY: + user_id = request.args.get("user") + case WhereisUserArg.JSON: + user_id = request.get_json().get("user") + case WhereisUserArg.FORM: + user_id = request.form.get("user") if not user_id and fetch_user_arg.required: raise ValueError("Arg user must be provided.") From ce2c41bbf5662f2334e5b178f1d807c62b44c20e Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 2 Feb 2026 19:07:30 +0900 Subject: [PATCH 16/32] refactor: port api/controllers/console/datasets/datasets_document.py api/controllers/service_api/app/annotation.py api/core/app/app_config/easy_ui_based_app/agent/manager.py api/core/app/apps/pipeline/pipeline_generator.py api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py to match case (#31832) --- .../console/datasets/datasets_document.py | 29 +-- api/controllers/service_api/app/annotation.py | 9 +- .../easy_ui_based_app/agent/manager.py | 17 +- .../app/apps/pipeline/pipeline_generator.py | 30 +-- api/core/indexing_runner.py | 113 ++++++----- .../knowledge_retrieval_node.py | 180 +++++++++--------- api/core/workflow/nodes/tool/tool_node.py | 21 +- 7 files changed, 202 insertions(+), 197 deletions(-) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index e8b8f2ec6d..bf097d374a 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -953,23 +953,24 @@ class DocumentProcessingApi(DocumentResource): if not current_user.is_dataset_editor: raise Forbidden() - if action == "pause": - if document.indexing_status != "indexing": - raise InvalidActionError("Document not in indexing state.") + match action: + case "pause": + if document.indexing_status != "indexing": + raise InvalidActionError("Document not in indexing state.") - document.paused_by = current_user.id - document.paused_at = naive_utc_now() - document.is_paused = True - db.session.commit() + document.paused_by = current_user.id + document.paused_at = naive_utc_now() + document.is_paused = True + db.session.commit() - elif action == "resume": - if document.indexing_status not in {"paused", "error"}: - raise InvalidActionError("Document not in paused or error state.") + case "resume": + if document.indexing_status not in {"paused", "error"}: + raise InvalidActionError("Document not in paused or error state.") - document.paused_by = None - document.paused_at = None - document.is_paused = False - db.session.commit() + document.paused_by = None + document.paused_at = None + document.is_paused = False + db.session.commit() return {"result": "success"}, 200 diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 85ac9336d6..5be146a13e 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -45,10 +45,11 @@ class AnnotationReplyActionApi(Resource): def post(self, app_model: App, action: Literal["enable", "disable"]): """Enable or disable annotation reply feature.""" args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump() - if action == "enable": - result = AppAnnotationService.enable_app_annotation(args, app_model.id) - elif action == "disable": - result = AppAnnotationService.disable_app_annotation(app_model.id) + match action: + case "enable": + result = AppAnnotationService.enable_app_annotation(args, app_model.id) + case "disable": + result = AppAnnotationService.disable_app_annotation(app_model.id) return result, 200 diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index c1f336fdde..9b981dfc09 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -14,16 +14,17 @@ class AgentConfigManager: agent_dict = config.get("agent_mode", {}) agent_strategy = agent_dict.get("strategy", "cot") - if agent_strategy == "function_call": - strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy in {"cot", "react"}: - strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT - else: - # old configs, try to detect default strategy - if config["model"]["provider"] == "openai": + match agent_strategy: + case "function_call": strategy = AgentEntity.Strategy.FUNCTION_CALLING - else: + case "cot" | "react": strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT + case _: + # old configs, try to detect default strategy + if config["model"]["provider"] == "openai": + strategy = AgentEntity.Strategy.FUNCTION_CALLING + else: + strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT agent_tools = [] for tool in agent_dict.get("tools", []): diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index ea4441b5d8..eca96cb074 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -120,7 +120,7 @@ class PipelineGenerator(BaseAppGenerator): raise ValueError("Pipeline dataset is required") inputs: Mapping[str, Any] = args["inputs"] start_node_id: str = args["start_node_id"] - datasource_type: str = args["datasource_type"] + datasource_type = DatasourceProviderType(args["datasource_type"]) datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list( datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user ) @@ -660,7 +660,7 @@ class PipelineGenerator(BaseAppGenerator): tenant_id: str, dataset_id: str, built_in_field_enabled: bool, - datasource_type: str, + datasource_type: DatasourceProviderType, datasource_info: Mapping[str, Any], created_from: str, position: int, @@ -668,17 +668,17 @@ class PipelineGenerator(BaseAppGenerator): batch: str, document_form: str, ): - if datasource_type == "local_file": - name = datasource_info.get("name", "untitled") - elif datasource_type == "online_document": - name = datasource_info.get("page", {}).get("page_name", "untitled") - elif datasource_type == "website_crawl": - name = datasource_info.get("title", "untitled") - elif datasource_type == "online_drive": - name = datasource_info.get("name", "untitled") - else: - raise ValueError(f"Unsupported datasource type: {datasource_type}") - + match datasource_type: + case DatasourceProviderType.LOCAL_FILE: + name = datasource_info.get("name", "untitled") + case DatasourceProviderType.ONLINE_DOCUMENT: + name = datasource_info.get("page", {}).get("page_name", "untitled") + case DatasourceProviderType.WEBSITE_CRAWL: + name = datasource_info.get("title", "untitled") + case DatasourceProviderType.ONLINE_DRIVE: + name = datasource_info.get("name", "untitled") + case _: + raise ValueError(f"Unsupported datasource type: {datasource_type}") document = Document( tenant_id=tenant_id, dataset_id=dataset_id, @@ -706,7 +706,7 @@ class PipelineGenerator(BaseAppGenerator): def _format_datasource_info_list( self, - datasource_type: str, + datasource_type: DatasourceProviderType, datasource_info_list: list[Mapping[str, Any]], pipeline: Pipeline, workflow: Workflow, @@ -716,7 +716,7 @@ class PipelineGenerator(BaseAppGenerator): """ Format datasource info list. """ - if datasource_type == "online_drive": + if datasource_type == DatasourceProviderType.ONLINE_DRIVE: all_files: list[Mapping[str, Any]] = [] datasource_node_data = None datasource_nodes = workflow.graph_dict.get("nodes", []) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 61f168a26f..4e3ad7bb75 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -378,70 +378,69 @@ class IndexingRunner: def _extract( self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict ) -> list[Document]: - # load file - if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}: - return [] - data_source_info = dataset_document.data_source_info_dict text_docs = [] - if dataset_document.data_source_type == "upload_file": - if not data_source_info or "upload_file_id" not in data_source_info: - raise ValueError("no upload file found") - stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) - file_detail = db.session.scalars(stmt).one_or_none() + match dataset_document.data_source_type: + case "upload_file": + if not data_source_info or "upload_file_id" not in data_source_info: + raise ValueError("no upload file found") + stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) + file_detail = db.session.scalars(stmt).one_or_none() - if file_detail: + if file_detail: + extract_setting = ExtractSetting( + datasource_type=DatasourceType.FILE, + upload_file=file_detail, + document_model=dataset_document.doc_form, + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case "notion_import": + if ( + not data_source_info + or "notion_workspace_id" not in data_source_info + or "notion_page_id" not in data_source_info + ): + raise ValueError("no notion import info found") extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE, - upload_file=file_detail, + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info.get("credential_id"), + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "document": dataset_document, + "tenant_id": dataset_document.tenant_id, + } + ), document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - elif dataset_document.data_source_type == "notion_import": - if ( - not data_source_info - or "notion_workspace_id" not in data_source_info - or "notion_page_id" not in data_source_info - ): - raise ValueError("no notion import info found") - extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION, - notion_info=NotionInfo.model_validate( - { - "credential_id": data_source_info.get("credential_id"), - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "document": dataset_document, - "tenant_id": dataset_document.tenant_id, - } - ), - document_model=dataset_document.doc_form, - ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - elif dataset_document.data_source_type == "website_crawl": - if ( - not data_source_info - or "provider" not in data_source_info - or "url" not in data_source_info - or "job_id" not in data_source_info - ): - raise ValueError("no website import info found") - extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE, - website_info=WebsiteInfo.model_validate( - { - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "tenant_id": dataset_document.tenant_id, - "url": data_source_info["url"], - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - } - ), - document_model=dataset_document.doc_form, - ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case "website_crawl": + if ( + not data_source_info + or "provider" not in data_source_info + or "url" not in data_source_info + or "job_id" not in data_source_info + ): + raise ValueError("no website import info found") + extract_setting = ExtractSetting( + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "tenant_id": dataset_document.tenant_id, + "url": data_source_info["url"], + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), + document_model=dataset_document.doc_form, + ) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + case _: + return [] # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 3c4850ebac..0827494a48 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -303,33 +303,34 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") - if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": - if node_data.multiple_retrieval_config.reranking_model: - reranking_model = { - "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, - "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, - } - else: + match node_data.multiple_retrieval_config.reranking_mode: + case "reranking_model": + if node_data.multiple_retrieval_config.reranking_model: + reranking_model = { + "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, + "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, + } + else: + reranking_model = None + weights = None + case "weighted_score": + if node_data.multiple_retrieval_config.weights is None: + raise ValueError("weights is required") reranking_model = None - weights = None - elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": - if node_data.multiple_retrieval_config.weights is None: - raise ValueError("weights is required") - reranking_model = None - vector_setting = node_data.multiple_retrieval_config.weights.vector_setting - weights = { - "vector_setting": { - "vector_weight": vector_setting.vector_weight, - "embedding_provider_name": vector_setting.embedding_provider_name, - "embedding_model_name": vector_setting.embedding_model_name, - }, - "keyword_setting": { - "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight - }, - } - else: - reranking_model = None - weights = None + vector_setting = node_data.multiple_retrieval_config.weights.vector_setting + weights = { + "vector_setting": { + "vector_weight": vector_setting.vector_weight, + "embedding_provider_name": vector_setting.embedding_provider_name, + "embedding_model_name": vector_setting.embedding_model_name, + }, + "keyword_setting": { + "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight + }, + } + case _: + reranking_model = None + weights = None all_documents = dataset_retrieval.multiple_retrieve( app_id=self.app_id, tenant_id=self.tenant_id, @@ -453,73 +454,74 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD ) filters: list[Any] = [] metadata_condition = None - if node_data.metadata_filtering_mode == "disabled": - return None, None, usage - elif node_data.metadata_filtering_mode == "automatic": - automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( - dataset_ids, query, node_data - ) - usage = self._merge_usage(usage, automatic_usage) - if automatic_metadata_filters: - conditions = [] - for sequence, filter in enumerate(automatic_metadata_filters): - DatasetRetrieval.process_metadata_filter_func( - sequence, - filter.get("condition", ""), - filter.get("metadata_name", ""), - filter.get("value"), - filters, - ) - conditions.append( - Condition( - name=filter.get("metadata_name"), # type: ignore - comparison_operator=filter.get("condition"), # type: ignore - value=filter.get("value"), - ) - ) - metadata_condition = MetadataCondition( - logical_operator=node_data.metadata_filtering_conditions.logical_operator - if node_data.metadata_filtering_conditions - else "or", - conditions=conditions, + match node_data.metadata_filtering_mode: + case "disabled": + return None, None, usage + case "automatic": + automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( + dataset_ids, query, node_data ) - elif node_data.metadata_filtering_mode == "manual": - if node_data.metadata_filtering_conditions: - conditions = [] - for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore - metadata_name = condition.name - expected_value = condition.value - if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): - if isinstance(expected_value, str): - expected_value = self.graph_runtime_state.variable_pool.convert_template( - expected_value - ).value[0] - if expected_value.value_type in {"number", "integer", "float"}: - expected_value = expected_value.value - elif expected_value.value_type == "string": - expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() - else: - raise ValueError("Invalid expected metadata value type") - conditions.append( - Condition( - name=metadata_name, - comparison_operator=condition.comparison_operator, - value=expected_value, + usage = self._merge_usage(usage, automatic_usage) + if automatic_metadata_filters: + conditions = [] + for sequence, filter in enumerate(automatic_metadata_filters): + DatasetRetrieval.process_metadata_filter_func( + sequence, + filter.get("condition", ""), + filter.get("metadata_name", ""), + filter.get("value"), + filters, ) + conditions.append( + Condition( + name=filter.get("metadata_name"), # type: ignore + comparison_operator=filter.get("condition"), # type: ignore + value=filter.get("value"), + ) + ) + metadata_condition = MetadataCondition( + logical_operator=node_data.metadata_filtering_conditions.logical_operator + if node_data.metadata_filtering_conditions + else "or", + conditions=conditions, ) - filters = DatasetRetrieval.process_metadata_filter_func( - sequence, - condition.comparison_operator, - metadata_name, - expected_value, - filters, + case "manual": + if node_data.metadata_filtering_conditions: + conditions = [] + for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore + metadata_name = condition.name + expected_value = condition.value + if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): + if isinstance(expected_value, str): + expected_value = self.graph_runtime_state.variable_pool.convert_template( + expected_value + ).value[0] + if expected_value.value_type in {"number", "integer", "float"}: + expected_value = expected_value.value + elif expected_value.value_type == "string": + expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() + else: + raise ValueError("Invalid expected metadata value type") + conditions.append( + Condition( + name=metadata_name, + comparison_operator=condition.comparison_operator, + value=expected_value, + ) + ) + filters = DatasetRetrieval.process_metadata_filter_func( + sequence, + condition.comparison_operator, + metadata_name, + expected_value, + filters, + ) + metadata_condition = MetadataCondition( + logical_operator=node_data.metadata_filtering_conditions.logical_operator, + conditions=conditions, ) - metadata_condition = MetadataCondition( - logical_operator=node_data.metadata_filtering_conditions.logical_operator, - conditions=conditions, - ) - else: - raise ValueError("Invalid metadata filtering mode") + case _: + raise ValueError("Invalid metadata filtering mode") if filters: if ( node_data.metadata_filtering_conditions diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 68ac60e4f6..60d76db9b6 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -482,16 +482,17 @@ class ToolNode(Node[ToolNodeData]): result = {} for parameter_name in typed_node_data.tool_parameters: input = typed_node_data.tool_parameters[parameter_name] - if input.type == "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - selector_key = ".".join(input.value) - result[f"#{selector_key}#"] = input.value - elif input.type == "constant": - pass + match input.type: + case "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + selector_key = ".".join(input.value) + result[f"#{selector_key}#"] = input.value + case "constant": + pass result = {node_id + "." + key: value for key, value in result.items()} From 491fa9923b8d1fd3820a5df40eda4ebf22affdf5 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 2 Feb 2026 21:03:16 +0900 Subject: [PATCH 17/32] refactor: port api/controllers/console/datasets/data_source.py /datasets/metadata.py /service_api/dataset/metadata.py /nodes/agent/agent_node.py api/core/workflow/nodes/datasource/datasource_node.py api/services/dataset_service.py to match case (#31836) --- .../console/datasets/data_source.py | 40 ++-- api/controllers/console/datasets/metadata.py | 9 +- .../service_api/dataset/metadata.py | 9 +- api/core/workflow/nodes/agent/agent_node.py | 64 +++--- .../nodes/datasource/datasource_node.py | 193 +++++++++--------- api/services/dataset_service.py | 114 ++++++----- 6 files changed, 223 insertions(+), 206 deletions(-) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 01e9bf77c0..daef4e005a 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,6 +1,6 @@ import json from collections.abc import Generator -from typing import Any, cast +from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal_with @@ -157,9 +157,8 @@ class DataSourceApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, binding_id, action): + def patch(self, binding_id, action: Literal["enable", "disable"]): binding_id = str(binding_id) - action = str(action) with Session(db.engine) as session: data_source_binding = session.execute( select(DataSourceOauthBinding).filter_by(id=binding_id) @@ -167,23 +166,24 @@ class DataSourceApi(Resource): if data_source_binding is None: raise NotFound("Data source binding not found.") # enable binding - if action == "enable": - if data_source_binding.disabled: - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.add(data_source_binding) - db.session.commit() - else: - raise ValueError("Data source is not disabled.") - # disable binding - if action == "disable": - if not data_source_binding.disabled: - data_source_binding.disabled = True - data_source_binding.updated_at = naive_utc_now() - db.session.add(data_source_binding) - db.session.commit() - else: - raise ValueError("Data source is disabled.") + match action: + case "enable": + if data_source_binding.disabled: + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError("Data source is not disabled.") + # disable binding + case "disable": + if not data_source_binding.disabled: + data_source_binding.disabled = True + data_source_binding.updated_at = naive_utc_now() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError("Data source is disabled.") return {"result": "success"}, 200 diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 05fc4cd714..2e69ddc5ab 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -126,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": - MetadataService.enable_built_in_field(dataset) - elif action == "disable": - MetadataService.disable_built_in_field(dataset) + match action: + case "enable": + MetadataService.enable_built_in_field(dataset) + case "disable": + MetadataService.disable_built_in_field(dataset) return {"result": "success"}, 200 diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index b8d9508004..692342a38a 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -168,10 +168,11 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": - MetadataService.enable_built_in_field(dataset) - elif action == "disable": - MetadataService.disable_built_in_field(dataset) + match action: + case "enable": + MetadataService.enable_built_in_field(dataset) + case "disable": + MetadataService.disable_built_in_field(dataset) return {"result": "success"}, 200 diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5a365f769d..e195aebe6d 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -192,32 +192,33 @@ class AgentNode(Node[AgentNodeData]): result[parameter_name] = None continue agent_input = node_data.agent_parameters[parameter_name] - if agent_input.type == "variable": - variable = variable_pool.get(agent_input.value) # type: ignore - if variable is None: - raise AgentVariableNotFoundError(str(agent_input.value)) - parameter_value = variable.value - elif agent_input.type in {"mixed", "constant"}: - # variable_pool.convert_template expects a string template, - # but if passing a dict, convert to JSON string first before rendering - try: - if not isinstance(agent_input.value, str): - parameter_value = json.dumps(agent_input.value, ensure_ascii=False) - else: + match agent_input.type: + case "variable": + variable = variable_pool.get(agent_input.value) # type: ignore + if variable is None: + raise AgentVariableNotFoundError(str(agent_input.value)) + parameter_value = variable.value + case "mixed" | "constant": + # variable_pool.convert_template expects a string template, + # but if passing a dict, convert to JSON string first before rendering + try: + if not isinstance(agent_input.value, str): + parameter_value = json.dumps(agent_input.value, ensure_ascii=False) + else: + parameter_value = str(agent_input.value) + except TypeError: parameter_value = str(agent_input.value) - except TypeError: - parameter_value = str(agent_input.value) - segment_group = variable_pool.convert_template(parameter_value) - parameter_value = segment_group.log if for_log else segment_group.text - # variable_pool.convert_template returns a string, - # so we need to convert it back to a dictionary - try: - if not isinstance(agent_input.value, str): - parameter_value = json.loads(parameter_value) - except json.JSONDecodeError: - parameter_value = parameter_value - else: - raise AgentInputTypeError(agent_input.type) + segment_group = variable_pool.convert_template(parameter_value) + parameter_value = segment_group.log if for_log else segment_group.text + # variable_pool.convert_template returns a string, + # so we need to convert it back to a dictionary + try: + if not isinstance(agent_input.value, str): + parameter_value = json.loads(parameter_value) + except json.JSONDecodeError: + parameter_value = parameter_value + case _: + raise AgentInputTypeError(agent_input.type) value = parameter_value if parameter.type == "array[tools]": value = cast(list[dict[str, Any]], value) @@ -374,12 +375,13 @@ class AgentNode(Node[AgentNodeData]): result: dict[str, Any] = {} for parameter_name in typed_node_data.agent_parameters: input = typed_node_data.agent_parameters[parameter_name] - if input.type in ["mixed", "constant"]: - selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - result[parameter_name] = input.value + match input.type: + case "mixed" | "constant": + selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value result = {node_id + "." + key: value for key, value in result.items()} diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index fd71d610b4..a732a70417 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -270,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]): if typed_node_data.datasource_parameters: for parameter_name in typed_node_data.datasource_parameters: input = typed_node_data.datasource_parameters[parameter_name] - if input.type == "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - result[parameter_name] = input.value - elif input.type == "constant": - pass + match input.type: + case "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value + case "constant": + pass + case None: + pass result = {node_id + "." + key: value for key, value in result.items()} @@ -308,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]): variables: dict[str, Any] = {} for message in message_stream: - if message.type in { - DatasourceMessage.MessageType.IMAGE_LINK, - DatasourceMessage.MessageType.BINARY_LINK, - DatasourceMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, DatasourceMessage.TextMessage) + match message.type: + case ( + DatasourceMessage.MessageType.IMAGE_LINK + | DatasourceMessage.MessageType.BINARY_LINK + | DatasourceMessage.MessageType.IMAGE + ): + assert isinstance(message.message, DatasourceMessage.TextMessage) - url = message.message.text - transfer_method = FileTransferMethod.TOOL_FILE + url = message.message.text + transfer_method = FileTransferMethod.TOOL_FILE - datasource_file_id = str(url).split("/")[-1].split(".")[0] + datasource_file_id = str(url).split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"Tool file {datasource_file_id} does not exist") + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"Tool file {datasource_file_id} does not exist") - mapping = { - "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - files.append(file) - elif message.type == DatasourceMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, DatasourceMessage.TextMessage) - assert message.meta - - datasource_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"datasource file {datasource_file_id} not exists") - - mapping = { - "tool_file_id": datasource_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( + mapping = { + "tool_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( mapping=mapping, tenant_id=self.tenant_id, ) - ) - elif message.type == DatasourceMessage.MessageType.TEXT: - assert isinstance(message.message, DatasourceMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == DatasourceMessage.MessageType.JSON: - assert isinstance(message.message, DatasourceMessage.JsonMessage) - json.append(message.message.json_object) - elif message.type == DatasourceMessage.MessageType.LINK: - assert isinstance(message.message, DatasourceMessage.TextMessage) - stream_text = f"Link: {message.message.text}\n" - text += stream_text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == DatasourceMessage.MessageType.VARIABLE: - assert isinstance(message.message, DatasourceMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value + files.append(file) + case DatasourceMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, DatasourceMessage.TextMessage) + assert message.meta + datasource_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"datasource file {datasource_file_id} not exists") + + mapping = { + "tool_file_id": datasource_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + ) + case DatasourceMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceMessage.TextMessage) + text += message.message.text yield StreamChunkEvent( - selector=[self._node_id, variable_name], - chunk=variable_value, + selector=[self._node_id, "text"], + chunk=message.message.text, is_final=False, ) - else: - variables[variable_name] = variable_value - elif message.type == DatasourceMessage.MessageType.FILE: - assert message.meta is not None - files.append(message.meta["file"]) + case DatasourceMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceMessage.JsonMessage) + json.append(message.message.json_object) + case DatasourceMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=stream_text, + is_final=False, + ) + case DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield StreamChunkEvent( + selector=[self._node_id, variable_name], + chunk=variable_value, + is_final=False, + ) + else: + variables[variable_name] = variable_value + case DatasourceMessage.MessageType.FILE: + assert message.meta is not None + files.append(message.meta["file"]) + case ( + DatasourceMessage.MessageType.BLOB_CHUNK + | DatasourceMessage.MessageType.LOG + | DatasourceMessage.MessageType.RETRIEVER_RESOURCES + ): + pass + # mark the end of the stream yield StreamChunkEvent( selector=[self._node_id, "text"], diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 16945fca6a..1ea6c4e1c3 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2978,14 +2978,15 @@ class DocumentService: """ now = naive_utc_now() - if action == "enable": - return DocumentService._prepare_enable_update(document, now) - elif action == "disable": - return DocumentService._prepare_disable_update(document, user, now) - elif action == "archive": - return DocumentService._prepare_archive_update(document, user, now) - elif action == "un_archive": - return DocumentService._prepare_unarchive_update(document, now) + match action: + case "enable": + return DocumentService._prepare_enable_update(document, now) + case "disable": + return DocumentService._prepare_disable_update(document, user, now) + case "archive": + return DocumentService._prepare_archive_update(document, user, now) + case "un_archive": + return DocumentService._prepare_unarchive_update(document, now) return None @@ -3622,56 +3623,57 @@ class SegmentService: # Check if segment_ids is not empty to avoid WHERE false condition if not segment_ids or len(segment_ids) == 0: return - if action == "enable": - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.document_id == document.id, - DocumentSegment.enabled == False, - ) - ).all() - if not segments: - return - real_deal_segment_ids = [] - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - continue - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - db.session.add(segment) - real_deal_segment_ids.append(segment.id) - db.session.commit() + match action: + case "enable": + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == False, + ) + ).all() + if not segments: + return + real_deal_segment_ids = [] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + db.session.add(segment) + real_deal_segment_ids.append(segment.id) + db.session.commit() - enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) - elif action == "disable": - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.document_id == document.id, - DocumentSegment.enabled == True, - ) - ).all() - if not segments: - return - real_deal_segment_ids = [] - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - continue - segment.enabled = False - segment.disabled_at = naive_utc_now() - segment.disabled_by = current_user.id - db.session.add(segment) - real_deal_segment_ids.append(segment.id) - db.session.commit() + enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) + case "disable": + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == True, + ) + ).all() + if not segments: + return + real_deal_segment_ids = [] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = False + segment.disabled_at = naive_utc_now() + segment.disabled_by = current_user.id + db.session.add(segment) + real_deal_segment_ids.append(segment.id) + db.session.commit() - disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) + disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) @classmethod def create_child_chunk( From 47f8de3f8ec8c89450ed8c0e92f2347fc1a83765 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 3 Feb 2026 10:59:00 +0900 Subject: [PATCH 18/32] refactor: port api/controllers/console/app/annotation.py api/controllers/console/explore/trial.py api/controllers/console/workspace/account.py api/controllers/console/workspace/members.py api/controllers/service_api/app/annotation.py to basemodel (#31833) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- api/controllers/console/app/annotation.py | 81 +++++++------ api/controllers/console/explore/trial.py | 6 +- api/controllers/console/workspace/account.py | 41 ++++--- api/controllers/console/workspace/members.py | 27 +++-- api/controllers/service_api/app/annotation.py | 67 +++++------ .../service_api/dataset/dataset.py | 19 +-- api/fields/annotation_fields.py | 89 +++++++++----- api/fields/end_user_fields.py | 22 +++- api/fields/member_fields.py | 109 +++++++++++++----- api/fields/tag_fields.py | 26 +++-- api/fields/workflow_app_log_fields.py | 20 +--- api/services/annotation_service.py | 4 +- 12 files changed, 307 insertions(+), 204 deletions(-) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index a07145ce9f..9931bb5dd7 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,10 +1,11 @@ from typing import Any, Literal from flask import abort, make_response, request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field, field_validator +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter, field_validator from controllers.common.errors import NoFileUploadedError, TooManyFilesError +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import ( account_initialization_required, @@ -16,9 +17,11 @@ from controllers.console.wraps import ( ) from extensions.ext_redis import redis_client from fields.annotation_fields import ( - annotation_fields, - annotation_hit_history_fields, - build_annotation_model, + Annotation, + AnnotationExportList, + AnnotationHitHistory, + AnnotationHitHistoryList, + AnnotationList, ) from libs.helper import uuid_value from libs.login import login_required @@ -89,6 +92,14 @@ reg(CreateAnnotationPayload) reg(UpdateAnnotationPayload) reg(AnnotationReplyStatusQuery) reg(AnnotationFilePayload) +register_schema_models( + console_ns, + Annotation, + AnnotationList, + AnnotationExportList, + AnnotationHitHistory, + AnnotationHitHistoryList, +) @console_ns.route("/apps//annotation-reply/") @@ -202,33 +213,33 @@ class AnnotationApi(Resource): app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) - response = { - "data": marshal(annotation_list, annotation_fields), - "has_more": len(annotation_list) == limit, - "limit": limit, - "total": total, - "page": page, - } - return response, 200 + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response = AnnotationList( + data=annotation_models, + has_more=len(annotation_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json"), 200 @console_ns.doc("create_annotation") @console_ns.doc(description="Create a new annotation for an app") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__]) - @console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns)) + @console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @cloud_edition_billing_resource_check("annotation") - @marshal_with(annotation_fields) @edit_permission_required def post(self, app_id): app_id = str(app_id) args = CreateAnnotationPayload.model_validate(console_ns.payload) data = args.model_dump(exclude_none=True) annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id) - return annotation + return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @@ -265,7 +276,7 @@ class AnnotationExportApi(Resource): @console_ns.response( 200, "Annotations exported successfully", - console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}), + console_ns.models[AnnotationExportList.__name__], ) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -275,7 +286,8 @@ class AnnotationExportApi(Resource): def get(self, app_id): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response_data = {"data": marshal(annotation_list, annotation_fields)} + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json") # Create response with secure headers for CSV export response = make_response(response_data, 200) @@ -290,7 +302,7 @@ class AnnotationUpdateDeleteApi(Resource): @console_ns.doc("update_delete_annotation") @console_ns.doc(description="Update or delete an annotation") @console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) - @console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns)) + @console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__]) @console_ns.response(204, "Annotation deleted successfully") @console_ns.response(403, "Insufficient permissions") @console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__]) @@ -299,7 +311,6 @@ class AnnotationUpdateDeleteApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - @marshal_with(annotation_fields) def post(self, app_id, annotation_id): app_id = str(app_id) annotation_id = str(annotation_id) @@ -307,7 +318,7 @@ class AnnotationUpdateDeleteApi(Resource): annotation = AppAnnotationService.update_app_annotation_directly( args.model_dump(exclude_none=True), app_id, annotation_id ) - return annotation + return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @@ -415,14 +426,7 @@ class AnnotationHitHistoryListApi(Resource): @console_ns.response( 200, "Hit histories retrieved successfully", - console_ns.model( - "AnnotationHitHistoryList", - { - "data": fields.List( - fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields)) - ) - }, - ), + console_ns.models[AnnotationHitHistoryList.__name__], ) @console_ns.response(403, "Insufficient permissions") @setup_required @@ -437,11 +441,14 @@ class AnnotationHitHistoryListApi(Resource): annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( app_id, annotation_id, page, limit ) - response = { - "data": marshal(annotation_hit_history_list, annotation_hit_history_fields), - "has_more": len(annotation_hit_history_list) == limit, - "limit": limit, - "total": total, - "page": page, - } - return response + history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python( + annotation_hit_history_list, from_attributes=True + ) + response = AnnotationHitHistoryList( + data=history_models, + has_more=len(annotation_hit_history_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json") diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 1eb0cdb019..cd523b481c 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -9,7 +9,7 @@ import services from controllers.common.fields import Parameters as ParametersResponse from controllers.common.fields import Site as SiteResponse from controllers.common.schema import get_or_create_model -from controllers.console import api, console_ns +from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -51,7 +51,7 @@ from fields.app_fields import ( tag_fields, ) from fields.dataset_fields import dataset_fields -from fields.member_fields import build_simple_account_model +from fields.member_fields import simple_account_fields from fields.workflow_fields import ( conversation_variable_fields, pipeline_variable_fields, @@ -103,7 +103,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model)) app_detail_fields_with_site_copy["site"] = fields.Nested(site_model) app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy) -simple_account_model = build_simple_account_model(console_ns) +simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields) conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields) pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 38c66525b3..708df62642 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -12,6 +12,7 @@ from sqlalchemy.orm import Session from configs import dify_config from constants.languages import supported_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailAlreadyInUseError, @@ -37,7 +38,7 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.member_fields import account_fields +from fields.member_fields import Account as AccountResponse from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required @@ -170,6 +171,12 @@ reg(ChangeEmailSendPayload) reg(ChangeEmailValidityPayload) reg(ChangeEmailResetPayload) reg(CheckEmailUniquePayload) +register_schema_models(console_ns, AccountResponse) + + +def _serialize_account(account) -> dict: + return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json") + integrate_fields = { "provider": fields.String, @@ -236,11 +243,11 @@ class AccountProfileApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) @enterprise_license_required def get(self): current_user, _ = current_account_with_tenant() - return current_user + return _serialize_account(current_user) @console_ns.route("/account/name") @@ -249,14 +256,14 @@ class AccountNameApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} args = AccountNamePayload.model_validate(payload) updated_account = AccountService.update_account(current_user, name=args.name) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/avatar") @@ -265,7 +272,7 @@ class AccountAvatarApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -273,7 +280,7 @@ class AccountAvatarApi(Resource): updated_account = AccountService.update_account(current_user, avatar=args.avatar) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/interface-language") @@ -282,7 +289,7 @@ class AccountInterfaceLanguageApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -290,7 +297,7 @@ class AccountInterfaceLanguageApi(Resource): updated_account = AccountService.update_account(current_user, interface_language=args.interface_language) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/interface-theme") @@ -299,7 +306,7 @@ class AccountInterfaceThemeApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -307,7 +314,7 @@ class AccountInterfaceThemeApi(Resource): updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/timezone") @@ -316,7 +323,7 @@ class AccountTimezoneApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -324,7 +331,7 @@ class AccountTimezoneApi(Resource): updated_account = AccountService.update_account(current_user, timezone=args.timezone) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/password") @@ -333,7 +340,7 @@ class AccountPasswordApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} @@ -344,7 +351,7 @@ class AccountPasswordApi(Resource): except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() - return {"result": "success"} + return _serialize_account(current_user) @console_ns.route("/account/integrates") @@ -620,7 +627,7 @@ class ChangeEmailResetApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_fields) + @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) def post(self): payload = console_ns.payload or {} args = ChangeEmailResetPayload.model_validate(payload) @@ -649,7 +656,7 @@ class ChangeEmailResetApi(Resource): email=normalized_new_email, ) - return updated_account + return _serialize_account(updated_account) @console_ns.route("/account/change-email/check-email-unique") diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 271cdce3c3..dd302b90d6 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,12 +1,12 @@ from urllib import parse from flask import abort, request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter import services from configs import dify_config -from controllers.common.schema import get_or_create_model, register_enum_models +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( CannotTransferOwnerToSelfError, @@ -25,7 +25,7 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.member_fields import account_with_role_fields, account_with_role_list_fields +from fields.member_fields import AccountWithRole, AccountWithRoleList from libs.helper import extract_remote_ip from libs.login import current_account_with_tenant, login_required from models.account import Account, TenantAccountRole @@ -69,12 +69,7 @@ reg(OwnerTransferEmailPayload) reg(OwnerTransferCheckPayload) reg(OwnerTransferPayload) register_enum_models(console_ns, TenantAccountRole) - -account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields) - -account_with_role_list_fields_copy = account_with_role_list_fields.copy() -account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model)) -account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy) +register_schema_models(console_ns, AccountWithRole, AccountWithRoleList) @console_ns.route("/workspaces/current/members") @@ -84,13 +79,15 @@ class MemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__]) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_tenant_members(current_user.current_tenant) - return {"result": "success", "accounts": members}, 200 + member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = AccountWithRoleList(accounts=member_models) + return response.model_dump(mode="json"), 200 @console_ns.route("/workspaces/current/members/invite-email") @@ -235,13 +232,15 @@ class DatasetOperatorMemberListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(account_with_role_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__]) def get(self): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") members = TenantService.get_dataset_operator_members(current_user.current_tenant) - return {"result": "success", "accounts": members}, 200 + member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = AccountWithRoleList(accounts=member_models) + return response.model_dump(mode="json"), 200 @console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email") diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 5be146a13e..ef254ca357 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -1,16 +1,16 @@ from typing import Literal from flask import request -from flask_restx import Namespace, Resource, fields +from flask_restx import Resource from flask_restx.api import HTTPStatus -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, TypeAdapter from controllers.common.schema import register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client -from fields.annotation_fields import annotation_fields, build_annotation_model +from fields.annotation_fields import Annotation, AnnotationList from models.model import App from services.annotation_service import AppAnnotationService @@ -26,7 +26,9 @@ class AnnotationReplyActionPayload(BaseModel): embedding_model_name: str = Field(description="Embedding model name") -register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload) +register_schema_models( + service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList +) @service_api_ns.route("/apps/annotation-reply/") @@ -83,23 +85,6 @@ class AnnotationReplyActionStatusApi(Resource): return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 -# Define annotation list response model -annotation_list_fields = { - "data": fields.List(fields.Nested(annotation_fields)), - "has_more": fields.Boolean, - "limit": fields.Integer, - "total": fields.Integer, - "page": fields.Integer, -} - - -def build_annotation_list_model(api_or_ns: Namespace): - """Build the annotation list model for the API or Namespace.""" - copied_annotation_list_fields = annotation_list_fields.copy() - copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns))) - return api_or_ns.model("AnnotationList", copied_annotation_list_fields) - - @service_api_ns.route("/apps/annotations") class AnnotationListApi(Resource): @service_api_ns.doc("list_annotations") @@ -110,8 +95,12 @@ class AnnotationListApi(Resource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.response( + 200, + "Annotations retrieved successfully", + service_api_ns.models[AnnotationList.__name__], + ) @validate_app_token - @service_api_ns.marshal_with(build_annotation_list_model(service_api_ns)) def get(self, app_model: App): """List annotations for the application.""" page = request.args.get("page", default=1, type=int) @@ -119,13 +108,15 @@ class AnnotationListApi(Resource): keyword = request.args.get("keyword", default="", type=str) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword) - return { - "data": annotation_list, - "has_more": len(annotation_list) == limit, - "limit": limit, - "total": total, - "page": page, - } + annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) + response = AnnotationList( + data=annotation_models, + has_more=len(annotation_list) == limit, + limit=limit, + total=total, + page=page, + ) + return response.model_dump(mode="json") @service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__]) @service_api_ns.doc("create_annotation") @@ -136,13 +127,18 @@ class AnnotationListApi(Resource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.response( + HTTPStatus.CREATED, + "Annotation created successfully", + service_api_ns.models[Annotation.__name__], + ) @validate_app_token - @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED) def post(self, app_model: App): """Create a new annotation.""" args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) - return annotation, 201 + response = Annotation.model_validate(annotation, from_attributes=True) + return response.model_dump(mode="json"), HTTPStatus.CREATED @service_api_ns.route("/apps/annotations/") @@ -159,14 +155,19 @@ class AnnotationUpdateDeleteApi(Resource): 404: "Annotation not found", } ) + @service_api_ns.response( + 200, + "Annotation updated successfully", + service_api_ns.models[Annotation.__name__], + ) @validate_app_token @edit_permission_required - @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) def put(self, app_model: App, annotation_id: str): """Update an existing annotation.""" args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) - return annotation + response = Annotation.model_validate(annotation, from_attributes=True) + return response.model_dump(mode="json") @service_api_ns.doc("delete_annotation") @service_api_ns.doc(description="Delete an annotation") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index c11f64585a..db5cabe8aa 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -17,7 +17,7 @@ from controllers.service_api.wraps import ( from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields -from fields.tag_fields import build_dataset_tag_fields +from fields.tag_fields import DataSetTag from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum @@ -114,6 +114,7 @@ register_schema_models( TagBindingPayload, TagUnbindingPayload, DatasetListQuery, + DataSetTag, ) @@ -480,15 +481,14 @@ class DatasetTagsApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def get(self, _): """Get all knowledge type tags.""" assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None tags = TagService.get_tags("knowledge", cid) - - return tags, 200 + tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True) + return [tag.model_dump(mode="json") for tag in tag_models], 200 @service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__]) @service_api_ns.doc("create_dataset_tag") @@ -500,7 +500,6 @@ class DatasetTagsApi(DatasetApiResource): 403: "Forbidden - insufficient permissions", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def post(self, _): """Add a knowledge type tag.""" assert isinstance(current_user, Account) @@ -510,7 +509,9 @@ class DatasetTagsApi(DatasetApiResource): payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) tag = TagService.save_tags({"name": payload.name, "type": "knowledge"}) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + response = DataSetTag.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + ).model_dump(mode="json") return response, 200 @service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__]) @@ -523,7 +524,6 @@ class DatasetTagsApi(DatasetApiResource): 403: "Forbidden - insufficient permissions", } ) - @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def patch(self, _): assert isinstance(current_user, Account) if not (current_user.has_edit_permission or current_user.is_dataset_editor): @@ -536,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource): binding_count = TagService.get_tag_binding_count(tag_id) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} - + response = DataSetTag.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + ).model_dump(mode="json") return response, 200 @service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__]) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index e69306dcb2..a646950722 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,36 +1,69 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from libs.helper import TimestampField +from datetime import datetime -annotation_fields = { - "id": fields.String, - "question": fields.String, - "answer": fields.Raw(attribute="content"), - "hit_count": fields.Integer, - "created_at": TimestampField, - # 'account': fields.Nested(simple_account_fields, allow_null=True) -} +from pydantic import BaseModel, ConfigDict, Field, field_validator -def build_annotation_model(api_or_ns: Namespace): - """Build the annotation model for the API or Namespace.""" - return api_or_ns.model("Annotation", annotation_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -annotation_list_fields = { - "data": fields.List(fields.Nested(annotation_fields)), -} +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) -annotation_hit_history_fields = { - "id": fields.String, - "source": fields.String, - "score": fields.Float, - "question": fields.String, - "created_at": TimestampField, - "match": fields.String(attribute="annotation_question"), - "response": fields.String(attribute="annotation_content"), -} -annotation_hit_history_list_fields = { - "data": fields.List(fields.Nested(annotation_hit_history_fields)), -} +class Annotation(ResponseModel): + id: str + question: str | None = None + answer: str | None = Field(default=None, validation_alias="content") + hit_count: int | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AnnotationList(ResponseModel): + data: list[Annotation] + has_more: bool + limit: int + total: int + page: int + + +class AnnotationExportList(ResponseModel): + data: list[Annotation] + + +class AnnotationHitHistory(ResponseModel): + id: str + source: str | None = None + score: float | None = None + question: str | None = None + created_at: int | None = None + match: str | None = Field(default=None, validation_alias="annotation_question") + response: str | None = Field(default=None, validation_alias="annotation_content") + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AnnotationHitHistoryList(ResponseModel): + data: list[AnnotationHitHistory] + has_more: bool + limit: int + total: int + page: int diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 5389b0213a..effe7bfb20 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,7 @@ -from flask_restx import Namespace, fields +from __future__ import annotations + +from flask_restx import fields +from pydantic import BaseModel, ConfigDict simple_end_user_fields = { "id": fields.String, @@ -8,5 +11,18 @@ simple_end_user_fields = { } -def build_simple_end_user_model(api_or_ns: Namespace): - return api_or_ns.model("SimpleEndUser", simple_end_user_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class SimpleEndUser(ResponseModel): + id: str + type: str + is_anonymous: bool + session_id: str | None = None diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 25160927e6..11d9a1a2fc 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,6 +1,11 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from libs.helper import AvatarUrlField, TimestampField +from datetime import datetime + +from flask_restx import fields +from pydantic import BaseModel, ConfigDict, computed_field, field_validator + +from core.file import helpers as file_helpers simple_account_fields = { "id": fields.String, @@ -9,36 +14,78 @@ simple_account_fields = { } -def build_simple_account_model(api_or_ns: Namespace): - return api_or_ns.model("SimpleAccount", simple_account_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -account_fields = { - "id": fields.String, - "name": fields.String, - "avatar": fields.String, - "avatar_url": AvatarUrlField, - "email": fields.String, - "is_password_set": fields.Boolean, - "interface_language": fields.String, - "interface_theme": fields.String, - "timezone": fields.String, - "last_login_at": TimestampField, - "last_login_ip": fields.String, - "created_at": TimestampField, -} +def _build_avatar_url(avatar: str | None) -> str | None: + if avatar is None: + return None + if avatar.startswith(("http://", "https://")): + return avatar + return file_helpers.get_signed_file_url(avatar) -account_with_role_fields = { - "id": fields.String, - "name": fields.String, - "avatar": fields.String, - "avatar_url": AvatarUrlField, - "email": fields.String, - "last_login_at": TimestampField, - "last_active_at": TimestampField, - "created_at": TimestampField, - "role": fields.String, - "status": fields.String, -} -account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))} +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class SimpleAccount(ResponseModel): + id: str + name: str + email: str + + +class _AccountAvatar(ResponseModel): + avatar: str | None = None + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def avatar_url(self) -> str | None: + return _build_avatar_url(self.avatar) + + +class Account(_AccountAvatar): + id: str + name: str + email: str + is_password_set: bool + interface_language: str | None = None + interface_theme: str | None = None + timezone: str | None = None + last_login_at: int | None = None + last_login_ip: str | None = None + created_at: int | None = None + + @field_validator("last_login_at", "created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountWithRole(_AccountAvatar): + id: str + name: str + email: str + last_login_at: int | None = None + last_active_at: int | None = None + created_at: int | None = None + role: str + status: str + + @field_validator("last_login_at", "last_active_at", "created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountWithRoleList(ResponseModel): + accounts: list[AccountWithRole] diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index e359a4408c..7cb64e5ca8 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,12 +1,20 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -dataset_tag_fields = { - "id": fields.String, - "name": fields.String, - "type": fields.String, - "binding_count": fields.String, -} +from pydantic import BaseModel, ConfigDict -def build_dataset_tag_fields(api_or_ns: Namespace): - return api_or_ns.model("DataSetTag", dataset_tag_fields) +class ResponseModel(BaseModel): + model_config = ConfigDict( + from_attributes=True, + extra="ignore", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + +class DataSetTag(ResponseModel): + id: str + name: str + type: str + binding_count: str | None = None diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index ae70356322..d0e762f62b 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,7 +1,7 @@ from flask_restx import Namespace, fields -from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields -from fields.member_fields import build_simple_account_model, simple_account_fields +from fields.end_user_fields import simple_end_user_fields +from fields.member_fields import simple_account_fields from fields.workflow_run_fields import ( build_workflow_run_for_archived_log_model, build_workflow_run_for_log_model, @@ -25,17 +25,9 @@ workflow_app_log_partial_fields = { def build_workflow_app_log_partial_model(api_or_ns: Namespace): """Build the workflow app log partial model for the API or Namespace.""" workflow_run_model = build_workflow_run_for_log_model(api_or_ns) - simple_account_model = build_simple_account_model(api_or_ns) - simple_end_user_model = build_simple_end_user_model(api_or_ns) copied_fields = workflow_app_log_partial_fields.copy() copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True) - copied_fields["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True - ) - copied_fields["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True - ) return api_or_ns.model("WorkflowAppLogPartial", copied_fields) @@ -52,17 +44,9 @@ workflow_archived_log_partial_fields = { def build_workflow_archived_log_partial_model(api_or_ns: Namespace): """Build the workflow archived log partial model for the API or Namespace.""" workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns) - simple_account_model = build_simple_account_model(api_or_ns) - simple_end_user_model = build_simple_end_user_model(api_or_ns) copied_fields = workflow_archived_log_partial_fields.copy() copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True) - copied_fields["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True - ) - copied_fields["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True - ) return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields) diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 56e9cc6a00..8ebc87a670 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -158,7 +158,7 @@ class AppAnnotationService: .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) ) annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False) - return annotations.items, annotations.total + return annotations.items, annotations.total or 0 @classmethod def export_annotation_list_by_app_id(cls, app_id: str): @@ -524,7 +524,7 @@ class AppAnnotationService: annotation_hit_histories = db.paginate( select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False ) - return annotation_hit_histories.items, annotation_hit_histories.total + return annotation_hit_histories.items, annotation_hit_histories.total or 0 @classmethod def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: From 8b50c0d9205857800d925288c6bdc0d582045d19 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 09:59:29 +0800 Subject: [PATCH 19/32] chore(deps-dev): bump types-psutil from 7.0.0.20251116 to 7.2.2.20260130 in /api (#31814) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- api/pyproject.toml | 2 +- api/uv.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 02d1aea21d..ab1f523267 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -145,7 +145,7 @@ dev = [ "types-openpyxl~=3.1.5", "types-pexpect~=4.9.0", "types-protobuf~=5.29.1", - "types-psutil~=7.0.0", + "types-psutil~=7.2.2", "types-psycopg2~=2.9.21", "types-pygments~=2.19.0", "types-pymysql~=1.1.0", diff --git a/api/uv.lock b/api/uv.lock index ad84b35212..f253976cc1 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1707,7 +1707,7 @@ dev = [ { name = "types-openpyxl", specifier = "~=3.1.5" }, { name = "types-pexpect", specifier = "~=4.9.0" }, { name = "types-protobuf", specifier = "~=5.29.1" }, - { name = "types-psutil", specifier = "~=7.0.0" }, + { name = "types-psutil", specifier = "~=7.2.2" }, { name = "types-psycopg2", specifier = "~=2.9.21" }, { name = "types-pygments", specifier = "~=2.19.0" }, { name = "types-pymysql", specifier = "~=1.1.0" }, @@ -6508,11 +6508,11 @@ wheels = [ [[package]] name = "types-psutil" -version = "7.0.0.20251116" +version = "7.2.2.20260130" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/47/ec/c1e9308b91582cad1d7e7d3007fd003ef45a62c2500f8219313df5fc3bba/types_psutil-7.0.0.20251116.tar.gz", hash = "sha256:92b5c78962e55ce1ed7b0189901a4409ece36ab9fd50c3029cca7e681c606c8a", size = 22192, upload-time = "2025-11-16T03:10:32.859Z" } +sdist = { url = "https://files.pythonhosted.org/packages/69/14/fc5fb0a6ddfadf68c27e254a02ececd4d5c7fdb0efcb7e7e917a183497fb/types_psutil-7.2.2.20260130.tar.gz", hash = "sha256:15b0ab69c52841cf9ce3c383e8480c620a4d13d6a8e22b16978ebddac5590950", size = 26535, upload-time = "2026-01-30T03:58:14.116Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/0e/11ba08a5375c21039ed5f8e6bba41e9452fb69f0e2f7ee05ed5cca2a2cdf/types_psutil-7.0.0.20251116-py3-none-any.whl", hash = "sha256:74c052de077c2024b85cd435e2cba971165fe92a5eace79cbeb821e776dbc047", size = 25376, upload-time = "2025-11-16T03:10:31.813Z" }, + { url = "https://files.pythonhosted.org/packages/17/d7/60974b7e31545d3768d1770c5fe6e093182c3bfd819429b33133ba6b3e89/types_psutil-7.2.2.20260130-py3-none-any.whl", hash = "sha256:15523a3caa7b3ff03ac7f9b78a6470a59f88f48df1d74a39e70e06d2a99107da", size = 32876, upload-time = "2026-01-30T03:58:13.172Z" }, ] [[package]] From b55c0ec4de805e49218040f9b928379172a9d948 Mon Sep 17 00:00:00 2001 From: Stephen Zhou <38493346+hyoban@users.noreply.github.com> Date: Tue, 3 Feb 2026 12:26:47 +0800 Subject: [PATCH 20/32] fix: revert "refactor: api/controllers/console/feature.py (test)" (#31850) --- api/controllers/console/feature.py | 94 +++--- .../console/test_fastopenapi_feature.py | 291 ------------------ 2 files changed, 48 insertions(+), 337 deletions(-) delete mode 100644 api/tests/unit_tests/controllers/console/test_fastopenapi_feature.py diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 1e98d622fe..d3811e2d1b 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,58 +1,60 @@ -from pydantic import BaseModel, Field +from flask_restx import Resource, fields from werkzeug.exceptions import Unauthorized -from controllers.fastopenapi import console_router from libs.login import current_account_with_tenant, current_user, login_required -from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel +from services.feature_service import FeatureService +from . import console_ns from .wraps import account_initialization_required, cloud_utm_record, setup_required -class FeatureResponse(BaseModel): - features: FeatureModel = Field(description="Feature configuration object") +@console_ns.route("/features") +class FeatureApi(Resource): + @console_ns.doc("get_tenant_features") + @console_ns.doc(description="Get feature configuration for current tenant") + @console_ns.response( + 200, + "Success", + console_ns.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}), + ) + @setup_required + @login_required + @account_initialization_required + @cloud_utm_record + def get(self): + """Get feature configuration for current tenant""" + _, current_tenant_id = current_account_with_tenant() + + return FeatureService.get_features(current_tenant_id).model_dump() -class SystemFeatureResponse(BaseModel): - features: SystemFeatureModel = Field(description="System feature configuration object") +@console_ns.route("/system-features") +class SystemFeatureApi(Resource): + @console_ns.doc("get_system_features") + @console_ns.doc(description="Get system-wide feature configuration") + @console_ns.response( + 200, + "Success", + console_ns.model( + "SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")} + ), + ) + def get(self): + """Get system-wide feature configuration + NOTE: This endpoint is unauthenticated by design, as it provides system features + data required for dashboard initialization. -@console_router.get( - "/features", - response_model=FeatureResponse, - tags=["console"], -) -@setup_required -@login_required -@account_initialization_required -@cloud_utm_record -def get_tenant_features() -> FeatureResponse: - """Get feature configuration for current tenant.""" - _, current_tenant_id = current_account_with_tenant() + Authentication would create circular dependency (can't login without dashboard loading). - return FeatureResponse(features=FeatureService.get_features(current_tenant_id)) - - -@console_router.get( - "/system-features", - response_model=SystemFeatureResponse, - tags=["console"], -) -def get_system_features() -> SystemFeatureResponse: - """Get system-wide feature configuration - - NOTE: This endpoint is unauthenticated by design, as it provides system features - data required for dashboard initialization. - - Authentication would create circular dependency (can't login without dashboard loading). - - Only non-sensitive configuration data should be returned by this endpoint. - """ - # NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated` - # without a try-catch. However, due to the implementation of user loader (the `load_user_from_request` - # in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will - # raise `Unauthorized` exception if authentication token is not provided. - try: - is_authenticated = current_user.is_authenticated - except Unauthorized: - is_authenticated = False - return SystemFeatureResponse(features=FeatureService.get_system_features(is_authenticated=is_authenticated)) + Only non-sensitive configuration data should be returned by this endpoint. + """ + # NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated` + # without a try-catch. However, due to the implementation of user loader (the `load_user_from_request` + # in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will + # raise `Unauthorized` exception if authentication token is not provided. + try: + is_authenticated = current_user.is_authenticated + except Unauthorized: + is_authenticated = False + return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump() diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_feature.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_feature.py deleted file mode 100644 index 68495dd979..0000000000 --- a/api/tests/unit_tests/controllers/console/test_fastopenapi_feature.py +++ /dev/null @@ -1,291 +0,0 @@ -import builtins -import contextlib -import importlib -import sys -from unittest.mock import MagicMock, PropertyMock, patch - -import pytest -from flask import Flask -from flask.views import MethodView -from werkzeug.exceptions import Unauthorized - -from extensions import ext_fastopenapi -from extensions.ext_database import db -from services.feature_service import FeatureModel, SystemFeatureModel - - -@pytest.fixture -def app(): - """ - Creates a Flask application instance configured for testing. - """ - app = Flask(__name__) - app.config["TESTING"] = True - app.config["SECRET_KEY"] = "test-secret" - app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" - - # Initialize the database with the app - db.init_app(app) - - return app - - -@pytest.fixture(autouse=True) -def fix_method_view_issue(monkeypatch): - """ - Automatic fixture to patch 'builtins.MethodView'. - - Why this is needed: - The official legacy codebase contains a global patch in its initialization logic: - if not hasattr(builtins, "MethodView"): - builtins.MethodView = MethodView - - Some dependencies (like ext_fastopenapi or older Flask extensions) might implicitly - rely on 'MethodView' being available in the global builtins namespace. - - Refactoring Note: - While patching builtins is generally discouraged due to global side effects, - this fixture reproduces the production environment's state to ensure tests are realistic. - We use 'monkeypatch' to ensure that this change is undone after the test finishes, - keeping other tests isolated. - """ - if not hasattr(builtins, "MethodView"): - # 'raising=False' allows us to set an attribute that doesn't exist yet - monkeypatch.setattr(builtins, "MethodView", MethodView, raising=False) - - -# ------------------------------------------------------------------------------ -# Helper Functions for Fixture Complexity Reduction -# ------------------------------------------------------------------------------ - - -def _create_isolated_router(): - """ - Creates a fresh, isolated router instance to prevent route pollution. - """ - import controllers.fastopenapi - - # Dynamically get the class type (e.g., FlaskRouter) to avoid hardcoding dependencies - RouterClass = type(controllers.fastopenapi.console_router) - return RouterClass() - - -@contextlib.contextmanager -def _patch_auth_and_router(temp_router): - """ - Context manager that applies all necessary patches for: - 1. The console_router (redirecting to our isolated temp_router) - 2. Authentication decorators (disabling them with no-ops) - 3. User/Account loaders (mocking authenticated state) - """ - - def noop(f): - return f - - # We patch the SOURCE of the decorators/functions, not the destination module. - # This ensures that when 'controllers.console.feature' imports them, it gets the mocks. - with ( - patch("controllers.fastopenapi.console_router", temp_router), - patch("extensions.ext_fastopenapi.console_router", temp_router), - patch("controllers.console.wraps.setup_required", side_effect=noop), - patch("libs.login.login_required", side_effect=noop), - patch("controllers.console.wraps.account_initialization_required", side_effect=noop), - patch("controllers.console.wraps.cloud_utm_record", side_effect=noop), - patch("libs.login.current_account_with_tenant", return_value=(MagicMock(), "tenant-id")), - patch("libs.login.current_user", MagicMock(is_authenticated=True)), - ): - # Explicitly reload ext_fastopenapi to ensure it uses the patched console_router - import extensions.ext_fastopenapi - - importlib.reload(extensions.ext_fastopenapi) - - yield - - -def _force_reload_module(target_module: str, alias_module: str): - """ - Forces a reload of the specified module and handles sys.modules aliasing. - - Why reload? - Python decorators (like @route, @login_required) run at IMPORT time. - To apply our patches (mocks/no-ops) to these decorators, we must re-import - the module while the patches are active. - - Why alias? - If 'ext_fastopenapi' imports the controller as 'api.controllers...', but we import - it as 'controllers...', Python treats them as two separate modules. This causes: - 1. Double execution of decorators (registering routes twice -> AssertionError). - 2. Type mismatch errors (Class A from module X is not Class A from module Y). - - This function ensures both names point to the SAME loaded module instance. - """ - # 1. Clean existing entries to force re-import - if target_module in sys.modules: - del sys.modules[target_module] - if alias_module in sys.modules: - del sys.modules[alias_module] - - # 2. Import the module (triggering decorators with active patches) - module = importlib.import_module(target_module) - - # 3. Alias the module in sys.modules to prevent double loading - sys.modules[alias_module] = sys.modules[target_module] - - return module - - -def _cleanup_modules(target_module: str, alias_module: str): - """ - Removes the module and its alias from sys.modules to prevent side effects - on other tests. - """ - if target_module in sys.modules: - del sys.modules[target_module] - if alias_module in sys.modules: - del sys.modules[alias_module] - - -@pytest.fixture -def mock_feature_module_env(): - """ - Sets up a mocked environment for the feature module. - - This fixture orchestrates: - 1. Creating an isolated router. - 2. Patching authentication and global dependencies. - 3. Reloading the controller module to apply patches to decorators. - 4. cleaning up sys.modules afterwards. - """ - target_module = "controllers.console.feature" - alias_module = "api.controllers.console.feature" - - # 1. Prepare isolated router - temp_router = _create_isolated_router() - - # 2. Apply patches - try: - with _patch_auth_and_router(temp_router): - # 3. Reload module to register routes on the temp_router - feature_module = _force_reload_module(target_module, alias_module) - - yield feature_module - - finally: - # 4. Teardown: Clean up sys.modules - _cleanup_modules(target_module, alias_module) - - -# ------------------------------------------------------------------------------ -# Test Cases -# ------------------------------------------------------------------------------ - - -@pytest.mark.parametrize( - ("url", "service_mock_path", "mock_model_instance", "json_key"), - [ - ( - "/console/api/features", - "controllers.console.feature.FeatureService.get_features", - FeatureModel(can_replace_logo=True), - "features", - ), - ( - "/console/api/system-features", - "controllers.console.feature.FeatureService.get_system_features", - SystemFeatureModel(enable_marketplace=True), - "features", - ), - ], -) -def test_console_features_success(app, mock_feature_module_env, url, service_mock_path, mock_model_instance, json_key): - """ - Tests that the feature APIs return a 200 OK status and correct JSON structure. - """ - # Patch the service layer to return our mock model instance - with patch(service_mock_path, return_value=mock_model_instance): - # Initialize the API extension - ext_fastopenapi.init_app(app) - - client = app.test_client() - response = client.get(url) - - # Assertions - assert response.status_code == 200, f"Request failed with status {response.status_code}: {response.text}" - - # Verify the JSON response matches the Pydantic model dump - expected_data = mock_model_instance.model_dump(mode="json") - assert response.get_json() == {json_key: expected_data} - - -@pytest.mark.parametrize( - ("url", "service_mock_path"), - [ - ("/console/api/features", "controllers.console.feature.FeatureService.get_features"), - ("/console/api/system-features", "controllers.console.feature.FeatureService.get_system_features"), - ], -) -def test_console_features_service_error(app, mock_feature_module_env, url, service_mock_path): - """ - Tests how the application handles Service layer errors. - - Note: When an exception occurs in the view, it is typically caught by the framework - (Flask or the OpenAPI wrapper) and converted to a 500 error response. - This test verifies that the application returns a 500 status code. - """ - # Simulate a service failure - with patch(service_mock_path, side_effect=ValueError("Service Failure")): - ext_fastopenapi.init_app(app) - client = app.test_client() - - # When an exception occurs in the view, it is typically caught by the framework - # (Flask or the OpenAPI wrapper) and converted to a 500 error response. - response = client.get(url) - - assert response.status_code == 500 - # Check if the error details are exposed in the response (depends on error handler config) - # We accept either generic 500 or the specific error message - assert "Service Failure" in response.text or "Internal Server Error" in response.text - - -def test_system_features_unauthenticated(app, mock_feature_module_env): - """ - Tests that /console/api/system-features endpoint works without authentication. - - This test verifies the try-except block in get_system_features that handles - unauthenticated requests by passing is_authenticated=False to the service layer. - """ - feature_module = mock_feature_module_env - - # Override the behavior of the current_user mock - # The fixture patched 'libs.login.current_user', so 'controllers.console.feature.current_user' - # refers to that same Mock object. - mock_user = feature_module.current_user - - # Simulate property access raising Unauthorized - # Note: We must reset side_effect if it was set, or set it here. - # The fixture initialized it as MagicMock(is_authenticated=True). - # We want type(mock_user).is_authenticated to raise Unauthorized. - type(mock_user).is_authenticated = PropertyMock(side_effect=Unauthorized) - - # Patch the service layer for this specific test - with patch("controllers.console.feature.FeatureService.get_system_features") as mock_service: - # Setup mock service return value - mock_model = SystemFeatureModel(enable_marketplace=True) - mock_service.return_value = mock_model - - # Initialize app - ext_fastopenapi.init_app(app) - client = app.test_client() - - # Act - response = client.get("/console/api/system-features") - - # Assert - assert response.status_code == 200, f"Request failed: {response.text}" - - # Verify service was called with is_authenticated=False - mock_service.assert_called_once_with(is_authenticated=False) - - # Verify response body - expected_data = mock_model.model_dump(mode="json") - assert response.get_json() == {"features": expected_data} From aa7fe42615b7d0fd4a8fe638e8d57a5959a7f7a8 Mon Sep 17 00:00:00 2001 From: Coding On Star <447357187@qq.com> Date: Tue, 3 Feb 2026 13:47:30 +0800 Subject: [PATCH 21/32] test: enhance CommandSelector and GotoAnythingProvider tests (#31743) Co-authored-by: CodingOnStar --- .../app/create-app-modal/index.spec.tsx | 4 +- .../explore/create-app-modal/index.spec.tsx | 32 +- .../goto-anything/command-selector.spec.tsx | 201 ++++++ .../components/empty-state.spec.tsx | 157 +++++ .../goto-anything/components/empty-state.tsx | 105 ++++ .../goto-anything/components/footer.spec.tsx | 273 ++++++++ .../goto-anything/components/footer.tsx | 90 +++ .../goto-anything/components/index.ts | 14 + .../goto-anything/components/result-item.tsx | 38 ++ .../goto-anything/components/result-list.tsx | 49 ++ .../components/search-input.spec.tsx | 206 ++++++ .../goto-anything/components/search-input.tsx | 62 ++ .../components/goto-anything/context.spec.tsx | 77 ++- .../components/goto-anything/hooks/index.ts | 11 + .../hooks/use-goto-anything-modal.spec.ts | 291 +++++++++ .../hooks/use-goto-anything-modal.ts | 59 ++ .../use-goto-anything-navigation.spec.ts | 391 ++++++++++++ .../hooks/use-goto-anything-navigation.ts | 96 +++ .../hooks/use-goto-anything-results.spec.ts | 354 +++++++++++ .../hooks/use-goto-anything-results.ts | 115 ++++ .../hooks/use-goto-anything-search.spec.ts | 301 +++++++++ .../hooks/use-goto-anything-search.ts | 77 +++ .../components/goto-anything/index.spec.tsx | 581 +++++++++++++++-- web/app/components/goto-anything/index.tsx | 585 +++++------------- .../workflow-onboarding-modal/index.spec.tsx | 4 +- web/eslint-suppressions.json | 10 - 26 files changed, 3666 insertions(+), 517 deletions(-) create mode 100644 web/app/components/goto-anything/components/empty-state.spec.tsx create mode 100644 web/app/components/goto-anything/components/empty-state.tsx create mode 100644 web/app/components/goto-anything/components/footer.spec.tsx create mode 100644 web/app/components/goto-anything/components/footer.tsx create mode 100644 web/app/components/goto-anything/components/index.ts create mode 100644 web/app/components/goto-anything/components/result-item.tsx create mode 100644 web/app/components/goto-anything/components/result-list.tsx create mode 100644 web/app/components/goto-anything/components/search-input.spec.tsx create mode 100644 web/app/components/goto-anything/components/search-input.tsx create mode 100644 web/app/components/goto-anything/hooks/index.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-modal.spec.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-modal.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-navigation.spec.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-navigation.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-results.spec.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-results.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-search.spec.ts create mode 100644 web/app/components/goto-anything/hooks/use-goto-anything-search.ts diff --git a/web/app/components/app/create-app-modal/index.spec.tsx b/web/app/components/app/create-app-modal/index.spec.tsx index cb8f4db67f..d26a581fda 100644 --- a/web/app/components/app/create-app-modal/index.spec.tsx +++ b/web/app/components/app/create-app-modal/index.spec.tsx @@ -124,7 +124,7 @@ describe('CreateAppModal', () => { const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder') fireEvent.change(nameInput, { target: { value: 'My App' } }) - fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' })) + fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ })) await waitFor(() => expect(mockCreateApp).toHaveBeenCalledWith({ name: 'My App', @@ -152,7 +152,7 @@ describe('CreateAppModal', () => { const nameInput = screen.getByPlaceholderText('app.newApp.appNamePlaceholder') fireEvent.change(nameInput, { target: { value: 'My App' } }) - fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Create' })) + fireEvent.click(screen.getByRole('button', { name: /app\.newApp\.Create/ })) await waitFor(() => expect(mockCreateApp).toHaveBeenCalled()) expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'boom' }) diff --git a/web/app/components/explore/create-app-modal/index.spec.tsx b/web/app/components/explore/create-app-modal/index.spec.tsx index 7ddb5a9082..65ec0e6096 100644 --- a/web/app/components/explore/create-app-modal/index.spec.tsx +++ b/web/app/components/explore/create-app-modal/index.spec.tsx @@ -138,7 +138,7 @@ describe('CreateAppModal', () => { setup({ appName: 'My App', isEditModal: false }) expect(screen.getByText('explore.appCustomize.title:{"name":"My App"}')).toBeInTheDocument() - expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeInTheDocument() expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument() }) @@ -146,7 +146,7 @@ describe('CreateAppModal', () => { setup({ isEditModal: true, appMode: AppModeEnum.CHAT, max_active_requests: 5 }) expect(screen.getByText('app.editAppTitle')).toBeInTheDocument() - expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: /common\.operation\.save/ })).toBeInTheDocument() expect(screen.getByRole('switch')).toBeInTheDocument() expect((screen.getByRole('spinbutton') as HTMLInputElement).value).toBe('5') }) @@ -166,7 +166,7 @@ describe('CreateAppModal', () => { it('should not render modal content when hidden', () => { setup({ show: false }) - expect(screen.queryByRole('button', { name: 'common.operation.create' })).not.toBeInTheDocument() + expect(screen.queryByRole('button', { name: /common\.operation\.create/ })).not.toBeInTheDocument() }) }) @@ -175,13 +175,13 @@ describe('CreateAppModal', () => { it('should disable confirm action when confirmDisabled is true', () => { setup({ confirmDisabled: true }) - expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled() + expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled() }) it('should disable confirm action when appName is empty', () => { setup({ appName: ' ' }) - expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled() + expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled() }) }) @@ -245,7 +245,7 @@ describe('CreateAppModal', () => { setup({ isEditModal: false }) expect(screen.getByText('billing.apps.fullTip2')).toBeInTheDocument() - expect(screen.getByRole('button', { name: 'common.operation.create' })).toBeDisabled() + expect(screen.getByRole('button', { name: /common\.operation\.create/ })).toBeDisabled() }) it('should allow saving when apps quota is reached in edit mode', () => { @@ -257,7 +257,7 @@ describe('CreateAppModal', () => { setup({ isEditModal: true }) expect(screen.queryByText('billing.apps.fullTip2')).not.toBeInTheDocument() - expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeEnabled() + expect(screen.getByRole('button', { name: /common\.operation\.save/ })).toBeEnabled() }) }) @@ -384,7 +384,7 @@ describe('CreateAppModal', () => { fireEvent.click(screen.getByRole('button', { name: 'app.iconPicker.ok' })) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -433,7 +433,7 @@ describe('CreateAppModal', () => { expect(screen.queryByRole('button', { name: 'app.iconPicker.cancel' })).not.toBeInTheDocument() // Submit and verify the payload uses the original icon (cancel reverts to props) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -471,7 +471,7 @@ describe('CreateAppModal', () => { appIconBackground: '#000000', }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -495,7 +495,7 @@ describe('CreateAppModal', () => { const { onConfirm } = setup({ appDescription: 'Old description' }) fireEvent.change(screen.getByPlaceholderText('app.newApp.appDescriptionPlaceholder'), { target: { value: 'Updated description' } }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -512,7 +512,7 @@ describe('CreateAppModal', () => { appIconBackground: null, }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -536,7 +536,7 @@ describe('CreateAppModal', () => { fireEvent.click(screen.getByRole('switch')) fireEvent.change(screen.getByRole('spinbutton'), { target: { value: '12' } }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -551,7 +551,7 @@ describe('CreateAppModal', () => { it('should omit max_active_requests when input is empty', () => { const { onConfirm } = setup({ isEditModal: true, max_active_requests: null }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -564,7 +564,7 @@ describe('CreateAppModal', () => { const { onConfirm } = setup({ isEditModal: true, max_active_requests: null }) fireEvent.change(screen.getByRole('spinbutton'), { target: { value: 'abc' } }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.save' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.save/ })) act(() => { vi.advanceTimersByTime(300) }) @@ -576,7 +576,7 @@ describe('CreateAppModal', () => { it('should show toast error and not submit when name becomes empty before debounced submit runs', () => { const { onConfirm, onHide } = setup({ appName: 'My App' }) - fireEvent.click(screen.getByRole('button', { name: 'common.operation.create' })) + fireEvent.click(screen.getByRole('button', { name: /common\.operation\.create/ })) fireEvent.change(screen.getByPlaceholderText('app.newApp.appNamePlaceholder'), { target: { value: ' ' } }) act(() => { diff --git a/web/app/components/goto-anything/command-selector.spec.tsx b/web/app/components/goto-anything/command-selector.spec.tsx index 0ee2086058..0712a1afd6 100644 --- a/web/app/components/goto-anything/command-selector.spec.tsx +++ b/web/app/components/goto-anything/command-selector.spec.tsx @@ -81,4 +81,205 @@ describe('CommandSelector', () => { expect(onSelect).toHaveBeenCalledWith('/zen') }) + + it('should show all slash commands when no filter provided', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + // Should show the zen command from mock + expect(screen.getByText('/zen')).toBeInTheDocument() + }) + + it('should exclude slash action when in @ mode', () => { + const actions = { + ...createActions(), + slash: { + key: '/', + shortcut: '/', + title: 'Slash', + search: vi.fn(), + description: '', + } as ActionItem, + } + const onSelect = vi.fn() + + render( + + + , + ) + + // Should show @ commands but not / + expect(screen.getByText('@app')).toBeInTheDocument() + expect(screen.queryByText('/')).not.toBeInTheDocument() + }) + + it('should show all actions when no filter in @ mode', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + expect(screen.getByText('@app')).toBeInTheDocument() + expect(screen.getByText('@plugin')).toBeInTheDocument() + }) + + it('should set default command value when items exist but value does not', () => { + const actions = createActions() + const onSelect = vi.fn() + const onCommandValueChange = vi.fn() + + render( + + + , + ) + + expect(onCommandValueChange).toHaveBeenCalledWith('@app') + }) + + it('should NOT set command value when value already exists in items', () => { + const actions = createActions() + const onSelect = vi.fn() + const onCommandValueChange = vi.fn() + + render( + + + , + ) + + expect(onCommandValueChange).not.toHaveBeenCalled() + }) + + it('should show no matching commands message when filter has no results', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument() + expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument() + }) + + it('should show no matching commands for slash mode with no results', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument() + }) + + it('should render description for @ commands', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + expect(screen.getByText('app.gotoAnything.actions.searchApplicationsDesc')).toBeInTheDocument() + expect(screen.getByText('app.gotoAnything.actions.searchPluginsDesc')).toBeInTheDocument() + }) + + it('should render group header for @ mode', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + expect(screen.getByText('app.gotoAnything.selectSearchType')).toBeInTheDocument() + }) + + it('should render group header for slash mode', () => { + const actions = createActions() + const onSelect = vi.fn() + + render( + + + , + ) + + expect(screen.getByText('app.gotoAnything.groups.commands')).toBeInTheDocument() + }) }) diff --git a/web/app/components/goto-anything/components/empty-state.spec.tsx b/web/app/components/goto-anything/components/empty-state.spec.tsx new file mode 100644 index 0000000000..e1e5e0dc89 --- /dev/null +++ b/web/app/components/goto-anything/components/empty-state.spec.tsx @@ -0,0 +1,157 @@ +import { render, screen } from '@testing-library/react' +import EmptyState from './empty-state' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { ns?: string, shortcuts?: string }) => { + if (options?.shortcuts !== undefined) + return `${key}:${options.shortcuts}` + return `${options?.ns || 'common'}.${key}` + }, + }), +})) + +describe('EmptyState', () => { + describe('loading variant', () => { + it('should render loading spinner', () => { + render() + + expect(screen.getByText('app.gotoAnything.searching')).toBeInTheDocument() + }) + + it('should have spinner animation class', () => { + const { container } = render() + + const spinner = container.querySelector('.animate-spin') + expect(spinner).toBeInTheDocument() + }) + }) + + describe('error variant', () => { + it('should render error message when error has message', () => { + const error = new Error('Connection failed') + render() + + expect(screen.getByText('app.gotoAnything.searchFailed')).toBeInTheDocument() + expect(screen.getByText('Connection failed')).toBeInTheDocument() + }) + + it('should render generic error when error has no message', () => { + render() + + expect(screen.getByText('app.gotoAnything.searchTemporarilyUnavailable')).toBeInTheDocument() + expect(screen.getByText('app.gotoAnything.servicesUnavailableMessage')).toBeInTheDocument() + }) + + it('should render generic error when error is undefined', () => { + render() + + expect(screen.getByText('app.gotoAnything.searchTemporarilyUnavailable')).toBeInTheDocument() + }) + + it('should have red error text styling', () => { + const error = new Error('Test error') + const { container } = render() + + const errorText = container.querySelector('.text-red-500') + expect(errorText).toBeInTheDocument() + }) + }) + + describe('default variant', () => { + it('should render search title', () => { + render() + + expect(screen.getByText('app.gotoAnything.searchTitle')).toBeInTheDocument() + }) + + it('should render all hint messages', () => { + render() + + expect(screen.getByText('app.gotoAnything.searchHint')).toBeInTheDocument() + expect(screen.getByText('app.gotoAnything.commandHint')).toBeInTheDocument() + expect(screen.getByText('app.gotoAnything.slashHint')).toBeInTheDocument() + }) + }) + + describe('no-results variant', () => { + describe('general search mode', () => { + it('should render generic no results message', () => { + render() + + expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument() + }) + + it('should show specific search hint with shortcuts', () => { + const Actions = { + app: { key: '@app', shortcut: '@app' }, + plugin: { key: '@plugin', shortcut: '@plugin' }, + } as unknown as Record + render() + + expect(screen.getByText('gotoAnything.emptyState.trySpecificSearch:@app, @plugin')).toBeInTheDocument() + }) + }) + + describe('app search mode', () => { + it('should render no apps found message', () => { + render() + + expect(screen.getByText('app.gotoAnything.emptyState.noAppsFound')).toBeInTheDocument() + }) + + it('should show try different term hint', () => { + render() + + expect(screen.getByText('app.gotoAnything.emptyState.tryDifferentTerm')).toBeInTheDocument() + }) + }) + + describe('plugin search mode', () => { + it('should render no plugins found message', () => { + render() + + expect(screen.getByText('app.gotoAnything.emptyState.noPluginsFound')).toBeInTheDocument() + }) + }) + + describe('knowledge search mode', () => { + it('should render no knowledge bases found message', () => { + render() + + expect(screen.getByText('app.gotoAnything.emptyState.noKnowledgeBasesFound')).toBeInTheDocument() + }) + }) + + describe('node search mode', () => { + it('should render no workflow nodes found message', () => { + render() + + expect(screen.getByText('app.gotoAnything.emptyState.noWorkflowNodesFound')).toBeInTheDocument() + }) + }) + + describe('unknown search mode', () => { + it('should fallback to generic no results message', () => { + render() + + expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument() + }) + }) + }) + + describe('default props', () => { + it('should use general as default searchMode', () => { + render() + + expect(screen.getByText('app.gotoAnything.noResults')).toBeInTheDocument() + }) + + it('should use empty object as default Actions', () => { + render() + + // Should show empty shortcuts + expect(screen.getByText('gotoAnything.emptyState.trySpecificSearch:')).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/goto-anything/components/empty-state.tsx b/web/app/components/goto-anything/components/empty-state.tsx new file mode 100644 index 0000000000..a07bc1d45a --- /dev/null +++ b/web/app/components/goto-anything/components/empty-state.tsx @@ -0,0 +1,105 @@ +'use client' + +import type { FC } from 'react' +import type { ActionItem } from '../actions/types' +import { useTranslation } from 'react-i18next' + +export type EmptyStateVariant = 'no-results' | 'error' | 'default' | 'loading' + +export type EmptyStateProps = { + variant: EmptyStateVariant + searchMode?: string + error?: Error | null + Actions?: Record +} + +const EmptyState: FC = ({ + variant, + searchMode = 'general', + error, + Actions = {}, +}) => { + const { t } = useTranslation() + + if (variant === 'loading') { + return ( +
+
+
+ {t('gotoAnything.searching', { ns: 'app' })} +
+
+ ) + } + + if (variant === 'error') { + return ( +
+
+
+ {error?.message + ? t('gotoAnything.searchFailed', { ns: 'app' }) + : t('gotoAnything.searchTemporarilyUnavailable', { ns: 'app' })} +
+
+ {error?.message || t('gotoAnything.servicesUnavailableMessage', { ns: 'app' })} +
+
+
+ ) + } + + if (variant === 'default') { + return ( +
+
+
{t('gotoAnything.searchTitle', { ns: 'app' })}
+
+
{t('gotoAnything.searchHint', { ns: 'app' })}
+
{t('gotoAnything.commandHint', { ns: 'app' })}
+
{t('gotoAnything.slashHint', { ns: 'app' })}
+
+
+
+ ) + } + + // variant === 'no-results' + const isCommandSearch = searchMode !== 'general' + const commandType = isCommandSearch ? searchMode.replace('@', '') : '' + + const getNoResultsMessage = () => { + if (!isCommandSearch) { + return t('gotoAnything.noResults', { ns: 'app' }) + } + + const keyMap = { + app: 'gotoAnything.emptyState.noAppsFound', + plugin: 'gotoAnything.emptyState.noPluginsFound', + knowledge: 'gotoAnything.emptyState.noKnowledgeBasesFound', + node: 'gotoAnything.emptyState.noWorkflowNodesFound', + } as const + + return t(keyMap[commandType as keyof typeof keyMap] || 'gotoAnything.noResults', { ns: 'app' }) + } + + const getHintMessage = () => { + if (isCommandSearch) { + return t('gotoAnything.emptyState.tryDifferentTerm', { ns: 'app' }) + } + + const shortcuts = Object.values(Actions).map(action => action.shortcut).join(', ') + return t('gotoAnything.emptyState.trySpecificSearch', { ns: 'app', shortcuts }) + } + + return ( +
+
+
{getNoResultsMessage()}
+
{getHintMessage()}
+
+
+ ) +} + +export default EmptyState diff --git a/web/app/components/goto-anything/components/footer.spec.tsx b/web/app/components/goto-anything/components/footer.spec.tsx new file mode 100644 index 0000000000..3dfac5f71c --- /dev/null +++ b/web/app/components/goto-anything/components/footer.spec.tsx @@ -0,0 +1,273 @@ +import { render, screen } from '@testing-library/react' +import Footer from './footer' + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { ns?: string, count?: number, scope?: string }) => { + if (options?.count !== undefined) + return `${key}:${options.count}` + if (options?.scope) + return `${key}:${options.scope}` + return `${options?.ns || 'common'}.${key}` + }, + }), +})) + +describe('Footer', () => { + describe('left content', () => { + describe('when there are results', () => { + it('should show result count', () => { + render( +