refactor(api): migrate tenant/user via DI for several endpoints (#37240)

This commit is contained in:
chariri 2026-06-10 13:11:53 +09:00 committed by GitHub
parent dad2e64a62
commit d849d60822
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 989 additions and 1470 deletions

View File

@ -24,9 +24,9 @@ from controllers.common.schema import (
)
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.model import App, AppMode
from services.agent_app_workspace_service import (
AgentAppWorkspaceService,
@ -142,8 +142,8 @@ class AgentAppWorkspaceListResource(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App):
query = query_params_from_request(AgentWorkspaceListQuery)
try:
result = AgentAppWorkspaceService().list_files(
@ -167,8 +167,8 @@ class AgentAppWorkspacePreviewResource(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App):
query = query_params_from_request(AgentWorkspaceFileQuery)
try:
result = AgentAppWorkspaceService().preview(
@ -194,8 +194,8 @@ class AgentAppWorkspaceDownloadResource(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App):
query = query_params_from_request(AgentWorkspaceFileQuery)
try:
result = AgentAppWorkspaceService().download(
@ -228,8 +228,8 @@ class WorkflowAgentWorkspaceListResource(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App, workflow_run_id: UUID, node_id: str):
query = query_params_from_request(WorkflowAgentWorkspaceListQuery)
try:
result = WorkflowAgentWorkspaceService().list_files(
@ -264,8 +264,8 @@ class WorkflowAgentWorkspacePreviewResource(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App, workflow_run_id: UUID, node_id: str):
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
try:
result = WorkflowAgentWorkspaceService().preview(
@ -302,8 +302,8 @@ class WorkflowAgentWorkspaceDownloadResource(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App, workflow_run_id: UUID, node_id: str):
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
try:
result = WorkflowAgentWorkspaceService().download(

View File

@ -25,6 +25,7 @@ from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
@ -58,7 +59,7 @@ from graphon.variables import SecretVariable, SegmentType, VariableBase
from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, dump_response, to_timestamp, uuid_value
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account, App
from models.model import AppMode
from models.workflow import Workflow
@ -1568,7 +1569,8 @@ class WorkflowOnlineUsersApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
@with_current_tenant_id
def post(self, current_tenant_id: str):
args = WorkflowOnlineUsersPayload.model_validate(console_ns.payload or {})
app_ids = args.app_ids
@ -1578,7 +1580,6 @@ class WorkflowOnlineUsersApi(Resource):
if not app_ids:
return {"data": []}
_, current_tenant_id = current_account_with_tenant()
workflow_service = WorkflowService()
accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id)
ordered_accessible_app_ids = [app_id for app_id in app_ids if app_id in accessible_app_ids]

View File

@ -22,6 +22,8 @@ from controllers.console.wraps import (
enterprise_license_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
@ -36,9 +38,9 @@ from fields.base import ResponseModel
from fields.dataset_fields import DatasetDetailResponse
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import build_icon_url, dump_response, to_timestamp
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from libs.url_utils import normalize_api_base_url
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models import Account, ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import ApiTokenType, SegmentStatus
from models.provider_ids import ModelProviderID
@ -389,8 +391,9 @@ class DatasetListApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
# Convert query parameters to dict, handling list parameters correctly
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
# Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value)
@ -471,9 +474,10 @@ class DatasetListApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
payload = DatasetCreatePayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@ -512,8 +516,9 @@ class DatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -566,14 +571,15 @@ class DatasetApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id: UUID):
@with_current_user
@with_current_tenant_id
def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY
@ -614,9 +620,9 @@ class DatasetApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Dataset deleted successfully")
def delete(self, dataset_id: UUID):
@with_current_user
def delete(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
current_user, _ = current_account_with_tenant()
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
@ -664,8 +670,8 @@ class DatasetQueryApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -704,10 +710,10 @@ class DatasetIndexingEstimateApi(Resource):
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__])
def post(self):
@with_current_tenant_id
def post(self, current_tenant_id: str):
payload = IndexingEstimatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
_, current_tenant_id = current_account_with_tenant()
# validate args
DocumentService.estimate_args_validate(args)
extract_settings = []
@ -804,8 +810,8 @@ class DatasetRelatedAppListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -840,8 +846,8 @@ class DatasetIndexingStatusApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str, dataset_id: UUID):
dataset_id_str = str(dataset_id)
documents = db.session.scalars(
select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id)
@ -898,8 +904,8 @@ class DatasetApiKeyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str):
keys = db.session.scalars(
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
).all()
@ -911,9 +917,8 @@ class DatasetApiKeyApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str):
current_key_count = (
db.session.scalar(
select(func.count(ApiToken.id)).where(
@ -952,8 +957,8 @@ class DatasetApiDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, api_key_id: UUID):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def delete(self, current_tenant_id: str, api_key_id: UUID):
api_key_id_str = str(api_key_id)
key = db.session.scalar(
select(ApiToken)
@ -1079,8 +1084,8 @@ class DatasetPermissionUserListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -29,6 +29,8 @@ from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -46,7 +48,7 @@ from fields.workflow_run_fields import (
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs import helper
from libs.helper import TimestampField, UUIDStrOrEmpty, dump_response
from libs.login import current_account_with_tenant, current_user, login_required
from libs.login import current_user, login_required
from models import Account
from models.dataset import Pipeline
from models.model import EndUser
@ -187,16 +189,14 @@ class DraftRagPipelineApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@get_rag_pipeline
@edit_permission_required
@console_ns.response(200, "Success", console_ns.models[RagPipelineWorkflowSyncResponse.__name__])
def post(self, pipeline: Pipeline):
def post(self, current_user: Account, pipeline: Pipeline):
"""
Sync draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type:
@ -247,15 +247,13 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline, node_id: str):
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
"""
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
payload = NodeRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
@ -283,14 +281,12 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
"""
Run draft workflow loop node
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
payload = NodeRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
@ -318,14 +314,12 @@ class DraftRagPipelineRunApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user
@get_rag_pipeline
def post(self, pipeline: Pipeline):
def post(self, current_user: Account, pipeline: Pipeline):
"""
Run draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
@ -350,14 +344,12 @@ class PublishedRagPipelineRunApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user
@get_rag_pipeline
def post(self, pipeline: Pipeline):
def post(self, current_user: Account, pipeline: Pipeline):
"""
Run published workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = payload.response_mode == "streaming"
@ -383,14 +375,12 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
"""
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
@ -416,14 +406,12 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
@with_current_user
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
"""
Run rag pipeline datasource
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
rag_pipeline_service = RagPipelineService()
@ -454,14 +442,12 @@ class RagPipelineDraftNodeRunApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
@with_current_user
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
"""
Run draft workflow node
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {})
inputs = payload.inputs
@ -485,14 +471,12 @@ class RagPipelineTaskStopApi(Resource):
@login_required
@edit_permission_required
@account_initialization_required
@with_current_user
@get_rag_pipeline
def post(self, pipeline: Pipeline, task_id: str):
def post(self, current_user: Account, pipeline: Pipeline, task_id: str):
"""
Stop workflow task
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
return {"result": "success"}
@ -532,13 +516,12 @@ class PublishedRagPipelineApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user
@get_rag_pipeline
def post(self, pipeline: Pipeline):
def post(self, current_user: Account, pipeline: Pipeline):
"""
Publish workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.publish_workflow(
session=db.session, # type: ignore[reportArgumentType,arg-type]
@ -609,13 +592,12 @@ class PublishedAllRagPipelineApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user
@get_rag_pipeline
def get(self, pipeline: Pipeline):
def get(self, current_user: Account, pipeline: Pipeline):
"""
Get published workflows
"""
current_user, _ = current_account_with_tenant()
query = WorkflowListQuery.model_validate(request.args.to_dict())
page = query.page
@ -655,9 +637,9 @@ class RagPipelineDraftWorkflowRestoreApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user
@get_rag_pipeline
def post(self, pipeline: Pipeline, workflow_id: str):
current_user, _ = current_account_with_tenant()
def post(self, current_user: Account, pipeline: Pipeline, workflow_id: str):
rag_pipeline_service = RagPipelineService()
try:
@ -689,14 +671,12 @@ class RagPipelineByIdApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
@with_current_user
@get_rag_pipeline
def patch(self, pipeline: Pipeline, workflow_id: str):
def patch(self, current_user: Account, pipeline: Pipeline, workflow_id: str):
"""
Update workflow attributes
"""
# Check permission
current_user, _ = current_account_with_tenant()
payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
@ -925,8 +905,8 @@ class DatasourceListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str):
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(current_tenant_id))
@ -961,9 +941,8 @@ class RagPipelineTransformApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
@ -984,13 +963,13 @@ class RagPipelineDatasourceVariableApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@get_rag_pipeline
@edit_permission_required
def post(self, pipeline: Pipeline):
def post(self, current_user: Account, pipeline: Pipeline):
"""
Set datasource variables
"""
current_user, _ = current_account_with_tenant()
args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump()
rag_pipeline_service = RagPipelineService()

View File

@ -30,6 +30,7 @@ from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
@ -45,6 +46,7 @@ from graphon.graph_engine.manager import GraphEngineManager
from libs import helper
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models import Account
from models.snippet import CustomizedSnippet
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.snippet_generate_service import SnippetGenerateService
@ -158,12 +160,11 @@ class SnippetDraftWorkflowApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
def post(self, current_user: Account, snippet: CustomizedSnippet):
"""Sync draft workflow for snippet."""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
try:
@ -239,11 +240,11 @@ class SnippetPublishedWorkflowApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
def post(self, current_user: Account, snippet: CustomizedSnippet):
"""Publish snippet workflow."""
current_user, _ = current_account_with_tenant()
snippet_service = _snippet_service()
with Session(db.engine) as session:
@ -331,11 +332,11 @@ class SnippetDraftWorkflowRestoreApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, workflow_id: str):
def post(self, current_user: Account, snippet: CustomizedSnippet, workflow_id: str):
"""Restore a published snippet workflow version into the draft workflow."""
current_user, _ = current_account_with_tenant()
snippet_service = _snippet_service()
try:
@ -455,16 +456,16 @@ class SnippetDraftNodeRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str):
"""
Run a single node in snippet draft workflow.
Executes a specific node with provided inputs for single-step debugging.
Returns the node execution result including status, outputs, and timing.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
user_inputs = payload.inputs
@ -539,16 +540,16 @@ class SnippetDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow iteration node for snippet.
Iteration nodes execute their internal sub-graph multiple times over an input list.
Returns an SSE event stream with iteration progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
@ -580,16 +581,16 @@ class SnippetDraftRunLoopNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow loop node for snippet.
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
Returns an SSE event stream with loop progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
@ -619,17 +620,16 @@ class SnippetDraftWorkflowRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
def post(self, current_user: Account, snippet: CustomizedSnippet):
"""
Run draft workflow for snippet.
Executes the snippet's draft workflow with the provided inputs
and returns an SSE event stream with execution progress and results.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)

View File

@ -13,13 +13,19 @@ from controllers.common.fields import SuccessResponse
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.plugin_service import PluginService
from fields.base import ResponseModel
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from libs.login import login_required
from models.account import Account, TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
@ -200,9 +206,8 @@ class PluginDebuggingKeyApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
try:
return {
"key": PluginService.get_debugging_key(tenant_id),
@ -219,8 +224,8 @@ class PluginListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
args = ParserList.model_validate(request.args.to_dict(flat=True))
try:
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
@ -253,9 +258,8 @@ class PluginListInstallationsFromIdsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str):
args = ParserLatest.model_validate(console_ns.payload)
try:
@ -288,10 +292,10 @@ class PluginAssetApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
@with_current_tenant_id
def get(self, tenant_id: str):
args = ParserAsset.model_validate(request.args.to_dict(flat=True))
_, tenant_id = current_account_with_tenant()
try:
binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name)
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
@ -305,9 +309,8 @@ class PluginUploadFromPkgApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str):
file = request.files["pkg"]
content = _read_upload_content(file, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
try:
@ -325,9 +328,8 @@ class PluginUploadFromGithubApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str):
args = ParserGithubUpload.model_validate(console_ns.payload)
try:
@ -344,9 +346,8 @@ class PluginUploadFromBundleApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str):
file = request.files["bundle"]
content = _read_upload_content(file, dify_config.PLUGIN_MAX_BUNDLE_SIZE)
try:
@ -364,8 +365,8 @@ class PluginInstallFromPkgApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str):
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
try:
@ -383,9 +384,8 @@ class PluginInstallFromGithubApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str):
args = ParserGithubInstall.model_validate(console_ns.payload)
try:
@ -409,9 +409,8 @@ class PluginInstallFromMarketplaceApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str):
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
try:
@ -429,8 +428,8 @@ class PluginFetchMarketplacePkgApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True))
try:
@ -453,9 +452,8 @@ class PluginFetchManifestApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True))
try:
@ -473,9 +471,8 @@ class PluginFetchInstallTasksApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
args = ParserTasks.model_validate(request.args.to_dict(flat=True))
try:
@ -490,9 +487,8 @@ class PluginFetchInstallTaskApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self, task_id: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, task_id: str):
try:
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
except PluginDaemonClientSideError as e:
@ -506,9 +502,8 @@ class PluginDeleteInstallTaskApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self, task_id: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str, task_id: str):
try:
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
except PluginDaemonClientSideError as e:
@ -522,9 +517,8 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str):
try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
except PluginDaemonClientSideError as e:
@ -538,9 +532,8 @@ class PluginDeleteInstallTaskItemApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self, task_id: str, identifier: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str, task_id: str, identifier: str):
try:
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
except PluginDaemonClientSideError as e:
@ -554,9 +547,8 @@ class PluginUpgradeFromMarketplaceApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str):
args = ParserMarketplaceUpgrade.model_validate(console_ns.payload)
try:
@ -576,9 +568,8 @@ class PluginUpgradeFromGithubApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str):
args = ParserGithubUpgrade.model_validate(console_ns.payload)
try:
@ -604,11 +595,10 @@ class PluginUninstallApi(Resource):
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
@with_current_tenant_id
def post(self, tenant_id: str):
args = ParserUninstall.model_validate(console_ns.payload)
_, tenant_id = current_account_with_tenant()
try:
return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)}
except PluginDaemonClientSideError as e:
@ -622,16 +612,14 @@ class PluginChangePermissionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account):
if not user.is_admin_or_owner:
raise Forbidden()
args = ParserPermissionChange.model_validate(console_ns.payload)
tenant_id = current_tenant_id
return {
"success": PluginPermissionService.change_permission(
tenant_id, args.install_permission, args.debug_permission
@ -644,9 +632,8 @@ class PluginFetchPermissionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
permission = PluginPermissionService.get_permission(tenant_id)
if not permission:
return jsonable_encoder(
@ -671,16 +658,15 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self):
current_user, tenant_id = current_account_with_tenant()
user_id = current_user.id
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, current_user: Account):
args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True))
try:
options = PluginParameterService.get_dynamic_select_options(
tenant_id=tenant_id,
user_id=user_id,
user_id=current_user.id,
plugin_id=args.plugin_id,
provider=args.provider,
action=args.action,
@ -701,17 +687,16 @@ class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, current_user: Account):
"""Fetch dynamic options using credentials directly (for edit mode)."""
current_user, tenant_id = current_account_with_tenant()
user_id = current_user.id
args = ParserDynamicOptionsWithCredentials.model_validate(console_ns.payload)
try:
options = PluginParameterService.get_dynamic_select_options_with_credentials(
tenant_id=tenant_id,
user_id=user_id,
user_id=current_user.id,
plugin_id=args.plugin_id,
provider=args.provider,
action=args.action,
@ -731,8 +716,9 @@ class PluginChangePreferencesApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account):
if not user.is_admin_or_owner:
raise Forbidden()
@ -780,9 +766,8 @@ class PluginFetchPreferencesApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
permission = PluginPermissionService.get_permission(tenant_id)
permission_dict = {
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
@ -820,10 +805,9 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
@with_current_tenant_id
def post(self, tenant_id: str):
# exclude one single plugin
_, tenant_id = current_account_with_tenant()
args = ParserExcludePlugin.model_validate(console_ns.payload)
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)})
@ -835,8 +819,8 @@ class PluginReadmeApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
args = ParserReadme.model_validate(request.args.to_dict(flat=True))
return jsonable_encoder(
{"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)}

View File

@ -21,10 +21,13 @@ from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from models.snippet import SnippetType
from services.app_dsl_service import ImportStatus
from services.snippet_dsl_service import SnippetDslService
@ -91,10 +94,9 @@ class CustomizedSnippetsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
@with_current_tenant_id
def get(self, current_tenant_id: str):
"""List customized snippets with pagination and search."""
_, current_tenant_id = current_account_with_tenant()
query = SnippetListQuery.model_validate(_normalize_snippet_list_query_args(request.args))
snippet_service = _snippet_service()
@ -124,10 +126,10 @@ class CustomizedSnippetsApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
"""Create a new customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
try:
@ -163,10 +165,9 @@ class CustomizedSnippetDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, snippet_id: str):
@with_current_tenant_id
def get(self, current_tenant_id: str, snippet_id: str):
"""Get customized snippet details."""
_, current_tenant_id = current_account_with_tenant()
snippet_service = _snippet_service()
snippet = snippet_service.get_snippet_by_id(
snippet_id=str(snippet_id),
@ -187,10 +188,10 @@ class CustomizedSnippetDetailApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, snippet_id: str):
@with_current_user
@with_current_tenant_id
def patch(self, current_tenant_id: str, current_user: Account, snippet_id: str):
"""Update customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
snippet_service = _snippet_service()
snippet = snippet_service.get_snippet_by_id(
snippet_id=str(snippet_id),
@ -231,10 +232,9 @@ class CustomizedSnippetDetailApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, snippet_id: str):
@with_current_tenant_id
def delete(self, current_tenant_id: str, snippet_id: str):
"""Delete customized snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet_service = _snippet_service()
snippet = snippet_service.get_snippet_by_id(
snippet_id=str(snippet_id),
@ -266,10 +266,9 @@ class CustomizedSnippetExportApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def get(self, snippet_id: str):
@with_current_tenant_id
def get(self, current_tenant_id: str, snippet_id: str):
"""Export snippet as DSL."""
_, current_tenant_id = current_account_with_tenant()
snippet_service = _snippet_service()
snippet = snippet_service.get_snippet_by_id(
snippet_id=str(snippet_id),
@ -312,9 +311,9 @@ class CustomizedSnippetImportApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
@with_current_user
def post(self, current_user: Account):
"""Import snippet from DSL."""
current_user, _ = current_account_with_tenant()
payload = SnippetImportPayload.model_validate(console_ns.payload or {})
with Session(db.engine) as session:
@ -350,10 +349,9 @@ class CustomizedSnippetImportConfirmApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, import_id: str):
@with_current_user
def post(self, current_user: Account, import_id: str):
"""Confirm a pending snippet import."""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
import_service = SnippetDslService(session)
result = import_service.confirm_import(import_id=import_id, account=current_user)
@ -375,10 +373,9 @@ class CustomizedSnippetCheckDependenciesApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def get(self, snippet_id: str):
@with_current_tenant_id
def get(self, current_tenant_id: str, snippet_id: str):
"""Check dependencies for a snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet_service = _snippet_service()
snippet = snippet_service.get_snippet_by_id(
snippet_id=str(snippet_id),
@ -406,10 +403,9 @@ class CustomizedSnippetUseCountIncrementApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, snippet_id: str):
@with_current_tenant_id
def post(self, current_tenant_id: str, snippet_id: str):
"""Increment snippet use count when it is inserted into a workflow."""
_, current_tenant_id = current_account_with_tenant()
snippet_service = _snippet_service()
snippet = snippet_service.get_snippet_by_id(
snippet_id=str(snippet_id),

View File

@ -18,6 +18,8 @@ from controllers.console.wraps import (
enterprise_license_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from core.db.session_factory import session_factory
from core.entities.mcp_provider import IdentityMode, MCPAuthentication, MCPConfiguration
@ -30,7 +32,8 @@ from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToo
from extensions.ext_database import db
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import alphanumeric, uuid_value
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from models.provider_ids import ToolProviderID
# from models.provider_ids import ToolProviderID
@ -286,15 +289,13 @@ class ToolProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account):
raw_args = request.args.to_dict()
query = ToolProviderListQuery.model_validate(raw_args)
return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) # type: ignore
return ToolCommonService.list_tool_providers(user.id, tenant_id, query.type) # type: ignore
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools")
@ -302,9 +303,8 @@ class ToolBuiltinProviderListToolsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
return jsonable_encoder(
BuiltinToolManageService.list_builtin_tool_provider_tools(
tenant_id,
@ -318,9 +318,8 @@ class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, provider: str):
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@ -331,9 +330,8 @@ class ToolBuiltinProviderDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {})
return BuiltinToolManageService.delete_builtin_tool_provider(
@ -349,15 +347,13 @@ class ToolBuiltinProviderAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account, provider: str):
payload = BuiltinToolAddPayload.model_validate(console_ns.payload or {})
return BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
user_id=user.id,
tenant_id=tenant_id,
provider=provider,
credentials=payload.credentials,
@ -374,14 +370,13 @@ class ToolBuiltinProviderUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account, provider: str):
payload = BuiltinToolUpdatePayload.model_validate(console_ns.payload or {})
result = BuiltinToolManageService.update_builtin_tool_provider(
user_id=user_id,
user_id=user.id,
tenant_id=tenant_id,
provider=provider,
credential_id=payload.credential_id,
@ -396,8 +391,9 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
user, tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, provider: str):
# Optional list of credential IDs to include even if visibility would hide them
# (used when a workflow/agent node still references another member's only_me credential).
include_credential_ids = request.args.getlist("include_credential_ids") or [
@ -430,15 +426,13 @@ class ToolApiProviderAddApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account):
payload = ApiToolProviderAddPayload.model_validate(console_ns.payload or {})
return ApiToolManageService.create_api_tool_provider(
user_id,
user.id,
tenant_id,
payload.provider,
payload.icon,
@ -456,16 +450,14 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account):
raw_args = request.args.to_dict()
query = UrlQuery.model_validate(raw_args)
return ApiToolManageService.get_api_tool_provider_remote_schema(
user_id,
user.id,
tenant_id,
str(query.url),
)
@ -476,17 +468,15 @@ class ToolApiProviderListToolsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account):
raw_args = request.args.to_dict()
query = ProviderQuery.model_validate(raw_args)
return jsonable_encoder(
ApiToolManageService.list_api_tool_provider_tools(
user_id,
user.id,
tenant_id,
query.provider,
)
@ -500,15 +490,13 @@ class ToolApiProviderUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account):
payload = ApiToolProviderUpdatePayload.model_validate(console_ns.payload or {})
return ApiToolManageService.update_api_tool_provider(
user_id,
user.id,
tenant_id,
payload.provider,
payload.original_provider,
@ -529,15 +517,13 @@ class ToolApiProviderDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account):
payload = ApiToolProviderDeletePayload.model_validate(console_ns.payload or {})
return ApiToolManageService.delete_api_tool_provider(
user_id,
user.id,
tenant_id,
payload.provider,
)
@ -548,16 +534,14 @@ class ToolApiProviderGetApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account):
raw_args = request.args.to_dict()
query = ProviderQuery.model_validate(raw_args)
return ApiToolManageService.get_api_tool_provider(
user_id,
user.id,
tenant_id,
query.provider,
)
@ -568,9 +552,8 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, credential_type):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, provider, credential_type):
return jsonable_encoder(
BuiltinToolManageService.list_builtin_provider_credentials_schema(
provider, CredentialType.of(credential_type), tenant_id
@ -598,9 +581,9 @@ class ToolApiProviderPreviousTestApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
@with_current_tenant_id
def post(self, current_tenant_id: str):
payload = ApiToolTestPayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
return ApiToolManageService.test_api_tool_preview(
current_tenant_id,
payload.provider_name or "",
@ -619,15 +602,13 @@ class ToolWorkflowProviderCreateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account):
payload = WorkflowToolCreatePayload.model_validate(console_ns.payload or {})
return WorkflowToolManageService.create_workflow_tool(
user_id=user_id,
user_id=user.id,
tenant_id=tenant_id,
workflow_app_id=payload.workflow_app_id,
name=payload.name,
@ -647,14 +628,13 @@ class ToolWorkflowProviderUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account):
payload = WorkflowToolUpdatePayload.model_validate(console_ns.payload or {})
return WorkflowToolManageService.update_workflow_tool(
user_id,
user.id,
tenant_id,
payload.workflow_tool_id,
payload.name,
@ -674,15 +654,13 @@ class ToolWorkflowProviderDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account):
payload = WorkflowToolDeletePayload.model_validate(console_ns.payload or {})
return WorkflowToolManageService.delete_workflow_tool(
user_id,
user.id,
tenant_id,
payload.workflow_tool_id,
)
@ -693,23 +671,21 @@ class ToolWorkflowProviderGetApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account):
raw_args = request.args.to_dict()
query = WorkflowToolGetQuery.model_validate(raw_args)
if query.workflow_tool_id:
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
user_id,
user.id,
tenant_id,
query.workflow_tool_id,
)
elif query.workflow_app_id:
tool = WorkflowToolManageService.get_workflow_tool_by_app_id(
user_id,
user.id,
tenant_id,
query.workflow_app_id,
)
@ -724,17 +700,15 @@ class ToolWorkflowProviderListToolApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account):
raw_args = request.args.to_dict()
query = WorkflowToolListQuery.model_validate(raw_args)
return jsonable_encoder(
WorkflowToolManageService.list_single_workflow_tools(
user_id,
user.id,
tenant_id,
query.workflow_tool_id,
)
@ -746,16 +720,14 @@ class ToolBuiltinListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account):
return jsonable_encoder(
[
provider.to_dict()
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
user.id,
tenant_id,
)
]
@ -767,9 +739,8 @@ class ToolApiListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
return jsonable_encoder(
[
provider.to_dict()
@ -785,16 +756,14 @@ class ToolWorkflowListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user, tenant_id = current_account_with_tenant()
user_id = user.id
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account):
return jsonable_encoder(
[
provider.to_dict()
for provider in WorkflowToolManageService.list_tenant_workflow_tools(
user_id,
user.id,
tenant_id,
)
]
@ -817,13 +786,13 @@ class ToolPluginOAuthApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def get(self, provider: str):
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, provider: str):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
user, tenant_id = current_account_with_tenant()
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
@ -912,8 +881,8 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider: str):
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
return BuiltinToolManageService.set_default_provider(
tenant_id=current_tenant_id, provider=provider, id=payload.id
@ -927,11 +896,10 @@ class ToolOAuthCustomClient(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self, provider: str):
@with_current_tenant_id
def post(self, tenant_id: str, provider: str):
payload = ToolOAuthCustomClientPayload.model_validate(console_ns.payload or {})
_, tenant_id = current_account_with_tenant()
return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=tenant_id,
provider=provider,
@ -944,8 +912,8 @@ class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str, provider: str):
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
)
@ -953,8 +921,8 @@ class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def delete(self, current_tenant_id: str, provider: str):
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
)
@ -965,8 +933,8 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str, provider: str):
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
tenant_id=current_tenant_id, provider_name=provider
@ -979,8 +947,9 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str):
user, tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, provider: str):
include_credential_ids = request.args.getlist("include_credential_ids") or [
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
]
@ -1001,9 +970,10 @@ class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, tenant_id: str, user: Account):
payload = MCPProviderCreatePayload.model_validate(console_ns.payload or {})
user, tenant_id = current_account_with_tenant()
# Parse and validate models
configuration = MCPConfiguration.model_validate(payload.configuration or {})
@ -1054,11 +1024,11 @@ class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@account_initialization_required
def put(self):
@with_current_tenant_id
def put(self, current_tenant_id: str):
payload = MCPProviderUpdatePayload.model_validate(console_ns.payload or {})
configuration = MCPConfiguration.model_validate(payload.configuration or {})
authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None
_, current_tenant_id = current_account_with_tenant()
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
validation_data = None
@ -1107,9 +1077,9 @@ class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self):
@with_current_tenant_id
def delete(self, current_tenant_id: str):
payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
@ -1124,10 +1094,10 @@ class ToolMCPAuthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
@with_current_tenant_id
def post(self, tenant_id: str):
payload = MCPAuthPayload.model_validate(console_ns.payload or {})
provider_id = payload.provider_id
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
@ -1197,8 +1167,8 @@ class ToolMCPDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, provider_id: str):
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
@ -1210,9 +1180,8 @@ class ToolMCPListAllApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str):
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
# Skip sensitive data decryption for list view to improve performance
@ -1226,8 +1195,8 @@ class ToolMCPUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, provider_id: str):
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
tools = service.list_provider_tools(

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import json
from datetime import datetime
from inspect import unwrap
from types import SimpleNamespace
from typing import TypedDict, Unpack
from unittest.mock import MagicMock, patch
@ -37,7 +38,8 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from libs.datetime_utils import naive_utc_now
from models.account import Account
from models.account import Account, TenantAccountRole
from models.dataset import Pipeline
from models.workflow import Workflow
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
@ -48,6 +50,14 @@ DEFAULT_WORKFLOW_CREATED_BY = "00000000-0000-0000-0000-000000000003"
type WorkflowVariablePayload = dict[str, object]
def empty_mapping() -> dict[str, object]:
return {}
def empty_list() -> list[object]:
return []
class WorkflowFactoryPayload(TypedDict):
id: str
tenant_id: str
@ -86,14 +96,8 @@ class WorkflowFactoryOverrides(TypedDict, total=False):
rag_pipeline_variables: list[WorkflowVariablePayload]
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def make_node_execution(**overrides):
payload = {
def make_node_execution(**overrides: object) -> SimpleNamespace:
payload: dict[str, object] = {
"id": "node-exec-1",
"index": 1,
"predecessor_node_id": None,
@ -148,6 +152,27 @@ def make_workflow(**overrides: Unpack[WorkflowFactoryOverrides]) -> Workflow:
return Workflow(**payload)
def make_account(*, id: str = "account-1", role: TenantAccountRole = TenantAccountRole.EDITOR) -> Account:
account = Account(name="Alice", email=f"{id}@example.com")
account.id = id
account.role = role
return account
def make_pipeline(
*,
id: str = "pipeline-1",
tenant_id: str = "tenant-1",
workflow_id: str | None = None,
is_published: bool = False,
) -> Pipeline:
pipeline = Pipeline(tenant_id=tenant_id, name="test-pipeline", description="test")
pipeline.id = id
pipeline.workflow_id = workflow_id
pipeline.is_published = is_published
return pipeline
@pytest.fixture
def workflow_author(db_session_with_containers: Session) -> Account:
account = Account(name="Alice", email=f"alice-{uuid4()}@example.com")
@ -158,14 +183,14 @@ def workflow_author(db_session_with_containers: Session) -> Account:
class TestDraftWorkflowApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_get_draft_success(self, app: Flask, workflow_author: Account):
def test_get_draft_success(self, app: Flask, workflow_author: Account) -> None:
api = DraftRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
pipeline = make_pipeline()
workflow = make_workflow(created_by=workflow_author.id)
service = MagicMock()
@ -191,11 +216,11 @@ class TestDraftWorkflowApi:
}
assert result["updated_by"] is None
def test_get_draft_not_exist(self, app: Flask):
def test_get_draft_not_exist(self, app: Flask) -> None:
api = DraftRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
pipeline = make_pipeline()
service = MagicMock()
service.get_draft_workflow.return_value = None
@ -209,54 +234,46 @@ class TestDraftWorkflowApi:
with pytest.raises(DraftWorkflowNotExist):
method(api, pipeline)
def test_sync_hash_not_match(self, app: Flask):
def test_sync_hash_not_match(self, app: Flask) -> None:
api = DraftRagPipelineApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
pipeline = make_pipeline()
user = make_account()
service = MagicMock()
service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError()
with (
app.test_request_context("/", json={"graph": {}, "features": {}}),
patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
app.test_request_context("/", json={"graph": empty_mapping(), "features": empty_mapping()}),
patch.object(type(console_ns), "payload", {"graph": empty_mapping(), "features": empty_mapping()}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(DraftWorkflowNotSync):
method(api, pipeline)
method(api, user, pipeline)
def test_sync_invalid_text_plain(self, app: Flask):
def test_sync_invalid_text_plain(self, app: Flask) -> None:
api = DraftRagPipelineApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
pipeline = make_pipeline()
user = make_account()
with (
app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
response, status = method(api, pipeline)
response, status = method(api, user, pipeline)
assert status == 400
def test_restore_published_workflow_to_draft_success(self, app: Flask):
def test_restore_published_workflow_to_draft_success(self, app: Flask) -> None:
api = RagPipelineDraftWorkflowRestoreApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock(id="account-1")
pipeline = make_pipeline()
user = make_account(id="account-1")
workflow = MagicMock(unique_hash="restored-hash", updated_at=None, created_at=datetime(2024, 1, 1))
service = MagicMock()
@ -264,50 +281,42 @@ class TestDraftWorkflowApi:
with (
app.test_request_context("/", method="POST"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline, "published-workflow")
result = method(api, user, pipeline, "published-workflow")
assert result["result"] == "success"
assert result["hash"] == "restored-hash"
def test_restore_published_workflow_to_draft_not_found(self, app: Flask):
def test_restore_published_workflow_to_draft_not_found(self, app: Flask) -> None:
api = RagPipelineDraftWorkflowRestoreApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock(id="account-1")
pipeline = make_pipeline()
user = make_account(id="account-1")
service = MagicMock()
service.restore_published_workflow_to_draft.side_effect = WorkflowNotFoundError("Workflow not found")
with (
app.test_request_context("/", method="POST"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(NotFound):
method(api, pipeline, "published-workflow")
method(api, user, pipeline, "published-workflow")
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app: Flask):
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app: Flask) -> None:
api = RagPipelineDraftWorkflowRestoreApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock(id="account-1")
pipeline = make_pipeline()
user = make_account(id="account-1")
service = MagicMock()
service.restore_published_workflow_to_draft.side_effect = IsDraftWorkflowError(
@ -316,17 +325,13 @@ class TestDraftWorkflowApi:
with (
app.test_request_context("/", method="POST"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(HTTPException) as exc:
method(api, pipeline, "draft-workflow")
method(api, user, pipeline, "draft-workflow")
assert exc.value.code == 400
assert exc.value.description == "source workflow must be published"
@ -334,23 +339,19 @@ class TestDraftWorkflowApi:
class TestDraftRunNodes:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_iteration_node_success(self, app: Flask):
def test_iteration_node_success(self, app: Flask) -> None:
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
pipeline = make_pipeline()
user = make_account()
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
app.test_request_context("/", json={"inputs": empty_mapping()}),
patch.object(type(console_ns), "payload", {"inputs": empty_mapping()}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
return_value=MagicMock(),
@ -360,45 +361,37 @@ class TestDraftRunNodes:
return_value={"ok": True},
),
):
result = method(api, pipeline, "node")
result = method(api, user, pipeline, "node")
assert result == {"ok": True}
def test_iteration_node_conversation_not_exists(self, app: Flask):
def test_iteration_node_conversation_not_exists(self, app: Flask) -> None:
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
pipeline = make_pipeline()
user = make_account()
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
app.test_request_context("/", json={"inputs": empty_mapping()}),
patch.object(type(console_ns), "payload", {"inputs": empty_mapping()}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
side_effect=services.errors.conversation.ConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(api, pipeline, "node")
method(api, user, pipeline, "node")
def test_loop_node_success(self, app: Flask):
def test_loop_node_success(self, app: Flask) -> None:
api = RagPipelineDraftRunLoopNodeApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
pipeline = make_pipeline()
user = make_account()
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
app.test_request_context("/", json={"inputs": empty_mapping()}),
patch.object(type(console_ns), "payload", {"inputs": empty_mapping()}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop",
return_value=MagicMock(),
@ -408,35 +401,31 @@ class TestDraftRunNodes:
return_value={"ok": True},
),
):
assert method(api, pipeline, "node") == {"ok": True}
assert method(api, user, pipeline, "node") == {"ok": True}
class TestPipelineRunApis:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_draft_run_success(self, app: Flask):
def test_draft_run_success(self, app: Flask) -> None:
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
pipeline = make_pipeline()
user = make_account()
payload = {
"inputs": {},
"inputs": empty_mapping(),
"datasource_type": "x",
"datasource_info_list": [],
"datasource_info_list": empty_list(),
"start_node_id": "n",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
return_value=MagicMock(),
@ -446,27 +435,27 @@ class TestPipelineRunApis:
return_value={"ok": True},
),
):
assert method(api, pipeline) == {"ok": True}
assert method(api, user, pipeline) == {"ok": True}
def test_draft_run_rate_limit(self, app: Flask):
def test_draft_run_rate_limit(self, app: Flask) -> None:
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
pipeline = make_pipeline()
user = make_account()
payload: dict[str, object] = {
"inputs": empty_mapping(),
"datasource_type": "x",
"datasource_info_list": empty_list(),
"start_node_id": "n",
}
with (
app.test_request_context(
"/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}
),
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
{"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"},
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
payload,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
@ -474,48 +463,42 @@ class TestPipelineRunApis:
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(api, pipeline)
method(api, user, pipeline)
class TestDraftNodeRun:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_execution_not_found(self, app: Flask):
def test_execution_not_found(self, app: Flask) -> None:
api = RagPipelineDraftNodeRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
pipeline = make_pipeline()
user = make_account()
service = MagicMock()
service.run_draft_workflow_node.return_value = None
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
app.test_request_context("/", json={"inputs": empty_mapping()}),
patch.object(type(console_ns), "payload", {"inputs": empty_mapping()}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(ValueError):
method(api, pipeline, "node")
method(api, user, pipeline, "node")
class TestPublishedPipelineApis:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_publish_success(self, app: Flask, db_session_with_containers: Session):
from models.dataset import Pipeline
def test_publish_success(self, app: Flask, db_session_with_containers: Session) -> None:
api = PublishedRagPipelineApi()
method = unwrap(api.post)
@ -530,7 +513,7 @@ class TestPublishedPipelineApis:
db_session_with_containers.commit()
db_session_with_containers.expire_all()
user = MagicMock(id="u1")
user = make_account(id="u1")
workflow = MagicMock(
id=str(uuid4()),
@ -542,16 +525,12 @@ class TestPublishedPipelineApis:
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
result = method(api, user, pipeline)
assert result["result"] == "success"
assert "created_at" in result
@ -559,47 +538,39 @@ class TestPublishedPipelineApis:
class TestMiscApis:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_task_stop(self, app: Flask):
def test_task_stop(self, app: Flask) -> None:
api = RagPipelineTaskStopApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock(id="u1")
pipeline = make_pipeline()
user = make_account(id="u1")
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag"
) as stop_mock,
):
result = method(api, pipeline, "task-1")
result = method(api, user, pipeline, "task-1")
stop_mock.assert_called_once()
assert result["result"] == "success"
def test_transform_forbidden(self, app: Flask):
def test_transform_forbidden(self, app: Flask) -> None:
api = RagPipelineTransformApi()
method = unwrap(api.post)
user = MagicMock(has_edit_permission=False, is_dataset_operator=False)
user = make_account(role=TenantAccountRole.NORMAL)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
with pytest.raises(Forbidden):
method(api, "ds1")
method(api, user, "ds1")
def test_recommended_plugins(self, app: Flask):
def test_recommended_plugins(self, app: Flask) -> None:
api = RagPipelineRecommendedPluginApi()
method = unwrap(api.get)
@ -619,20 +590,20 @@ class TestMiscApis:
class TestPublishedRagPipelineRunApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_published_run_success(self, app: Flask):
def test_published_run_success(self, app: Flask) -> None:
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
pipeline = make_pipeline()
user = make_account()
payload = {
"inputs": {},
"inputs": empty_mapping(),
"datasource_type": "x",
"datasource_info_list": [],
"datasource_info_list": empty_list(),
"start_node_id": "n",
"response_mode": "blocking",
}
@ -640,10 +611,6 @@ class TestPublishedRagPipelineRunApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
return_value=MagicMock(),
@ -653,49 +620,45 @@ class TestPublishedRagPipelineRunApi:
return_value={"ok": True},
),
):
result = method(api, pipeline)
result = method(api, user, pipeline)
assert result == {"ok": True}
def test_published_run_rate_limit(self, app: Flask):
def test_published_run_rate_limit(self, app: Flask) -> None:
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
pipeline = make_pipeline()
user = make_account()
payload = {
"inputs": {},
"inputs": empty_mapping(),
"datasource_type": "x",
"datasource_info_list": [],
"datasource_info_list": empty_list(),
"start_node_id": "n",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
side_effect=InvokeRateLimitError("limit"),
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(api, pipeline)
method(api, user, pipeline)
class TestDefaultBlockConfigApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_get_block_config_success(self, app: Flask):
def test_get_block_config_success(self, app: Flask) -> None:
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
pipeline = MagicMock()
pipeline = make_pipeline()
service = MagicMock()
service.get_default_block_config.return_value = {"k": "v"}
@ -710,11 +673,11 @@ class TestDefaultBlockConfigApi:
result = method(api, pipeline, "llm")
assert result == {"k": "v"}
def test_get_block_config_invalid_json(self, app: Flask):
def test_get_block_config_invalid_json(self, app: Flask) -> None:
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
pipeline = MagicMock()
pipeline = make_pipeline()
with app.test_request_context("/?q=bad-json"):
with pytest.raises(ValueError):
@ -723,65 +686,57 @@ class TestDefaultBlockConfigApi:
class TestPublishedAllRagPipelineApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_get_published_workflows_success(self, app: Flask):
def test_get_published_workflows_success(self, app: Flask) -> None:
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
user = MagicMock(id="u1")
pipeline = make_pipeline()
user = make_account(id="u1")
service = MagicMock()
service.get_all_published_workflow.return_value = ([make_workflow(id="w1")], False)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
result = method(api, user, pipeline)
assert result["items"][0]["id"] == "w1"
assert result["items"][0]["graph"] == {"nodes": [], "edges": []}
assert result["has_more"] is False
def test_get_published_workflows_forbidden(self, app: Flask):
def test_get_published_workflows_forbidden(self, app: Flask) -> None:
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
user = MagicMock(id="u1")
pipeline = make_pipeline()
user = make_account(id="u1")
with (
app.test_request_context("/?user_id=u2"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
with pytest.raises(Forbidden):
method(api, pipeline)
method(api, user, pipeline)
class TestRagPipelineByIdApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_patch_success(self, app: Flask):
def test_patch_success(self, app: Flask) -> None:
api = RagPipelineByIdApi()
method = unwrap(api.patch)
pipeline = MagicMock(tenant_id="t1")
user = MagicMock(id="u1")
pipeline = make_pipeline(tenant_id="t1")
user = make_account(id="u1")
workflow = make_workflow(id="w1", marked_name="test")
@ -793,44 +748,36 @@ class TestRagPipelineByIdApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline, "w1")
result = method(api, user, pipeline, "w1")
assert result["id"] == "w1"
assert result["marked_name"] == "test"
assert result["hash"] == workflow.unique_hash
def test_patch_no_fields(self, app: Flask):
def test_patch_no_fields(self, app: Flask) -> None:
api = RagPipelineByIdApi()
method = unwrap(api.patch)
pipeline = MagicMock()
user = MagicMock()
pipeline = make_pipeline()
user = make_account()
with (
app.test_request_context("/", json={}),
patch.object(type(console_ns), "payload", {}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch.object(type(console_ns), "payload", empty_mapping()),
):
result, status = method(api, pipeline, "w1")
result, status = method(api, user, pipeline, "w1")
assert status == 400
def test_delete_success(self, app: Flask):
def test_delete_success(self, app: Flask) -> None:
api = RagPipelineByIdApi()
method = unwrap(api.delete)
pipeline = MagicMock(tenant_id="t1", workflow_id="active-workflow", id="pipeline-1")
pipeline = make_pipeline(tenant_id="t1", workflow_id="active-workflow")
workflow_service = MagicMock()
@ -846,11 +793,11 @@ class TestRagPipelineByIdApi:
workflow_service.delete_workflow.assert_called_once()
assert result == (None, 204)
def test_delete_active_workflow_rejected(self, app: Flask):
def test_delete_active_workflow_rejected(self, app: Flask) -> None:
api = RagPipelineByIdApi()
method = unwrap(api.delete)
pipeline = MagicMock(tenant_id="t1", workflow_id="active-workflow", id="pipeline-1")
pipeline = make_pipeline(tenant_id="t1", workflow_id="active-workflow")
with app.test_request_context("/", method="DELETE"):
with pytest.raises(BadRequest, match="currently in use by pipeline"):
@ -859,14 +806,14 @@ class TestRagPipelineByIdApi:
class TestRagPipelineWorkflowLastRunApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_last_run_success(self, app: Flask):
def test_last_run_success(self, app: Flask) -> None:
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
pipeline = MagicMock()
pipeline = make_pipeline()
workflow = MagicMock()
node_exec = make_node_execution()
@ -886,11 +833,11 @@ class TestRagPipelineWorkflowLastRunApi:
assert result["inputs"] == {"query": "hello"}
assert result["outputs"] == {"answer": "world"}
def test_last_run_not_found(self, app: Flask):
def test_last_run_not_found(self, app: Flask) -> None:
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
pipeline = MagicMock()
pipeline = make_pipeline()
service = MagicMock()
service.get_draft_workflow.return_value = None
@ -908,19 +855,19 @@ class TestRagPipelineWorkflowLastRunApi:
class TestRagPipelineDatasourceVariableApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_set_datasource_variables_success(self, app: Flask):
def test_set_datasource_variables_success(self, app: Flask) -> None:
api = RagPipelineDatasourceVariableApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
pipeline = make_pipeline()
user = make_account()
payload = {
"datasource_type": "db",
"datasource_info": {},
"datasource_info": empty_mapping(),
"start_node_id": "n1",
"start_node_title": "Node",
}
@ -931,15 +878,11 @@ class TestRagPipelineDatasourceVariableApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
result = method(api, user, pipeline)
assert result["node_id"] == "n1"
assert result["process_data"] == {}

View File

@ -3,12 +3,13 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from inspect import unwrap
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.workspace.tool_providers import (
@ -43,41 +44,56 @@ from controllers.console.workspace.tool_providers import (
ToolWorkflowProviderUpdateApi,
is_valid_url,
)
from models.account import Account, TenantAccountRole
from services.tools.mcp_tools_manage_service import ReconnectResult
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def empty_mapping() -> dict[str, object]:
return {}
def empty_list() -> list[object]:
return []
@pytest.fixture
def _mock_cache():
def _mock_cache() -> None:
return
@pytest.fixture
def _mock_user_tenant():
def _mock_user_tenant() -> None:
return
@pytest.fixture
def client(flask_app_with_containers: Flask):
def client(flask_app_with_containers: Flask) -> FlaskClient:
return flask_app_with_containers.test_client()
@patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "t1"),
autospec=True,
)
def make_account(*, id: str = "u", role: TenantAccountRole = TenantAccountRole.EDITOR) -> Account:
account = Account(name="Alice", email=f"{id}@example.com")
account.id = id
account.role = role
return account
@patch("controllers.console.workspace.tool_providers.sessionmaker", autospec=True)
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url", autospec=True)
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
def test_create_mcp_provider_populates_tools(
mock_reconnect, mock_session, mock_current_account_with_tenant, client: FlaskClient
):
mock_reconnect: MagicMock,
mock_session: MagicMock,
client: FlaskClient,
db_session_with_containers: Session,
) -> None:
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
headers = authenticate_console_client(client, account)
# Arrange: reconnect returns tools immediately
mock_reconnect.return_value = ReconnectResult(
authed=True,
@ -104,21 +120,11 @@ def test_create_mcp_provider_populates_tools(
"icon_background": "#000",
"server_identifier": "demo-sid",
"configuration": {"timeout": 5, "sse_read_timeout": 30},
"headers": {},
"authentication": {},
"headers": empty_mapping(),
"authentication": empty_mapping(),
}
# Act
with (
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "t1"),
autospec=True,
),
patch("libs.login.check_csrf_token", return_value=None, autospec=True), # bypass CSRF in login_required
patch(
"libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True), autospec=True
), # login
patch(
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
@ -128,6 +134,7 @@ def test_create_mcp_provider_populates_tools(
resp = client.post(
"/console/api/workspaces/current/tool-provider/mcp",
data=json.dumps(payload),
headers=headers,
content_type="application/json",
)
@ -141,7 +148,7 @@ def test_create_mcp_provider_populates_tools(
class TestUtils:
def test_is_valid_url(self):
def test_is_valid_url(self) -> None:
assert is_valid_url("https://example.com")
assert is_valid_url("http://example.com")
assert not is_valid_url("")
@ -152,154 +159,121 @@ class TestUtils:
class TestToolProviderListApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_get_success(self, app: Flask):
def test_get_success(self, app: Flask) -> None:
api = ToolProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "t1"),
),
patch(
"controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers",
return_value=["p1"],
),
):
assert method(api) == ["p1"]
assert method(api, "t1", make_account(id="u1")) == ["p1"]
class TestBuiltinProviderApis:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_list_tools(self, app: Flask):
def test_list_tools(self, app: Flask) -> None:
api = ToolBuiltinProviderListToolsApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t1"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools",
return_value=[{"a": 1}],
),
):
assert method(api, "provider") == [{"a": 1}]
assert method(api, "t1", "provider") == [{"a": 1}]
def test_info(self, app: Flask):
def test_info(self, app: Flask) -> None:
api = ToolBuiltinProviderInfoApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t1"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info",
return_value={"x": 1},
),
):
assert method(api, "provider") == {"x": 1}
assert method(api, "t1", "provider") == {"x": 1}
def test_delete(self, app: Flask):
def test_delete(self, app: Flask) -> None:
api = ToolBuiltinProviderDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credential_id": "cid"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t1"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_builtin_tool_provider",
return_value={"result": "success"},
),
):
assert method(api, "provider")["result"] == "success"
assert method(api, "t1", "provider")["result"] == "success"
def test_add_invalid_type(self, app: Flask):
def test_add_invalid_type(self, app: Flask) -> None:
api = ToolBuiltinProviderAddApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"credentials": {}, "type": "invalid"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
app.test_request_context("/", json={"credentials": empty_mapping(), "type": "invalid"}),
):
with pytest.raises(ValueError):
method(api, "provider")
method(api, "t", make_account(), "provider")
def test_add_success(self, app: Flask):
def test_add_success(self, app: Flask) -> None:
api = ToolBuiltinProviderAddApi()
method = unwrap(api.post)
payload = {"credentials": {}, "type": "oauth2", "name": "n"}
payload = {"credentials": empty_mapping(), "type": "oauth2", "name": "n"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider",
return_value={"id": 1},
),
):
assert method(api, "provider")["id"] == 1
assert method(api, "t", make_account(), "provider")["id"] == 1
def test_update(self, app: Flask):
def test_update(self, app: Flask) -> None:
api = ToolBuiltinProviderUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "c1", "credentials": {}, "name": "n"}
payload = {"credential_id": "c1", "credentials": empty_mapping(), "name": "n"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]
assert method(api, "t", make_account(), "provider")["ok"]
def test_get_credentials(self, app: Flask):
def test_get_credentials(self, app: Flask) -> None:
api = ToolBuiltinProviderGetCredentialsApi()
method = unwrap(api.get)
mock_user = SimpleNamespace(id="user-1", is_admin_or_owner=False)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(mock_user, "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials",
return_value={"k": "v"},
),
):
assert method(api, "provider") == {"k": "v"}
assert method(api, "t", make_account(id="user-1"), "provider") == {"k": "v"}
def test_icon(self, app: Flask):
def test_icon(self, app: Flask) -> None:
api = ToolBuiltinProviderIconApi()
method = unwrap(api.get)
@ -313,208 +287,168 @@ class TestBuiltinProviderApis:
response = method(api, "provider")
assert response.mimetype == "image/png"
def test_credentials_schema(self, app: Flask):
def test_credentials_schema(self, app: Flask) -> None:
api = ToolBuiltinProviderCredentialsSchemaApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema",
return_value={"schema": {}},
),
):
assert method(api, "provider", "oauth2") == {"schema": {}}
assert method(api, "t", "provider", "oauth2") == {"schema": {}}
def test_set_default_credential(self, app: Flask):
def test_set_default_credential(self, app: Flask) -> None:
api = ToolBuiltinProviderSetDefaultApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"id": "c1"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]
assert method(api, "t", "provider")["ok"]
def test_get_credential_info(self, app: Flask):
def test_get_credential_info(self, app: Flask) -> None:
api = ToolBuiltinProviderGetCredentialInfoApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info",
return_value={"info": "x"},
),
):
assert method(api, "provider") == {"info": "x"}
assert method(api, "t", make_account(), "provider") == {"info": "x"}
def test_get_oauth_client_schema(self, app: Flask):
def test_get_oauth_client_schema(self, app: Flask) -> None:
api = ToolBuiltinProviderGetOauthClientSchemaApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema",
return_value={"schema": {}},
),
):
assert method(api, "provider") == {"schema": {}}
assert method(api, "t", "provider") == {"schema": {}}
class TestApiProviderApis:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_add(self, app: Flask):
def test_add(self, app: Flask) -> None:
api = ToolApiProviderAddApi()
method = unwrap(api.post)
payload = {
"credentials": {},
"credentials": empty_mapping(),
"schema_type": "openapi",
"schema": "{}",
"provider": "p",
"icon": {},
"icon": empty_mapping(),
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider",
return_value={"id": 1},
),
):
assert method(api)["id"] == 1
assert method(api, "t", make_account())["id"] == 1
def test_remote_schema(self, app: Flask):
def test_remote_schema(self, app: Flask) -> None:
api = ToolApiProviderGetRemoteSchemaApi()
method = unwrap(api.get)
with (
app.test_request_context("/?url=http://x.com"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema",
return_value={"schema": "x"},
),
):
assert method(api)["schema"] == "x"
assert method(api, "t", make_account())["schema"] == "x"
def test_list_tools(self, app: Flask):
def test_list_tools(self, app: Flask) -> None:
api = ToolApiProviderListToolsApi()
method = unwrap(api.get)
with (
app.test_request_context("/?provider=p"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools",
return_value=[{"tool": 1}],
),
):
assert method(api) == [{"tool": 1}]
assert method(api, "t", make_account()) == [{"tool": 1}]
def test_update(self, app: Flask):
def test_update(self, app: Flask) -> None:
api = ToolApiProviderUpdateApi()
method = unwrap(api.post)
payload = {
"credentials": {},
"credentials": empty_mapping(),
"schema_type": "openapi",
"schema": "{}",
"provider": "p",
"original_provider": "o",
"icon": {},
"icon": empty_mapping(),
"privacy_policy": "",
"custom_disclaimer": "",
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider",
return_value={"ok": True},
),
):
assert method(api)["ok"]
assert method(api, "t", make_account())["ok"]
def test_delete(self, app: Flask):
def test_delete(self, app: Flask) -> None:
api = ToolApiProviderDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"provider": "p"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.delete_api_tool_provider",
return_value={"result": "success"},
),
):
assert method(api)["result"] == "success"
assert method(api, "t", make_account())["result"] == "success"
def test_get(self, app: Flask):
def test_get(self, app: Flask) -> None:
api = ToolApiProviderGetApi()
method = unwrap(api.get)
with (
app.test_request_context("/?provider=p"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider",
return_value={"x": 1},
),
):
assert method(api) == {"x": 1}
assert method(api, "t", make_account()) == {"x": 1}
class TestWorkflowApis:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_create(self, app: Flask):
def test_create(self, app: Flask) -> None:
api = ToolWorkflowProviderCreateApi()
method = unwrap(api.post)
@ -523,24 +457,20 @@ class TestWorkflowApis:
"name": "n",
"label": "l",
"description": "d",
"icon": {},
"parameters": [],
"icon": empty_mapping(),
"parameters": empty_list(),
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool",
return_value={"id": 1},
),
):
assert method(api)["id"] == 1
assert method(api, "t", make_account())["id"] == 1
def test_update_invalid(self, app: Flask):
def test_update_invalid(self, app: Flask) -> None:
api = ToolWorkflowProviderUpdateApi()
method = unwrap(api.post)
@ -549,61 +479,49 @@ class TestWorkflowApis:
"name": "Tool",
"label": "Tool Label",
"description": "A tool",
"icon": {},
"icon": empty_mapping(),
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool",
return_value={"ok": True},
),
):
result = method(api)
result = method(api, "t", make_account())
assert result["ok"]
def test_delete(self, app: Flask):
def test_delete(self, app: Flask) -> None:
api = ToolWorkflowProviderDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool",
return_value={"ok": True},
),
):
assert method(api)["ok"]
assert method(api, "t", make_account())["ok"]
def test_get_error(self, app: Flask):
def test_get_error(self, app: Flask) -> None:
api = ToolWorkflowProviderGetApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
):
with pytest.raises(ValueError):
method(api)
method(api, "t", make_account())
class TestLists:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_builtin_list(self, app: Flask):
def test_builtin_list(self, app: Flask) -> None:
api = ToolBuiltinListApi()
method = unwrap(api.get)
@ -612,18 +530,14 @@ class TestLists:
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools",
return_value=[m],
),
):
assert method(api) == [{"x": 1}]
assert method(api, "t", make_account()) == [{"x": 1}]
def test_api_list(self, app: Flask):
def test_api_list(self, app: Flask) -> None:
api = ToolApiListApi()
method = unwrap(api.get)
@ -632,18 +546,14 @@ class TestLists:
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(None, "t"),
),
patch(
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools",
return_value=[m],
),
):
assert method(api) == [{"x": 1}]
assert method(api, "t") == [{"x": 1}]
def test_workflow_list(self, app: Flask):
def test_workflow_list(self, app: Flask) -> None:
api = ToolWorkflowListApi()
method = unwrap(api.get)
@ -652,24 +562,20 @@ class TestLists:
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools",
return_value=[m],
),
):
assert method(api) == [{"x": 1}]
assert method(api, "t", make_account()) == [{"x": 1}]
class TestLabels:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_labels(self, app: Flask):
def test_labels(self, app: Flask) -> None:
api = ToolLabelsApi()
method = unwrap(api.get)
@ -685,28 +591,24 @@ class TestLabels:
class TestOAuth:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_oauth_no_client(self, app: Flask):
def test_oauth_no_client(self, app: Flask) -> None:
api = ToolPluginOAuthApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u"), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_oauth_client",
return_value=None,
),
):
with pytest.raises(Forbidden):
method(api, "provider")
method(api, "t", make_account(), "provider")
def test_oauth_callback_no_cookie(self, app: Flask):
def test_oauth_callback_no_cookie(self, app: Flask) -> None:
api = ToolOAuthCallback()
method = unwrap(api.get)
@ -717,56 +619,44 @@ class TestOAuth:
class TestOAuthCustomClient:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_save_custom_client(self, app: Flask):
def test_save_custom_client(self, app: Flask) -> None:
api = ToolOAuthCustomClient()
method = unwrap(api.post)
with (
app.test_request_context("/", json={"client_params": {"a": 1}}),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]
assert method(api, "t", "provider")["ok"]
def test_get_custom_client(self, app: Flask):
def test_get_custom_client(self, app: Flask) -> None:
api = ToolOAuthCustomClient()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_custom_oauth_client_params",
return_value={"client_id": "x"},
),
):
assert method(api, "provider") == {"client_id": "x"}
assert method(api, "t", "provider") == {"client_id": "x"}
def test_delete_custom_client(self, app: Flask):
def test_delete_custom_client(self, app: Flask) -> None:
api = ToolOAuthCustomClient()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(), "t"),
),
patch(
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params",
return_value={"ok": True},
),
):
assert method(api, "provider")["ok"]
assert method(api, "t", "provider")["ok"]

View File

@ -1,8 +1,11 @@
from __future__ import annotations
from inspect import unwrap
from types import SimpleNamespace
from typing import cast
import pytest
from flask import Response
from clients.agent_backend.errors import AgentBackendHTTPError, AgentBackendTransportError
from clients.agent_backend.workspace_files_client import (
@ -12,16 +15,10 @@ from clients.agent_backend.workspace_files_client import (
WorkspacePreviewResult,
)
from controllers.console import agent_app_workspace as module
from models.model import App, AppMode, IconType
from services.agent_app_workspace_service import AgentWorkspaceInspectorError
def _unwrapped_get(resource_cls):
func = resource_cls.get
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class _AgentAppService:
def __init__(self) -> None:
self.calls: list[tuple[str, str, str, str, str]] = []
@ -87,6 +84,20 @@ class _WorkflowService:
return WorkspaceDownloadResult(path=path, size=3, truncated=False, content=b"abc")
def _app_model(app_id: str = "app-1") -> App:
return App(
id=app_id,
tenant_id="tenant-1",
name="App",
mode=AppMode.AGENT,
icon_type=IconType.EMOJI,
icon="bot",
icon_background="#fff",
enable_site=False,
enable_api=False,
)
def test_handle_maps_workspace_and_agent_backend_errors() -> None:
assert module._handle(AgentWorkspaceInspectorError("no_sandbox", "no sandbox", status_code=404)) == (
{"code": "no_sandbox", "message": "no sandbox"},
@ -108,8 +119,11 @@ def test_handle_maps_workspace_and_agent_backend_errors() -> None:
def test_download_response_returns_binary_or_too_large_error() -> None:
response = module._download_response(
WorkspaceDownloadResult(path="dir/report.txt", size=3, truncated=False, content=b"abc")
response = cast(
Response,
module._download_response(
WorkspaceDownloadResult(path="dir/report.txt", size=3, truncated=False, content=b"abc")
),
)
assert response.status_code == 200
@ -133,17 +147,16 @@ def test_download_response_returns_binary_or_too_large_error() -> None:
def test_agent_app_workspace_resources_proxy_service(monkeypatch: pytest.MonkeyPatch) -> None:
service = _AgentAppService()
monkeypatch.setattr(module, "AgentAppWorkspaceService", lambda: service)
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (None, "tenant-1"))
monkeypatch.setattr(
module,
"query_params_from_request",
lambda model: SimpleNamespace(conversation_id="conv-1", path="sub/report.txt"),
)
app_model = SimpleNamespace(id="app-1")
app_model = _app_model()
listing = _unwrapped_get(module.AgentAppWorkspaceListResource)(object(), app_model)
preview = _unwrapped_get(module.AgentAppWorkspacePreviewResource)(object(), app_model)
download = _unwrapped_get(module.AgentAppWorkspaceDownloadResource)(object(), app_model)
listing = unwrap(module.AgentAppWorkspaceListResource.get)(object(), "tenant-1", app_model)
preview = unwrap(module.AgentAppWorkspacePreviewResource.get)(object(), "tenant-1", app_model)
download = unwrap(module.AgentAppWorkspaceDownloadResource.get)(object(), "tenant-1", app_model)
assert listing["entries"][0]["name"] == "a.txt"
assert preview["text"] == "hello"
@ -161,14 +174,13 @@ def test_agent_app_workspace_resource_returns_normalized_errors(monkeypatch: pyt
raise AgentWorkspaceInspectorError("no_active_session", "no active session", status_code=404)
monkeypatch.setattr(module, "AgentAppWorkspaceService", FailingService)
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (None, "tenant-1"))
monkeypatch.setattr(
module,
"query_params_from_request",
lambda model: SimpleNamespace(conversation_id="conv-1", path="."),
)
assert _unwrapped_get(module.AgentAppWorkspaceListResource)(object(), SimpleNamespace(id="app-1")) == (
assert unwrap(module.AgentAppWorkspaceListResource.get)(object(), "tenant-1", _app_model()) == (
{"code": "no_active_session", "message": "no active session"},
404,
)
@ -177,17 +189,22 @@ def test_agent_app_workspace_resource_returns_normalized_errors(monkeypatch: pyt
def test_workflow_agent_workspace_resources_proxy_service(monkeypatch: pytest.MonkeyPatch) -> None:
service = _WorkflowService()
monkeypatch.setattr(module, "WorkflowAgentWorkspaceService", lambda: service)
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (None, "tenant-1"))
monkeypatch.setattr(
module,
"query_params_from_request",
lambda model: SimpleNamespace(node_execution_id="exec-1", path="out.txt"),
)
app_model = SimpleNamespace(id="app-1")
app_model = _app_model()
listing = _unwrapped_get(module.WorkflowAgentWorkspaceListResource)(object(), app_model, "run-1", "agent-node")
preview = _unwrapped_get(module.WorkflowAgentWorkspacePreviewResource)(object(), app_model, "run-1", "agent-node")
download = _unwrapped_get(module.WorkflowAgentWorkspaceDownloadResource)(object(), app_model, "run-1", "agent-node")
listing = unwrap(module.WorkflowAgentWorkspaceListResource.get)(
object(), "tenant-1", app_model, "run-1", "agent-node"
)
preview = unwrap(module.WorkflowAgentWorkspacePreviewResource.get)(
object(), "tenant-1", app_model, "run-1", "agent-node"
)
download = unwrap(module.WorkflowAgentWorkspaceDownloadResource.get)(
object(), "tenant-1", app_model, "run-1", "agent-node"
)
assert listing["path"] == "out.txt"
assert preview["text"] == "hello"

View File

@ -542,7 +542,6 @@ def test_workflow_online_users_filters_inaccessible_workflow(app: Flask, monkeyp
app_id_2 = "22222222-2222-2222-2222-222222222222"
signed_avatar_url = "https://files.example.com/signed/avatar-1"
sign_avatar = Mock(return_value=signed_avatar_url)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
monkeypatch.setattr(
workflow_module,
"WorkflowService",
@ -596,7 +595,7 @@ def test_workflow_online_users_filters_inaccessible_workflow(app: Flask, monkeyp
method="POST",
json={"app_ids": [app_id_1, app_id_2]},
):
response = handler(api)
response = handler(api, "tenant-1")
assert response == {
"data": [
@ -625,7 +624,6 @@ def test_workflow_online_users_filters_inaccessible_workflow(app: Flask, monkeyp
def test_workflow_online_users_batches_redis_reads(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
app_ids = [f"wf-{index}" for index in range(workflow_module.WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE + 1)]
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
monkeypatch.setattr(
workflow_module,
"WorkflowService",
@ -647,7 +645,7 @@ def test_workflow_online_users_batches_redis_reads(app: Flask, monkeypatch: pyte
method="POST",
json={"app_ids": app_ids},
):
response = handler(api)
response = handler(api, "tenant-1")
assert len(response["data"]) == len(app_ids)
assert redis_pipeline_factory.call_count == 2
@ -656,7 +654,6 @@ def test_workflow_online_users_batches_redis_reads(app: Flask, monkeypatch: pyte
def test_workflow_online_users_rejects_excessive_workflow_ids(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
accessible_app_ids = Mock(return_value=set())
monkeypatch.setattr(
workflow_module,
@ -675,7 +672,7 @@ def test_workflow_online_users_rejects_excessive_workflow_ids(app: Flask, monkey
json={"app_ids": excessive_ids},
):
with pytest.raises(HTTPException) as exc:
handler(api)
handler(api, "tenant-1")
assert exc.value.code == 400
assert exc.value.description is not None

View File

@ -39,7 +39,6 @@ def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app
monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None)
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
monkeypatch.delenv("INIT_PASSWORD", raising=False)

View File

@ -1,13 +1,15 @@
from __future__ import annotations
from datetime import datetime
from inspect import unwrap
from inspect import unwrap as unwrap_all
from types import SimpleNamespace
from unittest.mock import PropertyMock, patch
import pytest
from controllers.console.datasets.rag_pipeline import rag_pipeline_workflow as module
from models.account import Account, TenantAccountRole
from models.dataset import Pipeline
def _make_workflow(**overrides):
@ -33,6 +35,19 @@ def _make_workflow(**overrides):
return workflow
def _account() -> Account:
account = Account(name="Alice", email="alice@example.com")
account.id = "user-1"
account.role = TenantAccountRole.EDITOR
return account
def _pipeline() -> Pipeline:
pipeline = Pipeline(tenant_id="tenant-1", name="Pipeline", description="desc")
pipeline.id = "pipeline-1"
return pipeline
def test_draft_rag_pipeline_workflow_get_serializes_response_model(monkeypatch: pytest.MonkeyPatch) -> None:
workflow = _make_workflow()
monkeypatch.setattr(
@ -40,9 +55,9 @@ def test_draft_rag_pipeline_workflow_get_serializes_response_model(monkeypatch:
)
api = module.DraftRagPipelineApi()
handler = unwrap(api.get)
handler = unwrap_all(api.get)
response = handler(api, pipeline=SimpleNamespace(id="pipeline-1"))
response = handler(api, _pipeline())
assert response["id"] == "workflow-1"
assert response["graph"] == {"nodes": [], "edges": []}
@ -58,7 +73,7 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes(
app, monkeypatch: pytest.MonkeyPatch
) -> None:
api = module.PublishedAllRagPipelineApi()
handler = unwrap(api.get)
handler = unwrap_all(api.get)
session_state = {"open": False}
class _SessionContext:
@ -83,7 +98,6 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes(
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker())
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (SimpleNamespace(id="user-1"), "tenant-1"))
monkeypatch.setattr(
module,
"RagPipelineService",
@ -95,7 +109,7 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes(
method="GET",
query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"},
):
response = handler(api, pipeline=SimpleNamespace(id="pipeline-1"))
response = handler(api, _account(), pipeline=_pipeline())
assert response["items"][0]["id"] == "workflow-1"
assert response["page"] == 1
@ -105,7 +119,6 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes(
def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow = _make_workflow(marked_name="Updated release")
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (SimpleNamespace(id="user-1"), "tenant-1"))
class _SessionContext:
def __enter__(self):
@ -128,7 +141,7 @@ def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch:
payload: dict[str, object] = {"marked_name": "Updated release"}
api = module.RagPipelineByIdApi()
handler = unwrap(api.patch)
handler = unwrap_all(api.patch)
with (
app.test_request_context("/rag/pipelines/pipeline-1/workflows/workflow-1", method="PATCH", json=payload),
@ -136,7 +149,8 @@ def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch:
):
response = handler(
api,
pipeline=SimpleNamespace(id="pipeline-1", tenant_id="tenant-1"),
_account(),
pipeline=_pipeline(),
workflow_id="workflow-1",
)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from datetime import datetime
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import Mock
@ -8,21 +9,34 @@ import pytest
from werkzeug.exceptions import HTTPException, NotFound
from controllers.console.snippets import snippet_workflow as snippet_workflow_module
from models.account import Account, TenantAccountRole
from models.snippet import CustomizedSnippet
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _account(account_id: str = "account-1") -> Account:
account = Account(name="Test User", email=f"{account_id}@example.com")
account.id = account_id
account.role = TenantAccountRole.EDITOR
return account
def _snippet(**overrides) -> CustomizedSnippet:
data = {
"id": "snippet-1",
"tenant_id": "tenant-1",
"name": "Snippet",
"description": "Description",
"type": "node",
"created_by": "account-1",
}
data.update(overrides)
return CustomizedSnippet(**data)
@pytest.fixture(autouse=True)
def _patch_snippet_service_factory(monkeypatch: pytest.MonkeyPatch) -> None:
def factory():
service_factory = snippet_workflow_module.SnippetService
if isinstance(service_factory, type):
return service_factory.__new__(service_factory)
return service_factory()
return snippet_workflow_module.SnippetService()
monkeypatch.setattr(snippet_workflow_module, "_snippet_service", factory)
monkeypatch.setattr(snippet_workflow_module, "_snippet_session_maker", Mock(return_value=Mock()))
@ -39,7 +53,7 @@ def test_get_snippet_requires_snippet_id(app):
def test_get_snippet_injects_resolved_snippet(app, monkeypatch: pytest.MonkeyPatch) -> None:
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
snippet = _snippet()
@snippet_workflow_module.get_snippet
def view(**kwargs):
@ -48,7 +62,7 @@ def test_get_snippet_injects_resolved_snippet(app, monkeypatch: pytest.MonkeyPat
monkeypatch.setattr(
snippet_workflow_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(id="account-1"), "tenant-1"),
lambda: (_account("account-1"), "tenant-1"),
)
monkeypatch.setattr(snippet_workflow_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
@ -66,7 +80,7 @@ def test_get_snippet_raises_not_found_when_snippet_missing(app, monkeypatch: pyt
monkeypatch.setattr(
snippet_workflow_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(id="account-1"), "tenant-1"),
lambda: (_account("account-1"), "tenant-1"),
)
monkeypatch.setattr(snippet_workflow_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
@ -76,7 +90,7 @@ def test_get_snippet_raises_not_found_when_snippet_missing(app, monkeypatch: pyt
def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
snippet = _snippet()
monkeypatch.setattr(
snippet_workflow_module,
"SnippetService",
@ -84,7 +98,7 @@ def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyP
)
api = snippet_workflow_module.SnippetDraftWorkflowApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/snippets/snippet-1/workflows/draft"):
with pytest.raises(snippet_workflow_module.DraftWorkflowNotExist):
@ -92,10 +106,9 @@ def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyP
def test_draft_workflow_post_returns_400_for_invalid_graph(app, monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(id="account-1")
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
user = _account("account-1")
snippet = _snippet()
sync_draft_workflow = Mock(side_effect=ValueError("invalid graph"))
monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
monkeypatch.setattr(
snippet_workflow_module,
"SnippetService",
@ -103,14 +116,14 @@ def test_draft_workflow_post_returns_400_for_invalid_graph(app, monkeypatch: pyt
)
api = snippet_workflow_module.SnippetDraftWorkflowApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/snippets/snippet-1/workflows/draft",
method="POST",
json={"graph": {"nodes": [], "edges": []}, "hash": "hash-1"},
):
response, status_code = handler(api, snippet=snippet)
response, status_code = handler(api, user, snippet)
assert status_code == 400
assert response == {"message": "invalid graph"}
@ -118,7 +131,7 @@ def test_draft_workflow_post_returns_400_for_invalid_graph(app, monkeypatch: pyt
def test_draft_config_returns_parallel_depth_limit(app) -> None:
api = snippet_workflow_module.SnippetDraftConfigApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/snippets/snippet-1/workflows/draft/config"):
assert handler(api, snippet=SimpleNamespace(id="snippet-1")) == {"parallel_depth_limit": 3}
@ -126,16 +139,16 @@ def test_draft_config_returns_parallel_depth_limit(app) -> None:
def test_published_workflow_get_returns_none_when_not_published(app) -> None:
api = snippet_workflow_module.SnippetPublishedWorkflowApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/snippets/snippet-1/workflows/publish"):
assert handler(api, snippet=SimpleNamespace(id="snippet-1", is_published=False)) is None
def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(id="account-1")
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
merged_snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
user = _account("account-1")
snippet = _snippet()
merged_snippet = _snippet()
session = SimpleNamespace(merge=Mock(return_value=merged_snippet), commit=Mock())
class SessionContext:
@ -148,7 +161,6 @@ def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
monkeypatch.setattr(snippet_workflow_module, "Session", SessionContext)
monkeypatch.setattr(snippet_workflow_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
@ -158,10 +170,10 @@ def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch
)
api = snippet_workflow_module.SnippetPublishedWorkflowApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context("/snippets/snippet-1/workflows/publish", method="POST", json={}):
response, status_code = handler(api, snippet=snippet)
response, status_code = handler(api, user, snippet)
assert status_code == 400
assert response == {"message": "No valid workflow found."}
@ -177,7 +189,7 @@ def test_default_block_configs_delegates_to_service(app, monkeypatch: pytest.Mon
)
api = snippet_workflow_module.SnippetDefaultBlockConfigsApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/snippets/snippet-1/workflows/default-workflow-block-configs"):
result = handler(api, snippet=SimpleNamespace(id="snippet-1"))
@ -192,10 +204,9 @@ def test_restore_published_snippet_workflow_to_draft_success(app, monkeypatch: p
updated_at=None,
created_at=datetime(2024, 1, 1),
)
user = SimpleNamespace(id="account-1")
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
user = _account("account-1")
snippet = _snippet()
monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
monkeypatch.setattr(
snippet_workflow_module,
"SnippetService",
@ -203,23 +214,22 @@ def test_restore_published_snippet_workflow_to_draft_success(app, monkeypatch: p
)
api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/snippets/snippet-1/workflows/published-workflow/restore",
method="POST",
):
response = handler(api, snippet=snippet, workflow_id="published-workflow")
response = handler(api, user, snippet, workflow_id="published-workflow")
assert response["result"] == "success"
assert response["hash"] == "restored-hash"
def test_restore_published_snippet_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(id="account-1")
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
user = _account("account-1")
snippet = _snippet()
monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
monkeypatch.setattr(
snippet_workflow_module,
"SnippetService",
@ -231,23 +241,22 @@ def test_restore_published_snippet_workflow_to_draft_not_found(app, monkeypatch:
)
api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/snippets/snippet-1/workflows/published-workflow/restore",
method="POST",
):
with pytest.raises(NotFound):
handler(api, snippet=snippet, workflow_id="published-workflow")
handler(api, user, snippet, workflow_id="published-workflow")
def test_restore_published_snippet_workflow_to_draft_returns_400_for_draft_source(
app, monkeypatch: pytest.MonkeyPatch
) -> None:
user = SimpleNamespace(id="account-1")
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
user = _account("account-1")
snippet = _snippet()
monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
monkeypatch.setattr(
snippet_workflow_module,
"SnippetService",
@ -259,14 +268,14 @@ def test_restore_published_snippet_workflow_to_draft_returns_400_for_draft_sourc
)
api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/snippets/snippet-1/workflows/draft-workflow/restore",
method="POST",
):
with pytest.raises(HTTPException) as exc:
handler(api, snippet=snippet, workflow_id="draft-workflow")
handler(api, user, snippet, workflow_id="draft-workflow")
assert exc.value.code == 400
assert exc.value.description == snippet_workflow_module.RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE
@ -275,10 +284,9 @@ def test_restore_published_snippet_workflow_to_draft_returns_400_for_draft_sourc
def test_restore_published_snippet_workflow_to_draft_returns_400_for_invalid_graph(
app, monkeypatch: pytest.MonkeyPatch
) -> None:
user = SimpleNamespace(id="account-1")
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
user = _account("account-1")
snippet = _snippet()
monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
monkeypatch.setattr(
snippet_workflow_module,
"SnippetService",
@ -290,21 +298,21 @@ def test_restore_published_snippet_workflow_to_draft_returns_400_for_invalid_gra
)
api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/snippets/snippet-1/workflows/published-workflow/restore",
method="POST",
):
with pytest.raises(HTTPException) as exc:
handler(api, snippet=snippet, workflow_id="published-workflow")
handler(api, user, snippet, workflow_id="published-workflow")
assert exc.value.code == 400
assert exc.value.description == "invalid snippet workflow graph"
def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
snippet = _snippet()
monkeypatch.setattr(
snippet_workflow_module,
"SnippetService",
@ -312,7 +320,7 @@ def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch:
)
api = snippet_workflow_module.SnippetWorkflowRunDetailApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/snippets/snippet-1/workflow-runs/run-1"):
with pytest.raises(NotFound, match="Workflow run not found"):
@ -320,7 +328,7 @@ def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch:
def test_draft_node_last_run_raises_not_found_when_execution_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
snippet = _snippet()
draft_workflow = SimpleNamespace(id="workflow-1")
monkeypatch.setattr(
snippet_workflow_module,
@ -332,7 +340,7 @@ def test_draft_node_last_run_raises_not_found_when_execution_missing(app, monkey
)
api = snippet_workflow_module.SnippetDraftNodeLastRunApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/snippets/snippet-1/workflows/draft/nodes/llm-1/last-run"):
with pytest.raises(NotFound, match="Node last run not found"):
@ -354,7 +362,7 @@ def test_workflow_task_stop_uses_queue_flag_and_graph_command(app, monkeypatch:
)
api = snippet_workflow_module.SnippetWorkflowTaskStopApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context("/snippets/snippet-1/workflow-runs/tasks/task-1/stop", method="POST"):
result = handler(api, snippet=SimpleNamespace(id="snippet-1"), task_id="task-1")

View File

@ -1,4 +1,5 @@
import io
from inspect import unwrap
from unittest.mock import MagicMock, patch
import pytest
@ -39,21 +40,19 @@ from controllers.console.workspace.plugin import (
PluginUploadFromPkgApi,
)
from core.plugin.impl.exc import PluginDaemonClientSideError
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from models.account import Account, TenantAccountRole, TenantPluginAutoUpgradeStrategy, TenantPluginPermission
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _account(role: TenantAccountRole = TenantAccountRole.OWNER) -> Account:
account = Account(name="Test User", email="u1@example.com")
account.id = "u1"
account.role = role
return account
@pytest.fixture
def user():
u = MagicMock()
u.id = "u1"
u.is_admin_or_owner = True
return u
return _account()
@pytest.fixture
@ -102,10 +101,9 @@ class TestPluginDebuggingKeyApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.get_debugging_key", return_value="k"),
):
result = method(api)
result = method(api, "t1")
assert result["key"] == "k"
@ -115,13 +113,12 @@ class TestPluginDebuggingKeyApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.get_debugging_key",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api)
result = method(api, "t1")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
@ -134,10 +131,9 @@ class TestPluginListApi:
with (
app.test_request_context("/?page=1&page_size=10"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.list_with_total", return_value=mock_list),
):
result = method(api)
result = method(api, "t1")
assert result["total"] == 1
@ -163,10 +159,9 @@ class TestPluginAssetApi:
with (
app.test_request_context("/?plugin_unique_identifier=p&file_name=a.bin"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.extract_asset", return_value=b"x"),
):
response = method(api)
response = method(api, "t1")
assert response.mimetype == "application/octet-stream"
@ -182,10 +177,9 @@ class TestPluginUploadFromPkgApi:
with (
app.test_request_context("/", data=data, content_type="multipart/form-data"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.upload_pkg", return_value={"ok": True}),
):
result = method(api)
result = method(api, "t1")
assert result["ok"] is True
@ -199,12 +193,11 @@ class TestPluginUploadFromPkgApi:
with (
app.test_request_context("/", data=data, content_type="multipart/form-data"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 0),
patch("controllers.console.workspace.plugin.PluginService.upload_pkg") as upload_pkg_mock,
):
with pytest.raises(ValueError) as exc_info:
method(api)
method(api, "t1")
assert "File size exceeds the maximum allowed size" in str(exc_info.value)
upload_pkg_mock.assert_not_called()
@ -219,12 +212,11 @@ class TestPluginInstallFromPkgApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.install_from_local_pkg", return_value={"ok": True}
),
):
result = method(api)
result = method(api, "t1")
assert result["ok"] is True
@ -238,10 +230,9 @@ class TestPluginUninstallApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.uninstall", return_value=True),
):
result = method(api)
result = method(api, "t1")
assert result["success"] is True
@ -251,7 +242,7 @@ class TestPluginChangePermissionApi:
api = PluginChangePermissionApi()
method = unwrap(api.post)
user = MagicMock(is_admin_or_owner=False)
user = _account(TenantAccountRole.NORMAL)
payload = {
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
@ -260,16 +251,15 @@ class TestPluginChangePermissionApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
):
with pytest.raises(Forbidden):
method(api)
method(api, "t1", user)
def test_change_permission_success(self, app: Flask):
api = PluginChangePermissionApi()
method = unwrap(api.post)
user = MagicMock(is_admin_or_owner=True)
user = _account()
payload = {
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
@ -278,10 +268,9 @@ class TestPluginChangePermissionApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=True),
):
result = method(api)
result = method(api, "t1", user)
assert result["success"] is True
@ -293,10 +282,9 @@ class TestPluginFetchPermissionApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginPermissionService.get_permission", return_value=None),
):
result = method(api)
result = method(api, "t1")
assert result["install_permission"] is not None
@ -308,13 +296,12 @@ class TestPluginFetchDynamicSelectOptionsApi:
with (
app.test_request_context("/?plugin_id=p&provider=x&action=y&parameter=z&provider_type=tool"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
patch(
"controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options",
return_value=[1, 2],
),
):
result = method(api)
result = method(api, "t1", user)
assert result["options"] == [1, 2]
@ -326,16 +313,15 @@ class TestPluginReadmeApi:
with (
app.test_request_context("/?plugin_unique_identifier=p"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.fetch_plugin_readme", return_value="readme"),
):
result = method(api)
result = method(api, "t1")
assert result["readme"] == "readme"
class TestPluginListInstallationsFromIdsApi:
def test_success(self, app: Flask):
def test_success(self, app: Flask, user):
api = PluginListInstallationsFromIdsApi()
method = unwrap(api.post)
@ -343,13 +329,12 @@ class TestPluginListInstallationsFromIdsApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.list_installations_from_ids",
return_value=[{"id": "p1"}],
),
):
result = method(api)
result = method(api, "t1")
assert "plugins" in result
@ -361,18 +346,17 @@ class TestPluginListInstallationsFromIdsApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.list_installations_from_ids",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api)
result = method(api, "t1")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
class TestPluginUploadFromGithubApi:
def test_success(self, app: Flask):
def test_success(self, app: Flask, user):
api = PluginUploadFromGithubApi()
method = unwrap(api.post)
@ -380,12 +364,11 @@ class TestPluginUploadFromGithubApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.upload_pkg_from_github", return_value={"ok": True}
),
):
result = method(api)
result = method(api, "t1")
assert result["ok"] is True
@ -397,13 +380,12 @@ class TestPluginUploadFromGithubApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.upload_pkg_from_github",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api)
result = method(api, "t1")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
@ -424,10 +406,9 @@ class TestPluginUploadFromBundleApi:
data={"bundle": file},
content_type="multipart/form-data",
),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.upload_bundle", return_value={"ok": True}),
):
result = method(api)
result = method(api, "t1")
assert result["ok"] is True
@ -447,12 +428,11 @@ class TestPluginUploadFromBundleApi:
data={"bundle": file},
content_type="multipart/form-data",
),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_BUNDLE_SIZE", 0),
patch("controllers.console.workspace.plugin.PluginService.upload_bundle") as upload_bundle_mock,
):
with pytest.raises(ValueError) as exc_info:
method(api)
method(api, "t1")
assert "File size exceeds the maximum allowed size" in str(exc_info.value)
upload_bundle_mock.assert_not_called()
@ -472,10 +452,9 @@ class TestPluginInstallFromGithubApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.install_from_github", return_value={"ok": True}),
):
result = method(api)
result = method(api, "t1")
assert result["ok"] is True
@ -492,13 +471,12 @@ class TestPluginInstallFromGithubApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.install_from_github",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api)
result = method(api, "t1")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
@ -511,13 +489,12 @@ class TestPluginInstallFromMarketplaceApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.install_from_marketplace_pkg",
return_value={"ok": True},
),
):
result = method(api)
result = method(api, "t1")
assert result["ok"] is True
@ -529,13 +506,12 @@ class TestPluginInstallFromMarketplaceApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.install_from_marketplace_pkg",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api)
result = method(api, "t1")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
@ -546,10 +522,9 @@ class TestPluginFetchMarketplacePkgApi:
with (
app.test_request_context("/?plugin_unique_identifier=p"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.fetch_marketplace_pkg", return_value={"m": 1}),
):
result = method(api)
result = method(api, "t1")
assert "manifest" in result
@ -559,13 +534,12 @@ class TestPluginFetchMarketplacePkgApi:
with (
app.test_request_context("/?plugin_unique_identifier=p"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.fetch_marketplace_pkg",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api)
result = method(api, "t1")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
@ -579,10 +553,9 @@ class TestPluginFetchManifestApi:
with (
app.test_request_context("/?plugin_unique_identifier=p"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.fetch_plugin_manifest", return_value=manifest),
):
result = method(api)
result = method(api, "t1")
assert "manifest" in result
@ -592,13 +565,12 @@ class TestPluginFetchManifestApi:
with (
app.test_request_context("/?plugin_unique_identifier=p"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.fetch_plugin_manifest",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api)
result = method(api, "t1")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
@ -609,10 +581,9 @@ class TestPluginFetchInstallTasksApi:
with (
app.test_request_context("/?page=1&page_size=10"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.fetch_install_tasks", return_value=[{"id": 1}]),
):
result = method(api)
result = method(api, "t1")
assert "tasks" in result
@ -622,13 +593,12 @@ class TestPluginFetchInstallTasksApi:
with (
app.test_request_context("/?page=1&page_size=10"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.fetch_install_tasks",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api)
result = method(api, "t1")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
@ -639,10 +609,9 @@ class TestPluginFetchInstallTaskApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.fetch_install_task", return_value={"id": "x"}),
):
result = method(api, "x")
result = method(api, "t1", "x")
assert "task" in result
@ -652,13 +621,12 @@ class TestPluginFetchInstallTaskApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.fetch_install_task",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api, "t")
result = method(api, "t1", "t")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
@ -669,10 +637,9 @@ class TestPluginDeleteInstallTaskApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.delete_install_task", return_value=True),
):
result = method(api, "x")
result = method(api, "t1", "x")
assert result["success"] is True
@ -682,13 +649,12 @@ class TestPluginDeleteInstallTaskApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.delete_install_task",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api, "t")
result = method(api, "t1", "t")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
@ -699,12 +665,11 @@ class TestPluginDeleteAllInstallTaskItemsApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.delete_all_install_task_items", return_value=True
),
):
result = method(api)
result = method(api, "t1")
assert result["success"] is True
@ -714,13 +679,12 @@ class TestPluginDeleteAllInstallTaskItemsApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.delete_all_install_task_items",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api)
result = method(api, "t1")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
@ -731,10 +695,9 @@ class TestPluginDeleteInstallTaskItemApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginService.delete_install_task_item", return_value=True),
):
result = method(api, "task1", "item1")
result = method(api, "t1", "task1", "item1")
assert result["success"] is True
@ -744,18 +707,17 @@ class TestPluginDeleteInstallTaskItemApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.delete_install_task_item",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api, "task1", "item1")
result = method(api, "t1", "task1", "item1")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
class TestPluginUpgradeFromMarketplaceApi:
def test_success(self, app: Flask):
def test_success(self, app: Flask, user):
api = PluginUpgradeFromMarketplaceApi()
method = unwrap(api.post)
@ -766,13 +728,12 @@ class TestPluginUpgradeFromMarketplaceApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_marketplace",
return_value={"ok": True},
),
):
result = method(api)
result = method(api, "t1")
assert result["ok"] is True
@ -787,13 +748,12 @@ class TestPluginUpgradeFromMarketplaceApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_marketplace",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api)
result = method(api, "t1")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
@ -812,13 +772,12 @@ class TestPluginUpgradeFromGithubApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_github",
return_value={"ok": True},
),
):
result = method(api)
result = method(api, "t1")
assert result["ok"] is True
@ -836,23 +795,20 @@ class TestPluginUpgradeFromGithubApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_github",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api)
result = method(api, "t1")
assert result == ({"code": "plugin_error", "message": "error"}, 400)
class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
def test_success(self, app: Flask):
def test_success(self, app: Flask, user):
api = PluginFetchDynamicSelectOptionsWithCredentialsApi()
method = unwrap(api.post)
user = MagicMock(id="u1", is_admin_or_owner=True)
payload = {
"plugin_id": "p",
"provider": "x",
@ -864,22 +820,19 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
patch(
"controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options_with_credentials",
return_value=[1],
),
):
result = method(api)
result = method(api, "t1", user)
assert result["options"] == [1]
def test_daemon_error(self, app: Flask):
def test_daemon_error(self, app: Flask, user):
api = PluginFetchDynamicSelectOptionsWithCredentialsApi()
method = unwrap(api.post)
user = MagicMock(id="u1", is_admin_or_owner=True)
payload = {
"plugin_id": "p",
"provider": "x",
@ -891,13 +844,12 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
patch(
"controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options_with_credentials",
side_effect=PluginDaemonClientSideError("error"),
),
):
result = method(api)
result = method(api, "t1", user)
assert result == ({"code": "plugin_error", "message": "error"}, 400)
@ -906,7 +858,7 @@ class TestPluginChangePreferencesApi:
api = PluginChangePreferencesApi()
method = unwrap(api.post)
user = MagicMock(is_admin_or_owner=True)
user = _account()
payload = {
"permission": {
@ -924,11 +876,10 @@ class TestPluginChangePreferencesApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=True),
patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.change_strategy", return_value=True),
):
result = method(api)
result = method(api, "t1", user)
assert result["success"] is True
@ -936,7 +887,7 @@ class TestPluginChangePreferencesApi:
api = PluginChangePreferencesApi()
method = unwrap(api.post)
user = MagicMock(is_admin_or_owner=True)
user = _account()
payload = {
"permission": {
@ -954,10 +905,9 @@ class TestPluginChangePreferencesApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=False),
):
result = method(api)
result = method(api, "t1", user)
assert result["success"] is False
@ -982,7 +932,6 @@ class TestPluginFetchPreferencesApi:
with (
app.test_request_context("/"),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch(
"controllers.console.workspace.plugin.PluginPermissionService.get_permission", return_value=permission
),
@ -990,7 +939,7 @@ class TestPluginFetchPreferencesApi:
"controllers.console.workspace.plugin.PluginAutoUpgradeService.get_strategy", return_value=auto_upgrade
),
):
result = method(api)
result = method(api, "t1")
assert "permission" in result
assert "auto_upgrade" in result
@ -1005,10 +954,9 @@ class TestPluginAutoUpgradeExcludePluginApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.exclude_plugin", return_value=True),
):
result = method(api)
result = method(api, "t1")
assert result["success"] is True
@ -1020,9 +968,8 @@ class TestPluginAutoUpgradeExcludePluginApi:
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.exclude_plugin", return_value=False),
):
result = method(api)
result = method(api, "t1")
assert result["success"] is False

View File

@ -1,3 +1,4 @@
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import Mock
@ -5,15 +6,11 @@ import pytest
from werkzeug.exceptions import NotFound
from controllers.console.workspace import snippets as snippets_module
from models.account import Account, TenantAccountRole
from models.snippet import CustomizedSnippet
from services.snippet_dsl_service import ImportStatus, SnippetImportInfo
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture(autouse=True)
def _patch_snippet_service_factory(monkeypatch):
def factory():
@ -34,6 +31,26 @@ class _SessionContext:
return False
def _account(account_id: str = "account-1") -> Account:
account = Account(name="Test User", email=f"{account_id}@example.com")
account.id = account_id
account.role = TenantAccountRole.EDITOR
return account
def _snippet(**overrides) -> CustomizedSnippet:
data = {
"id": "snippet-1",
"tenant_id": "tenant-1",
"name": "Snippet",
"description": "Description",
"type": snippets_module.SnippetType.NODE,
"created_by": "account-1",
}
data.update(overrides)
return CustomizedSnippet(**data)
def test_normalize_snippet_list_query_args_sorts_indexed_values():
query_args = snippets_module.MultiDict(
[
@ -53,21 +70,19 @@ def test_normalize_snippet_list_query_args_sorts_indexed_values():
def test_list_snippets_returns_pagination(app, monkeypatch):
user = SimpleNamespace(id="account-1")
snippets = [SimpleNamespace(id="snippet-1")]
snippets = [_snippet()]
tag_id = "11111111-1111-1111-1111-111111111111"
get_snippets = Mock(return_value=(snippets, 1, False))
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
monkeypatch.setattr(snippets_module.SnippetService, "get_snippets", get_snippets)
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value=[{"id": "snippet-1"}]))
api = snippets_module.CustomizedSnippetsApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context(
f"/workspaces/current/customized-snippets?page=2&limit=10&tag_ids[0]={tag_id}&creator_ids[0]=account-2"
):
response, status_code = handler(api)
response, status_code = handler(api, "tenant-1")
assert status_code == 200
assert response == {
@ -89,10 +104,9 @@ def test_list_snippets_returns_pagination(app, monkeypatch):
def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypatch):
user = SimpleNamespace(id="account-1")
snippet = SimpleNamespace(id="snippet-1")
user = _account("account-1")
snippet = _snippet()
create_snippet = Mock(return_value=snippet)
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
monkeypatch.setattr(snippets_module.SnippetService, "create_snippet", create_snippet)
monkeypatch.setattr(
snippets_module.CreateSnippetPayload,
@ -111,14 +125,14 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypat
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"}))
api = snippets_module.CustomizedSnippetsApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/workspaces/current/customized-snippets",
method="POST",
json={"name": "Snippet", "type": "node", "description": "Description"},
):
response, status_code = handler(api)
response, status_code = handler(api, "tenant-1", user)
assert status_code == 201
assert response == {"id": "snippet-1"}
@ -126,13 +140,12 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypat
def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch):
user = SimpleNamespace(id="account-1")
user = _account("account-1")
create_snippet = Mock()
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
monkeypatch.setattr(snippets_module.SnippetService, "create_snippet", create_snippet)
api = snippets_module.CustomizedSnippetsApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/workspaces/current/customized-snippets",
@ -148,7 +161,7 @@ def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch):
},
},
):
response, status_code = handler(api)
response, status_code = handler(api, "tenant-1", user)
assert status_code == 400
assert "knowledge-retrieval" in response["message"]
@ -156,60 +169,54 @@ def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch):
def test_get_snippet_detail_raises_when_missing(app, monkeypatch):
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
api = snippets_module.CustomizedSnippetDetailApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/workspaces/current/customized-snippets/snippet-1"):
with pytest.raises(NotFound, match="Snippet not found"):
handler(api, snippet_id="snippet-1")
handler(api, "tenant-1", snippet_id="snippet-1")
def test_get_snippet_detail_returns_snippet(app, monkeypatch):
snippet = SimpleNamespace(id="snippet-1")
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
snippet = _snippet()
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"}))
api = snippets_module.CustomizedSnippetDetailApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/workspaces/current/customized-snippets/snippet-1"):
response, status_code = handler(api, snippet_id="snippet-1")
response, status_code = handler(api, "tenant-1", snippet_id="snippet-1")
assert status_code == 200
assert response == {"id": "snippet-1"}
def test_patch_snippet_returns_400_for_empty_payload(app, monkeypatch):
snippet = SimpleNamespace(id="snippet-1")
monkeypatch.setattr(
snippets_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
snippet = _snippet()
user = _account("user-1")
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
api = snippets_module.CustomizedSnippetDetailApi()
handler = _unwrap(api.patch)
handler = unwrap(api.patch)
with app.test_request_context(
"/workspaces/current/customized-snippets/snippet-1",
method="PATCH",
json={},
):
response, status_code = handler(api, snippet_id="snippet-1")
response, status_code = handler(api, "tenant-1", user, snippet_id="snippet-1")
assert status_code == 400
assert response == {"message": "No valid fields to update"}
def test_patch_snippet_updates_and_commits(app, monkeypatch):
user = SimpleNamespace(id="account-1")
snippet = SimpleNamespace(id="snippet-1")
updated_snippet = SimpleNamespace(id="snippet-1", name="New")
user = _account("account-1")
snippet = _snippet()
updated_snippet = _snippet(name="New")
session = SimpleNamespace(merge=Mock(return_value=snippet), commit=Mock())
update_snippet = Mock(return_value=updated_snippet)
@ -217,7 +224,6 @@ def test_patch_snippet_updates_and_commits(app, monkeypatch):
def __init__(self, engine, *args, **kwargs):
super().__init__(engine, *args, session=session, **kwargs)
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
monkeypatch.setattr(snippets_module.SnippetService, "update_snippet", update_snippet)
monkeypatch.setattr(snippets_module, "Session", SessionContext)
@ -225,14 +231,14 @@ def test_patch_snippet_updates_and_commits(app, monkeypatch):
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1", "name": "New"}))
api = snippets_module.CustomizedSnippetDetailApi()
handler = _unwrap(api.patch)
handler = unwrap(api.patch)
with app.test_request_context(
"/workspaces/current/customized-snippets/snippet-1",
method="PATCH",
json={"name": "New", "icon_info": {"icon": "star"}},
):
response, status_code = handler(api, snippet_id="snippet-1")
response, status_code = handler(api, "tenant-1", user, snippet_id="snippet-1")
assert status_code == 200
assert response == {"id": "snippet-1", "name": "New"}
@ -245,7 +251,7 @@ def test_patch_snippet_updates_and_commits(app, monkeypatch):
def test_delete_snippet_deletes_and_commits(app, monkeypatch):
snippet = SimpleNamespace(id="snippet-1")
snippet = _snippet()
session = SimpleNamespace(merge=Mock(return_value=snippet), commit=Mock())
delete_snippet = Mock()
@ -253,17 +259,16 @@ def test_delete_snippet_deletes_and_commits(app, monkeypatch):
def __init__(self, engine, *args, **kwargs):
super().__init__(engine, *args, session=session, **kwargs)
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
monkeypatch.setattr(snippets_module.SnippetService, "delete_snippet", delete_snippet)
monkeypatch.setattr(snippets_module, "Session", SessionContext)
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
api = snippets_module.CustomizedSnippetDetailApi()
handler = _unwrap(api.delete)
handler = unwrap(api.delete)
with app.test_request_context("/workspaces/current/customized-snippets/snippet-1", method="DELETE"):
response, status_code = handler(api, snippet_id="snippet-1")
response, status_code = handler(api, "tenant-1", snippet_id="snippet-1")
assert status_code == 204
assert response == ""
@ -272,7 +277,7 @@ def test_delete_snippet_deletes_and_commits(app, monkeypatch):
def test_export_snippet_returns_yaml_attachment(app, monkeypatch):
snippet = SimpleNamespace(id="snippet-1", name="Snippet One")
snippet = _snippet(name="Snippet One")
export_snippet_dsl = Mock(return_value="version: 0.1.0\nkind: snippet\n")
session = SimpleNamespace()
@ -280,7 +285,6 @@ def test_export_snippet_returns_yaml_attachment(app, monkeypatch):
def __init__(self, engine, *args, **kwargs):
super().__init__(engine, *args, session=session, **kwargs)
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
monkeypatch.setattr(
snippets_module,
@ -291,10 +295,10 @@ def test_export_snippet_returns_yaml_attachment(app, monkeypatch):
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
api = snippets_module.CustomizedSnippetExportApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/workspaces/current/customized-snippets/snippet-1/export?include_secret=true"):
response = handler(api, snippet_id="snippet-1")
response = handler(api, "tenant-1", snippet_id="snippet-1")
assert response.status_code == 200
assert response.get_data(as_text=True) == "version: 0.1.0\nkind: snippet\n"
@ -304,7 +308,7 @@ def test_export_snippet_returns_yaml_attachment(app, monkeypatch):
def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch):
user = SimpleNamespace(id="account-1")
user = _account("account-1")
result = SnippetImportInfo(id="import-1", status=ImportStatus.PENDING, imported_dsl_version="999.0.0")
import_snippet = Mock(return_value=result)
session = SimpleNamespace(commit=Mock())
@ -319,7 +323,6 @@ def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch):
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
monkeypatch.setattr(snippets_module, "Session", _SessionContext)
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
@ -329,14 +332,14 @@ def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch):
)
api = snippets_module.CustomizedSnippetImportApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/workspaces/current/customized-snippets/imports",
method="POST",
json={"mode": "yaml-content", "yaml_content": "kind: snippet"},
):
response, status_code = handler(api)
response, status_code = handler(api, user)
assert status_code == 202
assert response["status"] == ImportStatus.PENDING.value
@ -345,7 +348,7 @@ def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch):
def test_import_snippet_returns_400_for_failed_import(app, monkeypatch):
user = SimpleNamespace(id="account-1")
user = _account("account-1")
result = SnippetImportInfo(id="import-1", status=ImportStatus.FAILED, error="Invalid DSL")
import_snippet = Mock(return_value=result)
session = SimpleNamespace(commit=Mock())
@ -354,7 +357,6 @@ def test_import_snippet_returns_400_for_failed_import(app, monkeypatch):
def __init__(self, engine, *args, **kwargs):
super().__init__(engine, *args, session=session, **kwargs)
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
monkeypatch.setattr(snippets_module, "Session", SessionContext)
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
@ -364,14 +366,14 @@ def test_import_snippet_returns_400_for_failed_import(app, monkeypatch):
)
api = snippets_module.CustomizedSnippetImportApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/workspaces/current/customized-snippets/imports",
method="POST",
json={"mode": "yaml-content", "yaml_content": "kind: snippet"},
):
response, status_code = handler(api)
response, status_code = handler(api, user)
assert status_code == 400
assert response["error"] == "Invalid DSL"
@ -379,7 +381,7 @@ def test_import_snippet_returns_400_for_failed_import(app, monkeypatch):
def test_import_confirm_returns_200_for_completed_import(app, monkeypatch):
user = SimpleNamespace(id="account-1")
user = _account("account-1")
result = SnippetImportInfo(id="import-1", status=ImportStatus.COMPLETED, snippet_id="snippet-1")
confirm_import = Mock(return_value=result)
session = SimpleNamespace(commit=Mock())
@ -388,7 +390,6 @@ def test_import_confirm_returns_200_for_completed_import(app, monkeypatch):
def __init__(self, engine, *args, **kwargs):
super().__init__(engine, *args, session=session, **kwargs)
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
monkeypatch.setattr(snippets_module, "Session", SessionContext)
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
@ -398,13 +399,13 @@ def test_import_confirm_returns_200_for_completed_import(app, monkeypatch):
)
api = snippets_module.CustomizedSnippetImportConfirmApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/workspaces/current/customized-snippets/imports/import-1/confirm",
method="POST",
):
response, status_code = handler(api, import_id="import-1")
response, status_code = handler(api, user, import_id="import-1")
assert status_code == 200
assert response["snippet_id"] == "snippet-1"
@ -413,19 +414,18 @@ def test_import_confirm_returns_200_for_completed_import(app, monkeypatch):
def test_check_dependencies_raises_when_snippet_missing(app, monkeypatch):
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
api = snippets_module.CustomizedSnippetCheckDependenciesApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/workspaces/current/customized-snippets/snippet-1/check-dependencies"):
with pytest.raises(NotFound, match="Snippet not found"):
handler(api, snippet_id="snippet-1")
handler(api, "tenant-1", snippet_id="snippet-1")
def test_check_dependencies_returns_dependency_result(app, monkeypatch):
snippet = SimpleNamespace(id="snippet-1")
snippet = _snippet()
check_dependencies = Mock(
return_value=SimpleNamespace(model_dump=Mock(return_value={"dependencies": [], "missing_dependencies": []}))
)
@ -435,7 +435,6 @@ def test_check_dependencies_returns_dependency_result(app, monkeypatch):
def __init__(self, engine, *args, **kwargs):
super().__init__(engine, *args, session=session, **kwargs)
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
monkeypatch.setattr(snippets_module, "Session", SessionContext)
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
@ -446,10 +445,10 @@ def test_check_dependencies_returns_dependency_result(app, monkeypatch):
)
api = snippets_module.CustomizedSnippetCheckDependenciesApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/workspaces/current/customized-snippets/snippet-1/check-dependencies"):
response, status_code = handler(api, snippet_id="snippet-1")
response, status_code = handler(api, "tenant-1", snippet_id="snippet-1")
assert status_code == 200
assert response == {"dependencies": [], "missing_dependencies": []}
@ -457,18 +456,17 @@ def test_check_dependencies_returns_dependency_result(app, monkeypatch):
def test_increment_use_count_raises_when_snippet_missing(app, monkeypatch):
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
api = snippets_module.CustomizedSnippetUseCountIncrementApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/workspaces/current/customized-snippets/snippet-1/use-count/increment",
method="POST",
):
with pytest.raises(NotFound, match="Snippet not found"):
handler(api, snippet_id="snippet-1")
handler(api, "tenant-1", snippet_id="snippet-1")
def test_increment_use_count_returns_refreshed_count(app, monkeypatch):
@ -487,20 +485,19 @@ def test_increment_use_count_returns_refreshed_count(app, monkeypatch):
return False
increment_use_count = Mock()
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
monkeypatch.setattr(snippets_module.SnippetService, "increment_use_count", increment_use_count)
monkeypatch.setattr(snippets_module, "Session", _SessionContext)
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
api = snippets_module.CustomizedSnippetUseCountIncrementApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/workspaces/current/customized-snippets/snippet-1/use-count/increment",
method="POST",
):
response, status_code = handler(api, snippet_id="snippet-1")
response, status_code = handler(api, "tenant-1", snippet_id="snippet-1")
assert status_code == 200
assert response == {"result": "success", "use_count": 3}

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import builtins
import importlib
from contextlib import ExitStack, contextmanager
from inspect import unwrap
from types import ModuleType, SimpleNamespace
from unittest.mock import MagicMock, patch
@ -12,12 +13,14 @@ import pytest
from flask import Flask
from flask.views import MethodView
from models import Account
from models.account import TenantAccountRole
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
_CONTROLLER_MODULE: ModuleType | None = None
_WRAPS_MODULE: ModuleType | None = None
@contextmanager
@ -69,9 +72,7 @@ def controller_module(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload)
# Ensure decorators that consult deployment edition do not reach the database.
global _WRAPS_MODULE
wraps_module = importlib.import_module("controllers.console.wraps")
_WRAPS_MODULE = wraps_module
monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD")
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
@ -80,46 +81,24 @@ def controller_module(monkeypatch: pytest.MonkeyPatch):
return module
def _mock_account(user_id: str = "user-123") -> SimpleNamespace:
return SimpleNamespace(
id=user_id,
status="active",
is_authenticated=True,
current_tenant_id=None,
is_admin_or_owner=False,
)
def _set_current_account(
monkeypatch: pytest.MonkeyPatch,
controller_module: ModuleType,
user: SimpleNamespace,
tenant_id: str,
) -> None:
def _getter():
return user, tenant_id
user.current_tenant_id = tenant_id
monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter)
if _WRAPS_MODULE is not None:
monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter)
login_module = importlib.import_module("libs.login")
monkeypatch.setattr(login_module, "_get_user", lambda: user)
def _mock_account(user_id: str = "user-123") -> Account:
user = Account(name="Test User", email=f"{user_id}@example.com")
user.id = user_id
user.role = TenantAccountRole.NORMAL
return user
def test_tool_provider_list_calls_service_with_query(
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
service_mock = MagicMock(return_value=[{"provider": "builtin"}])
monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock)
with app.test_request_context("/workspaces/current/tool-providers?type=builtin"):
response = controller_module.ToolProviderListApi().get()
api = controller_module.ToolProviderListApi()
response = unwrap(api.get)(api, "tenant-456", user)
assert response == [{"provider": "builtin"}]
service_mock.assert_called_once_with(user.id, "tenant-456", "builtin")
@ -129,7 +108,6 @@ def test_builtin_provider_add_passes_payload(
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
service_mock = MagicMock(return_value={"status": "ok"})
monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock)
@ -145,7 +123,8 @@ def test_builtin_provider_add_passes_payload(
method="POST",
json=payload,
):
response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai")
api = controller_module.ToolBuiltinProviderAddApi()
response = unwrap(api.post)(api, "tenant-456", user, provider="openai")
assert response == {"status": "ok"}
service_mock.assert_called_once_with(
@ -161,7 +140,6 @@ def test_builtin_provider_add_passes_payload(
def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-789")
_set_current_account(monkeypatch, controller_module, user, "tenant-789")
service_mock = MagicMock(return_value=[{"name": "tool-a"}])
monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock)
@ -171,7 +149,8 @@ def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch:
"/workspaces/current/tool-provider/builtin/my-provider/tools",
method="GET",
):
response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider")
api = controller_module.ToolBuiltinProviderListToolsApi()
response = unwrap(api.get)(api, "tenant-789", provider="my-provider")
assert response == [{"name": "tool-a"}]
service_mock.assert_called_once_with("tenant-789", "my-provider")
@ -179,12 +158,12 @@ def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch:
def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-9")
_set_current_account(monkeypatch, controller_module, user, "tenant-9")
service_mock = MagicMock(return_value={"info": True})
monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock)
with app.test_request_context("/info", method="GET"):
resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo")
api = controller_module.ToolBuiltinProviderInfoApi()
resp = unwrap(api.get)(api, "tenant-9", provider="demo")
assert resp == {"info": True}
service_mock.assert_called_once_with("tenant-9", "demo")
@ -192,7 +171,6 @@ def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: p
def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-cred")
_set_current_account(monkeypatch, controller_module, user, "tenant-cred")
service_mock = MagicMock(return_value=[{"cred": 1}])
monkeypatch.setattr(
controller_module.BuiltinToolManageService,
@ -201,7 +179,8 @@ def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeyp
)
with app.test_request_context("/creds", method="GET"):
resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo")
api = controller_module.ToolBuiltinProviderGetCredentialsApi()
resp = unwrap(api.get)(api, "tenant-cred", user, provider="demo")
assert resp == [{"cred": 1}]
service_mock.assert_called_once_with(
@ -214,12 +193,12 @@ def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeyp
def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-10")
service_mock = MagicMock(return_value={"schema": "ok"})
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock)
with app.test_request_context("/remote?url=https://example.com/"):
resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get()
api = controller_module.ToolApiProviderGetRemoteSchemaApi()
resp = unwrap(api.get)(api, "tenant-10", user)
assert resp == {"schema": "ok"}
service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/")
@ -227,12 +206,12 @@ def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypat
def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-11")
service_mock = MagicMock(return_value=[{"tool": "t"}])
monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock)
with app.test_request_context("/tools?provider=foo"):
resp = controller_module.ToolApiProviderListToolsApi().get()
api = controller_module.ToolApiProviderListToolsApi()
resp = unwrap(api.get)(api, "tenant-11", user)
assert resp == [{"tool": "t"}]
service_mock.assert_called_once_with(user.id, "tenant-11", "foo")
@ -240,12 +219,12 @@ def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch:
def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-12")
service_mock = MagicMock(return_value={"provider": "foo"})
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock)
with app.test_request_context("/get?provider=foo"):
resp = controller_module.ToolApiProviderGetApi().get()
api = controller_module.ToolApiProviderGetApi()
resp = unwrap(api.get)(api, "tenant-12", user)
assert resp == {"provider": "foo"}
service_mock.assert_called_once_with(user.id, "tenant-12", "foo")
@ -253,7 +232,6 @@ def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.Mon
def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-13")
_set_current_account(monkeypatch, controller_module, user, "tenant-13")
service_mock = MagicMock(return_value={"schema": True})
monkeypatch.setattr(
controller_module.BuiltinToolManageService,
@ -262,9 +240,8 @@ def test_builtin_provider_credentials_schema_get(app: Flask, controller_module,
)
with app.test_request_context("/schema", method="GET"):
resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get(
provider="demo", credential_type="api-key"
)
api = controller_module.ToolBuiltinProviderCredentialsSchemaApi()
resp = unwrap(api.get)(api, "tenant-13", provider="demo", credential_type="api-key")
assert resp == {"schema": True}
service_mock.assert_called_once()
@ -272,7 +249,6 @@ def test_builtin_provider_credentials_schema_get(app: Flask, controller_module,
def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-wf")
tool_service = MagicMock(return_value={"wf": 1})
monkeypatch.setattr(
controller_module.WorkflowToolManageService,
@ -282,7 +258,8 @@ def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatc
tool_id = "00000000-0000-0000-0000-000000000001"
with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"):
resp = controller_module.ToolWorkflowProviderGetApi().get()
api = controller_module.ToolWorkflowProviderGetApi()
resp = unwrap(api.get)(api, "tenant-wf", user)
assert resp == {"wf": 1}
tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id)
@ -290,7 +267,6 @@ def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatc
def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-wf2")
service_mock = MagicMock(return_value={"app": 1})
monkeypatch.setattr(
controller_module.WorkflowToolManageService,
@ -300,7 +276,8 @@ def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch
app_id = "00000000-0000-0000-0000-000000000002"
with app.test_request_context(f"/workflow?workflow_app_id={app_id}"):
resp = controller_module.ToolWorkflowProviderGetApi().get()
api = controller_module.ToolWorkflowProviderGetApi()
resp = unwrap(api.get)(api, "tenant-wf2", user)
assert resp == {"app": 1}
service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id)
@ -308,13 +285,13 @@ def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch
def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-wf3")
service_mock = MagicMock(return_value=[{"id": 1}])
monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock)
tool_id = "00000000-0000-0000-0000-000000000003"
with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"):
resp = controller_module.ToolWorkflowProviderListToolApi().get()
api = controller_module.ToolWorkflowProviderListToolApi()
resp = unwrap(api.get)(api, "tenant-wf3", user)
assert resp == [{"id": 1}]
service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id)
@ -322,7 +299,6 @@ def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch
def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-bt")
provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"})
monkeypatch.setattr(
@ -332,14 +308,14 @@ def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.M
)
with app.test_request_context("/tools/builtin"):
resp = controller_module.ToolBuiltinListApi().get()
api = controller_module.ToolBuiltinListApi()
resp = unwrap(api.get)(api, "tenant-bt", user)
assert resp == [{"name": "builtin"}]
def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-tenant-api")
_set_current_account(monkeypatch, controller_module, user, "tenant-api")
provider = SimpleNamespace(to_dict=lambda: {"name": "api"})
monkeypatch.setattr(
@ -349,14 +325,14 @@ def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.Monke
)
with app.test_request_context("/tools/api"):
resp = controller_module.ToolApiListApi().get()
api = controller_module.ToolApiListApi()
resp = unwrap(api.get)(api, "tenant-api")
assert resp == [{"name": "api"}]
def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account()
_set_current_account(monkeypatch, controller_module, user, "tenant-wf4")
provider = SimpleNamespace(to_dict=lambda: {"name": "wf"})
monkeypatch.setattr(
@ -366,18 +342,18 @@ def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.
)
with app.test_request_context("/tools/workflow"):
resp = controller_module.ToolWorkflowListApi().get()
api = controller_module.ToolWorkflowListApi()
resp = unwrap(api.get)(api, "tenant-wf4", user)
assert resp == [{"name": "wf"}]
def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
user = _mock_account("user-label")
_set_current_account(monkeypatch, controller_module, user, "tenant-labels")
monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"])
with app.test_request_context("/tool-labels"):
resp = controller_module.ToolLabelsApi().get()
api = controller_module.ToolLabelsApi()
resp = unwrap(api.get)(api)
assert resp == ["a", "b"]