refactor(api): migrate tenant/user via DI for several endpoints (#37004)

This commit is contained in:
chariri 2026-06-03 17:59:00 +09:00 committed by GitHub
parent 9de40e8f21
commit 57b573d02b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 306 additions and 309 deletions

View File

@ -3,7 +3,13 @@ from flask_restx import Resource
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from fields.agent_fields import (
AgentAppComposerResponse,
AgentComposerCandidatesResponse,
@ -12,7 +18,7 @@ from fields.agent_fields import (
WorkflowAgentComposerResponse,
)
from libs.helper import dump_response
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.model import App, AppMode
from services.agent.composer_service import AgentComposerService
from services.agent.composer_validator import ComposerConfigValidator
@ -38,8 +44,8 @@ class WorkflowAgentComposerApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def get(self, app_model: App, node_id: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App, node_id: str):
return dump_response(
WorkflowAgentComposerResponse,
AgentComposerService.load_workflow_composer(
@ -58,8 +64,9 @@ class WorkflowAgentComposerApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def put(self, app_model: App, node_id: str):
account, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def put(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return dump_response(
WorkflowAgentComposerResponse,
@ -67,7 +74,7 @@ class WorkflowAgentComposerApi(Resource):
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account.id,
account_id=account_id,
payload=payload,
),
)
@ -113,8 +120,8 @@ class WorkflowAgentComposerImpactApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def post(self, app_model: App, node_id: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str, app_model: App, node_id: str):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
if not current_snapshot_id:
@ -138,8 +145,9 @@ class WorkflowAgentComposerSaveToRosterApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
def post(self, app_model: App, node_id: str):
account, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return dump_response(
WorkflowAgentComposerResponse,
@ -147,7 +155,7 @@ class WorkflowAgentComposerSaveToRosterApi(Resource):
tenant_id=tenant_id,
app_id=app_model.id,
node_id=node_id,
account_id=account.id,
account_id=account_id,
payload=payload,
),
)
@ -160,8 +168,8 @@ class AgentAppComposerApi(Resource):
@login_required
@account_initialization_required
@get_app_model()
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App):
return dump_response(
AgentAppComposerResponse,
AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id),
@ -174,15 +182,16 @@ class AgentAppComposerApi(Resource):
@account_initialization_required
@edit_permission_required
@get_app_model()
def put(self, app_model: App):
account, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def put(self, tenant_id: str, account_id: str, app_model: App):
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
return dump_response(
AgentAppComposerResponse,
AgentComposerService.save_agent_app_composer(
tenant_id=tenant_id,
app_id=app_model.id,
account_id=account.id,
account_id=account_id,
payload=payload,
),
)

View File

@ -6,7 +6,13 @@ from pydantic import BaseModel, Field
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from extensions.ext_database import db
from fields.agent_fields import (
AgentConfigSnapshotDetailResponse,
@ -16,7 +22,7 @@ from fields.agent_fields import (
AgentRosterResponse,
)
from libs.helper import dump_response
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from services.agent.roster_service import AgentRosterService
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload, RosterListQuery
@ -58,8 +64,8 @@ class AgentRosterListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
query = RosterListQuery.model_validate(request.args.to_dict(flat=True))
return dump_response(
AgentRosterListResponse,
@ -74,11 +80,12 @@ class AgentRosterListApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
account, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, account_id: str):
payload = RosterAgentCreatePayload.model_validate(console_ns.payload or {})
service = _agent_roster_service()
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account.id, payload=payload)
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account_id, payload=payload)
return dump_response(
AgentRosterResponse,
service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id),
@ -92,8 +99,8 @@ class AgentInviteOptionsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
query = AgentInviteOptionsQuery.model_validate(request.args.to_dict(flat=True))
return dump_response(
AgentInviteOptionsResponse,
@ -113,8 +120,8 @@ class AgentRosterDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID):
return dump_response(
AgentRosterResponse,
_agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id)),
@ -126,13 +133,14 @@ class AgentRosterDetailApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, agent_id: UUID):
account, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def patch(self, tenant_id: str, account_id: str, agent_id: UUID):
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
return dump_response(
AgentRosterResponse,
_agent_roster_service().update_roster_agent(
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id, payload=payload
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id, payload=payload
),
)
@ -141,9 +149,10 @@ class AgentRosterDetailApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, agent_id: UUID):
account, tenant_id = current_account_with_tenant()
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
@with_current_user_id
@with_current_tenant_id
def delete(self, tenant_id: str, account_id: str, agent_id: UUID):
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id)
return "", 204
@ -153,8 +162,8 @@ class AgentRosterVersionsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID):
return dump_response(
AgentConfigSnapshotListResponse,
{"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))},
@ -167,8 +176,8 @@ class AgentRosterVersionDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID, version_id: UUID):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, agent_id: UUID, version_id: UUID):
return dump_response(
AgentConfigSnapshotDetailResponse,
_agent_roster_service().get_agent_version_detail(

View File

@ -25,6 +25,9 @@ from controllers.console.wraps import (
enterprise_license_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
with_current_user_id,
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
@ -34,8 +37,8 @@ from extensions.ext_database import db
from fields.base import ResponseModel
from graphon.enums import WorkflowExecutionStatus
from libs.helper import build_icon_url, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from libs.login import login_required
from models import Account, App, DatasetPermissionEnum, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService
from services.app_service import AppListParams, AppService, CreateAppParams
@ -472,10 +475,10 @@ class AppListApi(Resource):
@account_initialization_required
@enterprise_license_required
@with_session(write=False)
def get(self, session: Session):
@with_current_user_id
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user_id: str, session: Session):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
params = AppListParams(
page=args.page,
@ -488,7 +491,7 @@ class AppListApi(Resource):
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params)
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
if not app_pagination:
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json"), 200
@ -548,9 +551,10 @@ class AppListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
"""Create app"""
current_user, current_tenant_id = current_account_with_tenant()
args = CreateAppPayload.model_validate(console_ns.payload)
params = CreateAppParams(
name=args.name,
@ -653,11 +657,10 @@ class AppCopyApi(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model: App):
@with_current_user
def post(self, current_user: Account, app_model: App):
"""Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = CopyAppPayload.model_validate(console_ns.payload or {})
with Session(db.engine, expire_on_commit=False) as session:
@ -736,7 +739,8 @@ class AppPublishToCreatorsPlatformApi(Resource):
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model: App):
@with_current_user_id
def post(self, current_user_id: str, app_model: App):
"""Publish app to Creators Platform"""
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
@ -744,13 +748,11 @@ class AppPublishToCreatorsPlatformApi(Resource):
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"error": "Creators Platform features are not enabled"}, 403
current_user, _ = current_account_with_tenant()
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
dsl_bytes = dsl_content.encode("utf-8")
claim_code = upload_dsl(dsl_bytes)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
redirect_url = get_redirect_url(current_user_id, claim_code)
return {"redirect_url": redirect_url}

View File

@ -8,14 +8,20 @@ from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.model import App, AppMode, AppModelConfig
from services.app_model_config_service import AppModelConfigService
@ -52,9 +58,10 @@ class ModelConfigResource(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model: App):
@with_current_user_id
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user_id: str, app_model: App):
"""Modify app model config"""
current_user, current_tenant_id = current_account_with_tenant()
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_tenant_id,
@ -64,8 +71,8 @@ class ModelConfigResource(Resource):
new_app_model_config = AppModelConfig(
app_id=app_model.id,
created_by=current_user.id,
updated_by=current_user.id,
created_by=current_user_id,
updated_by=current_user_id,
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
@ -90,7 +97,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user.id,
user_id=current_user_id,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_tenant_id,
@ -130,7 +137,7 @@ class ModelConfigResource(Resource):
tenant_id=current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
user_id=current_user.id,
user_id=current_user_id,
)
except Exception:
continue
@ -167,7 +174,7 @@ class ModelConfigResource(Resource):
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
app_model.updated_by = current_user.id
app_model.updated_by = current_user_id
app_model.updated_at = naive_utc_now()
db.session.commit()

View File

@ -12,14 +12,19 @@ from controllers.common.fields import SimpleMessageResponse, SimpleResultMessage
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from graphon.file import helpers as file_helpers
from libs.datetime_utils import naive_utc_now
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from libs.login import login_required
from models import Account, App, InstalledApp, RecommendedApp
from models.model import IconType
from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
@ -131,9 +136,10 @@ class InstalledAppsListApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
def get(self):
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()
if query.app_id:
installed_apps = db.session.scalars(
@ -212,7 +218,8 @@ class InstalledAppsListApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
@console_ns.response(200, "Success", console_ns.models[SimpleMessageResponse.__name__])
def post(self):
@with_current_tenant_id
def post(self, current_tenant_id: str):
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
recommended_app = db.session.scalar(
@ -221,8 +228,6 @@ class InstalledAppsListApi(Resource):
if recommended_app is None:
raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant()
app = db.session.get(App, payload.app_id)
if app is None:
@ -262,8 +267,8 @@ class InstalledAppApi(InstalledAppResource):
"""
@console_ns.response(204, "App uninstalled successfully")
def delete(self, installed_app: InstalledApp):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def delete(self, current_tenant_id: str, installed_app: InstalledApp):
if installed_app.app_owner_tenant_id == current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant")

View File

@ -8,12 +8,19 @@ from pydantic import BaseModel, Field, field_validator
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@ -95,10 +102,8 @@ class ModelProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
@with_current_tenant_id
def get(self, tenant_id: str):
payload = request.args.to_dict(flat=True)
args = ParserModelList.model_validate(payload)
@ -114,9 +119,8 @@ class ModelProviderCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
# if credential_id is not provided, return current used credential
payload = request.args.to_dict(flat=True)
args = ParserCredentialId.model_validate(payload)
@ -133,8 +137,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialCreate.model_validate(payload)
@ -157,9 +161,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def put(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialUpdate.model_validate(payload)
@ -184,8 +187,8 @@ class ModelProviderCredentialApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def delete(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialDelete.model_validate(payload)
@ -205,8 +208,8 @@ class ModelProviderCredentialSwitchApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialSwitch.model_validate(payload)
@ -225,8 +228,8 @@ class ModelProviderValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserCredentialValidate.model_validate(payload)
@ -280,11 +283,8 @@ class PreferredProviderTypeUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
payload = console_ns.payload or {}
args = ParserPreferredProviderType.model_validate(payload)
@ -301,10 +301,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, provider: str):
if provider != "anthropic":
raise ValueError(f"provider name {provider} is invalid")
current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user)
data = BillingService.get_model_provider_payment_link(
provider_name=provider,

View File

@ -13,12 +13,14 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
@ -269,8 +271,9 @@ class ModelProviderModelCredentialApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
def get(self, tenant_id: str, user: Account, provider: str):
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
model_provider_service = ModelProviderService()
@ -293,9 +296,8 @@ class ModelProviderModelCredentialApi(Resource):
if args.config_from == "predefined-model":
# Only the predefined-model branch needs visibility filtering by user.
# Defer the auth lookup so the other branch (and its tests) doesn't
# require flask-login setup.
user, _ = current_account_with_tenant()
# The account is injected once by the handler and only passed into the
# service branch that needs user-scoped credential visibility.
available_credentials = model_provider_service.get_provider_available_credentials(
tenant_id=tenant_id,
provider=provider,

View File

@ -26,7 +26,12 @@ from werkzeug.exceptions import BadRequest
from configs import dify_config
from controllers.common.schema import query_params_from_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
AccountPayload,
@ -42,7 +47,6 @@ from controllers.openapi._models import (
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
from libs.rate_limit import (
LIMIT_APPROVE_CONSOLE,
@ -50,6 +54,7 @@ from libs.rate_limit import (
LIMIT_LOOKUP_PUBLIC,
rate_limit,
)
from models import Account
from services.account_service import TenantService
from services.oauth_device_flow import (
ACCOUNT_ISSUER_SENTINEL,
@ -206,11 +211,12 @@ class DeviceApproveApi(Resource):
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, tenant: str, account: Account):
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
account, tenant = current_account_with_tenant()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)

View File

@ -1,6 +1,8 @@
from types import SimpleNamespace
from typing import Protocol, cast
import pytest
from flask import Flask
from controllers.console.agent import composer as composer_controller
from controllers.console.agent import roster as roster_controller
@ -114,34 +116,34 @@ def _candidates_response(variant: str) -> dict:
}
class _PayloadWithDescription(Protocol):
description: object
@pytest.fixture
def account():
return SimpleNamespace(id="account-1")
def account_id() -> str:
return "account-1"
@pytest.fixture(autouse=True)
def patch_account_context(monkeypatch, account):
monkeypatch.setattr(roster_controller, "current_account_with_tenant", lambda: (account, "tenant-1"))
monkeypatch.setattr(composer_controller, "current_account_with_tenant", lambda: (account, "tenant-1"))
def test_roster_list_get_parses_query_and_calls_service(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
captured: dict[str, object] = {}
def test_roster_list_get_parses_query_and_calls_service(app, monkeypatch):
captured = {}
def list_roster_agents(_self, **kwargs):
def list_roster_agents(_self: object, **kwargs: object) -> dict[str, object]:
captured.update(kwargs)
return {"data": [], "page": kwargs["page"], "limit": kwargs["limit"], "total": 0, "has_more": False}
monkeypatch.setattr(roster_controller.AgentRosterService, "list_roster_agents", list_roster_agents)
with app.test_request_context("/console/api/agents?page=2&limit=5&keyword=analyst"):
result = _unwrap(AgentRosterListApi.get)(AgentRosterListApi())
result = _unwrap(AgentRosterListApi.get)(AgentRosterListApi(), "tenant-1")
assert result["page"] == 2
assert captured == {"tenant_id": "tenant-1", "page": 2, "limit": 5, "keyword": "analyst"}
def test_roster_list_post_creates_agent_and_returns_detail(app, monkeypatch):
def test_roster_list_post_creates_agent_and_returns_detail(
app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str
) -> None:
created_agent = SimpleNamespace(id="agent-1")
monkeypatch.setattr(
roster_controller.AgentRosterService,
@ -155,42 +157,47 @@ def test_roster_list_post_creates_agent_and_returns_detail(app, monkeypatch):
)
with app.test_request_context(json={"name": "Analyst", "agent_soul": {"prompt": {"system_prompt": "x"}}}):
result, status = _unwrap(AgentRosterListApi.post)(AgentRosterListApi())
result, status = _unwrap(AgentRosterListApi.post)(AgentRosterListApi(), "tenant-1", account_id)
assert status == 201
assert result["id"] == "agent-1"
assert result["agent_kind"] == "dify_agent"
def test_invite_options_get_parses_app_id(app, monkeypatch):
captured = {}
def test_invite_options_get_parses_app_id(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
captured: dict[str, object] = {}
def list_invite_options(_self, **kwargs):
def list_invite_options(_self: object, **kwargs: object) -> dict[str, object]:
captured.update(kwargs)
return {"data": [], "page": kwargs["page"], "limit": kwargs["limit"], "total": 0, "has_more": False}
monkeypatch.setattr(roster_controller.AgentRosterService, "list_invite_options", list_invite_options)
with app.test_request_context("/console/api/agents/invite-options?page=1&limit=10&app_id=app-1"):
result = _unwrap(AgentInviteOptionsApi.get)(AgentInviteOptionsApi())
result = _unwrap(AgentInviteOptionsApi.get)(AgentInviteOptionsApi(), "tenant-1")
assert result == {"data": [], "page": 1, "limit": 10, "total": 0, "has_more": False}
assert captured == {"tenant_id": "tenant-1", "page": 1, "limit": 10, "keyword": None, "app_id": "app-1"}
def test_roster_detail_patch_delete_and_versions_call_services(app, monkeypatch):
def test_roster_detail_patch_delete_and_versions_call_services(
app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str
) -> None:
agent_id = "00000000-0000-0000-0000-000000000001"
version_id = "00000000-0000-0000-0000-000000000002"
archived = {}
archived: dict[str, object] = {}
monkeypatch.setattr(
roster_controller.AgentRosterService,
"get_roster_agent_detail",
lambda _self, **kwargs: _agent_response(kwargs["agent_id"]),
lambda _self, **kwargs: _agent_response(cast(str, kwargs["agent_id"])),
)
monkeypatch.setattr(
roster_controller.AgentRosterService,
"update_roster_agent",
lambda _self, **kwargs: {**_agent_response(kwargs["agent_id"]), "description": kwargs["payload"].description},
lambda _self, **kwargs: {
**_agent_response(cast(str, kwargs["agent_id"])),
"description": cast(_PayloadWithDescription, kwargs["payload"]).description,
},
)
monkeypatch.setattr(
roster_controller.AgentRosterService,
@ -206,7 +213,7 @@ def test_roster_detail_patch_delete_and_versions_call_services(app, monkeypatch)
roster_controller.AgentRosterService,
"get_agent_version_detail",
lambda _self, **kwargs: {
**_version_response(kwargs["version_id"]),
**_version_response(cast(str, kwargs["version_id"])),
"agent_id": kwargs["agent_id"],
"config_snapshot": {},
"revisions": [
@ -225,18 +232,28 @@ def test_roster_detail_patch_delete_and_versions_call_services(app, monkeypatch)
},
)
assert _unwrap(AgentRosterDetailApi.get)(AgentRosterDetailApi(), agent_id)["id"] == agent_id
assert _unwrap(AgentRosterDetailApi.get)(AgentRosterDetailApi(), "tenant-1", agent_id)["id"] == agent_id
with app.test_request_context(json={"description": "updated"}):
assert _unwrap(AgentRosterDetailApi.patch)(AgentRosterDetailApi(), agent_id)["description"] == "updated"
assert _unwrap(AgentRosterDetailApi.delete)(AgentRosterDetailApi(), agent_id) == ("", 204)
assert (
_unwrap(AgentRosterDetailApi.patch)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id)["description"]
== "updated"
)
assert _unwrap(AgentRosterDetailApi.delete)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id) == ("", 204)
assert archived["account_id"] == "account-1"
assert _unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), agent_id)["data"][0]["id"] == "version-1"
version_detail = _unwrap(AgentRosterVersionDetailApi.get)(AgentRosterVersionDetailApi(), agent_id, version_id)
assert (
_unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), "tenant-1", agent_id)["data"][0]["id"]
== "version-1"
)
version_detail = _unwrap(AgentRosterVersionDetailApi.get)(
AgentRosterVersionDetailApi(), "tenant-1", agent_id, version_id
)
assert version_detail["id"] == version_id
assert version_detail["agent_id"] == agent_id
def test_workflow_composer_get_put_validate_candidates_impact_and_save(app, monkeypatch):
def test_workflow_composer_get_put_validate_candidates_impact_and_save(
app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str
) -> None:
app_model = SimpleNamespace(id="app-1")
payload = {
"variant": ComposerVariant.WORKFLOW.value,
@ -269,10 +286,12 @@ def test_workflow_composer_get_put_validate_candidates_impact_and_save(app, monk
},
)
workflow_state = _unwrap(WorkflowAgentComposerApi.get)(WorkflowAgentComposerApi(), app_model, "node-1")
workflow_state = _unwrap(WorkflowAgentComposerApi.get)(WorkflowAgentComposerApi(), "tenant-1", app_model, "node-1")
assert workflow_state["node_id"] == "node-1"
with app.test_request_context(json=payload):
saved_state = _unwrap(WorkflowAgentComposerApi.put)(WorkflowAgentComposerApi(), app_model, "node-1")
saved_state = _unwrap(WorkflowAgentComposerApi.put)(
WorkflowAgentComposerApi(), "tenant-1", account_id, app_model, "node-1"
)
assert saved_state["save_options"] == ["node_job_only"]
assert _unwrap(WorkflowAgentComposerValidateApi.post)(
WorkflowAgentComposerValidateApi(), app_model, "node-1"
@ -284,28 +303,28 @@ def test_workflow_composer_get_put_validate_candidates_impact_and_save(app, monk
== "workflow"
)
with app.test_request_context(json=payload):
assert _unwrap(WorkflowAgentComposerImpactApi.post)(WorkflowAgentComposerImpactApi(), app_model, "node-1") == {
"current_snapshot_id": "version-1",
"workflow_node_count": 1,
"bindings": [],
}
assert _unwrap(WorkflowAgentComposerImpactApi.post)(
WorkflowAgentComposerImpactApi(), "tenant-1", app_model, "node-1"
) == {"current_snapshot_id": "version-1", "workflow_node_count": 1, "bindings": []}
assert _unwrap(WorkflowAgentComposerSaveToRosterApi.post)(
WorkflowAgentComposerSaveToRosterApi(), app_model, "node-1"
WorkflowAgentComposerSaveToRosterApi(), "tenant-1", account_id, app_model, "node-1"
)["save_options"] == ["node_job_only"]
def test_workflow_impact_returns_empty_without_version(app):
def test_workflow_impact_returns_empty_without_version(app: Flask) -> None:
payload = {"variant": ComposerVariant.WORKFLOW.value, "save_strategy": ComposerSaveStrategy.NODE_JOB_ONLY.value}
with app.test_request_context(json=payload):
result = _unwrap(WorkflowAgentComposerImpactApi.post)(
WorkflowAgentComposerImpactApi(), SimpleNamespace(id="app-1"), "node-1"
WorkflowAgentComposerImpactApi(), "tenant-1", SimpleNamespace(id="app-1"), "node-1"
)
assert result == {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
def test_agent_app_composer_get_put_validate_and_candidates(app, monkeypatch):
def test_agent_app_composer_get_put_validate_and_candidates(
app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str
) -> None:
app_model = SimpleNamespace(id="app-1")
payload = {
"variant": ComposerVariant.AGENT_APP.value,
@ -329,9 +348,12 @@ def test_agent_app_composer_get_put_validate_and_candidates(app, monkeypatch):
lambda **kwargs: _candidates_response("agent_app"),
)
assert _unwrap(AgentAppComposerApi.get)(AgentAppComposerApi(), app_model)["variant"] == "agent_app"
assert _unwrap(AgentAppComposerApi.get)(AgentAppComposerApi(), "tenant-1", app_model)["variant"] == "agent_app"
with app.test_request_context(json=payload):
assert _unwrap(AgentAppComposerApi.put)(AgentAppComposerApi(), app_model)["variant"] == "agent_app"
assert (
_unwrap(AgentAppComposerApi.put)(AgentAppComposerApi(), "tenant-1", account_id, app_model)["variant"]
== "agent_app"
)
assert _unwrap(AgentAppComposerValidateApi.post)(AgentAppComposerValidateApi(), app_model) == {
"result": "success",
"errors": [],

View File

@ -6,10 +6,10 @@ from datetime import datetime
from importlib import util
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from flask import Flask
from flask.views import MethodView
from pydantic import ValidationError
from werkzeug.datastructures import MultiDict
@ -19,6 +19,13 @@ if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
class _ConsoleModule(ModuleType):
console_ns: object
api: object | None
bp: object | None
app: ModuleType
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
@ -36,7 +43,7 @@ def app_module():
class _StubNamespace:
def __init__(self):
self.models: dict[str, Any] = {}
self.models: dict[str, object] = {}
self.payload = None
def schema_model(self, name, schema):
@ -77,7 +84,7 @@ def app_module():
}
stubbed_modules: list[tuple[str, ModuleType | None]] = []
console_module = ModuleType("controllers.console")
console_module = _ConsoleModule("controllers.console")
console_module.__path__ = [str(root / "controllers" / "console")]
console_module.console_ns = stub_namespace
console_module.api = None
@ -89,7 +96,7 @@ def app_module():
sys.modules["controllers.console.app"] = app_package
console_module.app = app_package
def _stub_module(name: str, attrs: dict[str, Any]):
def _stub_module(name: str, attrs: dict[str, object]) -> None:
original = sys.modules.get(name)
module = ModuleType(name)
for key, value in attrs.items():
@ -99,7 +106,7 @@ def app_module():
class _OpsTraceManager:
@staticmethod
def get_app_tracing_config(app_id: str) -> dict[str, Any]:
def get_app_tracing_config(app_id: str) -> dict[str, object]:
return {}
@staticmethod
@ -116,6 +123,7 @@ def app_module():
)
spec = util.spec_from_file_location(module_name, module_path)
assert spec is not None
module = util.module_from_spec(spec)
sys.modules[module_name] = module
@ -147,7 +155,7 @@ def app_models(app_module):
@pytest.fixture(autouse=True)
def patch_signed_url(monkeypatch, app_module):
def patch_signed_url(monkeypatch: pytest.MonkeyPatch, app_module: ModuleType) -> None:
"""Ensure icon URL generation uses a deterministic helper for tests."""
def _fake_build_icon_url(_icon_type, key: str | None) -> str | None:
@ -407,10 +415,11 @@ def test_app_pagination_aliases_per_page_and_has_next(app_models):
assert serialized["data"][1]["icon_url"] is None
def test_app_list_uses_injected_session_for_draft_workflows(app, app_module, monkeypatch):
def test_app_list_uses_injected_session_for_draft_workflows(
app: Flask, app_module: ModuleType, monkeypatch: pytest.MonkeyPatch
) -> None:
api = app_module.AppListApi()
method = _unwrap(api.get)
current_user = SimpleNamespace(id="user-1")
app_item = SimpleNamespace(
id="app-1",
name="Workflow App",
@ -428,7 +437,6 @@ def test_app_list_uses_injected_session_for_draft_workflows(app, app_module, mon
session.execute.return_value.scalars.return_value.all.return_value = [workflow]
scoped_session = SimpleNamespace(execute=MagicMock(side_effect=AssertionError("db.session should not be used")))
monkeypatch.setattr(app_module, "current_account_with_tenant", lambda: (current_user, "tenant-1"))
monkeypatch.setattr(
app_module,
"AppService",
@ -442,7 +450,7 @@ def test_app_list_uses_injected_session_for_draft_workflows(app, app_module, mon
monkeypatch.setattr(app_module, "db", SimpleNamespace(session=scoped_session))
with app.test_request_context("/console/api/apps?page=1&limit=20", method="GET"):
response, status = method(session)
response, status = method("tenant-1", "user-1", session)
assert status == 200
assert response["data"][0]["has_draft_trigger"] is True

View File

@ -5,6 +5,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask
from controllers.console.app import model_config as model_config_module
from models.model import AppMode, AppModelConfig
@ -19,7 +20,7 @@ def _unwrap(func):
return func
def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_post_updates_app_model_config_for_chat(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = model_config_module.ModelConfigResource()
method = _unwrap(api.post)
@ -36,8 +37,6 @@ def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyP
"validate_configuration",
lambda **_kwargs: {"pre_prompt": "hi"},
)
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
session = MagicMock()
monkeypatch.setattr(model_config_module.db, "session", session)
@ -51,7 +50,7 @@ def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyP
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
response = method(app_model=app_model)
response = method("t1", "u1", app_model=app_model)
session.add.assert_called_once()
session.flush.assert_called_once()
@ -61,7 +60,7 @@ def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyP
assert response["result"] == "success"
def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_post_encrypts_agent_tool_parameters(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = model_config_module.ModelConfigResource()
method = _unwrap(api.post)
@ -115,8 +114,6 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc
},
},
)
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
monkeypatch.setattr(model_config_module.ToolManager, "get_agent_tool_runtime", lambda **_kwargs: object())
class _ParamManager:
@ -140,7 +137,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
response = method(app_model=app_model)
response = method("t1", "u1", app_model=app_model)
stored_config = session.add.call_args[0][0]
stored_agent_mode = json.loads(stored_config.agent_mode)

View File

@ -1,3 +1,5 @@
from collections.abc import Callable
from contextlib import AbstractContextManager
from datetime import datetime
from unittest.mock import MagicMock, PropertyMock, patch
@ -7,6 +9,9 @@ from werkzeug.exceptions import BadRequest, Forbidden, NotFound
import controllers.console.explore.installed_app as module
type Payload = dict[str, object]
type PayloadPatch = Callable[[Payload], AbstractContextManager[object]]
def unwrap(func):
while hasattr(func, "__wrapped__"):
@ -15,12 +20,12 @@ def unwrap(func):
@pytest.fixture
def tenant_id():
def tenant_id() -> str:
return "t1"
@pytest.fixture
def current_user(tenant_id):
def current_user(tenant_id: str) -> MagicMock:
user = MagicMock()
user.id = "u1"
user.current_tenant = MagicMock(id=tenant_id)
@ -28,7 +33,7 @@ def current_user(tenant_id):
@pytest.fixture
def installed_app():
def installed_app() -> MagicMock:
app = MagicMock()
app.id = "ia1"
app.app = MagicMock(id="a1")
@ -39,8 +44,8 @@ def installed_app():
@pytest.fixture
def payload_patch():
def _patch(payload):
def payload_patch() -> PayloadPatch:
def _patch(payload: Payload) -> AbstractContextManager[object]:
return patch.object(
type(module.console_ns),
"payload",
@ -52,7 +57,9 @@ def payload_patch():
class TestInstalledAppsListApi:
def test_get_installed_apps(self, app: Flask, current_user, tenant_id, installed_app):
def test_get_installed_apps(
self, app: Flask, current_user: MagicMock, tenant_id: str, installed_app: MagicMock
) -> None:
api = module.InstalledAppsListApi()
method = unwrap(api.get)
@ -61,7 +68,6 @@ class TestInstalledAppsListApi:
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="owner"),
patch.object(
@ -70,13 +76,13 @@ class TestInstalledAppsListApi:
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
),
):
result = method(api)
result = method(api, tenant_id, current_user)
assert "installed_apps" in result
assert result["installed_apps"][0]["editable"] is True
assert result["installed_apps"][0]["uninstallable"] is False
def test_get_installed_apps_with_app_id_filter(self, app: Flask, current_user, tenant_id):
def test_get_installed_apps_with_app_id_filter(self, app: Flask, current_user: MagicMock, tenant_id: str) -> None:
api = module.InstalledAppsListApi()
method = unwrap(api.get)
@ -85,7 +91,6 @@ class TestInstalledAppsListApi:
with (
app.test_request_context("/?app_id=a1"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="member"),
patch.object(
@ -94,11 +99,13 @@ class TestInstalledAppsListApi:
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
),
):
result = method(api)
result = method(api, tenant_id, current_user)
assert result == {"installed_apps": []}
def test_get_installed_apps_with_webapp_auth_enabled(self, app: Flask, current_user, tenant_id, installed_app):
def test_get_installed_apps_with_webapp_auth_enabled(
self, app: Flask, current_user: MagicMock, tenant_id: str, installed_app: MagicMock
) -> None:
"""Test filtering when webapp_auth is enabled."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
@ -111,7 +118,6 @@ class TestInstalledAppsListApi:
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="owner"),
patch.object(
@ -130,11 +136,13 @@ class TestInstalledAppsListApi:
return_value={"a1": True},
),
):
result = method(api)
result = method(api, tenant_id, current_user)
assert len(result["installed_apps"]) == 1
def test_get_installed_apps_with_webapp_auth_user_denied(self, app: Flask, current_user, tenant_id, installed_app):
def test_get_installed_apps_with_webapp_auth_user_denied(
self, app: Flask, current_user: MagicMock, tenant_id: str, installed_app: MagicMock
) -> None:
"""Test filtering when user doesn't have access."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
@ -147,7 +155,6 @@ class TestInstalledAppsListApi:
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="member"),
patch.object(
@ -166,11 +173,13 @@ class TestInstalledAppsListApi:
return_value={"a1": False},
),
):
result = method(api)
result = method(api, tenant_id, current_user)
assert result["installed_apps"] == []
def test_get_installed_apps_with_sso_verified_access(self, app: Flask, current_user, tenant_id, installed_app):
def test_get_installed_apps_with_sso_verified_access(
self, app: Flask, current_user: MagicMock, tenant_id: str, installed_app: MagicMock
) -> None:
"""Test that sso_verified access mode apps are skipped in filtering."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
@ -183,7 +192,6 @@ class TestInstalledAppsListApi:
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="owner"),
patch.object(
@ -197,11 +205,11 @@ class TestInstalledAppsListApi:
return_value={"a1": mock_webapp_setting},
),
):
result = method(api)
result = method(api, tenant_id, current_user)
assert len(result["installed_apps"]) == 0
def test_get_installed_apps_filters_null_apps(self, app: Flask, current_user, tenant_id):
def test_get_installed_apps_filters_null_apps(self, app: Flask, current_user: MagicMock, tenant_id: str) -> None:
"""Test that installed apps with null app are filtered out."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
@ -214,7 +222,6 @@ class TestInstalledAppsListApi:
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
patch.object(module.TenantService, "get_user_role", return_value="owner"),
patch.object(
@ -223,11 +230,11 @@ class TestInstalledAppsListApi:
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
),
):
result = method(api)
result = method(api, tenant_id, current_user)
assert result["installed_apps"] == []
def test_get_installed_apps_current_tenant_none(self, app: Flask, tenant_id, installed_app):
def test_get_installed_apps_current_tenant_none(self, app: Flask, tenant_id: str, installed_app: MagicMock) -> None:
"""Test error when current_user.current_tenant is None."""
api = module.InstalledAppsListApi()
method = unwrap(api.get)
@ -240,15 +247,14 @@ class TestInstalledAppsListApi:
with (
app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
patch.object(module.db, "session", session),
):
with pytest.raises(ValueError, match="current_user.current_tenant must not be None"):
method(api)
method(api, tenant_id, current_user)
class TestInstalledAppsCreateApi:
def test_post_success(self, app: Flask, tenant_id, payload_patch):
def test_post_success(self, app: Flask, tenant_id: str, payload_patch: PayloadPatch) -> None:
api = module.InstalledAppsListApi()
method = unwrap(api.post)
@ -270,14 +276,13 @@ class TestInstalledAppsCreateApi:
app.test_request_context("/", json={"app_id": "a1"}),
payload_patch({"app_id": "a1"}),
patch.object(module.db, "session", session),
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
):
result = method(api)
result = method(api, tenant_id)
assert result == {"message": "App installed successfully"}
assert recommended.install_count == 1
def test_post_recommended_not_found(self, app: Flask, payload_patch):
def test_post_recommended_not_found(self, app: Flask, tenant_id: str, payload_patch: PayloadPatch) -> None:
api = module.InstalledAppsListApi()
method = unwrap(api.post)
@ -290,9 +295,9 @@ class TestInstalledAppsCreateApi:
patch.object(module.db, "session", session),
):
with pytest.raises(NotFound):
method(api)
method(api, tenant_id)
def test_post_app_not_public(self, app: Flask, tenant_id, payload_patch):
def test_post_app_not_public(self, app: Flask, tenant_id: str, payload_patch: PayloadPatch) -> None:
api = module.InstalledAppsListApi()
method = unwrap(api.post)
@ -309,37 +314,32 @@ class TestInstalledAppsCreateApi:
app.test_request_context("/", json={"app_id": "a1"}),
payload_patch({"app_id": "a1"}),
patch.object(module.db, "session", session),
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
):
with pytest.raises(Forbidden):
method(api)
method(api, tenant_id)
class TestInstalledAppApi:
def test_delete_success(self, tenant_id: str, installed_app):
def test_delete_success(self, tenant_id: str, installed_app: MagicMock) -> None:
api = module.InstalledAppApi()
method = unwrap(api.delete)
with (
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
patch.object(module.db, "session"),
):
resp, status = method(installed_app)
with patch.object(module.db, "session"):
resp, status = method(api, tenant_id, installed_app)
assert status == 204
assert resp == ""
def test_delete_owned_by_current_tenant(self, tenant_id: str):
def test_delete_owned_by_current_tenant(self, tenant_id: str) -> None:
api = module.InstalledAppApi()
method = unwrap(api.delete)
installed_app = MagicMock(app_owner_tenant_id=tenant_id)
with patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)):
with pytest.raises(BadRequest):
method(installed_app)
with pytest.raises(BadRequest):
method(api, tenant_id, installed_app)
def test_patch_update_pin(self, app: Flask, payload_patch, installed_app):
def test_patch_update_pin(self, app: Flask, payload_patch: PayloadPatch, installed_app: MagicMock) -> None:
api = module.InstalledAppApi()
method = unwrap(api.patch)
@ -353,7 +353,7 @@ class TestInstalledAppApi:
assert installed_app.is_pinned is True
assert result["result"] == "success"
def test_patch_no_change(self, app: Flask, payload_patch, installed_app):
def test_patch_no_change(self, app: Flask, payload_patch: PayloadPatch, installed_app: MagicMock) -> None:
api = module.InstalledAppApi()
method = unwrap(api.patch)

View File

@ -33,16 +33,12 @@ class TestModelProviderListApi:
with (
app.test_request_context("/?model_type=llm"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_list",
return_value=[{"name": "openai"}],
),
):
result = method(api)
result = method(api, "tenant1")
assert "data" in result
@ -54,16 +50,12 @@ class TestModelProviderCredentialApi:
with (
app.test_request_context(f"/?credential_id={VALID_UUID}"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_credential",
return_value={"key": "value"},
),
):
result = method(api, provider="openai")
result = method(api, "tenant1", provider="openai")
assert "credentials" in result
@ -71,15 +63,9 @@ class TestModelProviderCredentialApi:
api = ModelProviderCredentialApi()
method = unwrap(api.get)
with (
app.test_request_context(f"/?credential_id={INVALID_UUID}"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with app.test_request_context(f"/?credential_id={INVALID_UUID}"):
with pytest.raises(ValidationError):
method(api, provider="openai")
method(api, "tenant1", provider="openai")
def test_post_create_success(self, app: Flask):
api = ModelProviderCredentialApi()
@ -89,16 +75,12 @@ class TestModelProviderCredentialApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
return_value=None,
),
):
result, status = method(api, provider="openai")
result, status = method(api, "tenant1", provider="openai")
assert result["result"] == "success"
assert status == 201
@ -111,17 +93,13 @@ class TestModelProviderCredentialApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
side_effect=CredentialsValidateFailedError("bad"),
),
):
with pytest.raises(ValueError):
method(api, provider="openai")
method(api, "tenant1", provider="openai")
def test_put_update_success(self, app: Flask):
api = ModelProviderCredentialApi()
@ -131,16 +109,12 @@ class TestModelProviderCredentialApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.update_provider_credential",
return_value=None,
),
):
result = method(api, provider="openai")
result = method(api, "tenant1", provider="openai")
assert result["result"] == "success"
@ -150,15 +124,9 @@ class TestModelProviderCredentialApi:
payload = {"credential_id": INVALID_UUID, "credentials": {"a": "b"}}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with app.test_request_context("/", json=payload):
with pytest.raises(ValidationError):
method(api, provider="openai")
method(api, "tenant1", provider="openai")
def test_delete_success(self, app: Flask):
api = ModelProviderCredentialApi()
@ -168,16 +136,12 @@ class TestModelProviderCredentialApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.remove_provider_credential",
return_value=None,
),
):
result, status = method(api, provider="openai")
result, status = method(api, "tenant1", provider="openai")
assert status == 204
assert result == ""
@ -192,16 +156,12 @@ class TestModelProviderCredentialSwitchApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.switch_active_provider_credential",
return_value=None,
),
):
result = method(api, provider="openai")
result = method(api, "tenant1", provider="openai")
assert result["result"] == "success"
@ -211,15 +171,9 @@ class TestModelProviderCredentialSwitchApi:
payload = {"credential_id": INVALID_UUID}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with app.test_request_context("/", json=payload):
with pytest.raises(ValidationError):
method(api, provider="openai")
method(api, "tenant1", provider="openai")
class TestModelProviderValidateApi:
@ -231,16 +185,12 @@ class TestModelProviderValidateApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
return_value=None,
),
):
result = method(api, provider="openai")
result = method(api, "tenant1", provider="openai")
assert result["result"] == "success"
@ -252,16 +202,12 @@ class TestModelProviderValidateApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
side_effect=CredentialsValidateFailedError("bad"),
),
):
result = method(api, provider="openai")
result = method(api, "tenant1", provider="openai")
assert result["result"] == "error"
@ -304,16 +250,12 @@ class TestPreferredProviderTypeUpdateApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.ModelProviderService.switch_preferred_provider",
return_value=None,
),
):
result = method(api, provider="openai")
result = method(api, "tenant1", provider="openai")
assert result["result"] == "success"
@ -323,15 +265,9 @@ class TestPreferredProviderTypeUpdateApi:
payload = {"preferred_provider_type": "invalid"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(MagicMock(), "tenant1"),
),
):
with app.test_request_context("/", json=payload):
with pytest.raises(ValidationError):
method(api, provider="openai")
method(api, "tenant1", provider="openai")
class TestModelProviderPaymentCheckoutUrlApi:
@ -343,10 +279,6 @@ class TestModelProviderPaymentCheckoutUrlApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(user, "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
return_value=None,
@ -356,7 +288,7 @@ class TestModelProviderPaymentCheckoutUrlApi:
return_value={"url": "x"},
),
):
result = method(api, provider="anthropic")
result = method(api, "tenant1", user, provider="anthropic")
assert "url" in result
@ -366,7 +298,7 @@ class TestModelProviderPaymentCheckoutUrlApi:
with app.test_request_context("/"):
with pytest.raises(ValueError):
method(api, provider="openai")
method(api, "tenant1", MagicMock(), provider="openai")
def test_permission_denied(self, app: Flask):
api = ModelProviderPaymentCheckoutUrlApi()
@ -376,14 +308,10 @@ class TestModelProviderPaymentCheckoutUrlApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.model_providers.current_account_with_tenant",
return_value=(user, "tenant1"),
),
patch(
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
side_effect=Forbidden(),
),
):
with pytest.raises(Forbidden):
method(api, provider="anthropic")
method(api, "tenant1", user, provider="anthropic")

View File

@ -1,3 +1,4 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
@ -172,7 +173,7 @@ class TestModelProviderModelCredentialApi:
provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
lb_service.return_value.get_load_balancing_configs.return_value = (False, [])
result = method(api, "tenant1", "openai")
result = method(api, "tenant1", SimpleNamespace(id="u1"), "openai")
assert "credentials" in result
@ -207,7 +208,7 @@ class TestModelProviderModelCredentialApi:
service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
lb.return_value.get_load_balancing_configs.return_value = (False, [])
result = method(api, "t1", "openai")
result = method(api, "t1", SimpleNamespace(id="u1"), "openai")
assert result["credentials"] == {}