mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:23:44 +08:00
refactor(api): migrate tenant/user via DI for several endpoints (#37004)
This commit is contained in:
parent
9de40e8f21
commit
57b573d02b
@ -3,7 +3,13 @@ from flask_restx import Resource
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from fields.agent_fields import (
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerCandidatesResponse,
|
||||
@ -12,7 +18,7 @@ from fields.agent_fields import (
|
||||
WorkflowAgentComposerResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent.composer_service import AgentComposerService
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
@ -38,8 +44,8 @@ class WorkflowAgentComposerApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App, node_id: str):
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
AgentComposerService.load_workflow_composer(
|
||||
@ -58,8 +64,9 @@ class WorkflowAgentComposerApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def put(self, app_model: App, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def put(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
@ -67,7 +74,7 @@ class WorkflowAgentComposerApi(Resource):
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
@ -113,8 +120,8 @@ class WorkflowAgentComposerImpactApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
|
||||
if not current_snapshot_id:
|
||||
@ -138,8 +145,9 @@ class WorkflowAgentComposerSaveToRosterApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
@ -147,7 +155,7 @@ class WorkflowAgentComposerSaveToRosterApi(Resource):
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
@ -160,8 +168,8 @@ class AgentAppComposerApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
return dump_response(
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id),
|
||||
@ -174,15 +182,16 @@ class AgentAppComposerApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model()
|
||||
def put(self, app_model: App):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def put(self, tenant_id: str, account_id: str, app_model: App):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return dump_response(
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerService.save_agent_app_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
account_id=account.id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
|
||||
@ -6,7 +6,13 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.agent_fields import (
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
@ -16,7 +22,7 @@ from fields.agent_fields import (
|
||||
AgentRosterResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload, RosterListQuery
|
||||
|
||||
@ -58,8 +64,8 @@ class AgentRosterListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
query = RosterListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return dump_response(
|
||||
AgentRosterListResponse,
|
||||
@ -74,11 +80,12 @@ class AgentRosterListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, account_id: str):
|
||||
payload = RosterAgentCreatePayload.model_validate(console_ns.payload or {})
|
||||
service = _agent_roster_service()
|
||||
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account.id, payload=payload)
|
||||
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account_id, payload=payload)
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id),
|
||||
@ -92,8 +99,8 @@ class AgentInviteOptionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
query = AgentInviteOptionsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return dump_response(
|
||||
AgentInviteOptionsResponse,
|
||||
@ -113,8 +120,8 @@ class AgentRosterDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
_agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id)),
|
||||
@ -126,13 +133,14 @@ class AgentRosterDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, agent_id: UUID):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, account_id: str, agent_id: UUID):
|
||||
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
_agent_roster_service().update_roster_agent(
|
||||
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id, payload=payload
|
||||
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id, payload=payload
|
||||
),
|
||||
)
|
||||
|
||||
@ -141,9 +149,10 @@ class AgentRosterDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, agent_id: UUID):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, account_id: str, agent_id: UUID):
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id)
|
||||
return "", 204
|
||||
|
||||
|
||||
@ -153,8 +162,8 @@ class AgentRosterVersionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
return dump_response(
|
||||
AgentConfigSnapshotListResponse,
|
||||
{"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))},
|
||||
@ -167,8 +176,8 @@ class AgentRosterVersionDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID, version_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID, version_id: UUID):
|
||||
return dump_response(
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
_agent_roster_service().get_agent_version_detail(
|
||||
|
||||
@ -25,6 +25,9 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
@ -34,8 +37,8 @@ from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.helper import build_icon_url, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from libs.login import login_required
|
||||
from models import Account, App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
@ -472,10 +475,10 @@ class AppListApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@with_session(write=False)
|
||||
def get(self, session: Session):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user_id: str, session: Session):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
|
||||
params = AppListParams(
|
||||
page=args.page,
|
||||
@ -488,7 +491,7 @@ class AppListApi(Resource):
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params)
|
||||
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
|
||||
if not app_pagination:
|
||||
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json"), 200
|
||||
@ -548,9 +551,10 @@ class AppListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
"""Create app"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||
params = CreateAppParams(
|
||||
name=args.name,
|
||||
@ -653,11 +657,10 @@ class AppCopyApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""Copy app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@ -736,7 +739,8 @@ class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
@ -744,13 +748,11 @@ class AppPublishToCreatorsPlatformApi(Resource):
|
||||
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
return {"error": "Creators Platform features are not enabled"}, 403
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
|
||||
dsl_bytes = dsl_content.encode("utf-8")
|
||||
|
||||
claim_code = upload_dsl(dsl_bytes)
|
||||
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||
redirect_url = get_redirect_url(current_user_id, claim_code)
|
||||
|
||||
return {"redirect_url": redirect_url}
|
||||
|
||||
|
||||
@ -8,14 +8,20 @@ from pydantic import BaseModel, Field
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
@ -52,9 +58,10 @@ class ModelConfigResource(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model: App):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user_id: str, app_model: App):
|
||||
"""Modify app model config"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -64,8 +71,8 @@ class ModelConfigResource(Resource):
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
updated_by=current_user.id,
|
||||
created_by=current_user_id,
|
||||
updated_by=current_user_id,
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
@ -90,7 +97,7 @@ class ModelConfigResource(Resource):
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -130,7 +137,7 @@ class ModelConfigResource(Resource):
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
@ -167,7 +174,7 @@ class ModelConfigResource(Resource):
|
||||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app_model.updated_by = current_user.id
|
||||
app_model.updated_by = current_user_id
|
||||
app_model.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -12,14 +12,19 @@ from controllers.common.fields import SimpleMessageResponse, SimpleResultMessage
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from libs.login import login_required
|
||||
from models import Account, App, InstalledApp, RecommendedApp
|
||||
from models.model import IconType
|
||||
from services.account_service import TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -131,9 +136,10 @@ class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
|
||||
def get(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if query.app_id:
|
||||
installed_apps = db.session.scalars(
|
||||
@ -212,7 +218,8 @@ class InstalledAppsListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleMessageResponse.__name__])
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
recommended_app = db.session.scalar(
|
||||
@ -221,8 +228,6 @@ class InstalledAppsListApi(Resource):
|
||||
if recommended_app is None:
|
||||
raise NotFound("Recommended app not found")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.get(App, payload.app_id)
|
||||
|
||||
if app is None:
|
||||
@ -262,8 +267,8 @@ class InstalledAppApi(InstalledAppResource):
|
||||
"""
|
||||
|
||||
@console_ns.response(204, "App uninstalled successfully")
|
||||
def delete(self, installed_app: InstalledApp):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, installed_app: InstalledApp):
|
||||
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
|
||||
|
||||
@ -8,12 +8,19 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
@ -95,10 +102,8 @@ class ModelProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = ParserModelList.model_validate(payload)
|
||||
|
||||
@ -114,9 +119,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
# if credential_id is not provided, return current used credential
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = ParserCredentialId.model_validate(payload)
|
||||
@ -133,8 +137,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialCreate.model_validate(payload)
|
||||
|
||||
@ -157,9 +161,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialUpdate.model_validate(payload)
|
||||
|
||||
@ -184,8 +187,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialDelete.model_validate(payload)
|
||||
|
||||
@ -205,8 +208,8 @@ class ModelProviderCredentialSwitchApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialSwitch.model_validate(payload)
|
||||
|
||||
@ -225,8 +228,8 @@ class ModelProviderValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialValidate.model_validate(payload)
|
||||
|
||||
@ -280,11 +283,8 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserPreferredProviderType.model_validate(payload)
|
||||
|
||||
@ -301,10 +301,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, provider: str):
|
||||
if provider != "anthropic":
|
||||
raise ValueError(f"provider name {provider} is invalid")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
data = BillingService.get_model_provider_payment_link(
|
||||
provider_name=provider,
|
||||
|
||||
@ -13,12 +13,14 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
@ -269,8 +271,9 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -293,9 +296,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
|
||||
if args.config_from == "predefined-model":
|
||||
# Only the predefined-model branch needs visibility filtering by user.
|
||||
# Defer the auth lookup so the other branch (and its tests) doesn't
|
||||
# require flask-login setup.
|
||||
user, _ = current_account_with_tenant()
|
||||
# The account is injected once by the handler and only passed into the
|
||||
# service branch that needs user-scoped credential visibility.
|
||||
available_credentials = model_provider_service.get_provider_available_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
|
||||
@ -26,7 +26,12 @@ from werkzeug.exceptions import BadRequest
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
@ -42,7 +47,6 @@ from controllers.openapi._models import (
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_CONSOLE,
|
||||
@ -50,6 +54,7 @@ from libs.rate_limit import (
|
||||
LIMIT_LOOKUP_PUBLIC,
|
||||
rate_limit,
|
||||
)
|
||||
from models import Account
|
||||
from services.account_service import TenantService
|
||||
from services.oauth_device_flow import (
|
||||
ACCOUNT_ISSUER_SENTINEL,
|
||||
@ -206,11 +211,12 @@ class DeviceApproveApi(Resource):
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant: str, account: Account):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
account, tenant = current_account_with_tenant()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
found = store.load_by_user_code(user_code)
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Protocol, cast
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.agent import composer as composer_controller
|
||||
from controllers.console.agent import roster as roster_controller
|
||||
@ -114,34 +116,34 @@ def _candidates_response(variant: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
class _PayloadWithDescription(Protocol):
|
||||
description: object
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def account():
|
||||
return SimpleNamespace(id="account-1")
|
||||
def account_id() -> str:
|
||||
return "account-1"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_account_context(monkeypatch, account):
|
||||
monkeypatch.setattr(roster_controller, "current_account_with_tenant", lambda: (account, "tenant-1"))
|
||||
monkeypatch.setattr(composer_controller, "current_account_with_tenant", lambda: (account, "tenant-1"))
|
||||
def test_roster_list_get_parses_query_and_calls_service(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
|
||||
def test_roster_list_get_parses_query_and_calls_service(app, monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def list_roster_agents(_self, **kwargs):
|
||||
def list_roster_agents(_self: object, **kwargs: object) -> dict[str, object]:
|
||||
captured.update(kwargs)
|
||||
return {"data": [], "page": kwargs["page"], "limit": kwargs["limit"], "total": 0, "has_more": False}
|
||||
|
||||
monkeypatch.setattr(roster_controller.AgentRosterService, "list_roster_agents", list_roster_agents)
|
||||
|
||||
with app.test_request_context("/console/api/agents?page=2&limit=5&keyword=analyst"):
|
||||
result = _unwrap(AgentRosterListApi.get)(AgentRosterListApi())
|
||||
result = _unwrap(AgentRosterListApi.get)(AgentRosterListApi(), "tenant-1")
|
||||
|
||||
assert result["page"] == 2
|
||||
assert captured == {"tenant_id": "tenant-1", "page": 2, "limit": 5, "keyword": "analyst"}
|
||||
|
||||
|
||||
def test_roster_list_post_creates_agent_and_returns_detail(app, monkeypatch):
|
||||
def test_roster_list_post_creates_agent_and_returns_detail(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str
|
||||
) -> None:
|
||||
created_agent = SimpleNamespace(id="agent-1")
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
@ -155,42 +157,47 @@ def test_roster_list_post_creates_agent_and_returns_detail(app, monkeypatch):
|
||||
)
|
||||
|
||||
with app.test_request_context(json={"name": "Analyst", "agent_soul": {"prompt": {"system_prompt": "x"}}}):
|
||||
result, status = _unwrap(AgentRosterListApi.post)(AgentRosterListApi())
|
||||
result, status = _unwrap(AgentRosterListApi.post)(AgentRosterListApi(), "tenant-1", account_id)
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "agent-1"
|
||||
assert result["agent_kind"] == "dify_agent"
|
||||
|
||||
|
||||
def test_invite_options_get_parses_app_id(app, monkeypatch):
|
||||
captured = {}
|
||||
def test_invite_options_get_parses_app_id(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def list_invite_options(_self, **kwargs):
|
||||
def list_invite_options(_self: object, **kwargs: object) -> dict[str, object]:
|
||||
captured.update(kwargs)
|
||||
return {"data": [], "page": kwargs["page"], "limit": kwargs["limit"], "total": 0, "has_more": False}
|
||||
|
||||
monkeypatch.setattr(roster_controller.AgentRosterService, "list_invite_options", list_invite_options)
|
||||
|
||||
with app.test_request_context("/console/api/agents/invite-options?page=1&limit=10&app_id=app-1"):
|
||||
result = _unwrap(AgentInviteOptionsApi.get)(AgentInviteOptionsApi())
|
||||
result = _unwrap(AgentInviteOptionsApi.get)(AgentInviteOptionsApi(), "tenant-1")
|
||||
|
||||
assert result == {"data": [], "page": 1, "limit": 10, "total": 0, "has_more": False}
|
||||
assert captured == {"tenant_id": "tenant-1", "page": 1, "limit": 10, "keyword": None, "app_id": "app-1"}
|
||||
|
||||
|
||||
def test_roster_detail_patch_delete_and_versions_call_services(app, monkeypatch):
|
||||
def test_roster_detail_patch_delete_and_versions_call_services(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str
|
||||
) -> None:
|
||||
agent_id = "00000000-0000-0000-0000-000000000001"
|
||||
version_id = "00000000-0000-0000-0000-000000000002"
|
||||
archived = {}
|
||||
archived: dict[str, object] = {}
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
"get_roster_agent_detail",
|
||||
lambda _self, **kwargs: _agent_response(kwargs["agent_id"]),
|
||||
lambda _self, **kwargs: _agent_response(cast(str, kwargs["agent_id"])),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
"update_roster_agent",
|
||||
lambda _self, **kwargs: {**_agent_response(kwargs["agent_id"]), "description": kwargs["payload"].description},
|
||||
lambda _self, **kwargs: {
|
||||
**_agent_response(cast(str, kwargs["agent_id"])),
|
||||
"description": cast(_PayloadWithDescription, kwargs["payload"]).description,
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
roster_controller.AgentRosterService,
|
||||
@ -206,7 +213,7 @@ def test_roster_detail_patch_delete_and_versions_call_services(app, monkeypatch)
|
||||
roster_controller.AgentRosterService,
|
||||
"get_agent_version_detail",
|
||||
lambda _self, **kwargs: {
|
||||
**_version_response(kwargs["version_id"]),
|
||||
**_version_response(cast(str, kwargs["version_id"])),
|
||||
"agent_id": kwargs["agent_id"],
|
||||
"config_snapshot": {},
|
||||
"revisions": [
|
||||
@ -225,18 +232,28 @@ def test_roster_detail_patch_delete_and_versions_call_services(app, monkeypatch)
|
||||
},
|
||||
)
|
||||
|
||||
assert _unwrap(AgentRosterDetailApi.get)(AgentRosterDetailApi(), agent_id)["id"] == agent_id
|
||||
assert _unwrap(AgentRosterDetailApi.get)(AgentRosterDetailApi(), "tenant-1", agent_id)["id"] == agent_id
|
||||
with app.test_request_context(json={"description": "updated"}):
|
||||
assert _unwrap(AgentRosterDetailApi.patch)(AgentRosterDetailApi(), agent_id)["description"] == "updated"
|
||||
assert _unwrap(AgentRosterDetailApi.delete)(AgentRosterDetailApi(), agent_id) == ("", 204)
|
||||
assert (
|
||||
_unwrap(AgentRosterDetailApi.patch)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id)["description"]
|
||||
== "updated"
|
||||
)
|
||||
assert _unwrap(AgentRosterDetailApi.delete)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id) == ("", 204)
|
||||
assert archived["account_id"] == "account-1"
|
||||
assert _unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), agent_id)["data"][0]["id"] == "version-1"
|
||||
version_detail = _unwrap(AgentRosterVersionDetailApi.get)(AgentRosterVersionDetailApi(), agent_id, version_id)
|
||||
assert (
|
||||
_unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), "tenant-1", agent_id)["data"][0]["id"]
|
||||
== "version-1"
|
||||
)
|
||||
version_detail = _unwrap(AgentRosterVersionDetailApi.get)(
|
||||
AgentRosterVersionDetailApi(), "tenant-1", agent_id, version_id
|
||||
)
|
||||
assert version_detail["id"] == version_id
|
||||
assert version_detail["agent_id"] == agent_id
|
||||
|
||||
|
||||
def test_workflow_composer_get_put_validate_candidates_impact_and_save(app, monkeypatch):
|
||||
def test_workflow_composer_get_put_validate_candidates_impact_and_save(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str
|
||||
) -> None:
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
payload = {
|
||||
"variant": ComposerVariant.WORKFLOW.value,
|
||||
@ -269,10 +286,12 @@ def test_workflow_composer_get_put_validate_candidates_impact_and_save(app, monk
|
||||
},
|
||||
)
|
||||
|
||||
workflow_state = _unwrap(WorkflowAgentComposerApi.get)(WorkflowAgentComposerApi(), app_model, "node-1")
|
||||
workflow_state = _unwrap(WorkflowAgentComposerApi.get)(WorkflowAgentComposerApi(), "tenant-1", app_model, "node-1")
|
||||
assert workflow_state["node_id"] == "node-1"
|
||||
with app.test_request_context(json=payload):
|
||||
saved_state = _unwrap(WorkflowAgentComposerApi.put)(WorkflowAgentComposerApi(), app_model, "node-1")
|
||||
saved_state = _unwrap(WorkflowAgentComposerApi.put)(
|
||||
WorkflowAgentComposerApi(), "tenant-1", account_id, app_model, "node-1"
|
||||
)
|
||||
assert saved_state["save_options"] == ["node_job_only"]
|
||||
assert _unwrap(WorkflowAgentComposerValidateApi.post)(
|
||||
WorkflowAgentComposerValidateApi(), app_model, "node-1"
|
||||
@ -284,28 +303,28 @@ def test_workflow_composer_get_put_validate_candidates_impact_and_save(app, monk
|
||||
== "workflow"
|
||||
)
|
||||
with app.test_request_context(json=payload):
|
||||
assert _unwrap(WorkflowAgentComposerImpactApi.post)(WorkflowAgentComposerImpactApi(), app_model, "node-1") == {
|
||||
"current_snapshot_id": "version-1",
|
||||
"workflow_node_count": 1,
|
||||
"bindings": [],
|
||||
}
|
||||
assert _unwrap(WorkflowAgentComposerImpactApi.post)(
|
||||
WorkflowAgentComposerImpactApi(), "tenant-1", app_model, "node-1"
|
||||
) == {"current_snapshot_id": "version-1", "workflow_node_count": 1, "bindings": []}
|
||||
assert _unwrap(WorkflowAgentComposerSaveToRosterApi.post)(
|
||||
WorkflowAgentComposerSaveToRosterApi(), app_model, "node-1"
|
||||
WorkflowAgentComposerSaveToRosterApi(), "tenant-1", account_id, app_model, "node-1"
|
||||
)["save_options"] == ["node_job_only"]
|
||||
|
||||
|
||||
def test_workflow_impact_returns_empty_without_version(app):
|
||||
def test_workflow_impact_returns_empty_without_version(app: Flask) -> None:
|
||||
payload = {"variant": ComposerVariant.WORKFLOW.value, "save_strategy": ComposerSaveStrategy.NODE_JOB_ONLY.value}
|
||||
|
||||
with app.test_request_context(json=payload):
|
||||
result = _unwrap(WorkflowAgentComposerImpactApi.post)(
|
||||
WorkflowAgentComposerImpactApi(), SimpleNamespace(id="app-1"), "node-1"
|
||||
WorkflowAgentComposerImpactApi(), "tenant-1", SimpleNamespace(id="app-1"), "node-1"
|
||||
)
|
||||
|
||||
assert result == {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
|
||||
|
||||
|
||||
def test_agent_app_composer_get_put_validate_and_candidates(app, monkeypatch):
|
||||
def test_agent_app_composer_get_put_validate_and_candidates(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str
|
||||
) -> None:
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
payload = {
|
||||
"variant": ComposerVariant.AGENT_APP.value,
|
||||
@ -329,9 +348,12 @@ def test_agent_app_composer_get_put_validate_and_candidates(app, monkeypatch):
|
||||
lambda **kwargs: _candidates_response("agent_app"),
|
||||
)
|
||||
|
||||
assert _unwrap(AgentAppComposerApi.get)(AgentAppComposerApi(), app_model)["variant"] == "agent_app"
|
||||
assert _unwrap(AgentAppComposerApi.get)(AgentAppComposerApi(), "tenant-1", app_model)["variant"] == "agent_app"
|
||||
with app.test_request_context(json=payload):
|
||||
assert _unwrap(AgentAppComposerApi.put)(AgentAppComposerApi(), app_model)["variant"] == "agent_app"
|
||||
assert (
|
||||
_unwrap(AgentAppComposerApi.put)(AgentAppComposerApi(), "tenant-1", account_id, app_model)["variant"]
|
||||
== "agent_app"
|
||||
)
|
||||
assert _unwrap(AgentAppComposerValidateApi.post)(AgentAppComposerValidateApi(), app_model) == {
|
||||
"result": "success",
|
||||
"errors": [],
|
||||
|
||||
@ -6,10 +6,10 @@ from datetime import datetime
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.datastructures import MultiDict
|
||||
@ -19,6 +19,13 @@ if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class _ConsoleModule(ModuleType):
|
||||
console_ns: object
|
||||
api: object | None
|
||||
bp: object | None
|
||||
app: ModuleType
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
@ -36,7 +43,7 @@ def app_module():
|
||||
|
||||
class _StubNamespace:
|
||||
def __init__(self):
|
||||
self.models: dict[str, Any] = {}
|
||||
self.models: dict[str, object] = {}
|
||||
self.payload = None
|
||||
|
||||
def schema_model(self, name, schema):
|
||||
@ -77,7 +84,7 @@ def app_module():
|
||||
}
|
||||
stubbed_modules: list[tuple[str, ModuleType | None]] = []
|
||||
|
||||
console_module = ModuleType("controllers.console")
|
||||
console_module = _ConsoleModule("controllers.console")
|
||||
console_module.__path__ = [str(root / "controllers" / "console")]
|
||||
console_module.console_ns = stub_namespace
|
||||
console_module.api = None
|
||||
@ -89,7 +96,7 @@ def app_module():
|
||||
sys.modules["controllers.console.app"] = app_package
|
||||
console_module.app = app_package
|
||||
|
||||
def _stub_module(name: str, attrs: dict[str, Any]):
|
||||
def _stub_module(name: str, attrs: dict[str, object]) -> None:
|
||||
original = sys.modules.get(name)
|
||||
module = ModuleType(name)
|
||||
for key, value in attrs.items():
|
||||
@ -99,7 +106,7 @@ def app_module():
|
||||
|
||||
class _OpsTraceManager:
|
||||
@staticmethod
|
||||
def get_app_tracing_config(app_id: str) -> dict[str, Any]:
|
||||
def get_app_tracing_config(app_id: str) -> dict[str, object]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
@ -116,6 +123,7 @@ def app_module():
|
||||
)
|
||||
|
||||
spec = util.spec_from_file_location(module_name, module_path)
|
||||
assert spec is not None
|
||||
module = util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
|
||||
@ -147,7 +155,7 @@ def app_models(app_module):
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_signed_url(monkeypatch, app_module):
|
||||
def patch_signed_url(monkeypatch: pytest.MonkeyPatch, app_module: ModuleType) -> None:
|
||||
"""Ensure icon URL generation uses a deterministic helper for tests."""
|
||||
|
||||
def _fake_build_icon_url(_icon_type, key: str | None) -> str | None:
|
||||
@ -407,10 +415,11 @@ def test_app_pagination_aliases_per_page_and_has_next(app_models):
|
||||
assert serialized["data"][1]["icon_url"] is None
|
||||
|
||||
|
||||
def test_app_list_uses_injected_session_for_draft_workflows(app, app_module, monkeypatch):
|
||||
def test_app_list_uses_injected_session_for_draft_workflows(
|
||||
app: Flask, app_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
api = app_module.AppListApi()
|
||||
method = _unwrap(api.get)
|
||||
current_user = SimpleNamespace(id="user-1")
|
||||
app_item = SimpleNamespace(
|
||||
id="app-1",
|
||||
name="Workflow App",
|
||||
@ -428,7 +437,6 @@ def test_app_list_uses_injected_session_for_draft_workflows(app, app_module, mon
|
||||
session.execute.return_value.scalars.return_value.all.return_value = [workflow]
|
||||
scoped_session = SimpleNamespace(execute=MagicMock(side_effect=AssertionError("db.session should not be used")))
|
||||
|
||||
monkeypatch.setattr(app_module, "current_account_with_tenant", lambda: (current_user, "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
app_module,
|
||||
"AppService",
|
||||
@ -442,7 +450,7 @@ def test_app_list_uses_injected_session_for_draft_workflows(app, app_module, mon
|
||||
monkeypatch.setattr(app_module, "db", SimpleNamespace(session=scoped_session))
|
||||
|
||||
with app.test_request_context("/console/api/apps?page=1&limit=20", method="GET"):
|
||||
response, status = method(session)
|
||||
response, status = method("tenant-1", "user-1", session)
|
||||
|
||||
assert status == 200
|
||||
assert response["data"][0]["has_draft_trigger"] is True
|
||||
|
||||
@ -5,6 +5,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.app import model_config as model_config_module
|
||||
from models.model import AppMode, AppModelConfig
|
||||
@ -19,7 +20,7 @@ def _unwrap(func):
|
||||
return func
|
||||
|
||||
|
||||
def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_post_updates_app_model_config_for_chat(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = model_config_module.ModelConfigResource()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -36,8 +37,6 @@ def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyP
|
||||
"validate_configuration",
|
||||
lambda **_kwargs: {"pre_prompt": "hi"},
|
||||
)
|
||||
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
session = MagicMock()
|
||||
monkeypatch.setattr(model_config_module.db, "session", session)
|
||||
|
||||
@ -51,7 +50,7 @@ def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyP
|
||||
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
|
||||
response = method(app_model=app_model)
|
||||
response = method("t1", "u1", app_model=app_model)
|
||||
|
||||
session.add.assert_called_once()
|
||||
session.flush.assert_called_once()
|
||||
@ -61,7 +60,7 @@ def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyP
|
||||
assert response["result"] == "success"
|
||||
|
||||
|
||||
def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_post_encrypts_agent_tool_parameters(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = model_config_module.ModelConfigResource()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
@ -115,8 +114,6 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc
|
||||
},
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
monkeypatch.setattr(model_config_module.ToolManager, "get_agent_tool_runtime", lambda **_kwargs: object())
|
||||
|
||||
class _ParamManager:
|
||||
@ -140,7 +137,7 @@ def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatc
|
||||
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
|
||||
response = method(app_model=app_model)
|
||||
response = method("t1", "u1", app_model=app_model)
|
||||
|
||||
stored_config = session.add.call_args[0][0]
|
||||
stored_agent_mode = json.loads(stored_config.agent_mode)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
@ -7,6 +9,9 @@ from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
import controllers.console.explore.installed_app as module
|
||||
|
||||
type Payload = dict[str, object]
|
||||
type PayloadPatch = Callable[[Payload], AbstractContextManager[object]]
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
@ -15,12 +20,12 @@ def unwrap(func):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_id():
|
||||
def tenant_id() -> str:
|
||||
return "t1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user(tenant_id):
|
||||
def current_user(tenant_id: str) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.current_tenant = MagicMock(id=tenant_id)
|
||||
@ -28,7 +33,7 @@ def current_user(tenant_id):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def installed_app():
|
||||
def installed_app() -> MagicMock:
|
||||
app = MagicMock()
|
||||
app.id = "ia1"
|
||||
app.app = MagicMock(id="a1")
|
||||
@ -39,8 +44,8 @@ def installed_app():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def payload_patch():
|
||||
def _patch(payload):
|
||||
def payload_patch() -> PayloadPatch:
|
||||
def _patch(payload: Payload) -> AbstractContextManager[object]:
|
||||
return patch.object(
|
||||
type(module.console_ns),
|
||||
"payload",
|
||||
@ -52,7 +57,9 @@ def payload_patch():
|
||||
|
||||
|
||||
class TestInstalledAppsListApi:
|
||||
def test_get_installed_apps(self, app: Flask, current_user, tenant_id, installed_app):
|
||||
def test_get_installed_apps(
|
||||
self, app: Flask, current_user: MagicMock, tenant_id: str, installed_app: MagicMock
|
||||
) -> None:
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -61,7 +68,6 @@ class TestInstalledAppsListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="owner"),
|
||||
patch.object(
|
||||
@ -70,13 +76,13 @@ class TestInstalledAppsListApi:
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, tenant_id, current_user)
|
||||
|
||||
assert "installed_apps" in result
|
||||
assert result["installed_apps"][0]["editable"] is True
|
||||
assert result["installed_apps"][0]["uninstallable"] is False
|
||||
|
||||
def test_get_installed_apps_with_app_id_filter(self, app: Flask, current_user, tenant_id):
|
||||
def test_get_installed_apps_with_app_id_filter(self, app: Flask, current_user: MagicMock, tenant_id: str) -> None:
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -85,7 +91,6 @@ class TestInstalledAppsListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?app_id=a1"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="member"),
|
||||
patch.object(
|
||||
@ -94,11 +99,13 @@ class TestInstalledAppsListApi:
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, tenant_id, current_user)
|
||||
|
||||
assert result == {"installed_apps": []}
|
||||
|
||||
def test_get_installed_apps_with_webapp_auth_enabled(self, app: Flask, current_user, tenant_id, installed_app):
|
||||
def test_get_installed_apps_with_webapp_auth_enabled(
|
||||
self, app: Flask, current_user: MagicMock, tenant_id: str, installed_app: MagicMock
|
||||
) -> None:
|
||||
"""Test filtering when webapp_auth is enabled."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
@ -111,7 +118,6 @@ class TestInstalledAppsListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="owner"),
|
||||
patch.object(
|
||||
@ -130,11 +136,13 @@ class TestInstalledAppsListApi:
|
||||
return_value={"a1": True},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, tenant_id, current_user)
|
||||
|
||||
assert len(result["installed_apps"]) == 1
|
||||
|
||||
def test_get_installed_apps_with_webapp_auth_user_denied(self, app: Flask, current_user, tenant_id, installed_app):
|
||||
def test_get_installed_apps_with_webapp_auth_user_denied(
|
||||
self, app: Flask, current_user: MagicMock, tenant_id: str, installed_app: MagicMock
|
||||
) -> None:
|
||||
"""Test filtering when user doesn't have access."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
@ -147,7 +155,6 @@ class TestInstalledAppsListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="member"),
|
||||
patch.object(
|
||||
@ -166,11 +173,13 @@ class TestInstalledAppsListApi:
|
||||
return_value={"a1": False},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, tenant_id, current_user)
|
||||
|
||||
assert result["installed_apps"] == []
|
||||
|
||||
def test_get_installed_apps_with_sso_verified_access(self, app: Flask, current_user, tenant_id, installed_app):
|
||||
def test_get_installed_apps_with_sso_verified_access(
|
||||
self, app: Flask, current_user: MagicMock, tenant_id: str, installed_app: MagicMock
|
||||
) -> None:
|
||||
"""Test that sso_verified access mode apps are skipped in filtering."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
@ -183,7 +192,6 @@ class TestInstalledAppsListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="owner"),
|
||||
patch.object(
|
||||
@ -197,11 +205,11 @@ class TestInstalledAppsListApi:
|
||||
return_value={"a1": mock_webapp_setting},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, tenant_id, current_user)
|
||||
|
||||
assert len(result["installed_apps"]) == 0
|
||||
|
||||
def test_get_installed_apps_filters_null_apps(self, app: Flask, current_user, tenant_id):
|
||||
def test_get_installed_apps_filters_null_apps(self, app: Flask, current_user: MagicMock, tenant_id: str) -> None:
|
||||
"""Test that installed apps with null app are filtered out."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
@ -214,7 +222,6 @@ class TestInstalledAppsListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module.TenantService, "get_user_role", return_value="owner"),
|
||||
patch.object(
|
||||
@ -223,11 +230,11 @@ class TestInstalledAppsListApi:
|
||||
return_value=MagicMock(webapp_auth=MagicMock(enabled=False)),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, tenant_id, current_user)
|
||||
|
||||
assert result["installed_apps"] == []
|
||||
|
||||
def test_get_installed_apps_current_tenant_none(self, app: Flask, tenant_id, installed_app):
|
||||
def test_get_installed_apps_current_tenant_none(self, app: Flask, tenant_id: str, installed_app: MagicMock) -> None:
|
||||
"""Test error when current_user.current_tenant is None."""
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.get)
|
||||
@ -240,15 +247,14 @@ class TestInstalledAppsListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)),
|
||||
patch.object(module.db, "session", session),
|
||||
):
|
||||
with pytest.raises(ValueError, match="current_user.current_tenant must not be None"):
|
||||
method(api)
|
||||
method(api, tenant_id, current_user)
|
||||
|
||||
|
||||
class TestInstalledAppsCreateApi:
|
||||
def test_post_success(self, app: Flask, tenant_id, payload_patch):
|
||||
def test_post_success(self, app: Flask, tenant_id: str, payload_patch: PayloadPatch) -> None:
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -270,14 +276,13 @@ class TestInstalledAppsCreateApi:
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
payload_patch({"app_id": "a1"}),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, tenant_id)
|
||||
|
||||
assert result == {"message": "App installed successfully"}
|
||||
assert recommended.install_count == 1
|
||||
|
||||
def test_post_recommended_not_found(self, app: Flask, payload_patch):
|
||||
def test_post_recommended_not_found(self, app: Flask, tenant_id: str, payload_patch: PayloadPatch) -> None:
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -290,9 +295,9 @@ class TestInstalledAppsCreateApi:
|
||||
patch.object(module.db, "session", session),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api)
|
||||
method(api, tenant_id)
|
||||
|
||||
def test_post_app_not_public(self, app: Flask, tenant_id, payload_patch):
|
||||
def test_post_app_not_public(self, app: Flask, tenant_id: str, payload_patch: PayloadPatch) -> None:
|
||||
api = module.InstalledAppsListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -309,37 +314,32 @@ class TestInstalledAppsCreateApi:
|
||||
app.test_request_context("/", json={"app_id": "a1"}),
|
||||
payload_patch({"app_id": "a1"}),
|
||||
patch.object(module.db, "session", session),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, tenant_id)
|
||||
|
||||
|
||||
class TestInstalledAppApi:
|
||||
def test_delete_success(self, tenant_id: str, installed_app):
|
||||
def test_delete_success(self, tenant_id: str, installed_app: MagicMock) -> None:
|
||||
api = module.InstalledAppApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)),
|
||||
patch.object(module.db, "session"),
|
||||
):
|
||||
resp, status = method(installed_app)
|
||||
with patch.object(module.db, "session"):
|
||||
resp, status = method(api, tenant_id, installed_app)
|
||||
|
||||
assert status == 204
|
||||
assert resp == ""
|
||||
|
||||
def test_delete_owned_by_current_tenant(self, tenant_id: str):
|
||||
def test_delete_owned_by_current_tenant(self, tenant_id: str) -> None:
|
||||
api = module.InstalledAppApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
installed_app = MagicMock(app_owner_tenant_id=tenant_id)
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)):
|
||||
with pytest.raises(BadRequest):
|
||||
method(installed_app)
|
||||
with pytest.raises(BadRequest):
|
||||
method(api, tenant_id, installed_app)
|
||||
|
||||
def test_patch_update_pin(self, app: Flask, payload_patch, installed_app):
|
||||
def test_patch_update_pin(self, app: Flask, payload_patch: PayloadPatch, installed_app: MagicMock) -> None:
|
||||
api = module.InstalledAppApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@ -353,7 +353,7 @@ class TestInstalledAppApi:
|
||||
assert installed_app.is_pinned is True
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_patch_no_change(self, app: Flask, payload_patch, installed_app):
|
||||
def test_patch_no_change(self, app: Flask, payload_patch: PayloadPatch, installed_app: MagicMock) -> None:
|
||||
api = module.InstalledAppApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
|
||||
@ -33,16 +33,12 @@ class TestModelProviderListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?model_type=llm"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_list",
|
||||
return_value=[{"name": "openai"}],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "tenant1")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
@ -54,16 +50,12 @@ class TestModelProviderCredentialApi:
|
||||
|
||||
with (
|
||||
app.test_request_context(f"/?credential_id={VALID_UUID}"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_credential",
|
||||
return_value={"key": "value"},
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
result = method(api, "tenant1", provider="openai")
|
||||
|
||||
assert "credentials" in result
|
||||
|
||||
@ -71,15 +63,9 @@ class TestModelProviderCredentialApi:
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context(f"/?credential_id={INVALID_UUID}"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context(f"/?credential_id={INVALID_UUID}"):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
method(api, "tenant1", provider="openai")
|
||||
|
||||
def test_post_create_success(self, app: Flask):
|
||||
api = ModelProviderCredentialApi()
|
||||
@ -89,16 +75,12 @@ class TestModelProviderCredentialApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result, status = method(api, provider="openai")
|
||||
result, status = method(api, "tenant1", provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert status == 201
|
||||
@ -111,17 +93,13 @@ class TestModelProviderCredentialApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
|
||||
side_effect=CredentialsValidateFailedError("bad"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, provider="openai")
|
||||
method(api, "tenant1", provider="openai")
|
||||
|
||||
def test_put_update_success(self, app: Flask):
|
||||
api = ModelProviderCredentialApi()
|
||||
@ -131,16 +109,12 @@ class TestModelProviderCredentialApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.update_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
result = method(api, "tenant1", provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@ -150,15 +124,9 @@ class TestModelProviderCredentialApi:
|
||||
|
||||
payload = {"credential_id": INVALID_UUID, "credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/", json=payload):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
method(api, "tenant1", provider="openai")
|
||||
|
||||
def test_delete_success(self, app: Flask):
|
||||
api = ModelProviderCredentialApi()
|
||||
@ -168,16 +136,12 @@ class TestModelProviderCredentialApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.remove_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result, status = method(api, provider="openai")
|
||||
result, status = method(api, "tenant1", provider="openai")
|
||||
|
||||
assert status == 204
|
||||
assert result == ""
|
||||
@ -192,16 +156,12 @@ class TestModelProviderCredentialSwitchApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.switch_active_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
result = method(api, "tenant1", provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@ -211,15 +171,9 @@ class TestModelProviderCredentialSwitchApi:
|
||||
|
||||
payload = {"credential_id": INVALID_UUID}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/", json=payload):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
method(api, "tenant1", provider="openai")
|
||||
|
||||
|
||||
class TestModelProviderValidateApi:
|
||||
@ -231,16 +185,12 @@ class TestModelProviderValidateApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
result = method(api, "tenant1", provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@ -252,16 +202,12 @@ class TestModelProviderValidateApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
|
||||
side_effect=CredentialsValidateFailedError("bad"),
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
result = method(api, "tenant1", provider="openai")
|
||||
|
||||
assert result["result"] == "error"
|
||||
|
||||
@ -304,16 +250,12 @@ class TestPreferredProviderTypeUpdateApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.switch_preferred_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
result = method(api, "tenant1", provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@ -323,15 +265,9 @@ class TestPreferredProviderTypeUpdateApi:
|
||||
|
||||
payload = {"preferred_provider_type": "invalid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/", json=payload):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
method(api, "tenant1", provider="openai")
|
||||
|
||||
|
||||
class TestModelProviderPaymentCheckoutUrlApi:
|
||||
@ -343,10 +279,6 @@ class TestModelProviderPaymentCheckoutUrlApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(user, "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
|
||||
return_value=None,
|
||||
@ -356,7 +288,7 @@ class TestModelProviderPaymentCheckoutUrlApi:
|
||||
return_value={"url": "x"},
|
||||
),
|
||||
):
|
||||
result = method(api, provider="anthropic")
|
||||
result = method(api, "tenant1", user, provider="anthropic")
|
||||
|
||||
assert "url" in result
|
||||
|
||||
@ -366,7 +298,7 @@ class TestModelProviderPaymentCheckoutUrlApi:
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, provider="openai")
|
||||
method(api, "tenant1", MagicMock(), provider="openai")
|
||||
|
||||
def test_permission_denied(self, app: Flask):
|
||||
api = ModelProviderPaymentCheckoutUrlApi()
|
||||
@ -376,14 +308,10 @@ class TestModelProviderPaymentCheckoutUrlApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(user, "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
|
||||
side_effect=Forbidden(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, provider="anthropic")
|
||||
method(api, "tenant1", user, provider="anthropic")
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -172,7 +173,7 @@ class TestModelProviderModelCredentialApi:
|
||||
provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
|
||||
lb_service.return_value.get_load_balancing_configs.return_value = (False, [])
|
||||
|
||||
result = method(api, "tenant1", "openai")
|
||||
result = method(api, "tenant1", SimpleNamespace(id="u1"), "openai")
|
||||
|
||||
assert "credentials" in result
|
||||
|
||||
@ -207,7 +208,7 @@ class TestModelProviderModelCredentialApi:
|
||||
service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
|
||||
lb.return_value.get_load_balancing_configs.return_value = (False, [])
|
||||
|
||||
result = method(api, "t1", "openai")
|
||||
result = method(api, "t1", SimpleNamespace(id="u1"), "openai")
|
||||
|
||||
assert result["credentials"] == {}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user