From 57b573d02bc11ba87aec0afabf1656d7052b7c22 Mon Sep 17 00:00:00 2001 From: chariri Date: Wed, 3 Jun 2026 17:59:00 +0900 Subject: [PATCH] refactor(api): migrate tenant/user via DI for several endpoints (#37004) --- api/controllers/console/agent/composer.py | 43 ++++--- api/controllers/console/agent/roster.py | 51 ++++---- api/controllers/console/app/app.py | 32 ++--- api/controllers/console/app/model_config.py | 25 ++-- .../console/explore/installed_app.py | 25 ++-- .../console/workspace/model_providers.py | 55 ++++----- api/controllers/console/workspace/models.py | 12 +- api/controllers/openapi/oauth_device.py | 14 ++- .../console/agent/test_agent_controllers.py | 106 +++++++++------- .../console/app/test_app_response_models.py | 28 +++-- .../console/app/test_model_config_api.py | 13 +- .../console/explore/test_installed_app.py | 92 +++++++------- .../console/workspace/test_model_providers.py | 114 ++++-------------- .../console/workspace/test_models.py | 5 +- 14 files changed, 306 insertions(+), 309 deletions(-) diff --git a/api/controllers/console/agent/composer.py b/api/controllers/console/agent/composer.py index db8d7fec60..7f7370454c 100644 --- a/api/controllers/console/agent/composer.py +++ b/api/controllers/console/agent/composer.py @@ -3,7 +3,13 @@ from flask_restx import Resource from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_tenant_id, + with_current_user_id, +) from fields.agent_fields import ( AgentAppComposerResponse, AgentComposerCandidatesResponse, @@ -12,7 +18,7 @@ from fields.agent_fields import ( WorkflowAgentComposerResponse, ) from libs.helper import dump_response -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models.model import App, AppMode from services.agent.composer_service import AgentComposerService from services.agent.composer_validator import ComposerConfigValidator @@ -38,8 +44,8 @@ class WorkflowAgentComposerApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) - def get(self, app_model: App, node_id: str): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, app_model: App, node_id: str): return dump_response( WorkflowAgentComposerResponse, AgentComposerService.load_workflow_composer( @@ -58,8 +64,9 @@ class WorkflowAgentComposerApi(Resource): @account_initialization_required @edit_permission_required @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) - def put(self, app_model: App, node_id: str): - account, tenant_id = current_account_with_tenant() + @with_current_user_id + @with_current_tenant_id + def put(self, tenant_id: str, account_id: str, app_model: App, node_id: str): payload = ComposerSavePayload.model_validate(console_ns.payload or {}) return dump_response( WorkflowAgentComposerResponse, @@ -67,7 +74,7 @@ class WorkflowAgentComposerApi(Resource): tenant_id=tenant_id, app_id=app_model.id, node_id=node_id, - account_id=account.id, + account_id=account_id, payload=payload, ), ) @@ -113,8 +120,8 @@ class WorkflowAgentComposerImpactApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) - def post(self, app_model: App, node_id: str): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def post(self, tenant_id: str, app_model: App, node_id: str): payload = ComposerSavePayload.model_validate(console_ns.payload or {}) current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None if not current_snapshot_id: @@ -138,8 +145,9 @@ class WorkflowAgentComposerSaveToRosterApi(Resource): @account_initialization_required @edit_permission_required @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) - def post(self, app_model: App, node_id: str): - account, tenant_id = current_account_with_tenant() + @with_current_user_id + @with_current_tenant_id + def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str): payload = ComposerSavePayload.model_validate(console_ns.payload or {}) return dump_response( WorkflowAgentComposerResponse, @@ -147,7 +155,7 @@ class WorkflowAgentComposerSaveToRosterApi(Resource): tenant_id=tenant_id, app_id=app_model.id, node_id=node_id, - account_id=account.id, + account_id=account_id, payload=payload, ), ) @@ -160,8 +168,8 @@ class AgentAppComposerApi(Resource): @login_required @account_initialization_required @get_app_model() - def get(self, app_model: App): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, app_model: App): return dump_response( AgentAppComposerResponse, AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id), @@ -174,15 +182,16 @@ class AgentAppComposerApi(Resource): @account_initialization_required @edit_permission_required @get_app_model() - def put(self, app_model: App): - account, tenant_id = current_account_with_tenant() + @with_current_user_id + @with_current_tenant_id + def put(self, tenant_id: str, account_id: str, app_model: App): payload = ComposerSavePayload.model_validate(console_ns.payload or {}) return dump_response( AgentAppComposerResponse, AgentComposerService.save_agent_app_composer( tenant_id=tenant_id, app_id=app_model.id, - account_id=account.id, + account_id=account_id, payload=payload, ), ) diff --git a/api/controllers/console/agent/roster.py b/api/controllers/console/agent/roster.py index be41b4e3b3..c305a816ee 100644 --- a/api/controllers/console/agent/roster.py +++ b/api/controllers/console/agent/roster.py @@ -6,7 +6,13 @@ from pydantic import BaseModel, Field from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_tenant_id, + with_current_user_id, +) from extensions.ext_database import db from fields.agent_fields import ( AgentConfigSnapshotDetailResponse, @@ -16,7 +22,7 @@ from fields.agent_fields import ( AgentRosterResponse, ) from libs.helper import dump_response -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from services.agent.roster_service import AgentRosterService from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload, RosterListQuery @@ -58,8 +64,8 @@ class AgentRosterListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str): query = RosterListQuery.model_validate(request.args.to_dict(flat=True)) return dump_response( AgentRosterListResponse, @@ -74,11 +80,12 @@ class AgentRosterListApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self): - account, tenant_id = current_account_with_tenant() + @with_current_user_id + @with_current_tenant_id + def post(self, tenant_id: str, account_id: str): payload = RosterAgentCreatePayload.model_validate(console_ns.payload or {}) service = _agent_roster_service() - agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account.id, payload=payload) + agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account_id, payload=payload) return dump_response( AgentRosterResponse, service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id), @@ -92,8 +99,8 @@ class AgentInviteOptionsApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str): query = AgentInviteOptionsQuery.model_validate(request.args.to_dict(flat=True)) return dump_response( AgentInviteOptionsResponse, @@ -113,8 +120,8 @@ class AgentRosterDetailApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, agent_id: UUID): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, agent_id: UUID): return dump_response( AgentRosterResponse, _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id)), @@ -126,13 +133,14 @@ class AgentRosterDetailApi(Resource): @login_required @account_initialization_required @edit_permission_required - def patch(self, agent_id: UUID): - account, tenant_id = current_account_with_tenant() + @with_current_user_id + @with_current_tenant_id + def patch(self, tenant_id: str, account_id: str, agent_id: UUID): payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {}) return dump_response( AgentRosterResponse, _agent_roster_service().update_roster_agent( - tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id, payload=payload + tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id, payload=payload ), ) @@ -141,9 +149,10 @@ class AgentRosterDetailApi(Resource): @login_required @account_initialization_required @edit_permission_required - def delete(self, agent_id: UUID): - account, tenant_id = current_account_with_tenant() - _agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id) + @with_current_user_id + @with_current_tenant_id + def delete(self, tenant_id: str, account_id: str, agent_id: UUID): + _agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id) return "", 204 @@ -153,8 +162,8 @@ class AgentRosterVersionsApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, agent_id: UUID): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, agent_id: UUID): return dump_response( AgentConfigSnapshotListResponse, {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))}, @@ -167,8 +176,8 @@ class AgentRosterVersionDetailApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, agent_id: UUID, version_id: UUID): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, agent_id: UUID, version_id: UUID): return dump_response( AgentConfigSnapshotDetailResponse, _agent_roster_service().get_agent_version_detail( diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index e44e32c892..d51dc68391 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -25,6 +25,9 @@ from controllers.console.wraps import ( enterprise_license_required, is_admin_or_owner_required, setup_required, + with_current_tenant_id, + with_current_user, + with_current_user_id, ) from core.ops.ops_trace_manager import OpsTraceManager from core.rag.entities import PreProcessingRule, Rule, Segmentation @@ -34,8 +37,8 @@ from extensions.ext_database import db from fields.base import ResponseModel from graphon.enums import WorkflowExecutionStatus from libs.helper import build_icon_url, to_timestamp -from libs.login import current_account_with_tenant, login_required -from models import App, DatasetPermissionEnum, Workflow +from libs.login import login_required +from models import Account, App, DatasetPermissionEnum, Workflow from models.model import IconType from services.app_dsl_service import AppDslService from services.app_service import AppListParams, AppService, CreateAppParams @@ -472,10 +475,10 @@ class AppListApi(Resource): @account_initialization_required @enterprise_license_required @with_session(write=False) - def get(self, session: Session): + @with_current_user_id + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user_id: str, session: Session): """Get app list""" - current_user, current_tenant_id = current_account_with_tenant() - args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args)) params = AppListParams( page=args.page, @@ -488,7 +491,7 @@ class AppListApi(Resource): # get app list app_service = AppService() - app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params) + app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params) if not app_pagination: empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[]) return empty.model_dump(mode="json"), 200 @@ -548,9 +551,10 @@ class AppListApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("apps") @edit_permission_required - def post(self): + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account): """Create app""" - current_user, current_tenant_id = current_account_with_tenant() args = CreateAppPayload.model_validate(console_ns.payload) params = CreateAppParams( name=args.name, @@ -653,11 +657,10 @@ class AppCopyApi(Resource): @account_initialization_required @get_app_model(mode=None) @edit_permission_required - def post(self, app_model: App): + @with_current_user + def post(self, current_user: Account, app_model: App): """Copy app""" # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() - args = CopyAppPayload.model_validate(console_ns.payload or {}) with Session(db.engine, expire_on_commit=False) as session: @@ -736,7 +739,8 @@ class AppPublishToCreatorsPlatformApi(Resource): @account_initialization_required @get_app_model(mode=None) @edit_permission_required - def post(self, app_model: App): + @with_current_user_id + def post(self, current_user_id: str, app_model: App): """Publish app to Creators Platform""" from configs import dify_config from core.helper.creators import get_redirect_url, upload_dsl @@ -744,13 +748,11 @@ class AppPublishToCreatorsPlatformApi(Resource): if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED: return {"error": "Creators Platform features are not enabled"}, 403 - current_user, _ = current_account_with_tenant() - dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False) dsl_bytes = dsl_content.encode("utf-8") claim_code = upload_dsl(dsl_bytes) - redirect_url = get_redirect_url(str(current_user.id), claim_code) + redirect_url = get_redirect_url(current_user_id, claim_code) return {"redirect_url": redirect_url} diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index a893b66911..8951a71510 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -8,14 +8,20 @@ from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_tenant_id, + with_current_user_id, +) from core.agent.entities import AgentToolEntity from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models.model import App, AppMode, AppModelConfig from services.app_model_config_service import AppModelConfigService @@ -52,9 +58,10 @@ class ModelConfigResource(Resource): @edit_permission_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) - def post(self, app_model: App): + @with_current_user_id + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user_id: str, app_model: App): """Modify app model config""" - current_user, current_tenant_id = current_account_with_tenant() # validate config model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_tenant_id, @@ -64,8 +71,8 @@ class ModelConfigResource(Resource): new_app_model_config = AppModelConfig( app_id=app_model.id, - created_by=current_user.id, - updated_by=current_user.id, + created_by=current_user_id, + updated_by=current_user_id, ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) @@ -90,7 +97,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, - user_id=current_user.id, + user_id=current_user_id, ) manager = ToolParameterConfigurationManager( tenant_id=current_tenant_id, @@ -130,7 +137,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, - user_id=current_user.id, + user_id=current_user_id, ) except Exception: continue @@ -167,7 +174,7 @@ class ModelConfigResource(Resource): db.session.flush() app_model.app_model_config_id = new_app_model_config.id - app_model.updated_by = current_user.id + app_model.updated_by = current_user_id app_model.updated_at = naive_utc_now() db.session.commit() diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index bd4d1ef49f..86b36e3c92 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -12,14 +12,19 @@ from controllers.common.fields import SimpleMessageResponse, SimpleResultMessage from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.explore.wraps import InstalledAppResource -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + with_current_tenant_id, + with_current_user, +) from extensions.ext_database import db from fields.base import ResponseModel from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.helper import to_timestamp -from libs.login import current_account_with_tenant, login_required -from models import App, InstalledApp, RecommendedApp +from libs.login import login_required +from models import Account, App, InstalledApp, RecommendedApp from models.model import IconType from services.account_service import TenantService from services.enterprise.enterprise_service import EnterpriseService @@ -131,9 +136,10 @@ class InstalledAppsListApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__]) - def get(self): + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account): query = InstalledAppsListQuery.model_validate(request.args.to_dict()) - current_user, current_tenant_id = current_account_with_tenant() if query.app_id: installed_apps = db.session.scalars( @@ -212,7 +218,8 @@ class InstalledAppsListApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("apps") @console_ns.response(200, "Success", console_ns.models[SimpleMessageResponse.__name__]) - def post(self): + @with_current_tenant_id + def post(self, current_tenant_id: str): payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {}) recommended_app = db.session.scalar( @@ -221,8 +228,6 @@ class InstalledAppsListApi(Resource): if recommended_app is None: raise NotFound("Recommended app not found") - _, current_tenant_id = current_account_with_tenant() - app = db.session.get(App, payload.app_id) if app is None: @@ -262,8 +267,8 @@ class InstalledAppApi(InstalledAppResource): """ @console_ns.response(204, "App uninstalled successfully") - def delete(self, installed_app: InstalledApp): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def delete(self, current_tenant_id: str, installed_app: InstalledApp): if installed_app.app_owner_tenant_id == current_tenant_id: raise BadRequest("You can't uninstall an app owned by the current tenant") diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 221cb3e406..e77f17b2d0 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -8,12 +8,19 @@ from pydantic import BaseModel, Field, field_validator from controllers.common.fields import SimpleResultResponse from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + is_admin_or_owner_required, + setup_required, + with_current_tenant_id, + with_current_user, +) from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from services.billing_service import BillingService from services.model_provider_service import ModelProviderService @@ -95,10 +102,8 @@ class ModelProviderListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, current_tenant_id = current_account_with_tenant() - tenant_id = current_tenant_id - + @with_current_tenant_id + def get(self, tenant_id: str): payload = request.args.to_dict(flat=True) args = ParserModelList.model_validate(payload) @@ -114,9 +119,8 @@ class ModelProviderCredentialApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider: str): - _, current_tenant_id = current_account_with_tenant() - tenant_id = current_tenant_id + @with_current_tenant_id + def get(self, tenant_id: str, provider: str): # if credential_id is not provided, return current used credential payload = request.args.to_dict(flat=True) args = ParserCredentialId.model_validate(payload) @@ -133,8 +137,8 @@ class ModelProviderCredentialApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self, provider: str): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def post(self, current_tenant_id: str, provider: str): payload = console_ns.payload or {} args = ParserCredentialCreate.model_validate(payload) @@ -157,9 +161,8 @@ class ModelProviderCredentialApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def put(self, provider: str): - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def put(self, current_tenant_id: str, provider: str): payload = console_ns.payload or {} args = ParserCredentialUpdate.model_validate(payload) @@ -184,8 +187,8 @@ class ModelProviderCredentialApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def delete(self, provider: str): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def delete(self, current_tenant_id: str, provider: str): payload = console_ns.payload or {} args = ParserCredentialDelete.model_validate(payload) @@ -205,8 +208,8 @@ class ModelProviderCredentialSwitchApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self, provider: str): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def post(self, current_tenant_id: str, provider: str): payload = console_ns.payload or {} args = ParserCredentialSwitch.model_validate(payload) @@ -225,8 +228,8 @@ class ModelProviderValidateApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, provider: str): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def post(self, current_tenant_id: str, provider: str): payload = console_ns.payload or {} args = ParserCredentialValidate.model_validate(payload) @@ -280,11 +283,8 @@ class PreferredProviderTypeUpdateApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self, provider: str): - _, current_tenant_id = current_account_with_tenant() - - tenant_id = current_tenant_id - + @with_current_tenant_id + def post(self, tenant_id: str, provider: str): payload = console_ns.payload or {} args = ParserPreferredProviderType.model_validate(payload) @@ -301,10 +301,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider: str): + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account, provider: str): if provider != "anthropic": raise ValueError(f"provider name {provider} is invalid") - current_user, current_tenant_id = current_account_with_tenant() BillingService.is_tenant_owner_or_admin(current_user) data = BillingService.get_model_provider_payment_link( provider_name=provider, diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 5cd849725f..19e3fc60bb 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -13,12 +13,14 @@ from controllers.console.wraps import ( is_admin_or_owner_required, setup_required, with_current_tenant_id, + with_current_user, ) from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from services.model_load_balancing_service import ModelLoadBalancingService from services.model_provider_service import ModelProviderService @@ -269,8 +271,9 @@ class ModelProviderModelCredentialApi(Resource): @setup_required @login_required @account_initialization_required + @with_current_user @with_current_tenant_id - def get(self, tenant_id: str, provider: str): + def get(self, tenant_id: str, user: Account, provider: str): args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) model_provider_service = ModelProviderService() @@ -293,9 +296,8 @@ class ModelProviderModelCredentialApi(Resource): if args.config_from == "predefined-model": # Only the predefined-model branch needs visibility filtering by user. - # Defer the auth lookup so the other branch (and its tests) doesn't - # require flask-login setup. - user, _ = current_account_with_tenant() + # The account is injected once by the handler and only passed into the + # service branch that needs user-scoped credential visibility. available_credentials = model_provider_service.get_provider_available_credentials( tenant_id=tenant_id, provider=provider, diff --git a/api/controllers/openapi/oauth_device.py b/api/controllers/openapi/oauth_device.py index bbee345767..d685d1fb29 100644 --- a/api/controllers/openapi/oauth_device.py +++ b/api/controllers/openapi/oauth_device.py @@ -26,7 +26,12 @@ from werkzeug.exceptions import BadRequest from configs import dify_config from controllers.common.schema import query_params_from_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + setup_required, + with_current_tenant_id, + with_current_user, +) from controllers.openapi import openapi_ns from controllers.openapi._models import ( AccountPayload, @@ -42,7 +47,6 @@ from controllers.openapi._models import ( from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.helper import extract_remote_ip -from libs.login import current_account_with_tenant from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required from libs.rate_limit import ( LIMIT_APPROVE_CONSOLE, @@ -50,6 +54,7 @@ from libs.rate_limit import ( LIMIT_LOOKUP_PUBLIC, rate_limit, ) +from models import Account from services.account_service import TenantService from services.oauth_device_flow import ( ACCOUNT_ISSUER_SENTINEL, @@ -206,11 +211,12 @@ class DeviceApproveApi(Resource): @account_initialization_required @bearer_feature_required @rate_limit(LIMIT_APPROVE_CONSOLE) - def post(self): + @with_current_user + @with_current_tenant_id + def post(self, tenant: str, account: Account): payload = _validate_json(DeviceMutateRequest) user_code = payload.user_code.strip().upper() - account, tenant = current_account_with_tenant() store = DeviceFlowRedis(redis_client) found = store.load_by_user_code(user_code) diff --git a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py index 718f44cdf2..a1f567de74 100644 --- a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py +++ b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py @@ -1,6 +1,8 @@ from types import SimpleNamespace +from typing import Protocol, cast import pytest +from flask import Flask from controllers.console.agent import composer as composer_controller from controllers.console.agent import roster as roster_controller @@ -114,34 +116,34 @@ def _candidates_response(variant: str) -> dict: } +class _PayloadWithDescription(Protocol): + description: object + + @pytest.fixture -def account(): - return SimpleNamespace(id="account-1") +def account_id() -> str: + return "account-1" -@pytest.fixture(autouse=True) -def patch_account_context(monkeypatch, account): - monkeypatch.setattr(roster_controller, "current_account_with_tenant", lambda: (account, "tenant-1")) - monkeypatch.setattr(composer_controller, "current_account_with_tenant", lambda: (account, "tenant-1")) +def test_roster_list_get_parses_query_and_calls_service(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, object] = {} - -def test_roster_list_get_parses_query_and_calls_service(app, monkeypatch): - captured = {} - - def list_roster_agents(_self, **kwargs): + def list_roster_agents(_self: object, **kwargs: object) -> dict[str, object]: captured.update(kwargs) return {"data": [], "page": kwargs["page"], "limit": kwargs["limit"], "total": 0, "has_more": False} monkeypatch.setattr(roster_controller.AgentRosterService, "list_roster_agents", list_roster_agents) with app.test_request_context("/console/api/agents?page=2&limit=5&keyword=analyst"): - result = _unwrap(AgentRosterListApi.get)(AgentRosterListApi()) + result = _unwrap(AgentRosterListApi.get)(AgentRosterListApi(), "tenant-1") assert result["page"] == 2 assert captured == {"tenant_id": "tenant-1", "page": 2, "limit": 5, "keyword": "analyst"} -def test_roster_list_post_creates_agent_and_returns_detail(app, monkeypatch): +def test_roster_list_post_creates_agent_and_returns_detail( + app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str +) -> None: created_agent = SimpleNamespace(id="agent-1") monkeypatch.setattr( roster_controller.AgentRosterService, @@ -155,42 +157,47 @@ def test_roster_list_post_creates_agent_and_returns_detail(app, monkeypatch): ) with app.test_request_context(json={"name": "Analyst", "agent_soul": {"prompt": {"system_prompt": "x"}}}): - result, status = _unwrap(AgentRosterListApi.post)(AgentRosterListApi()) + result, status = _unwrap(AgentRosterListApi.post)(AgentRosterListApi(), "tenant-1", account_id) assert status == 201 assert result["id"] == "agent-1" assert result["agent_kind"] == "dify_agent" -def test_invite_options_get_parses_app_id(app, monkeypatch): - captured = {} +def test_invite_options_get_parses_app_id(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, object] = {} - def list_invite_options(_self, **kwargs): + def list_invite_options(_self: object, **kwargs: object) -> dict[str, object]: captured.update(kwargs) return {"data": [], "page": kwargs["page"], "limit": kwargs["limit"], "total": 0, "has_more": False} monkeypatch.setattr(roster_controller.AgentRosterService, "list_invite_options", list_invite_options) with app.test_request_context("/console/api/agents/invite-options?page=1&limit=10&app_id=app-1"): - result = _unwrap(AgentInviteOptionsApi.get)(AgentInviteOptionsApi()) + result = _unwrap(AgentInviteOptionsApi.get)(AgentInviteOptionsApi(), "tenant-1") assert result == {"data": [], "page": 1, "limit": 10, "total": 0, "has_more": False} assert captured == {"tenant_id": "tenant-1", "page": 1, "limit": 10, "keyword": None, "app_id": "app-1"} -def test_roster_detail_patch_delete_and_versions_call_services(app, monkeypatch): +def test_roster_detail_patch_delete_and_versions_call_services( + app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str +) -> None: agent_id = "00000000-0000-0000-0000-000000000001" version_id = "00000000-0000-0000-0000-000000000002" - archived = {} + archived: dict[str, object] = {} monkeypatch.setattr( roster_controller.AgentRosterService, "get_roster_agent_detail", - lambda _self, **kwargs: _agent_response(kwargs["agent_id"]), + lambda _self, **kwargs: _agent_response(cast(str, kwargs["agent_id"])), ) monkeypatch.setattr( roster_controller.AgentRosterService, "update_roster_agent", - lambda _self, **kwargs: {**_agent_response(kwargs["agent_id"]), "description": kwargs["payload"].description}, + lambda _self, **kwargs: { + **_agent_response(cast(str, kwargs["agent_id"])), + "description": cast(_PayloadWithDescription, kwargs["payload"]).description, + }, ) monkeypatch.setattr( roster_controller.AgentRosterService, @@ -206,7 +213,7 @@ def test_roster_detail_patch_delete_and_versions_call_services(app, monkeypatch) roster_controller.AgentRosterService, "get_agent_version_detail", lambda _self, **kwargs: { - **_version_response(kwargs["version_id"]), + **_version_response(cast(str, kwargs["version_id"])), "agent_id": kwargs["agent_id"], "config_snapshot": {}, "revisions": [ @@ -225,18 +232,28 @@ def test_roster_detail_patch_delete_and_versions_call_services(app, monkeypatch) }, ) - assert _unwrap(AgentRosterDetailApi.get)(AgentRosterDetailApi(), agent_id)["id"] == agent_id + assert _unwrap(AgentRosterDetailApi.get)(AgentRosterDetailApi(), "tenant-1", agent_id)["id"] == agent_id with app.test_request_context(json={"description": "updated"}): - assert _unwrap(AgentRosterDetailApi.patch)(AgentRosterDetailApi(), agent_id)["description"] == "updated" - assert _unwrap(AgentRosterDetailApi.delete)(AgentRosterDetailApi(), agent_id) == ("", 204) + assert ( + _unwrap(AgentRosterDetailApi.patch)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id)["description"] + == "updated" + ) + assert _unwrap(AgentRosterDetailApi.delete)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id) == ("", 204) assert archived["account_id"] == "account-1" - assert _unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), agent_id)["data"][0]["id"] == "version-1" - version_detail = _unwrap(AgentRosterVersionDetailApi.get)(AgentRosterVersionDetailApi(), agent_id, version_id) + assert ( + _unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), "tenant-1", agent_id)["data"][0]["id"] + == "version-1" + ) + version_detail = _unwrap(AgentRosterVersionDetailApi.get)( + AgentRosterVersionDetailApi(), "tenant-1", agent_id, version_id + ) assert version_detail["id"] == version_id assert version_detail["agent_id"] == agent_id -def test_workflow_composer_get_put_validate_candidates_impact_and_save(app, monkeypatch): +def test_workflow_composer_get_put_validate_candidates_impact_and_save( + app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str +) -> None: app_model = SimpleNamespace(id="app-1") payload = { "variant": ComposerVariant.WORKFLOW.value, @@ -269,10 +286,12 @@ def test_workflow_composer_get_put_validate_candidates_impact_and_save(app, monk }, ) - workflow_state = _unwrap(WorkflowAgentComposerApi.get)(WorkflowAgentComposerApi(), app_model, "node-1") + workflow_state = _unwrap(WorkflowAgentComposerApi.get)(WorkflowAgentComposerApi(), "tenant-1", app_model, "node-1") assert workflow_state["node_id"] == "node-1" with app.test_request_context(json=payload): - saved_state = _unwrap(WorkflowAgentComposerApi.put)(WorkflowAgentComposerApi(), app_model, "node-1") + saved_state = _unwrap(WorkflowAgentComposerApi.put)( + WorkflowAgentComposerApi(), "tenant-1", account_id, app_model, "node-1" + ) assert saved_state["save_options"] == ["node_job_only"] assert _unwrap(WorkflowAgentComposerValidateApi.post)( WorkflowAgentComposerValidateApi(), app_model, "node-1" @@ -284,28 +303,28 @@ def test_workflow_composer_get_put_validate_candidates_impact_and_save(app, monk == "workflow" ) with app.test_request_context(json=payload): - assert _unwrap(WorkflowAgentComposerImpactApi.post)(WorkflowAgentComposerImpactApi(), app_model, "node-1") == { - "current_snapshot_id": "version-1", - "workflow_node_count": 1, - "bindings": [], - } + assert _unwrap(WorkflowAgentComposerImpactApi.post)( + WorkflowAgentComposerImpactApi(), "tenant-1", app_model, "node-1" + ) == {"current_snapshot_id": "version-1", "workflow_node_count": 1, "bindings": []} assert _unwrap(WorkflowAgentComposerSaveToRosterApi.post)( - WorkflowAgentComposerSaveToRosterApi(), app_model, "node-1" + WorkflowAgentComposerSaveToRosterApi(), "tenant-1", account_id, app_model, "node-1" )["save_options"] == ["node_job_only"] -def test_workflow_impact_returns_empty_without_version(app): +def test_workflow_impact_returns_empty_without_version(app: Flask) -> None: payload = {"variant": ComposerVariant.WORKFLOW.value, "save_strategy": ComposerSaveStrategy.NODE_JOB_ONLY.value} with app.test_request_context(json=payload): result = _unwrap(WorkflowAgentComposerImpactApi.post)( - WorkflowAgentComposerImpactApi(), SimpleNamespace(id="app-1"), "node-1" + WorkflowAgentComposerImpactApi(), "tenant-1", SimpleNamespace(id="app-1"), "node-1" ) assert result == {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []} -def test_agent_app_composer_get_put_validate_and_candidates(app, monkeypatch): +def test_agent_app_composer_get_put_validate_and_candidates( + app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str +) -> None: app_model = SimpleNamespace(id="app-1") payload = { "variant": ComposerVariant.AGENT_APP.value, @@ -329,9 +348,12 @@ def test_agent_app_composer_get_put_validate_and_candidates(app, monkeypatch): lambda **kwargs: _candidates_response("agent_app"), ) - assert _unwrap(AgentAppComposerApi.get)(AgentAppComposerApi(), app_model)["variant"] == "agent_app" + assert _unwrap(AgentAppComposerApi.get)(AgentAppComposerApi(), "tenant-1", app_model)["variant"] == "agent_app" with app.test_request_context(json=payload): - assert _unwrap(AgentAppComposerApi.put)(AgentAppComposerApi(), app_model)["variant"] == "agent_app" + assert ( + _unwrap(AgentAppComposerApi.put)(AgentAppComposerApi(), "tenant-1", account_id, app_model)["variant"] + == "agent_app" + ) assert _unwrap(AgentAppComposerValidateApi.post)(AgentAppComposerValidateApi(), app_model) == { "result": "success", "errors": [], diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py index f9d3f0ad87..d698b03893 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -6,10 +6,10 @@ from datetime import datetime from importlib import util from pathlib import Path from types import ModuleType, SimpleNamespace -from typing import Any from unittest.mock import MagicMock import pytest +from flask import Flask from flask.views import MethodView from pydantic import ValidationError from werkzeug.datastructures import MultiDict @@ -19,6 +19,13 @@ if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] +class _ConsoleModule(ModuleType): + console_ns: object + api: object | None + bp: object | None + app: ModuleType + + def _unwrap(func): bound_self = getattr(func, "__self__", None) while hasattr(func, "__wrapped__"): @@ -36,7 +43,7 @@ def app_module(): class _StubNamespace: def __init__(self): - self.models: dict[str, Any] = {} + self.models: dict[str, object] = {} self.payload = None def schema_model(self, name, schema): @@ -77,7 +84,7 @@ def app_module(): } stubbed_modules: list[tuple[str, ModuleType | None]] = [] - console_module = ModuleType("controllers.console") + console_module = _ConsoleModule("controllers.console") console_module.__path__ = [str(root / "controllers" / "console")] console_module.console_ns = stub_namespace console_module.api = None @@ -89,7 +96,7 @@ def app_module(): sys.modules["controllers.console.app"] = app_package console_module.app = app_package - def _stub_module(name: str, attrs: dict[str, Any]): + def _stub_module(name: str, attrs: dict[str, object]) -> None: original = sys.modules.get(name) module = ModuleType(name) for key, value in attrs.items(): @@ -99,7 +106,7 @@ def app_module(): class _OpsTraceManager: @staticmethod - def get_app_tracing_config(app_id: str) -> dict[str, Any]: + def get_app_tracing_config(app_id: str) -> dict[str, object]: return {} @staticmethod @@ -116,6 +123,7 @@ def app_module(): ) spec = util.spec_from_file_location(module_name, module_path) + assert spec is not None module = util.module_from_spec(spec) sys.modules[module_name] = module @@ -147,7 +155,7 @@ def app_models(app_module): @pytest.fixture(autouse=True) -def patch_signed_url(monkeypatch, app_module): +def patch_signed_url(monkeypatch: pytest.MonkeyPatch, app_module: ModuleType) -> None: """Ensure icon URL generation uses a deterministic helper for tests.""" def _fake_build_icon_url(_icon_type, key: str | None) -> str | None: @@ -407,10 +415,11 @@ def test_app_pagination_aliases_per_page_and_has_next(app_models): assert serialized["data"][1]["icon_url"] is None -def test_app_list_uses_injected_session_for_draft_workflows(app, app_module, monkeypatch): +def test_app_list_uses_injected_session_for_draft_workflows( + app: Flask, app_module: ModuleType, monkeypatch: pytest.MonkeyPatch +) -> None: api = app_module.AppListApi() method = _unwrap(api.get) - current_user = SimpleNamespace(id="user-1") app_item = SimpleNamespace( id="app-1", name="Workflow App", @@ -428,7 +437,6 @@ def test_app_list_uses_injected_session_for_draft_workflows(app, app_module, mon session.execute.return_value.scalars.return_value.all.return_value = [workflow] scoped_session = SimpleNamespace(execute=MagicMock(side_effect=AssertionError("db.session should not be used"))) - monkeypatch.setattr(app_module, "current_account_with_tenant", lambda: (current_user, "tenant-1")) monkeypatch.setattr( app_module, "AppService", @@ -442,7 +450,7 @@ def test_app_list_uses_injected_session_for_draft_workflows(app, app_module, mon monkeypatch.setattr(app_module, "db", SimpleNamespace(session=scoped_session)) with app.test_request_context("/console/api/apps?page=1&limit=20", method="GET"): - response, status = method(session) + response, status = method("tenant-1", "user-1", session) assert status == 200 assert response["data"][0]["has_draft_trigger"] is True diff --git a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py index a0e2edb8cf..5fc60d8046 100644 --- a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from flask import Flask from controllers.console.app import model_config as model_config_module from models.model import AppMode, AppModelConfig @@ -19,7 +20,7 @@ def _unwrap(func): return func -def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_post_updates_app_model_config_for_chat(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = model_config_module.ModelConfigResource() method = _unwrap(api.post) @@ -36,8 +37,6 @@ def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyP "validate_configuration", lambda **_kwargs: {"pre_prompt": "hi"}, ) - monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) - session = MagicMock() monkeypatch.setattr(model_config_module.db, "session", session) @@ -51,7 +50,7 @@ def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyP monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock) with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}): - response = method(app_model=app_model) + response = method("t1", "u1", app_model=app_model) session.add.assert_called_once() session.flush.assert_called_once() @@ -61,7 +60,7 @@ def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyP assert response["result"] == "success" -def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_post_encrypts_agent_tool_parameters(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = model_config_module.ModelConfigResource() method = _unwrap(api.post) @@ -115,8 +114,6 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc }, }, ) - monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) - monkeypatch.setattr(model_config_module.ToolManager, "get_agent_tool_runtime", lambda **_kwargs: object()) class _ParamManager: @@ -140,7 +137,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock) with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}): - response = method(app_model=app_model) + response = method("t1", "u1", app_model=app_model) stored_config = session.add.call_args[0][0] stored_agent_mode = json.loads(stored_config.agent_mode) diff --git a/api/tests/unit_tests/controllers/console/explore/test_installed_app.py b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py index 47ac8d8f3f..d7ec808efa 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_installed_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py @@ -1,3 +1,5 @@ +from collections.abc import Callable +from contextlib import AbstractContextManager from datetime import datetime from unittest.mock import MagicMock, PropertyMock, patch @@ -7,6 +9,9 @@ from werkzeug.exceptions import BadRequest, Forbidden, NotFound import controllers.console.explore.installed_app as module +type Payload = dict[str, object] +type PayloadPatch = Callable[[Payload], AbstractContextManager[object]] + def unwrap(func): while hasattr(func, "__wrapped__"): @@ -15,12 +20,12 @@ def unwrap(func): @pytest.fixture -def tenant_id(): +def tenant_id() -> str: return "t1" @pytest.fixture -def current_user(tenant_id): +def current_user(tenant_id: str) -> MagicMock: user = MagicMock() user.id = "u1" user.current_tenant = MagicMock(id=tenant_id) @@ -28,7 +33,7 @@ def current_user(tenant_id): @pytest.fixture -def installed_app(): +def installed_app() -> MagicMock: app = MagicMock() app.id = "ia1" app.app = MagicMock(id="a1") @@ -39,8 +44,8 @@ def installed_app(): @pytest.fixture -def payload_patch(): - def _patch(payload): +def payload_patch() -> PayloadPatch: + def _patch(payload: Payload) -> AbstractContextManager[object]: return patch.object( type(module.console_ns), "payload", @@ -52,7 +57,9 @@ def payload_patch(): class TestInstalledAppsListApi: - def test_get_installed_apps(self, app: Flask, current_user, tenant_id, installed_app): + def test_get_installed_apps( + self, app: Flask, current_user: MagicMock, tenant_id: str, installed_app: MagicMock + ) -> None: api = module.InstalledAppsListApi() method = unwrap(api.get) @@ -61,7 +68,6 @@ class TestInstalledAppsListApi: with ( app.test_request_context("/"), - patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), patch.object(module.db, "session", session), patch.object(module.TenantService, "get_user_role", return_value="owner"), patch.object( @@ -70,13 +76,13 @@ class TestInstalledAppsListApi: return_value=MagicMock(webapp_auth=MagicMock(enabled=False)), ), ): - result = method(api) + result = method(api, tenant_id, current_user) assert "installed_apps" in result assert result["installed_apps"][0]["editable"] is True assert result["installed_apps"][0]["uninstallable"] is False - def test_get_installed_apps_with_app_id_filter(self, app: Flask, current_user, tenant_id): + def test_get_installed_apps_with_app_id_filter(self, app: Flask, current_user: MagicMock, tenant_id: str) -> None: api = module.InstalledAppsListApi() method = unwrap(api.get) @@ -85,7 +91,6 @@ class TestInstalledAppsListApi: with ( app.test_request_context("/?app_id=a1"), - patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), patch.object(module.db, "session", session), patch.object(module.TenantService, "get_user_role", return_value="member"), patch.object( @@ -94,11 +99,13 @@ class TestInstalledAppsListApi: return_value=MagicMock(webapp_auth=MagicMock(enabled=False)), ), ): - result = method(api) + result = method(api, tenant_id, current_user) assert result == {"installed_apps": []} - def test_get_installed_apps_with_webapp_auth_enabled(self, app: Flask, current_user, tenant_id, installed_app): + def test_get_installed_apps_with_webapp_auth_enabled( + self, app: Flask, current_user: MagicMock, tenant_id: str, installed_app: MagicMock + ) -> None: """Test filtering when webapp_auth is enabled.""" api = module.InstalledAppsListApi() method = unwrap(api.get) @@ -111,7 +118,6 @@ class TestInstalledAppsListApi: with ( app.test_request_context("/"), - patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), patch.object(module.db, "session", session), patch.object(module.TenantService, "get_user_role", return_value="owner"), patch.object( @@ -130,11 +136,13 @@ class TestInstalledAppsListApi: return_value={"a1": True}, ), ): - result = method(api) + result = method(api, tenant_id, current_user) assert len(result["installed_apps"]) == 1 - def test_get_installed_apps_with_webapp_auth_user_denied(self, app: Flask, current_user, tenant_id, installed_app): + def test_get_installed_apps_with_webapp_auth_user_denied( + self, app: Flask, current_user: MagicMock, tenant_id: str, installed_app: MagicMock + ) -> None: """Test filtering when user doesn't have access.""" api = module.InstalledAppsListApi() method = unwrap(api.get) @@ -147,7 +155,6 @@ class TestInstalledAppsListApi: with ( app.test_request_context("/"), - patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), patch.object(module.db, "session", session), patch.object(module.TenantService, "get_user_role", return_value="member"), patch.object( @@ -166,11 +173,13 @@ class TestInstalledAppsListApi: return_value={"a1": False}, ), ): - result = method(api) + result = method(api, tenant_id, current_user) assert result["installed_apps"] == [] - def test_get_installed_apps_with_sso_verified_access(self, app: Flask, current_user, tenant_id, installed_app): + def test_get_installed_apps_with_sso_verified_access( + self, app: Flask, current_user: MagicMock, tenant_id: str, installed_app: MagicMock + ) -> None: """Test that sso_verified access mode apps are skipped in filtering.""" api = module.InstalledAppsListApi() method = unwrap(api.get) @@ -183,7 +192,6 @@ class TestInstalledAppsListApi: with ( app.test_request_context("/"), - patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), patch.object(module.db, "session", session), patch.object(module.TenantService, "get_user_role", return_value="owner"), patch.object( @@ -197,11 +205,11 @@ class TestInstalledAppsListApi: return_value={"a1": mock_webapp_setting}, ), ): - result = method(api) + result = method(api, tenant_id, current_user) assert len(result["installed_apps"]) == 0 - def test_get_installed_apps_filters_null_apps(self, app: Flask, current_user, tenant_id): + def test_get_installed_apps_filters_null_apps(self, app: Flask, current_user: MagicMock, tenant_id: str) -> None: """Test that installed apps with null app are filtered out.""" api = module.InstalledAppsListApi() method = unwrap(api.get) @@ -214,7 +222,6 @@ class TestInstalledAppsListApi: with ( app.test_request_context("/"), - patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), patch.object(module.db, "session", session), patch.object(module.TenantService, "get_user_role", return_value="owner"), patch.object( @@ -223,11 +230,11 @@ class TestInstalledAppsListApi: return_value=MagicMock(webapp_auth=MagicMock(enabled=False)), ), ): - result = method(api) + result = method(api, tenant_id, current_user) assert result["installed_apps"] == [] - def test_get_installed_apps_current_tenant_none(self, app: Flask, tenant_id, installed_app): + def test_get_installed_apps_current_tenant_none(self, app: Flask, tenant_id: str, installed_app: MagicMock) -> None: """Test error when current_user.current_tenant is None.""" api = module.InstalledAppsListApi() method = unwrap(api.get) @@ -240,15 +247,14 @@ class TestInstalledAppsListApi: with ( app.test_request_context("/"), - patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), patch.object(module.db, "session", session), ): with pytest.raises(ValueError, match="current_user.current_tenant must not be None"): - method(api) + method(api, tenant_id, current_user) class TestInstalledAppsCreateApi: - def test_post_success(self, app: Flask, tenant_id, payload_patch): + def test_post_success(self, app: Flask, tenant_id: str, payload_patch: PayloadPatch) -> None: api = module.InstalledAppsListApi() method = unwrap(api.post) @@ -270,14 +276,13 @@ class TestInstalledAppsCreateApi: app.test_request_context("/", json={"app_id": "a1"}), payload_patch({"app_id": "a1"}), patch.object(module.db, "session", session), - patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)), ): - result = method(api) + result = method(api, tenant_id) assert result == {"message": "App installed successfully"} assert recommended.install_count == 1 - def test_post_recommended_not_found(self, app: Flask, payload_patch): + def test_post_recommended_not_found(self, app: Flask, tenant_id: str, payload_patch: PayloadPatch) -> None: api = module.InstalledAppsListApi() method = unwrap(api.post) @@ -290,9 +295,9 @@ class TestInstalledAppsCreateApi: patch.object(module.db, "session", session), ): with pytest.raises(NotFound): - method(api) + method(api, tenant_id) - def test_post_app_not_public(self, app: Flask, tenant_id, payload_patch): + def test_post_app_not_public(self, app: Flask, tenant_id: str, payload_patch: PayloadPatch) -> None: api = module.InstalledAppsListApi() method = unwrap(api.post) @@ -309,37 +314,32 @@ class TestInstalledAppsCreateApi: app.test_request_context("/", json={"app_id": "a1"}), payload_patch({"app_id": "a1"}), patch.object(module.db, "session", session), - patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)), ): with pytest.raises(Forbidden): - method(api) + method(api, tenant_id) class TestInstalledAppApi: - def test_delete_success(self, tenant_id: str, installed_app): + def test_delete_success(self, tenant_id: str, installed_app: MagicMock) -> None: api = module.InstalledAppApi() method = unwrap(api.delete) - with ( - patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)), - patch.object(module.db, "session"), - ): - resp, status = method(installed_app) + with patch.object(module.db, "session"): + resp, status = method(api, tenant_id, installed_app) assert status == 204 assert resp == "" - def test_delete_owned_by_current_tenant(self, tenant_id: str): + def test_delete_owned_by_current_tenant(self, tenant_id: str) -> None: api = module.InstalledAppApi() method = unwrap(api.delete) installed_app = MagicMock(app_owner_tenant_id=tenant_id) - with patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)): - with pytest.raises(BadRequest): - method(installed_app) + with pytest.raises(BadRequest): + method(api, tenant_id, installed_app) - def test_patch_update_pin(self, app: Flask, payload_patch, installed_app): + def test_patch_update_pin(self, app: Flask, payload_patch: PayloadPatch, installed_app: MagicMock) -> None: api = module.InstalledAppApi() method = unwrap(api.patch) @@ -353,7 +353,7 @@ class TestInstalledAppApi: assert installed_app.is_pinned is True assert result["result"] == "success" - def test_patch_no_change(self, app: Flask, payload_patch, installed_app): + def test_patch_no_change(self, app: Flask, payload_patch: PayloadPatch, installed_app: MagicMock) -> None: api = module.InstalledAppApi() method = unwrap(api.patch) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index a81a8e1b1a..d938558806 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -33,16 +33,12 @@ class TestModelProviderListApi: with ( app.test_request_context("/?model_type=llm"), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch( "controllers.console.workspace.model_providers.ModelProviderService.get_provider_list", return_value=[{"name": "openai"}], ), ): - result = method(api) + result = method(api, "tenant1") assert "data" in result @@ -54,16 +50,12 @@ class TestModelProviderCredentialApi: with ( app.test_request_context(f"/?credential_id={VALID_UUID}"), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch( "controllers.console.workspace.model_providers.ModelProviderService.get_provider_credential", return_value={"key": "value"}, ), ): - result = method(api, provider="openai") + result = method(api, "tenant1", provider="openai") assert "credentials" in result @@ -71,15 +63,9 @@ class TestModelProviderCredentialApi: api = ModelProviderCredentialApi() method = unwrap(api.get) - with ( - app.test_request_context(f"/?credential_id={INVALID_UUID}"), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), - ): + with app.test_request_context(f"/?credential_id={INVALID_UUID}"): with pytest.raises(ValidationError): - method(api, provider="openai") + method(api, "tenant1", provider="openai") def test_post_create_success(self, app: Flask): api = ModelProviderCredentialApi() @@ -89,16 +75,12 @@ class TestModelProviderCredentialApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch( "controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential", return_value=None, ), ): - result, status = method(api, provider="openai") + result, status = method(api, "tenant1", provider="openai") assert result["result"] == "success" assert status == 201 @@ -111,17 +93,13 @@ class TestModelProviderCredentialApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch( "controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential", side_effect=CredentialsValidateFailedError("bad"), ), ): with pytest.raises(ValueError): - method(api, provider="openai") + method(api, "tenant1", provider="openai") def test_put_update_success(self, app: Flask): api = ModelProviderCredentialApi() @@ -131,16 +109,12 @@ class TestModelProviderCredentialApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch( "controllers.console.workspace.model_providers.ModelProviderService.update_provider_credential", return_value=None, ), ): - result = method(api, provider="openai") + result = method(api, "tenant1", provider="openai") assert result["result"] == "success" @@ -150,15 +124,9 @@ class TestModelProviderCredentialApi: payload = {"credential_id": INVALID_UUID, "credentials": {"a": "b"}} - with ( - app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), - ): + with app.test_request_context("/", json=payload): with pytest.raises(ValidationError): - method(api, provider="openai") + method(api, "tenant1", provider="openai") def test_delete_success(self, app: Flask): api = ModelProviderCredentialApi() @@ -168,16 +136,12 @@ class TestModelProviderCredentialApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch( "controllers.console.workspace.model_providers.ModelProviderService.remove_provider_credential", return_value=None, ), ): - result, status = method(api, provider="openai") + result, status = method(api, "tenant1", provider="openai") assert status == 204 assert result == "" @@ -192,16 +156,12 @@ class TestModelProviderCredentialSwitchApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch( "controllers.console.workspace.model_providers.ModelProviderService.switch_active_provider_credential", return_value=None, ), ): - result = method(api, provider="openai") + result = method(api, "tenant1", provider="openai") assert result["result"] == "success" @@ -211,15 +171,9 @@ class TestModelProviderCredentialSwitchApi: payload = {"credential_id": INVALID_UUID} - with ( - app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), - ): + with app.test_request_context("/", json=payload): with pytest.raises(ValidationError): - method(api, provider="openai") + method(api, "tenant1", provider="openai") class TestModelProviderValidateApi: @@ -231,16 +185,12 @@ class TestModelProviderValidateApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch( "controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials", return_value=None, ), ): - result = method(api, provider="openai") + result = method(api, "tenant1", provider="openai") assert result["result"] == "success" @@ -252,16 +202,12 @@ class TestModelProviderValidateApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch( "controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials", side_effect=CredentialsValidateFailedError("bad"), ), ): - result = method(api, provider="openai") + result = method(api, "tenant1", provider="openai") assert result["result"] == "error" @@ -304,16 +250,12 @@ class TestPreferredProviderTypeUpdateApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), patch( "controllers.console.workspace.model_providers.ModelProviderService.switch_preferred_provider", return_value=None, ), ): - result = method(api, provider="openai") + result = method(api, "tenant1", provider="openai") assert result["result"] == "success" @@ -323,15 +265,9 @@ class TestPreferredProviderTypeUpdateApi: payload = {"preferred_provider_type": "invalid"} - with ( - app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(MagicMock(), "tenant1"), - ), - ): + with app.test_request_context("/", json=payload): with pytest.raises(ValidationError): - method(api, provider="openai") + method(api, "tenant1", provider="openai") class TestModelProviderPaymentCheckoutUrlApi: @@ -343,10 +279,6 @@ class TestModelProviderPaymentCheckoutUrlApi: with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(user, "tenant1"), - ), patch( "controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin", return_value=None, @@ -356,7 +288,7 @@ class TestModelProviderPaymentCheckoutUrlApi: return_value={"url": "x"}, ), ): - result = method(api, provider="anthropic") + result = method(api, "tenant1", user, provider="anthropic") assert "url" in result @@ -366,7 +298,7 @@ class TestModelProviderPaymentCheckoutUrlApi: with app.test_request_context("/"): with pytest.raises(ValueError): - method(api, provider="openai") + method(api, "tenant1", MagicMock(), provider="openai") def test_permission_denied(self, app: Flask): api = ModelProviderPaymentCheckoutUrlApi() @@ -376,14 +308,10 @@ class TestModelProviderPaymentCheckoutUrlApi: with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.model_providers.current_account_with_tenant", - return_value=(user, "tenant1"), - ), patch( "controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin", side_effect=Forbidden(), ), ): with pytest.raises(Forbidden): - method(api, provider="anthropic") + method(api, "tenant1", user, provider="anthropic") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index 564505d32b..00977e6d7b 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from unittest.mock import patch import pytest @@ -172,7 +173,7 @@ class TestModelProviderModelCredentialApi: provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = [] lb_service.return_value.get_load_balancing_configs.return_value = (False, []) - result = method(api, "tenant1", "openai") + result = method(api, "tenant1", SimpleNamespace(id="u1"), "openai") assert "credentials" in result @@ -207,7 +208,7 @@ class TestModelProviderModelCredentialApi: service.return_value.provider_manager.get_provider_model_available_credentials.return_value = [] lb.return_value.get_load_balancing_configs.return_value = (False, []) - result = method(api, "t1", "openai") + result = method(api, "t1", SimpleNamespace(id="u1"), "openai") assert result["credentials"] == {}