From 0a448a13c8b3b062063a5704b2737729e37af62b Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Wed, 24 Dec 2025 10:41:42 +0900 Subject: [PATCH] refactor: split changes for api/controllers/console/extension.py (#29888) --- api/controllers/console/extension.py | 75 +++--- .../controllers/console/test_extension.py | 236 ++++++++++++++++++ 2 files changed, 271 insertions(+), 40 deletions(-) create mode 100644 api/tests/unit_tests/controllers/console/test_extension.py diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 08f29b4655..efa46c9779 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,14 +1,32 @@ -from flask_restx import Resource, fields, marshal_with, reqparse +from flask import request +from flask_restx import Resource, fields, marshal_with +from pydantic import BaseModel, Field from constants import HIDDEN_VALUE -from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields from libs.login import current_account_with_tenant, login_required from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService +from ..common.schema import register_schema_models +from . import console_ns +from .wraps import account_initialization_required, setup_required + + +class CodeBasedExtensionQuery(BaseModel): + module: str + + +class APIBasedExtensionPayload(BaseModel): + name: str = Field(description="Extension name") + api_endpoint: str = Field(description="API endpoint URL") + api_key: str = Field(description="API key for authentication") + + +register_schema_models(console_ns, APIBasedExtensionPayload) + + api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields) api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model)) @@ -18,11 +36,7 @@ api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_m class CodeBasedExtensionAPI(Resource): @console_ns.doc("get_code_based_extension") @console_ns.doc(description="Get code-based extension data by module name") - @console_ns.expect( - console_ns.parser().add_argument( - "module", type=str, required=True, location="args", help="Extension module name" - ) - ) + @console_ns.doc(params={"module": "Extension module name"}) @console_ns.response( 200, "Success", @@ -35,10 +49,9 @@ class CodeBasedExtensionAPI(Resource): @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args") - args = parser.parse_args() + query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} + return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)} @console_ns.route("/api-based-extension") @@ -56,30 +69,21 @@ class APIBasedExtensionAPI(Resource): @console_ns.doc("create_api_based_extension") @console_ns.doc(description="Create a new API-based extension") - @console_ns.expect( - console_ns.model( - "CreateAPIBasedExtensionRequest", - { - "name": fields.String(required=True, description="Extension name"), - "api_endpoint": fields.String(required=True, description="API endpoint URL"), - "api_key": fields.String(required=True, description="API key for authentication"), - }, - ) - ) + @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__]) @console_ns.response(201, "Extension created successfully", api_based_extension_model) @setup_required @login_required @account_initialization_required @marshal_with(api_based_extension_model) def post(self): - args = console_ns.payload + payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() extension_data = APIBasedExtension( tenant_id=current_tenant_id, - name=args["name"], - api_endpoint=args["api_endpoint"], - api_key=args["api_key"], + name=payload.name, + api_endpoint=payload.api_endpoint, + api_key=payload.api_key, ) return APIBasedExtensionService.save(extension_data) @@ -104,16 +108,7 @@ class APIBasedExtensionDetailAPI(Resource): @console_ns.doc("update_api_based_extension") @console_ns.doc(description="Update API-based extension") @console_ns.doc(params={"id": "Extension ID"}) - @console_ns.expect( - console_ns.model( - "UpdateAPIBasedExtensionRequest", - { - "name": fields.String(required=True, description="Extension name"), - "api_endpoint": fields.String(required=True, description="API endpoint URL"), - "api_key": fields.String(required=True, description="API key for authentication"), - }, - ) - ) + @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__]) @console_ns.response(200, "Extension updated successfully", api_based_extension_model) @setup_required @login_required @@ -125,13 +120,13 @@ class APIBasedExtensionDetailAPI(Resource): extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) - args = console_ns.payload + payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {}) - extension_data_from_db.name = args["name"] - extension_data_from_db.api_endpoint = args["api_endpoint"] + extension_data_from_db.name = payload.name + extension_data_from_db.api_endpoint = payload.api_endpoint - if args["api_key"] != HIDDEN_VALUE: - extension_data_from_db.api_key = args["api_key"] + if payload.api_key != HIDDEN_VALUE: + extension_data_from_db.api_key = payload.api_key return APIBasedExtensionService.save(extension_data_from_db) diff --git a/api/tests/unit_tests/controllers/console/test_extension.py b/api/tests/unit_tests/controllers/console/test_extension.py new file mode 100644 index 0000000000..32b41baa27 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_extension.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import builtins +import uuid +from datetime import UTC, datetime +from unittest.mock import MagicMock + +import pytest +from flask import Flask +from flask.views import MethodView as FlaskMethodView + +_NEEDS_METHOD_VIEW_CLEANUP = False +if not hasattr(builtins, "MethodView"): + builtins.MethodView = FlaskMethodView + _NEEDS_METHOD_VIEW_CLEANUP = True + +from constants import HIDDEN_VALUE +from controllers.console.extension import ( + APIBasedExtensionAPI, + APIBasedExtensionDetailAPI, + CodeBasedExtensionAPI, +) + +if _NEEDS_METHOD_VIEW_CLEANUP: + delattr(builtins, "MethodView") +from models.account import AccountStatus +from models.api_based_extension import APIBasedExtension + + +def _make_extension( + *, + name: str = "Sample Extension", + api_endpoint: str = "https://example.com/api", + api_key: str = "super-secret-key", +) -> APIBasedExtension: + extension = APIBasedExtension( + tenant_id="tenant-123", + name=name, + api_endpoint=api_endpoint, + api_key=api_key, + ) + extension.id = f"{uuid.uuid4()}" + extension.created_at = datetime.now(tz=UTC) + return extension + + +@pytest.fixture(autouse=True) +def _mock_console_guards(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + """Bypass console decorators so handlers can run in isolation.""" + + import controllers.console.extension as extension_module + from controllers.console import wraps as wraps_module + + account = MagicMock() + account.status = AccountStatus.ACTIVE + account.current_tenant_id = "tenant-123" + account.id = "account-123" + account.is_authenticated = True + + monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD") + monkeypatch.setattr("libs.login.dify_config.LOGIN_DISABLED", True) + monkeypatch.delenv("INIT_PASSWORD", raising=False) + monkeypatch.setattr(extension_module, "current_account_with_tenant", lambda: (account, "tenant-123")) + monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (account, "tenant-123")) + + # The login_required decorator consults the shared LocalProxy in libs.login. + monkeypatch.setattr("libs.login.current_user", account) + monkeypatch.setattr("libs.login.check_csrf_token", lambda *_, **__: None) + + return account + + +@pytest.fixture(autouse=True) +def _restx_mask_defaults(app: Flask): + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + app.config.setdefault("RESTX_MASK_SWAGGER", False) + + +def test_code_based_extension_get_returns_service_data(app: Flask, monkeypatch: pytest.MonkeyPatch): + service_result = {"entrypoint": "main:agent"} + service_mock = MagicMock(return_value=service_result) + monkeypatch.setattr( + "controllers.console.extension.CodeBasedExtensionService.get_code_based_extension", + service_mock, + ) + + with app.test_request_context( + "/console/api/code-based-extension", + method="GET", + query_string={"module": "workflow.tools"}, + ): + response = CodeBasedExtensionAPI().get() + + assert response == {"module": "workflow.tools", "data": service_result} + service_mock.assert_called_once_with("workflow.tools") + + +def test_api_based_extension_get_returns_tenant_extensions(app: Flask, monkeypatch: pytest.MonkeyPatch): + extension = _make_extension(name="Weather API", api_key="abcdefghi123") + service_mock = MagicMock(return_value=[extension]) + monkeypatch.setattr( + "controllers.console.extension.APIBasedExtensionService.get_all_by_tenant_id", + service_mock, + ) + + with app.test_request_context("/console/api/api-based-extension", method="GET"): + response = APIBasedExtensionAPI().get() + + assert response[0]["id"] == extension.id + assert response[0]["name"] == "Weather API" + assert response[0]["api_endpoint"] == extension.api_endpoint + assert response[0]["api_key"].startswith(extension.api_key[:3]) + service_mock.assert_called_once_with("tenant-123") + + +def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pytest.MonkeyPatch): + saved_extension = _make_extension(name="Docs API", api_key="saved-secret") + save_mock = MagicMock(return_value=saved_extension) + monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock) + + payload = { + "name": "Docs API", + "api_endpoint": "https://docs.example.com/hook", + "api_key": "plain-secret", + } + + with app.test_request_context("/console/api/api-based-extension", method="POST", json=payload): + response = APIBasedExtensionAPI().post() + + args, _ = save_mock.call_args + created_extension: APIBasedExtension = args[0] + assert created_extension.tenant_id == "tenant-123" + assert created_extension.name == payload["name"] + assert created_extension.api_endpoint == payload["api_endpoint"] + assert created_extension.api_key == payload["api_key"] + assert response["name"] == saved_extension.name + save_mock.assert_called_once() + + +def test_api_based_extension_detail_get_fetches_extension(app: Flask, monkeypatch: pytest.MonkeyPatch): + extension = _make_extension(name="Docs API", api_key="abcdefg12345") + service_mock = MagicMock(return_value=extension) + monkeypatch.setattr( + "controllers.console.extension.APIBasedExtensionService.get_with_tenant_id", + service_mock, + ) + + extension_id = uuid.uuid4() + with app.test_request_context(f"/console/api/api-based-extension/{extension_id}", method="GET"): + response = APIBasedExtensionDetailAPI().get(extension_id) + + assert response["id"] == extension.id + assert response["name"] == extension.name + service_mock.assert_called_once_with("tenant-123", str(extension_id)) + + +def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkeypatch: pytest.MonkeyPatch): + existing_extension = _make_extension(name="Docs API", api_key="keep-me") + get_mock = MagicMock(return_value=existing_extension) + save_mock = MagicMock(return_value=existing_extension) + monkeypatch.setattr( + "controllers.console.extension.APIBasedExtensionService.get_with_tenant_id", + get_mock, + ) + monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock) + + payload = { + "name": "Docs API Updated", + "api_endpoint": "https://docs.example.com/v2", + "api_key": HIDDEN_VALUE, + } + + extension_id = uuid.uuid4() + with app.test_request_context( + f"/console/api/api-based-extension/{extension_id}", + method="POST", + json=payload, + ): + response = APIBasedExtensionDetailAPI().post(extension_id) + + assert existing_extension.name == payload["name"] + assert existing_extension.api_endpoint == payload["api_endpoint"] + assert existing_extension.api_key == "keep-me" + save_mock.assert_called_once_with(existing_extension) + assert response["name"] == payload["name"] + + +def test_api_based_extension_detail_post_updates_api_key_when_provided(app: Flask, monkeypatch: pytest.MonkeyPatch): + existing_extension = _make_extension(name="Docs API", api_key="old-secret") + get_mock = MagicMock(return_value=existing_extension) + save_mock = MagicMock(return_value=existing_extension) + monkeypatch.setattr( + "controllers.console.extension.APIBasedExtensionService.get_with_tenant_id", + get_mock, + ) + monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock) + + payload = { + "name": "Docs API Updated", + "api_endpoint": "https://docs.example.com/v2", + "api_key": "new-secret", + } + + extension_id = uuid.uuid4() + with app.test_request_context( + f"/console/api/api-based-extension/{extension_id}", + method="POST", + json=payload, + ): + response = APIBasedExtensionDetailAPI().post(extension_id) + + assert existing_extension.api_key == "new-secret" + save_mock.assert_called_once_with(existing_extension) + assert response["name"] == payload["name"] + + +def test_api_based_extension_detail_delete_removes_extension(app: Flask, monkeypatch: pytest.MonkeyPatch): + existing_extension = _make_extension() + get_mock = MagicMock(return_value=existing_extension) + delete_mock = MagicMock() + monkeypatch.setattr( + "controllers.console.extension.APIBasedExtensionService.get_with_tenant_id", + get_mock, + ) + monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.delete", delete_mock) + + extension_id = uuid.uuid4() + with app.test_request_context( + f"/console/api/api-based-extension/{extension_id}", + method="DELETE", + ): + response, status = APIBasedExtensionDetailAPI().delete(extension_id) + + delete_mock.assert_called_once_with(existing_extension) + assert response == {"result": "success"} + assert status == 204