mirror of https://github.com/langgenius/dify.git
refactor: split changes for api/controllers/console/workspace/load_ba… (#29887)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
fa1009b938
commit
93d1b2fc32
|
|
@ -1,6 +1,8 @@
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
|
@ -10,10 +12,20 @@ from models import TenantAccountRole
|
||||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||||
|
|
||||||
|
|
||||||
|
class LoadBalancingCredentialPayload(BaseModel):
|
||||||
|
model: str
|
||||||
|
model_type: ModelType
|
||||||
|
credentials: dict[str, object]
|
||||||
|
|
||||||
|
|
||||||
|
register_schema_models(console_ns, LoadBalancingCredentialPayload)
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route(
|
@console_ns.route(
|
||||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
|
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
|
||||||
)
|
)
|
||||||
class LoadBalancingCredentialsValidateApi(Resource):
|
class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -24,20 +36,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
|
|
||||||
tenant_id = current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
parser = (
|
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument(
|
|
||||||
"model_type",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
nullable=False,
|
|
||||||
choices=[mt.value for mt in ModelType],
|
|
||||||
location="json",
|
|
||||||
)
|
|
||||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# validate model load balancing credentials
|
# validate model load balancing credentials
|
||||||
model_load_balancing_service = ModelLoadBalancingService()
|
model_load_balancing_service = ModelLoadBalancingService()
|
||||||
|
|
@ -49,9 +48,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
model_load_balancing_service.validate_load_balancing_credentials(
|
model_load_balancing_service.validate_load_balancing_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=args["model"],
|
model=payload.model,
|
||||||
model_type=args["model_type"],
|
model_type=payload.model_type,
|
||||||
credentials=args["credentials"],
|
credentials=payload.credentials,
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
result = False
|
result = False
|
||||||
|
|
@ -69,6 +68,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
|
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
|
||||||
)
|
)
|
||||||
class LoadBalancingConfigCredentialsValidateApi(Resource):
|
class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||||
|
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
|
@ -79,20 +79,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||||
|
|
||||||
tenant_id = current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
parser = (
|
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||||
reqparse.RequestParser()
|
|
||||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
|
||||||
.add_argument(
|
|
||||||
"model_type",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
nullable=False,
|
|
||||||
choices=[mt.value for mt in ModelType],
|
|
||||||
location="json",
|
|
||||||
)
|
|
||||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# validate model load balancing config credentials
|
# validate model load balancing config credentials
|
||||||
model_load_balancing_service = ModelLoadBalancingService()
|
model_load_balancing_service = ModelLoadBalancingService()
|
||||||
|
|
@ -104,9 +91,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||||
model_load_balancing_service.validate_load_balancing_credentials(
|
model_load_balancing_service.validate_load_balancing_credentials(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
model=args["model"],
|
model=payload.model,
|
||||||
model_type=args["model_type"],
|
model_type=payload.model_type,
|
||||||
credentials=args["credentials"],
|
credentials=payload.credentials,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
)
|
)
|
||||||
except CredentialsValidateFailedError as ex:
|
except CredentialsValidateFailedError as ex:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,145 @@
|
||||||
|
"""Unit tests for load balancing credential validation APIs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
from flask.views import MethodView
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
|
||||||
|
if not hasattr(builtins, "MethodView"):
|
||||||
|
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
from models.account import TenantAccountRole
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app() -> Flask:
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config["TESTING"] = True
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def load_balancing_module(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Reload controller module with lightweight decorators for testing."""
|
||||||
|
|
||||||
|
from controllers.console import console_ns, wraps
|
||||||
|
from libs import login
|
||||||
|
|
||||||
|
def _noop(func):
|
||||||
|
return func
|
||||||
|
|
||||||
|
monkeypatch.setattr(login, "login_required", _noop)
|
||||||
|
monkeypatch.setattr(wraps, "setup_required", _noop)
|
||||||
|
monkeypatch.setattr(wraps, "account_initialization_required", _noop)
|
||||||
|
|
||||||
|
def _noop_route(*args, **kwargs): # type: ignore[override]
|
||||||
|
def _decorator(cls):
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return _decorator
|
||||||
|
|
||||||
|
monkeypatch.setattr(console_ns, "route", _noop_route)
|
||||||
|
|
||||||
|
module_name = "controllers.console.workspace.load_balancing_config"
|
||||||
|
sys.modules.pop(module_name, None)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_user(role: TenantAccountRole) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(current_role=role)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER):
|
||||||
|
user = _mock_user(role)
|
||||||
|
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123"))
|
||||||
|
mock_service = MagicMock()
|
||||||
|
monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service)
|
||||||
|
return mock_service
|
||||||
|
|
||||||
|
|
||||||
|
def _request_payload():
|
||||||
|
return {"model": "gpt-4o", "model_type": ModelType.LLM, "credentials": {"api_key": "sk-***"}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials_success(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||||
|
method="POST",
|
||||||
|
json=_request_payload(),
|
||||||
|
):
|
||||||
|
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
|
||||||
|
|
||||||
|
assert response == {"result": "success"}
|
||||||
|
service.validate_load_balancing_credentials.assert_called_once_with(
|
||||||
|
tenant_id="tenant-123",
|
||||||
|
provider="openai",
|
||||||
|
model="gpt-4o",
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
credentials={"api_key": "sk-***"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials_returns_error_message(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||||
|
service.validate_load_balancing_credentials.side_effect = CredentialsValidateFailedError("invalid credentials")
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||||
|
method="POST",
|
||||||
|
json=_request_payload(),
|
||||||
|
):
|
||||||
|
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
|
||||||
|
|
||||||
|
assert response == {"result": "error", "error": "invalid credentials"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials_requires_privileged_role(
|
||||||
|
app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
_prepare_context(load_balancing_module, monkeypatch, role=TenantAccountRole.NORMAL)
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||||
|
method="POST",
|
||||||
|
json=_request_payload(),
|
||||||
|
):
|
||||||
|
api = load_balancing_module.LoadBalancingCredentialsValidateApi()
|
||||||
|
with pytest.raises(Forbidden):
|
||||||
|
api.post(provider="openai")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials_with_config_id(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||||
|
|
||||||
|
with app.test_request_context(
|
||||||
|
"/workspaces/current/model-providers/openai/models/load-balancing-configs/cfg-1/credentials-validate",
|
||||||
|
method="POST",
|
||||||
|
json=_request_payload(),
|
||||||
|
):
|
||||||
|
response = load_balancing_module.LoadBalancingConfigCredentialsValidateApi().post(
|
||||||
|
provider="openai", config_id="cfg-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response == {"result": "success"}
|
||||||
|
service.validate_load_balancing_credentials.assert_called_once_with(
|
||||||
|
tenant_id="tenant-123",
|
||||||
|
provider="openai",
|
||||||
|
model="gpt-4o",
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
credentials={"api_key": "sk-***"},
|
||||||
|
config_id="cfg-1",
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue