diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index d4f501d34c..47a9b8aedb 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -11,7 +11,7 @@ from controllers.console.app.error import ( ProviderNotInitializeError, ProviderQuotaExceededError, ) -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id from core.app.app_config.entities import ModelConfig from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.helper.code_executor.code_node_provider import CodeNodeProvider @@ -22,7 +22,7 @@ from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMMode from graphon.model_runtime.errors.invoke import InvokeError -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models import App from services.workflow_service import WorkflowService @@ -64,9 +64,9 @@ class RuleGenerateApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): + @with_current_tenant_id + def post(self, current_tenant_id: str): args = RuleGeneratePayload.model_validate(console_ns.payload) - _, current_tenant_id = current_account_with_tenant() try: rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args) @@ -93,9 +93,9 @@ class RuleCodeGenerateApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): + @with_current_tenant_id + def post(self, current_tenant_id: str): args = RuleCodeGeneratePayload.model_validate(console_ns.payload) - _, current_tenant_id = current_account_with_tenant() try: code_result = LLMGenerator.generate_code( @@ -125,9 +125,9 @@ class RuleStructuredOutputGenerateApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): + @with_current_tenant_id + def post(self, current_tenant_id: str): args = RuleStructuredOutputPayload.model_validate(console_ns.payload) - _, current_tenant_id = current_account_with_tenant() try: structured_output = LLMGenerator.generate_structured_output( @@ -157,9 +157,9 @@ class InstructionGenerateApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): + @with_current_tenant_id + def post(self, current_tenant_id: str): args = InstructionGeneratePayload.model_validate(console_ns.payload) - _, current_tenant_id = current_account_with_tenant() providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] code_provider: type[CodeNodeProvider] | None = next( (p for p in providers if p.is_accept_language(args.language)), None diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index e5ef15c3e1..157bec4fbc 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -11,11 +11,16 @@ from werkzeug.exceptions import NotFound from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_tenant_id, +) from extensions.ext_database import db from fields.base import ResponseModel from libs.helper import to_timestamp -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer @@ -92,8 +97,8 @@ class AppMCPServerController(Resource): @login_required @setup_required @edit_permission_required - def post(self, app_model: App): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def post(self, current_tenant_id: str, app_model: App): payload = MCPServerCreatePayload.model_validate(console_ns.payload or {}) description = payload.description @@ -163,8 +168,8 @@ class AppMCPServerRefreshController(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, server_id: UUID): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, current_tenant_id: str, server_id: UUID): server = db.session.scalar( select(AppMCPServer) .where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 21c6b1c7ce..6c703d782d 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -8,12 +8,17 @@ from pydantic import BaseModel, Field, field_validator from controllers.common.fields import SimpleResultResponse from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + is_admin_or_owner_required, + setup_required, + with_current_tenant_id, +) from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from services.model_load_balancing_service import ModelLoadBalancingService from services.model_provider_service import ModelProviderService @@ -138,9 +143,8 @@ class DefaultModelApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str): args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) model_provider_service = ModelProviderService() @@ -156,9 +160,8 @@ class DefaultModelApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str): args = ParserPostDefault.model_validate(console_ns.payload) model_provider_service = ModelProviderService() model_settings = args.model_settings @@ -189,9 +192,8 @@ class ModelProviderModelApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str, provider): model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) @@ -202,9 +204,9 @@ class ModelProviderModelApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self, provider: str): + @with_current_tenant_id + def post(self, tenant_id: str, provider: str): # To save the model's load balance configs - _, tenant_id = current_account_with_tenant() args = ParserPostModels.model_validate(console_ns.payload) if args.config_from == "custom-model": @@ -249,9 +251,8 @@ class ModelProviderModelApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def delete(self, provider: str): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def delete(self, tenant_id: str, provider: str): args = ParserDeleteModels.model_validate(console_ns.payload) model_provider_service = ModelProviderService() @@ -268,9 +269,8 @@ class ModelProviderModelCredentialApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider: str): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str, provider: str): args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) model_provider_service = ModelProviderService() @@ -323,9 +323,8 @@ class ModelProviderModelCredentialApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self, provider: str): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str, provider: str): args = ParserCreateCredential.model_validate(console_ns.payload) model_provider_service = ModelProviderService() @@ -355,8 +354,8 @@ class ModelProviderModelCredentialApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def put(self, provider: str): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def put(self, current_tenant_id: str, provider: str): args = ParserUpdateCredential.model_validate(console_ns.payload) model_provider_service = ModelProviderService() @@ -382,8 +381,8 @@ class ModelProviderModelCredentialApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def delete(self, provider: str): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def delete(self, current_tenant_id: str, provider: str): args = ParserDeleteCredential.model_validate(console_ns.payload) model_provider_service = ModelProviderService() @@ -406,8 +405,8 @@ class ModelProviderModelCredentialSwitchApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self, provider: str): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def post(self, current_tenant_id: str, provider: str): args = ParserSwitch.model_validate(console_ns.payload) service = ModelProviderService() @@ -430,9 +429,8 @@ class ModelProviderModelEnableApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, provider: str): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def patch(self, tenant_id: str, provider: str): args = ParserDeleteModels.model_validate(console_ns.payload) model_provider_service = ModelProviderService() @@ -452,9 +450,8 @@ class ModelProviderModelDisableApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, provider: str): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def patch(self, tenant_id: str, provider: str): args = ParserDeleteModels.model_validate(console_ns.payload) model_provider_service = ModelProviderService() @@ -480,8 +477,8 @@ class ModelProviderModelValidateApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, provider: str): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def post(self, tenant_id: str, provider: str): args = ParserValidate.model_validate(console_ns.payload) model_provider_service = ModelProviderService() @@ -515,9 +512,9 @@ class ModelProviderModelParameterRuleApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider: str): + @with_current_tenant_id + def get(self, tenant_id: str, provider: str): args = ParserParameter.model_validate(request.args.to_dict(flat=True)) - _, tenant_id = current_account_with_tenant() model_provider_service = ModelProviderService() parameter_rules = model_provider_service.get_model_parameter_rules( @@ -532,8 +529,8 @@ class ModelProviderAvailableModelApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, model_type: str): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, model_type: str): model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) diff --git a/api/tests/unit_tests/controllers/console/app/test_generator_api.py b/api/tests/unit_tests/controllers/console/app/test_generator_api.py index e64c508b82..11c6acfcc1 100644 --- a/api/tests/unit_tests/controllers/console/app/test_generator_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_generator_api.py @@ -34,7 +34,6 @@ def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None: api = generator_module.RuleGenerateApi() method = _unwrap(api.post) - monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []}) with app.test_request_context( @@ -42,7 +41,7 @@ def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None: method="POST", json={"instruction": "do it", "model_config": _model_config_payload()}, ): - response = method() + response = method("t1") assert response == {"rules": []} @@ -51,8 +50,6 @@ def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatc api = generator_module.RuleCodeGenerateApi() method = _unwrap(api.post) - monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) - def _raise(*_args, **_kwargs): raise ProviderTokenNotInitError("missing token") @@ -64,15 +61,13 @@ def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatc json={"instruction": "do it", "model_config": _model_config_payload()}, ): with pytest.raises(ProviderNotInitializeError): - method() + method("t1") def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: api = generator_module.InstructionGenerateApi() method = _unwrap(api.post) - monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) - monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: None)) with app.test_request_context( @@ -85,7 +80,7 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch "model_config": _model_config_payload(), }, ): - response, status = method() + response, status = method("t1") assert status == 400 assert response["error"] == "app app-1 not found" @@ -95,8 +90,6 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey api = generator_module.InstructionGenerateApi() method = _unwrap(api.post) - monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) - app_model = SimpleNamespace(id="app-1") monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) _install_workflow_service(monkeypatch, workflow=None) @@ -111,7 +104,7 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey "model_config": _model_config_payload(), }, ): - response, status = method() + response, status = method("t1") assert status == 400 assert response["error"] == "workflow app-1 not found" @@ -121,8 +114,6 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) api = generator_module.InstructionGenerateApi() method = _unwrap(api.post) - monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) - app_model = SimpleNamespace(id="app-1") monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) @@ -139,7 +130,7 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) "model_config": _model_config_payload(), }, ): - response, status = method() + response, status = method("t1") assert status == 400 assert response["error"] == "node node-1 not found" @@ -149,8 +140,6 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> api = generator_module.InstructionGenerateApi() method = _unwrap(api.post) - monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) - app_model = SimpleNamespace(id="app-1") monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(get=lambda *_args, **_kwargs: app_model)) @@ -174,7 +163,7 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> "model_config": _model_config_payload(), }, ): - response = method() + response = method("t1") assert response == {"code": "x"} @@ -183,7 +172,6 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch api = generator_module.InstructionGenerateApi() method = _unwrap(api.post) - monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) monkeypatch.setattr( generator_module.LLMGenerator, "instruction_modify_legacy", @@ -201,7 +189,7 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch "model_config": _model_config_payload(), }, ): - response = method() + response = method("t1") assert response == {"instruction": "ok"} @@ -210,8 +198,6 @@ def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.Monke api = generator_module.InstructionGenerateApi() method = _unwrap(api.post) - monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) - with app.test_request_context( "/console/api/instruction-generate", method="POST", @@ -223,7 +209,7 @@ def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.Monke "model_config": _model_config_payload(), }, ): - response, status = method() + response, status = method("t1") assert status == 400 assert response["error"] == "incompatible parameters" diff --git a/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py index 1af15d8dc6..c7dba15216 100644 --- a/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py +++ b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py @@ -121,7 +121,6 @@ class TestAppMCPServerController: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), - patch("controllers.console.app.mcp_server.current_account_with_tenant", return_value=(None, "tenant-1")), patch("controllers.console.app.mcp_server.db.session.add"), patch("controllers.console.app.mcp_server.db.session.commit"), patch("controllers.console.app.mcp_server.AppMCPServer.generate_server_code", return_value="server-code"), @@ -131,7 +130,7 @@ class TestAppMCPServerController: ), ): response, status_code = method( - api, app_model=SimpleNamespace(id="app-1", name="Demo App", description="App description") + api, "tenant-1", app_model=SimpleNamespace(id="app-1", name="Demo App", description="App description") ) assert response == {"id": "server-1"} diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index 3c4acbab44..564505d32b 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from flask import Flask @@ -34,15 +34,11 @@ class TestDefaultModelApi: "/", query_string={"model_type": ModelType.LLM}, ), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService") as service_mock, ): service_mock.return_value.get_default_model_of_model_type.return_value = {"model": "gpt-4"} - result = method(api) + result = method(api, "tenant1") assert "data" in result @@ -62,13 +58,9 @@ class TestDefaultModelApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService"), ): - result = method(api) + result = method(api, "tenant1") assert result["result"] == "success" @@ -78,12 +70,11 @@ class TestDefaultModelApi: with ( app.test_request_context("/", query_string={"model_type": ModelType.LLM}), - patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), patch("controllers.console.workspace.models.ModelProviderService") as service, ): service.return_value.get_default_model_of_model_type.return_value = None - result = method(api) + result = method(api, "t1") assert "data" in result @@ -95,15 +86,11 @@ class TestModelProviderModelApi: with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService") as service_mock, ): service_mock.return_value.get_models_by_provider.return_value = [] - result = method(api, "openai") + result = method(api, "tenant1", "openai") assert "data" in result @@ -122,14 +109,10 @@ class TestModelProviderModelApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService"), patch("controllers.console.workspace.models.ModelLoadBalancingService"), ): - result, status = method(api, "openai") + result, status = method(api, "tenant1", "openai") assert status == 200 @@ -144,13 +127,9 @@ class TestModelProviderModelApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService"), ): - result, status = method(api, "openai") + result, status = method(api, "tenant1", "openai") assert status == 204 @@ -160,12 +139,11 @@ class TestModelProviderModelApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), patch("controllers.console.workspace.models.ModelProviderService") as service, ): service.return_value.get_models_by_provider.return_value = [] - result = method(api, "openai") + result = method(api, "t1", "openai") assert "data" in result @@ -183,10 +161,6 @@ class TestModelProviderModelCredentialApi: "model_type": ModelType.LLM, }, ), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService") as provider_service, patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb_service, ): @@ -198,7 +172,7 @@ class TestModelProviderModelCredentialApi: provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = [] lb_service.return_value.get_load_balancing_configs.return_value = (False, []) - result = method(api, "openai") + result = method(api, "tenant1", "openai") assert "credentials" in result @@ -214,13 +188,9 @@ class TestModelProviderModelCredentialApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService"), ): - result, status = method(api, "openai") + result, status = method(api, "tenant1", "openai") assert status == 201 @@ -230,7 +200,6 @@ class TestModelProviderModelCredentialApi: with ( app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM}), - patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), patch("controllers.console.workspace.models.ModelProviderService") as service, patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb, ): @@ -238,7 +207,7 @@ class TestModelProviderModelCredentialApi: service.return_value.provider_manager.get_provider_model_available_credentials.return_value = [] lb.return_value.get_load_balancing_configs.return_value = (False, []) - result = method(api, "openai") + result = method(api, "t1", "openai") assert result["credentials"] == {} @@ -254,10 +223,9 @@ class TestModelProviderModelCredentialApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), patch("controllers.console.workspace.models.ModelProviderService"), ): - result, status = method(api, "openai") + result, status = method(api, "t1", "openai") assert status == 204 @@ -275,13 +243,9 @@ class TestModelProviderModelCredentialSwitchApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService"), ): - result = method(api, "openai") + result = method(api, "tenant1", "openai") assert result["result"] == "success" @@ -298,13 +262,9 @@ class TestModelEnableDisableApis: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService"), ): - result = method(api, "openai") + result = method(api, "tenant1", "openai") assert result["result"] == "success" @@ -319,13 +279,9 @@ class TestModelEnableDisableApis: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService"), ): - result = method(api, "openai") + result = method(api, "tenant1", "openai") assert result["result"] == "success" @@ -343,13 +299,9 @@ class TestModelProviderModelValidateApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService"), ): - result = method(api, "openai") + result = method(api, "tenant1", "openai") assert result["result"] == "success" @@ -366,15 +318,11 @@ class TestModelProviderModelValidateApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService") as service_mock, ): service_mock.return_value.validate_model_credentials.side_effect = CredentialsValidateFailedError("invalid") - result = method(api, "openai") + result = method(api, "tenant1", "openai") assert result["result"] == "error" @@ -386,15 +334,11 @@ class TestParameterAndAvailableModels: with ( app.test_request_context("/", query_string={"model": "gpt-4"}), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService") as service_mock, ): service_mock.return_value.get_model_parameter_rules.return_value = [] - result = method(api, "openai") + result = method(api, "tenant1", "openai") assert "data" in result @@ -404,15 +348,11 @@ class TestParameterAndAvailableModels: with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.models.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch("controllers.console.workspace.models.ModelProviderService") as service_mock, ): service_mock.return_value.get_models_by_model_type.return_value = [] - result = method(api, ModelType.LLM) + result = method(api, "tenant1", ModelType.LLM) assert "data" in result @@ -422,12 +362,11 @@ class TestParameterAndAvailableModels: with ( app.test_request_context("/", query_string={"model": "gpt"}), - patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), patch("controllers.console.workspace.models.ModelProviderService") as service, ): service.return_value.get_model_parameter_rules.return_value = [] - result = method(api, "openai") + result = method(api, "t1", "openai") assert result["data"] == [] @@ -437,11 +376,10 @@ class TestParameterAndAvailableModels: with ( app.test_request_context("/"), - patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), patch("controllers.console.workspace.models.ModelProviderService") as service, ): service.return_value.get_models_by_model_type.return_value = [] - result = method(api, ModelType.LLM) + result = method(api, "t1", ModelType.LLM) assert result["data"] == []