mirror of https://github.com/langgenius/dify.git
refactor: split changes for api/controllers/console/extension.py (#29888)
This commit is contained in:
parent
111a39b549
commit
0a448a13c8
|
|
@ -1,14 +1,32 @@
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask import request
|
||||||
|
from flask_restx import Resource, fields, marshal_with
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from controllers.console import console_ns
|
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
|
||||||
from fields.api_based_extension_fields import api_based_extension_fields
|
from fields.api_based_extension_fields import api_based_extension_fields
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.api_based_extension import APIBasedExtension
|
from models.api_based_extension import APIBasedExtension
|
||||||
from services.api_based_extension_service import APIBasedExtensionService
|
from services.api_based_extension_service import APIBasedExtensionService
|
||||||
from services.code_based_extension_service import CodeBasedExtensionService
|
from services.code_based_extension_service import CodeBasedExtensionService
|
||||||
|
|
||||||
|
from ..common.schema import register_schema_models
|
||||||
|
from . import console_ns
|
||||||
|
from .wraps import account_initialization_required, setup_required
|
||||||
|
|
||||||
|
|
||||||
|
class CodeBasedExtensionQuery(BaseModel):
|
||||||
|
module: str
|
||||||
|
|
||||||
|
|
||||||
|
class APIBasedExtensionPayload(BaseModel):
|
||||||
|
name: str = Field(description="Extension name")
|
||||||
|
api_endpoint: str = Field(description="API endpoint URL")
|
||||||
|
api_key: str = Field(description="API key for authentication")
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, APIBasedExtensionPayload)
|
||||||
|
|
||||||
|
|
||||||
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
|
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
|
||||||
|
|
||||||
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
|
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
|
||||||
|
|
@ -18,11 +36,7 @@ api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_m
|
||||||
class CodeBasedExtensionAPI(Resource):
|
class CodeBasedExtensionAPI(Resource):
|
||||||
@console_ns.doc("get_code_based_extension")
|
@console_ns.doc("get_code_based_extension")
|
||||||
@console_ns.doc(description="Get code-based extension data by module name")
|
@console_ns.doc(description="Get code-based extension data by module name")
|
||||||
@console_ns.expect(
|
@console_ns.doc(params={"module": "Extension module name"})
|
||||||
console_ns.parser().add_argument(
|
|
||||||
"module", type=str, required=True, location="args", help="Extension module name"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(
|
@console_ns.response(
|
||||||
200,
|
200,
|
||||||
"Success",
|
"Success",
|
||||||
|
|
@ -35,10 +49,9 @@ class CodeBasedExtensionAPI(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args")
|
query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
|
return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/api-based-extension")
|
@console_ns.route("/api-based-extension")
|
||||||
|
|
@ -56,30 +69,21 @@ class APIBasedExtensionAPI(Resource):
|
||||||
|
|
||||||
@console_ns.doc("create_api_based_extension")
|
@console_ns.doc("create_api_based_extension")
|
||||||
@console_ns.doc(description="Create a new API-based extension")
|
@console_ns.doc(description="Create a new API-based extension")
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"CreateAPIBasedExtensionRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(required=True, description="Extension name"),
|
|
||||||
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
|
|
||||||
"api_key": fields.String(required=True, description="API key for authentication"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
|
@console_ns.response(201, "Extension created successfully", api_based_extension_model)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_model)
|
@marshal_with(api_based_extension_model)
|
||||||
def post(self):
|
def post(self):
|
||||||
args = console_ns.payload
|
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
extension_data = APIBasedExtension(
|
extension_data = APIBasedExtension(
|
||||||
tenant_id=current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
name=args["name"],
|
name=payload.name,
|
||||||
api_endpoint=args["api_endpoint"],
|
api_endpoint=payload.api_endpoint,
|
||||||
api_key=args["api_key"],
|
api_key=payload.api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
return APIBasedExtensionService.save(extension_data)
|
return APIBasedExtensionService.save(extension_data)
|
||||||
|
|
@ -104,16 +108,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||||
@console_ns.doc("update_api_based_extension")
|
@console_ns.doc("update_api_based_extension")
|
||||||
@console_ns.doc(description="Update API-based extension")
|
@console_ns.doc(description="Update API-based extension")
|
||||||
@console_ns.doc(params={"id": "Extension ID"})
|
@console_ns.doc(params={"id": "Extension ID"})
|
||||||
@console_ns.expect(
|
@console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__])
|
||||||
console_ns.model(
|
|
||||||
"UpdateAPIBasedExtensionRequest",
|
|
||||||
{
|
|
||||||
"name": fields.String(required=True, description="Extension name"),
|
|
||||||
"api_endpoint": fields.String(required=True, description="API endpoint URL"),
|
|
||||||
"api_key": fields.String(required=True, description="API key for authentication"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
|
@console_ns.response(200, "Extension updated successfully", api_based_extension_model)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
|
|
@ -125,13 +120,13 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||||
|
|
||||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
args = console_ns.payload
|
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||||
|
|
||||||
extension_data_from_db.name = args["name"]
|
extension_data_from_db.name = payload.name
|
||||||
extension_data_from_db.api_endpoint = args["api_endpoint"]
|
extension_data_from_db.api_endpoint = payload.api_endpoint
|
||||||
|
|
||||||
if args["api_key"] != HIDDEN_VALUE:
|
if payload.api_key != HIDDEN_VALUE:
|
||||||
extension_data_from_db.api_key = args["api_key"]
|
extension_data_from_db.api_key = payload.api_key
|
||||||
|
|
||||||
return APIBasedExtensionService.save(extension_data_from_db)
|
return APIBasedExtensionService.save(extension_data_from_db)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,236 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
from flask.views import MethodView as FlaskMethodView
|
||||||
|
|
||||||
|
_NEEDS_METHOD_VIEW_CLEANUP = False
|
||||||
|
if not hasattr(builtins, "MethodView"):
|
||||||
|
builtins.MethodView = FlaskMethodView
|
||||||
|
_NEEDS_METHOD_VIEW_CLEANUP = True
|
||||||
|
|
||||||
|
from constants import HIDDEN_VALUE
|
||||||
|
from controllers.console.extension import (
|
||||||
|
APIBasedExtensionAPI,
|
||||||
|
APIBasedExtensionDetailAPI,
|
||||||
|
CodeBasedExtensionAPI,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _NEEDS_METHOD_VIEW_CLEANUP:
|
||||||
|
delattr(builtins, "MethodView")
|
||||||
|
from models.account import AccountStatus
|
||||||
|
from models.api_based_extension import APIBasedExtension
|
||||||
|
|
||||||
|
|
||||||
|
def _make_extension(
|
||||||
|
*,
|
||||||
|
name: str = "Sample Extension",
|
||||||
|
api_endpoint: str = "https://example.com/api",
|
||||||
|
api_key: str = "super-secret-key",
|
||||||
|
) -> APIBasedExtension:
|
||||||
|
extension = APIBasedExtension(
|
||||||
|
tenant_id="tenant-123",
|
||||||
|
name=name,
|
||||||
|
api_endpoint=api_endpoint,
|
||||||
|
api_key=api_key,
|
||||||
|
)
|
||||||
|
extension.id = f"{uuid.uuid4()}"
|
||||||
|
extension.created_at = datetime.now(tz=UTC)
|
||||||
|
return extension
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _mock_console_guards(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||||
|
"""Bypass console decorators so handlers can run in isolation."""
|
||||||
|
|
||||||
|
import controllers.console.extension as extension_module
|
||||||
|
from controllers.console import wraps as wraps_module
|
||||||
|
|
||||||
|
account = MagicMock()
|
||||||
|
account.status = AccountStatus.ACTIVE
|
||||||
|
account.current_tenant_id = "tenant-123"
|
||||||
|
account.id = "account-123"
|
||||||
|
account.is_authenticated = True
|
||||||
|
|
||||||
|
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
|
||||||
|
monkeypatch.setattr("libs.login.dify_config.LOGIN_DISABLED", True)
|
||||||
|
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||||
|
monkeypatch.setattr(extension_module, "current_account_with_tenant", lambda: (account, "tenant-123"))
|
||||||
|
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (account, "tenant-123"))
|
||||||
|
|
||||||
|
# The login_required decorator consults the shared LocalProxy in libs.login.
|
||||||
|
monkeypatch.setattr("libs.login.current_user", account)
|
||||||
|
monkeypatch.setattr("libs.login.check_csrf_token", lambda *_, **__: None)
|
||||||
|
|
||||||
|
return account
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _restx_mask_defaults(app: Flask):
|
||||||
|
app.config.setdefault("RESTX_MASK_HEADER", "X-Fields")
|
||||||
|
app.config.setdefault("RESTX_MASK_SWAGGER", False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_code_based_extension_get_returns_service_data(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
service_result = {"entrypoint": "main:agent"}
|
||||||
|
service_mock = MagicMock(return_value=service_result)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"controllers.console.extension.CodeBasedExtensionService.get_code_based_extension",
|
||||||
|
service_mock,
|
||||||
|
)
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/console/api/code-based-extension",
|
||||||
|
method="GET",
|
||||||
|
query_string={"module": "workflow.tools"},
|
||||||
|
):
|
||||||
|
response = CodeBasedExtensionAPI().get()
|
||||||
|
|
||||||
|
assert response == {"module": "workflow.tools", "data": service_result}
|
||||||
|
service_mock.assert_called_once_with("workflow.tools")
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_based_extension_get_returns_tenant_extensions(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
extension = _make_extension(name="Weather API", api_key="abcdefghi123")
|
||||||
|
service_mock = MagicMock(return_value=[extension])
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"controllers.console.extension.APIBasedExtensionService.get_all_by_tenant_id",
|
||||||
|
service_mock,
|
||||||
|
)
|
||||||
|
|
||||||
|
with app.test_request_context("/console/api/api-based-extension", method="GET"):
|
||||||
|
response = APIBasedExtensionAPI().get()
|
||||||
|
|
||||||
|
assert response[0]["id"] == extension.id
|
||||||
|
assert response[0]["name"] == "Weather API"
|
||||||
|
assert response[0]["api_endpoint"] == extension.api_endpoint
|
||||||
|
assert response[0]["api_key"].startswith(extension.api_key[:3])
|
||||||
|
service_mock.assert_called_once_with("tenant-123")
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
saved_extension = _make_extension(name="Docs API", api_key="saved-secret")
|
||||||
|
save_mock = MagicMock(return_value=saved_extension)
|
||||||
|
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"name": "Docs API",
|
||||||
|
"api_endpoint": "https://docs.example.com/hook",
|
||||||
|
"api_key": "plain-secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
with app.test_request_context("/console/api/api-based-extension", method="POST", json=payload):
|
||||||
|
response = APIBasedExtensionAPI().post()
|
||||||
|
|
||||||
|
args, _ = save_mock.call_args
|
||||||
|
created_extension: APIBasedExtension = args[0]
|
||||||
|
assert created_extension.tenant_id == "tenant-123"
|
||||||
|
assert created_extension.name == payload["name"]
|
||||||
|
assert created_extension.api_endpoint == payload["api_endpoint"]
|
||||||
|
assert created_extension.api_key == payload["api_key"]
|
||||||
|
assert response["name"] == saved_extension.name
|
||||||
|
save_mock.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_based_extension_detail_get_fetches_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
extension = _make_extension(name="Docs API", api_key="abcdefg12345")
|
||||||
|
service_mock = MagicMock(return_value=extension)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||||
|
service_mock,
|
||||||
|
)
|
||||||
|
|
||||||
|
extension_id = uuid.uuid4()
|
||||||
|
with app.test_request_context(f"/console/api/api-based-extension/{extension_id}", method="GET"):
|
||||||
|
response = APIBasedExtensionDetailAPI().get(extension_id)
|
||||||
|
|
||||||
|
assert response["id"] == extension.id
|
||||||
|
assert response["name"] == extension.name
|
||||||
|
service_mock.assert_called_once_with("tenant-123", str(extension_id))
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
existing_extension = _make_extension(name="Docs API", api_key="keep-me")
|
||||||
|
get_mock = MagicMock(return_value=existing_extension)
|
||||||
|
save_mock = MagicMock(return_value=existing_extension)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||||
|
get_mock,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"name": "Docs API Updated",
|
||||||
|
"api_endpoint": "https://docs.example.com/v2",
|
||||||
|
"api_key": HIDDEN_VALUE,
|
||||||
|
}
|
||||||
|
|
||||||
|
extension_id = uuid.uuid4()
|
||||||
|
with app.test_request_context(
|
||||||
|
f"/console/api/api-based-extension/{extension_id}",
|
||||||
|
method="POST",
|
||||||
|
json=payload,
|
||||||
|
):
|
||||||
|
response = APIBasedExtensionDetailAPI().post(extension_id)
|
||||||
|
|
||||||
|
assert existing_extension.name == payload["name"]
|
||||||
|
assert existing_extension.api_endpoint == payload["api_endpoint"]
|
||||||
|
assert existing_extension.api_key == "keep-me"
|
||||||
|
save_mock.assert_called_once_with(existing_extension)
|
||||||
|
assert response["name"] == payload["name"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_based_extension_detail_post_updates_api_key_when_provided(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
existing_extension = _make_extension(name="Docs API", api_key="old-secret")
|
||||||
|
get_mock = MagicMock(return_value=existing_extension)
|
||||||
|
save_mock = MagicMock(return_value=existing_extension)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||||
|
get_mock,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.save", save_mock)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"name": "Docs API Updated",
|
||||||
|
"api_endpoint": "https://docs.example.com/v2",
|
||||||
|
"api_key": "new-secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
extension_id = uuid.uuid4()
|
||||||
|
with app.test_request_context(
|
||||||
|
f"/console/api/api-based-extension/{extension_id}",
|
||||||
|
method="POST",
|
||||||
|
json=payload,
|
||||||
|
):
|
||||||
|
response = APIBasedExtensionDetailAPI().post(extension_id)
|
||||||
|
|
||||||
|
assert existing_extension.api_key == "new-secret"
|
||||||
|
save_mock.assert_called_once_with(existing_extension)
|
||||||
|
assert response["name"] == payload["name"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_based_extension_detail_delete_removes_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
existing_extension = _make_extension()
|
||||||
|
get_mock = MagicMock(return_value=existing_extension)
|
||||||
|
delete_mock = MagicMock()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"controllers.console.extension.APIBasedExtensionService.get_with_tenant_id",
|
||||||
|
get_mock,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("controllers.console.extension.APIBasedExtensionService.delete", delete_mock)
|
||||||
|
|
||||||
|
extension_id = uuid.uuid4()
|
||||||
|
with app.test_request_context(
|
||||||
|
f"/console/api/api-based-extension/{extension_id}",
|
||||||
|
method="DELETE",
|
||||||
|
):
|
||||||
|
response, status = APIBasedExtensionDetailAPI().delete(extension_id)
|
||||||
|
|
||||||
|
delete_mock.assert_called_once_with(existing_extension)
|
||||||
|
assert response == {"result": "success"}
|
||||||
|
assert status == 204
|
||||||
Loading…
Reference in New Issue