diff --git a/api/controllers/console/app/agent_app_workspace.py b/api/controllers/console/app/agent_app_workspace.py index 888fa8d40c..0699d08bf7 100644 --- a/api/controllers/console/app/agent_app_workspace.py +++ b/api/controllers/console/app/agent_app_workspace.py @@ -24,9 +24,9 @@ from controllers.common.schema import ( ) from controllers.console import console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id from fields.base import ResponseModel -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models.model import App, AppMode from services.agent_app_workspace_service import ( AgentAppWorkspaceService, @@ -142,8 +142,8 @@ class AgentAppWorkspaceListResource(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT]) - def get(self, app_model: App): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, app_model: App): query = query_params_from_request(AgentWorkspaceListQuery) try: result = AgentAppWorkspaceService().list_files( @@ -167,8 +167,8 @@ class AgentAppWorkspacePreviewResource(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT]) - def get(self, app_model: App): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, app_model: App): query = query_params_from_request(AgentWorkspaceFileQuery) try: result = AgentAppWorkspaceService().preview( @@ -194,8 +194,8 @@ class AgentAppWorkspaceDownloadResource(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT]) - def get(self, app_model: App): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, app_model: App): query = query_params_from_request(AgentWorkspaceFileQuery) try: result = AgentAppWorkspaceService().download( @@ -228,8 +228,8 @@ class WorkflowAgentWorkspaceListResource(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - def get(self, app_model: App, workflow_run_id: UUID, node_id: str): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, app_model: App, workflow_run_id: UUID, node_id: str): query = query_params_from_request(WorkflowAgentWorkspaceListQuery) try: result = WorkflowAgentWorkspaceService().list_files( @@ -264,8 +264,8 @@ class WorkflowAgentWorkspacePreviewResource(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - def get(self, app_model: App, workflow_run_id: UUID, node_id: str): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, app_model: App, workflow_run_id: UUID, node_id: str): query = query_params_from_request(WorkflowAgentWorkspaceFileQuery) try: result = WorkflowAgentWorkspaceService().preview( @@ -302,8 +302,8 @@ class WorkflowAgentWorkspaceDownloadResource(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - def get(self, app_model: App, workflow_run_id: UUID, node_id: str): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, app_model: App, workflow_run_id: UUID, node_id: str): query = query_params_from_request(WorkflowAgentWorkspaceFileQuery) try: result = WorkflowAgentWorkspaceService().download( diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index cb26963dbf..5d24506f21 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -25,6 +25,7 @@ from controllers.console.wraps import ( account_initialization_required, edit_permission_required, setup_required, + with_current_tenant_id, with_current_user, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError @@ -58,7 +59,7 @@ from graphon.variables import SecretVariable, SegmentType, VariableBase from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, dump_response, to_timestamp, uuid_value -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models import Account, App from models.model import AppMode from models.workflow import Workflow @@ -1568,7 +1569,8 @@ class WorkflowOnlineUsersApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): + @with_current_tenant_id + def post(self, current_tenant_id: str): args = WorkflowOnlineUsersPayload.model_validate(console_ns.payload or {}) app_ids = args.app_ids @@ -1578,7 +1580,6 @@ class WorkflowOnlineUsersApi(Resource): if not app_ids: return {"data": []} - _, current_tenant_id = current_account_with_tenant() workflow_service = WorkflowService() accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id) ordered_accessible_app_ids = [app_id for app_id in app_ids if app_id in accessible_app_ids] diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 8e453f96dd..4ee8ae176a 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -22,6 +22,8 @@ from controllers.console.wraps import ( enterprise_license_required, is_admin_or_owner_required, setup_required, + with_current_tenant_id, + with_current_user, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner @@ -36,9 +38,9 @@ from fields.base import ResponseModel from fields.dataset_fields import DatasetDetailResponse from graphon.model_runtime.entities.model_entities import ModelType from libs.helper import build_icon_url, dump_response, to_timestamp -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from libs.url_utils import normalize_api_base_url -from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile +from models import Account, ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum from models.enums import ApiTokenType, SegmentStatus from models.provider_ids import ModelProviderID @@ -389,8 +391,9 @@ class DatasetListApi(Resource): @login_required @account_initialization_required @enterprise_license_required - def get(self): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account): # Convert query parameters to dict, handling list parameters correctly query_params: dict[str, str | list[str]] = dict(request.args.to_dict()) # Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value) @@ -471,9 +474,10 @@ class DatasetListApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def post(self): + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account): payload = DatasetCreatePayload.model_validate(console_ns.payload or {}) - current_user, current_tenant_id = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: @@ -512,8 +516,9 @@ class DatasetApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -566,14 +571,15 @@ class DatasetApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def patch(self, dataset_id: UUID): + @with_current_user + @with_current_tenant_id + def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") payload = DatasetUpdatePayload.model_validate(console_ns.payload or {}) - current_user, current_tenant_id = current_account_with_tenant() # check embedding model setting if ( payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY @@ -614,9 +620,9 @@ class DatasetApi(Resource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(204, "Dataset deleted successfully") - def delete(self, dataset_id: UUID): + @with_current_user + def delete(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) - current_user, _ = current_account_with_tenant() if not (current_user.has_edit_permission or current_user.is_dataset_operator): raise Forbidden() @@ -664,8 +670,8 @@ class DatasetQueryApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -704,10 +710,10 @@ class DatasetIndexingEstimateApi(Resource): @login_required @account_initialization_required @console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__]) - def post(self): + @with_current_tenant_id + def post(self, current_tenant_id: str): payload = IndexingEstimatePayload.model_validate(console_ns.payload or {}) args = payload.model_dump() - _, current_tenant_id = current_account_with_tenant() # validate args DocumentService.estimate_args_validate(args) extract_settings = [] @@ -804,8 +810,8 @@ class DatasetRelatedAppListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -840,8 +846,8 @@ class DatasetIndexingStatusApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, current_tenant_id: str, dataset_id: UUID): dataset_id_str = str(dataset_id) documents = db.session.scalars( select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id) @@ -898,8 +904,8 @@ class DatasetApiKeyApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, current_tenant_id: str): keys = db.session.scalars( select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id) ).all() @@ -911,9 +917,8 @@ class DatasetApiKeyApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, current_tenant_id: str): current_key_count = ( db.session.scalar( select(func.count(ApiToken.id)).where( @@ -952,8 +957,8 @@ class DatasetApiDeleteApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def delete(self, api_key_id: UUID): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def delete(self, current_tenant_id: str, api_key_id: UUID): api_key_id_str = str(api_key_id) key = db.session.scalar( select(ApiToken) @@ -1079,8 +1084,8 @@ class DatasetPermissionUserListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 5d6a779b5a..ebc3b92dd6 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -29,6 +29,8 @@ from controllers.console.wraps import ( account_initialization_required, edit_permission_required, setup_required, + with_current_tenant_id, + with_current_user, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager @@ -46,7 +48,7 @@ from fields.workflow_run_fields import ( from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty, dump_response -from libs.login import current_account_with_tenant, current_user, login_required +from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline from models.model import EndUser @@ -187,16 +189,14 @@ class DraftRagPipelineApi(Resource): @setup_required @login_required @account_initialization_required + @with_current_user @get_rag_pipeline @edit_permission_required @console_ns.response(200, "Success", console_ns.models[RagPipelineWorkflowSyncResponse.__name__]) - def post(self, pipeline: Pipeline): + def post(self, current_user: Account, pipeline: Pipeline): """ Sync draft workflow """ - # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() - content_type = request.headers.get("Content-Type", "") if "application/json" in content_type: @@ -247,15 +247,13 @@ class RagPipelineDraftRunIterationNodeApi(Resource): @setup_required @login_required @account_initialization_required + @with_current_user @get_rag_pipeline @edit_permission_required - def post(self, pipeline: Pipeline, node_id: str): + def post(self, current_user: Account, pipeline: Pipeline, node_id: str): """ Run draft workflow iteration node """ - # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() - payload = NodeRunPayload.model_validate(console_ns.payload or {}) args = payload.model_dump(exclude_none=True) @@ -283,14 +281,12 @@ class RagPipelineDraftRunLoopNodeApi(Resource): @login_required @account_initialization_required @edit_permission_required + @with_current_user @get_rag_pipeline - def post(self, pipeline: Pipeline, node_id: str): + def post(self, current_user: Account, pipeline: Pipeline, node_id: str): """ Run draft workflow loop node """ - # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() - payload = NodeRunPayload.model_validate(console_ns.payload or {}) args = payload.model_dump(exclude_none=True) @@ -318,14 +314,12 @@ class DraftRagPipelineRunApi(Resource): @login_required @account_initialization_required @edit_permission_required + @with_current_user @get_rag_pipeline - def post(self, pipeline: Pipeline): + def post(self, current_user: Account, pipeline: Pipeline): """ Run draft workflow """ - # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() - payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}) args = payload.model_dump() @@ -350,14 +344,12 @@ class PublishedRagPipelineRunApi(Resource): @login_required @account_initialization_required @edit_permission_required + @with_current_user @get_rag_pipeline - def post(self, pipeline: Pipeline): + def post(self, current_user: Account, pipeline: Pipeline): """ Run published workflow """ - # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() - payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {}) args = payload.model_dump(exclude_none=True) streaming = payload.response_mode == "streaming" @@ -383,14 +375,12 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource): @login_required @account_initialization_required @edit_permission_required + @with_current_user @get_rag_pipeline - def post(self, pipeline: Pipeline, node_id: str): + def post(self, current_user: Account, pipeline: Pipeline, node_id: str): """ Run rag pipeline datasource """ - # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() - payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {}) rag_pipeline_service = RagPipelineService() @@ -416,14 +406,12 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): @login_required @edit_permission_required @account_initialization_required + @with_current_user @get_rag_pipeline - def post(self, pipeline: Pipeline, node_id: str): + def post(self, current_user: Account, pipeline: Pipeline, node_id: str): """ Run rag pipeline datasource """ - # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() - payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {}) rag_pipeline_service = RagPipelineService() @@ -454,14 +442,12 @@ class RagPipelineDraftNodeRunApi(Resource): @login_required @edit_permission_required @account_initialization_required + @with_current_user @get_rag_pipeline - def post(self, pipeline: Pipeline, node_id: str): + def post(self, current_user: Account, pipeline: Pipeline, node_id: str): """ Run draft workflow node """ - # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() - payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {}) inputs = payload.inputs @@ -485,14 +471,12 @@ class RagPipelineTaskStopApi(Resource): @login_required @edit_permission_required @account_initialization_required + @with_current_user @get_rag_pipeline - def post(self, pipeline: Pipeline, task_id: str): + def post(self, current_user: Account, pipeline: Pipeline, task_id: str): """ Stop workflow task """ - # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() - AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) return {"result": "success"} @@ -532,13 +516,12 @@ class PublishedRagPipelineApi(Resource): @login_required @account_initialization_required @edit_permission_required + @with_current_user @get_rag_pipeline - def post(self, pipeline: Pipeline): + def post(self, current_user: Account, pipeline: Pipeline): """ Publish workflow """ - # The role of the current user in the ta table must be admin, owner, or editor - current_user, _ = current_account_with_tenant() rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.publish_workflow( session=db.session, # type: ignore[reportArgumentType,arg-type] @@ -609,13 +592,12 @@ class PublishedAllRagPipelineApi(Resource): @login_required @account_initialization_required @edit_permission_required + @with_current_user @get_rag_pipeline - def get(self, pipeline: Pipeline): + def get(self, current_user: Account, pipeline: Pipeline): """ Get published workflows """ - current_user, _ = current_account_with_tenant() - query = WorkflowListQuery.model_validate(request.args.to_dict()) page = query.page @@ -655,9 +637,9 @@ class RagPipelineDraftWorkflowRestoreApi(Resource): @login_required @account_initialization_required @edit_permission_required + @with_current_user @get_rag_pipeline - def post(self, pipeline: Pipeline, workflow_id: str): - current_user, _ = current_account_with_tenant() + def post(self, current_user: Account, pipeline: Pipeline, workflow_id: str): rag_pipeline_service = RagPipelineService() try: @@ -689,14 +671,12 @@ class RagPipelineByIdApi(Resource): @login_required @account_initialization_required @edit_permission_required + @with_current_user @get_rag_pipeline - def patch(self, pipeline: Pipeline, workflow_id: str): + def patch(self, current_user: Account, pipeline: Pipeline, workflow_id: str): """ Update workflow attributes """ - # Check permission - current_user, _ = current_account_with_tenant() - payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {}) update_data = payload.model_dump(exclude_unset=True) @@ -925,8 +905,8 @@ class DatasourceListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, current_tenant_id: str): return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(current_tenant_id)) @@ -961,9 +941,8 @@ class RagPipelineTransformApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, dataset_id: UUID): - current_user, _ = current_account_with_tenant() - + @with_current_user + def post(self, current_user: Account, dataset_id: UUID): if not (current_user.has_edit_permission or current_user.is_dataset_operator): raise Forbidden() @@ -984,13 +963,13 @@ class RagPipelineDatasourceVariableApi(Resource): @setup_required @login_required @account_initialization_required + @with_current_user @get_rag_pipeline @edit_permission_required - def post(self, pipeline: Pipeline): + def post(self, current_user: Account, pipeline: Pipeline): """ Set datasource variables """ - current_user, _ = current_account_with_tenant() args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump() rag_pipeline_service = RagPipelineService() diff --git a/api/controllers/console/snippets/snippet_workflow.py b/api/controllers/console/snippets/snippet_workflow.py index 5e2421275b..59608afc55 100644 --- a/api/controllers/console/snippets/snippet_workflow.py +++ b/api/controllers/console/snippets/snippet_workflow.py @@ -30,6 +30,7 @@ from controllers.console.wraps import ( account_initialization_required, edit_permission_required, setup_required, + with_current_user, ) from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom @@ -45,6 +46,7 @@ from graphon.graph_engine.manager import GraphEngineManager from libs import helper from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required +from models import Account from models.snippet import CustomizedSnippet from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.snippet_generate_service import SnippetGenerateService @@ -158,12 +160,11 @@ class SnippetDraftWorkflowApi(Resource): @setup_required @login_required @account_initialization_required + @with_current_user @get_snippet @edit_permission_required - def post(self, snippet: CustomizedSnippet): + def post(self, current_user: Account, snippet: CustomizedSnippet): """Sync draft workflow for snippet.""" - current_user, _ = current_account_with_tenant() - payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {}) try: @@ -239,11 +240,11 @@ class SnippetPublishedWorkflowApi(Resource): @setup_required @login_required @account_initialization_required + @with_current_user @get_snippet @edit_permission_required - def post(self, snippet: CustomizedSnippet): + def post(self, current_user: Account, snippet: CustomizedSnippet): """Publish snippet workflow.""" - current_user, _ = current_account_with_tenant() snippet_service = _snippet_service() with Session(db.engine) as session: @@ -331,11 +332,11 @@ class SnippetDraftWorkflowRestoreApi(Resource): @setup_required @login_required @account_initialization_required + @with_current_user @get_snippet @edit_permission_required - def post(self, snippet: CustomizedSnippet, workflow_id: str): + def post(self, current_user: Account, snippet: CustomizedSnippet, workflow_id: str): """Restore a published snippet workflow version into the draft workflow.""" - current_user, _ = current_account_with_tenant() snippet_service = _snippet_service() try: @@ -455,16 +456,16 @@ class SnippetDraftNodeRunApi(Resource): @setup_required @login_required @account_initialization_required + @with_current_user @get_snippet @edit_permission_required - def post(self, snippet: CustomizedSnippet, node_id: str): + def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str): """ Run a single node in snippet draft workflow. Executes a specific node with provided inputs for single-step debugging. Returns the node execution result including status, outputs, and timing. """ - current_user, _ = current_account_with_tenant() payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {}) user_inputs = payload.inputs @@ -539,16 +540,16 @@ class SnippetDraftRunIterationNodeApi(Resource): @setup_required @login_required @account_initialization_required + @with_current_user @get_snippet @edit_permission_required - def post(self, snippet: CustomizedSnippet, node_id: str): + def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str): """ Run a draft workflow iteration node for snippet. Iteration nodes execute their internal sub-graph multiple times over an input list. Returns an SSE event stream with iteration progress and results. """ - current_user, _ = current_account_with_tenant() args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) try: @@ -580,16 +581,16 @@ class SnippetDraftRunLoopNodeApi(Resource): @setup_required @login_required @account_initialization_required + @with_current_user @get_snippet @edit_permission_required - def post(self, snippet: CustomizedSnippet, node_id: str): + def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str): """ Run a draft workflow loop node for snippet. Loop nodes execute their internal sub-graph repeatedly until a condition is met. Returns an SSE event stream with loop progress and results. """ - current_user, _ = current_account_with_tenant() args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {}) try: @@ -619,17 +620,16 @@ class SnippetDraftWorkflowRunApi(Resource): @setup_required @login_required @account_initialization_required + @with_current_user @get_snippet @edit_permission_required - def post(self, snippet: CustomizedSnippet): + def post(self, current_user: Account, snippet: CustomizedSnippet): """ Run draft workflow for snippet. Executes the snippet's draft workflow with the provided inputs and returns an SSE event stream with execution progress and results. """ - current_user, _ = current_account_with_tenant() - payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {}) args = payload.model_dump(exclude_none=True) diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index c41bf99563..a30a54b945 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -13,13 +13,19 @@ from controllers.common.fields import SuccessResponse from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required -from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + is_admin_or_owner_required, + setup_required, + with_current_tenant_id, + with_current_user, +) from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.plugin_service import PluginService from fields.base import ResponseModel from graphon.model_runtime.utils.encoders import jsonable_encoder -from libs.login import current_account_with_tenant, login_required -from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission +from libs.login import login_required +from models.account import Account, TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService from services.plugin.plugin_parameter_service import PluginParameterService from services.plugin.plugin_permission_service import PluginPermissionService @@ -200,9 +206,8 @@ class PluginDebuggingKeyApi(Resource): @login_required @account_initialization_required @plugin_permission_required(debug_required=True) - def get(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str): try: return { "key": PluginService.get_debugging_key(tenant_id), @@ -219,8 +224,8 @@ class PluginListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str): args = ParserList.model_validate(request.args.to_dict(flat=True)) try: plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) @@ -253,9 +258,8 @@ class PluginListInstallationsFromIdsApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str): args = ParserLatest.model_validate(console_ns.payload) try: @@ -288,10 +292,10 @@ class PluginAssetApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): + @with_current_tenant_id + def get(self, tenant_id: str): args = ParserAsset.model_validate(request.args.to_dict(flat=True)) - _, tenant_id = current_account_with_tenant() try: binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name) return send_file(io.BytesIO(binary), mimetype="application/octet-stream") @@ -305,9 +309,8 @@ class PluginUploadFromPkgApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def post(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str): file = request.files["pkg"] content = _read_upload_content(file, dify_config.PLUGIN_MAX_PACKAGE_SIZE) try: @@ -325,9 +328,8 @@ class PluginUploadFromGithubApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def post(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str): args = ParserGithubUpload.model_validate(console_ns.payload) try: @@ -344,9 +346,8 @@ class PluginUploadFromBundleApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def post(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str): file = request.files["bundle"] content = _read_upload_content(file, dify_config.PLUGIN_MAX_BUNDLE_SIZE) try: @@ -364,8 +365,8 @@ class PluginInstallFromPkgApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def post(self): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def post(self, tenant_id: str): args = ParserPluginIdentifiers.model_validate(console_ns.payload) try: @@ -383,9 +384,8 @@ class PluginInstallFromGithubApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def post(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str): args = ParserGithubInstall.model_validate(console_ns.payload) try: @@ -409,9 +409,8 @@ class PluginInstallFromMarketplaceApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def post(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str): args = ParserPluginIdentifiers.model_validate(console_ns.payload) try: @@ -429,8 +428,8 @@ class PluginFetchMarketplacePkgApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def get(self): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str): args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) try: @@ -453,9 +452,8 @@ class PluginFetchManifestApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def get(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str): args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) try: @@ -473,9 +471,8 @@ class PluginFetchInstallTasksApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def get(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str): args = ParserTasks.model_validate(request.args.to_dict(flat=True)) try: @@ -490,9 +487,8 @@ class PluginFetchInstallTaskApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def get(self, task_id: str): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str, task_id: str): try: return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)}) except PluginDaemonClientSideError as e: @@ -506,9 +502,8 @@ class PluginDeleteInstallTaskApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def post(self, task_id: str): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str, task_id: str): try: return {"success": PluginService.delete_install_task(tenant_id, task_id)} except PluginDaemonClientSideError as e: @@ -522,9 +517,8 @@ class PluginDeleteAllInstallTaskItemsApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def post(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str): try: return {"success": PluginService.delete_all_install_task_items(tenant_id)} except PluginDaemonClientSideError as e: @@ -538,9 +532,8 @@ class PluginDeleteInstallTaskItemApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def post(self, task_id: str, identifier: str): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str, task_id: str, identifier: str): try: return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)} except PluginDaemonClientSideError as e: @@ -554,9 +547,8 @@ class PluginUpgradeFromMarketplaceApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def post(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str): args = ParserMarketplaceUpgrade.model_validate(console_ns.payload) try: @@ -576,9 +568,8 @@ class PluginUpgradeFromGithubApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def post(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str): args = ParserGithubUpgrade.model_validate(console_ns.payload) try: @@ -604,11 +595,10 @@ class PluginUninstallApi(Resource): @login_required @account_initialization_required @plugin_permission_required(install_required=True) - def post(self): + @with_current_tenant_id + def post(self, tenant_id: str): args = ParserUninstall.model_validate(console_ns.payload) - _, tenant_id = current_account_with_tenant() - try: return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)} except PluginDaemonClientSideError as e: @@ -622,16 +612,14 @@ class PluginChangePermissionApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): - current_user, current_tenant_id = current_account_with_tenant() - user = current_user + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account): if not user.is_admin_or_owner: raise Forbidden() args = ParserPermissionChange.model_validate(console_ns.payload) - tenant_id = current_tenant_id - return { "success": PluginPermissionService.change_permission( tenant_id, args.install_permission, args.debug_permission @@ -644,9 +632,8 @@ class PluginFetchPermissionApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str): permission = PluginPermissionService.get_permission(tenant_id) if not permission: return jsonable_encoder( @@ -671,16 +658,15 @@ class PluginFetchDynamicSelectOptionsApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def get(self): - current_user, tenant_id = current_account_with_tenant() - user_id = current_user.id - + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, current_user: Account): args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) try: options = PluginParameterService.get_dynamic_select_options( tenant_id=tenant_id, - user_id=user_id, + user_id=current_user.id, plugin_id=args.plugin_id, provider=args.provider, action=args.action, @@ -701,17 +687,16 @@ class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, current_user: Account): """Fetch dynamic options using credentials directly (for edit mode).""" - current_user, tenant_id = current_account_with_tenant() - user_id = current_user.id - args = ParserDynamicOptionsWithCredentials.model_validate(console_ns.payload) try: options = PluginParameterService.get_dynamic_select_options_with_credentials( tenant_id=tenant_id, - user_id=user_id, + user_id=current_user.id, plugin_id=args.plugin_id, provider=args.provider, action=args.action, @@ -731,8 +716,9 @@ class PluginChangePreferencesApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): - user, tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account): if not user.is_admin_or_owner: raise Forbidden() @@ -780,9 +766,8 @@ class PluginFetchPreferencesApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str): permission = PluginPermissionService.get_permission(tenant_id) permission_dict = { "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, @@ -820,10 +805,9 @@ class PluginAutoUpgradeExcludePluginApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): + @with_current_tenant_id + def post(self, tenant_id: str): # exclude one single plugin - _, tenant_id = current_account_with_tenant() - args = ParserExcludePlugin.model_validate(console_ns.payload) return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)}) @@ -835,8 +819,8 @@ class PluginReadmeApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str): args = ParserReadme.model_validate(request.args.to_dict(flat=True)) return jsonable_encoder( {"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)} diff --git a/api/controllers/console/workspace/snippets.py b/api/controllers/console/workspace/snippets.py index 509dcd4584..4bec22e091 100644 --- a/api/controllers/console/workspace/snippets.py +++ b/api/controllers/console/workspace/snippets.py @@ -21,10 +21,13 @@ from controllers.console.wraps import ( account_initialization_required, edit_permission_required, setup_required, + with_current_tenant_id, + with_current_user, ) from extensions.ext_database import db from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from models.snippet import SnippetType from services.app_dsl_service import ImportStatus from services.snippet_dsl_service import SnippetDslService @@ -91,10 +94,9 @@ class CustomizedSnippetsApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): + @with_current_tenant_id + def get(self, current_tenant_id: str): """List customized snippets with pagination and search.""" - _, current_tenant_id = current_account_with_tenant() - query = SnippetListQuery.model_validate(_normalize_snippet_list_query_args(request.args)) snippet_service = _snippet_service() @@ -124,10 +126,10 @@ class CustomizedSnippetsApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self): + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account): """Create a new customized snippet.""" - current_user, current_tenant_id = current_account_with_tenant() - payload = CreateSnippetPayload.model_validate(console_ns.payload or {}) try: @@ -163,10 +165,9 @@ class CustomizedSnippetDetailApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, snippet_id: str): + @with_current_tenant_id + def get(self, current_tenant_id: str, snippet_id: str): """Get customized snippet details.""" - _, current_tenant_id = current_account_with_tenant() - snippet_service = _snippet_service() snippet = snippet_service.get_snippet_by_id( snippet_id=str(snippet_id), @@ -187,10 +188,10 @@ class CustomizedSnippetDetailApi(Resource): @login_required @account_initialization_required @edit_permission_required - def patch(self, snippet_id: str): + @with_current_user + @with_current_tenant_id + def patch(self, current_tenant_id: str, current_user: Account, snippet_id: str): """Update customized snippet.""" - current_user, current_tenant_id = current_account_with_tenant() - snippet_service = _snippet_service() snippet = snippet_service.get_snippet_by_id( snippet_id=str(snippet_id), @@ -231,10 +232,9 @@ class CustomizedSnippetDetailApi(Resource): @login_required @account_initialization_required @edit_permission_required - def delete(self, snippet_id: str): + @with_current_tenant_id + def delete(self, current_tenant_id: str, snippet_id: str): """Delete customized snippet.""" - _, current_tenant_id = current_account_with_tenant() - snippet_service = _snippet_service() snippet = snippet_service.get_snippet_by_id( snippet_id=str(snippet_id), @@ -266,10 +266,9 @@ class CustomizedSnippetExportApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, snippet_id: str): + @with_current_tenant_id + def get(self, current_tenant_id: str, snippet_id: str): """Export snippet as DSL.""" - _, current_tenant_id = current_account_with_tenant() - snippet_service = _snippet_service() snippet = snippet_service.get_snippet_by_id( snippet_id=str(snippet_id), @@ -312,9 +311,9 @@ class CustomizedSnippetImportApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self): + @with_current_user + def post(self, current_user: Account): """Import snippet from DSL.""" - current_user, _ = current_account_with_tenant() payload = SnippetImportPayload.model_validate(console_ns.payload or {}) with Session(db.engine) as session: @@ -350,10 +349,9 @@ class CustomizedSnippetImportConfirmApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, import_id: str): + @with_current_user + def post(self, current_user: Account, import_id: str): """Confirm a pending snippet import.""" - current_user, _ = current_account_with_tenant() - with Session(db.engine) as session: import_service = SnippetDslService(session) result = import_service.confirm_import(import_id=import_id, account=current_user) @@ -375,10 +373,9 @@ class CustomizedSnippetCheckDependenciesApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, snippet_id: str): + @with_current_tenant_id + def get(self, current_tenant_id: str, snippet_id: str): """Check dependencies for a snippet.""" - _, current_tenant_id = current_account_with_tenant() - snippet_service = _snippet_service() snippet = snippet_service.get_snippet_by_id( snippet_id=str(snippet_id), @@ -406,10 +403,9 @@ class CustomizedSnippetUseCountIncrementApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, snippet_id: str): + @with_current_tenant_id + def post(self, current_tenant_id: str, snippet_id: str): """Increment snippet use count when it is inserted into a workflow.""" - _, current_tenant_id = current_account_with_tenant() - snippet_service = _snippet_service() snippet = snippet_service.get_snippet_by_id( snippet_id=str(snippet_id), diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index f9a6923f05..4f1fd6be0a 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -18,6 +18,8 @@ from controllers.console.wraps import ( enterprise_license_required, is_admin_or_owner_required, setup_required, + with_current_tenant_id, + with_current_user, ) from core.db.session_factory import session_factory from core.entities.mcp_provider import IdentityMode, MCPAuthentication, MCPConfiguration @@ -30,7 +32,8 @@ from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToo from extensions.ext_database import db from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from models.provider_ids import ToolProviderID # from models.provider_ids import ToolProviderID @@ -286,15 +289,13 @@ class ToolProviderListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account): raw_args = request.args.to_dict() query = ToolProviderListQuery.model_validate(raw_args) - return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) # type: ignore + return ToolCommonService.list_tool_providers(user.id, tenant_id, query.type) # type: ignore @console_ns.route("/workspaces/current/tool-provider/builtin//tools") @@ -302,9 +303,8 @@ class ToolBuiltinProviderListToolsApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider: str): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str, provider: str): return jsonable_encoder( BuiltinToolManageService.list_builtin_tool_provider_tools( tenant_id, @@ -318,9 +318,8 @@ class ToolBuiltinProviderInfoApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider: str): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str, provider: str): return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) @@ -331,9 +330,8 @@ class ToolBuiltinProviderDeleteApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self, provider: str): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, tenant_id: str, provider: str): payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.delete_builtin_tool_provider( @@ -349,15 +347,13 @@ class ToolBuiltinProviderAddApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, provider: str): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account, provider: str): payload = BuiltinToolAddPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.add_builtin_tool_provider( - user_id=user_id, + user_id=user.id, tenant_id=tenant_id, provider=provider, credentials=payload.credentials, @@ -374,14 +370,13 @@ class ToolBuiltinProviderUpdateApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self, provider: str): - user, tenant_id = current_account_with_tenant() - user_id = user.id - + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account, provider: str): payload = BuiltinToolUpdatePayload.model_validate(console_ns.payload or {}) result = BuiltinToolManageService.update_builtin_tool_provider( - user_id=user_id, + user_id=user.id, tenant_id=tenant_id, provider=provider, credential_id=payload.credential_id, @@ -396,8 +391,9 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider: str): - user, tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account, provider: str): # Optional list of credential IDs to include even if visibility would hide them # (used when a workflow/agent node still references another member's only_me credential). include_credential_ids = request.args.getlist("include_credential_ids") or [ @@ -430,15 +426,13 @@ class ToolApiProviderAddApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account): payload = ApiToolProviderAddPayload.model_validate(console_ns.payload or {}) return ApiToolManageService.create_api_tool_provider( - user_id, + user.id, tenant_id, payload.provider, payload.icon, @@ -456,16 +450,14 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account): raw_args = request.args.to_dict() query = UrlQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider_remote_schema( - user_id, + user.id, tenant_id, str(query.url), ) @@ -476,17 +468,15 @@ class ToolApiProviderListToolsApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account): raw_args = request.args.to_dict() query = ProviderQuery.model_validate(raw_args) return jsonable_encoder( ApiToolManageService.list_api_tool_provider_tools( - user_id, + user.id, tenant_id, query.provider, ) @@ -500,15 +490,13 @@ class ToolApiProviderUpdateApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account): payload = ApiToolProviderUpdatePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.update_api_tool_provider( - user_id, + user.id, tenant_id, payload.provider, payload.original_provider, @@ -529,15 +517,13 @@ class ToolApiProviderDeleteApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account): payload = ApiToolProviderDeletePayload.model_validate(console_ns.payload or {}) return ApiToolManageService.delete_api_tool_provider( - user_id, + user.id, tenant_id, payload.provider, ) @@ -548,16 +534,14 @@ class ToolApiProviderGetApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account): raw_args = request.args.to_dict() query = ProviderQuery.model_validate(raw_args) return ApiToolManageService.get_api_tool_provider( - user_id, + user.id, tenant_id, query.provider, ) @@ -568,9 +552,8 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider, credential_type): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str, provider, credential_type): return jsonable_encoder( BuiltinToolManageService.list_builtin_provider_credentials_schema( provider, CredentialType.of(credential_type), tenant_id @@ -598,9 +581,9 @@ class ToolApiProviderPreviousTestApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): + @with_current_tenant_id + def post(self, current_tenant_id: str): payload = ApiToolTestPayload.model_validate(console_ns.payload or {}) - _, current_tenant_id = current_account_with_tenant() return ApiToolManageService.test_api_tool_preview( current_tenant_id, payload.provider_name or "", @@ -619,15 +602,13 @@ class ToolWorkflowProviderCreateApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account): payload = WorkflowToolCreatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.create_workflow_tool( - user_id=user_id, + user_id=user.id, tenant_id=tenant_id, workflow_app_id=payload.workflow_app_id, name=payload.name, @@ -647,14 +628,13 @@ class ToolWorkflowProviderUpdateApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): - user, tenant_id = current_account_with_tenant() - user_id = user.id - + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account): payload = WorkflowToolUpdatePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.update_workflow_tool( - user_id, + user.id, tenant_id, payload.workflow_tool_id, payload.name, @@ -674,15 +654,13 @@ class ToolWorkflowProviderDeleteApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account): payload = WorkflowToolDeletePayload.model_validate(console_ns.payload or {}) return WorkflowToolManageService.delete_workflow_tool( - user_id, + user.id, tenant_id, payload.workflow_tool_id, ) @@ -693,23 +671,21 @@ class ToolWorkflowProviderGetApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account): raw_args = request.args.to_dict() query = WorkflowToolGetQuery.model_validate(raw_args) if query.workflow_tool_id: tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( - user_id, + user.id, tenant_id, query.workflow_tool_id, ) elif query.workflow_app_id: tool = WorkflowToolManageService.get_workflow_tool_by_app_id( - user_id, + user.id, tenant_id, query.workflow_app_id, ) @@ -724,17 +700,15 @@ class ToolWorkflowProviderListToolApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account): raw_args = request.args.to_dict() query = WorkflowToolListQuery.model_validate(raw_args) return jsonable_encoder( WorkflowToolManageService.list_single_workflow_tools( - user_id, + user.id, tenant_id, query.workflow_tool_id, ) @@ -746,16 +720,14 @@ class ToolBuiltinListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account): return jsonable_encoder( [ provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools( - user_id, + user.id, tenant_id, ) ] @@ -767,9 +739,8 @@ class ToolApiListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str): return jsonable_encoder( [ provider.to_dict() @@ -785,16 +756,14 @@ class ToolWorkflowListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - user, tenant_id = current_account_with_tenant() - - user_id = user.id - + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account): return jsonable_encoder( [ provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools( - user_id, + user.id, tenant_id, ) ] @@ -817,13 +786,13 @@ class ToolPluginOAuthApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def get(self, provider: str): + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account, provider: str): tool_provider = ToolProviderID(provider) plugin_id = tool_provider.plugin_id provider_name = tool_provider.provider_name - user, tenant_id = current_account_with_tenant() - oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider) if oauth_client_params is None: raise Forbidden("no oauth available client config found for this tool provider") @@ -912,8 +881,8 @@ class ToolBuiltinProviderSetDefaultApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self, provider: str): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def post(self, current_tenant_id: str, provider: str): payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.set_default_provider( tenant_id=current_tenant_id, provider=provider, id=payload.id @@ -927,11 +896,10 @@ class ToolOAuthCustomClient(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self, provider: str): + @with_current_tenant_id + def post(self, tenant_id: str, provider: str): payload = ToolOAuthCustomClientPayload.model_validate(console_ns.payload or {}) - _, tenant_id = current_account_with_tenant() - return BuiltinToolManageService.save_custom_oauth_client_params( tenant_id=tenant_id, provider=provider, @@ -944,8 +912,8 @@ class ToolOAuthCustomClient(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider: str): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, current_tenant_id: str, provider: str): return jsonable_encoder( BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider) ) @@ -953,8 +921,8 @@ class ToolOAuthCustomClient(Resource): @setup_required @login_required @account_initialization_required - def delete(self, provider: str): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def delete(self, current_tenant_id: str, provider: str): return jsonable_encoder( BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider) ) @@ -965,8 +933,8 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider: str): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, current_tenant_id: str, provider: str): return jsonable_encoder( BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema( tenant_id=current_tenant_id, provider_name=provider @@ -979,8 +947,9 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider: str): - user, tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account, provider: str): include_credential_ids = request.args.getlist("include_credential_ids") or [ s for s in (request.args.get("include_credential_ids") or "").split(",") if s ] @@ -1001,9 +970,10 @@ class ToolProviderMCPApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account): payload = MCPProviderCreatePayload.model_validate(console_ns.payload or {}) - user, tenant_id = current_account_with_tenant() # Parse and validate models configuration = MCPConfiguration.model_validate(payload.configuration or {}) @@ -1054,11 +1024,11 @@ class ToolProviderMCPApi(Resource): @setup_required @login_required @account_initialization_required - def put(self): + @with_current_tenant_id + def put(self, current_tenant_id: str): payload = MCPProviderUpdatePayload.model_validate(console_ns.payload or {}) configuration = MCPConfiguration.model_validate(payload.configuration or {}) authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None - _, current_tenant_id = current_account_with_tenant() # Step 1: Get provider data for URL validation (short-lived session, no network I/O) validation_data = None @@ -1107,9 +1077,9 @@ class ToolProviderMCPApi(Resource): @setup_required @login_required @account_initialization_required - def delete(self): + @with_current_tenant_id + def delete(self, current_tenant_id: str): payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {}) - _, current_tenant_id = current_account_with_tenant() with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) @@ -1124,10 +1094,10 @@ class ToolMCPAuthApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): + @with_current_tenant_id + def post(self, tenant_id: str): payload = MCPAuthPayload.model_validate(console_ns.payload or {}) provider_id = payload.provider_id - _, tenant_id = current_account_with_tenant() with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) @@ -1197,8 +1167,8 @@ class ToolMCPDetailApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider_id: str): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, provider_id: str): with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id) @@ -1210,9 +1180,8 @@ class ToolMCPListAllApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, tenant_id: str): with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) # Skip sensitive data decryption for list view to improve performance @@ -1226,8 +1195,8 @@ class ToolMCPUpdateApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider_id: str): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, provider_id: str): with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) tools = service.list_provider_tools( diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index 77b3c72e5e..d1d8e6fd75 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -4,6 +4,7 @@ from __future__ import annotations import json from datetime import datetime +from inspect import unwrap from types import SimpleNamespace from typing import TypedDict, Unpack from unittest.mock import MagicMock, patch @@ -37,7 +38,8 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import ( ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from libs.datetime_utils import naive_utc_now -from models.account import Account +from models.account import Account, TenantAccountRole +from models.dataset import Pipeline from models.workflow import Workflow from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -48,6 +50,14 @@ DEFAULT_WORKFLOW_CREATED_BY = "00000000-0000-0000-0000-000000000003" type WorkflowVariablePayload = dict[str, object] +def empty_mapping() -> dict[str, object]: + return {} + + +def empty_list() -> list[object]: + return [] + + class WorkflowFactoryPayload(TypedDict): id: str tenant_id: str @@ -86,14 +96,8 @@ class WorkflowFactoryOverrides(TypedDict, total=False): rag_pipeline_variables: list[WorkflowVariablePayload] -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - -def make_node_execution(**overrides): - payload = { +def make_node_execution(**overrides: object) -> SimpleNamespace: + payload: dict[str, object] = { "id": "node-exec-1", "index": 1, "predecessor_node_id": None, @@ -148,6 +152,27 @@ def make_workflow(**overrides: Unpack[WorkflowFactoryOverrides]) -> Workflow: return Workflow(**payload) +def make_account(*, id: str = "account-1", role: TenantAccountRole = TenantAccountRole.EDITOR) -> Account: + account = Account(name="Alice", email=f"{id}@example.com") + account.id = id + account.role = role + return account + + +def make_pipeline( + *, + id: str = "pipeline-1", + tenant_id: str = "tenant-1", + workflow_id: str | None = None, + is_published: bool = False, +) -> Pipeline: + pipeline = Pipeline(tenant_id=tenant_id, name="test-pipeline", description="test") + pipeline.id = id + pipeline.workflow_id = workflow_id + pipeline.is_published = is_published + return pipeline + + @pytest.fixture def workflow_author(db_session_with_containers: Session) -> Account: account = Account(name="Alice", email=f"alice-{uuid4()}@example.com") @@ -158,14 +183,14 @@ def workflow_author(db_session_with_containers: Session) -> Account: class TestDraftWorkflowApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_get_draft_success(self, app: Flask, workflow_author: Account): + def test_get_draft_success(self, app: Flask, workflow_author: Account) -> None: api = DraftRagPipelineApi() method = unwrap(api.get) - pipeline = MagicMock() + pipeline = make_pipeline() workflow = make_workflow(created_by=workflow_author.id) service = MagicMock() @@ -191,11 +216,11 @@ class TestDraftWorkflowApi: } assert result["updated_by"] is None - def test_get_draft_not_exist(self, app: Flask): + def test_get_draft_not_exist(self, app: Flask) -> None: api = DraftRagPipelineApi() method = unwrap(api.get) - pipeline = MagicMock() + pipeline = make_pipeline() service = MagicMock() service.get_draft_workflow.return_value = None @@ -209,54 +234,46 @@ class TestDraftWorkflowApi: with pytest.raises(DraftWorkflowNotExist): method(api, pipeline) - def test_sync_hash_not_match(self, app: Flask): + def test_sync_hash_not_match(self, app: Flask) -> None: api = DraftRagPipelineApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock() + pipeline = make_pipeline() + user = make_account() service = MagicMock() service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError() with ( - app.test_request_context("/", json={"graph": {}, "features": {}}), - patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), + app.test_request_context("/", json={"graph": empty_mapping(), "features": empty_mapping()}), + patch.object(type(console_ns), "payload", {"graph": empty_mapping(), "features": empty_mapping()}), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, ), ): with pytest.raises(DraftWorkflowNotSync): - method(api, pipeline) + method(api, user, pipeline) - def test_sync_invalid_text_plain(self, app: Flask): + def test_sync_invalid_text_plain(self, app: Flask) -> None: api = DraftRagPipelineApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock() + pipeline = make_pipeline() + user = make_account() with ( app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), ): - response, status = method(api, pipeline) + response, status = method(api, user, pipeline) assert status == 400 - def test_restore_published_workflow_to_draft_success(self, app: Flask): + def test_restore_published_workflow_to_draft_success(self, app: Flask) -> None: api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock(id="account-1") + pipeline = make_pipeline() + user = make_account(id="account-1") workflow = MagicMock(unique_hash="restored-hash", updated_at=None, created_at=datetime(2024, 1, 1)) service = MagicMock() @@ -264,50 +281,42 @@ class TestDraftWorkflowApi: with ( app.test_request_context("/", method="POST"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, ), ): - result = method(api, pipeline, "published-workflow") + result = method(api, user, pipeline, "published-workflow") assert result["result"] == "success" assert result["hash"] == "restored-hash" - def test_restore_published_workflow_to_draft_not_found(self, app: Flask): + def test_restore_published_workflow_to_draft_not_found(self, app: Flask) -> None: api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock(id="account-1") + pipeline = make_pipeline() + user = make_account(id="account-1") service = MagicMock() service.restore_published_workflow_to_draft.side_effect = WorkflowNotFoundError("Workflow not found") with ( app.test_request_context("/", method="POST"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, ), ): with pytest.raises(NotFound): - method(api, pipeline, "published-workflow") + method(api, user, pipeline, "published-workflow") - def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app: Flask): + def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app: Flask) -> None: api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock(id="account-1") + pipeline = make_pipeline() + user = make_account(id="account-1") service = MagicMock() service.restore_published_workflow_to_draft.side_effect = IsDraftWorkflowError( @@ -316,17 +325,13 @@ class TestDraftWorkflowApi: with ( app.test_request_context("/", method="POST"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, ), ): with pytest.raises(HTTPException) as exc: - method(api, pipeline, "draft-workflow") + method(api, user, pipeline, "draft-workflow") assert exc.value.code == 400 assert exc.value.description == "source workflow must be published" @@ -334,23 +339,19 @@ class TestDraftWorkflowApi: class TestDraftRunNodes: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_iteration_node_success(self, app: Flask): + def test_iteration_node_success(self, app: Flask) -> None: api = RagPipelineDraftRunIterationNodeApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock() + pipeline = make_pipeline() + user = make_account() with ( - app.test_request_context("/", json={"inputs": {}}), - patch.object(type(console_ns), "payload", {"inputs": {}}), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), + app.test_request_context("/", json={"inputs": empty_mapping()}), + patch.object(type(console_ns), "payload", {"inputs": empty_mapping()}), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration", return_value=MagicMock(), @@ -360,45 +361,37 @@ class TestDraftRunNodes: return_value={"ok": True}, ), ): - result = method(api, pipeline, "node") + result = method(api, user, pipeline, "node") assert result == {"ok": True} - def test_iteration_node_conversation_not_exists(self, app: Flask): + def test_iteration_node_conversation_not_exists(self, app: Flask) -> None: api = RagPipelineDraftRunIterationNodeApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock() + pipeline = make_pipeline() + user = make_account() with ( - app.test_request_context("/", json={"inputs": {}}), - patch.object(type(console_ns), "payload", {"inputs": {}}), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), + app.test_request_context("/", json={"inputs": empty_mapping()}), + patch.object(type(console_ns), "payload", {"inputs": empty_mapping()}), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration", side_effect=services.errors.conversation.ConversationNotExistsError(), ), ): with pytest.raises(NotFound): - method(api, pipeline, "node") + method(api, user, pipeline, "node") - def test_loop_node_success(self, app: Flask): + def test_loop_node_success(self, app: Flask) -> None: api = RagPipelineDraftRunLoopNodeApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock() + pipeline = make_pipeline() + user = make_account() with ( - app.test_request_context("/", json={"inputs": {}}), - patch.object(type(console_ns), "payload", {"inputs": {}}), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), + app.test_request_context("/", json={"inputs": empty_mapping()}), + patch.object(type(console_ns), "payload", {"inputs": empty_mapping()}), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop", return_value=MagicMock(), @@ -408,35 +401,31 @@ class TestDraftRunNodes: return_value={"ok": True}, ), ): - assert method(api, pipeline, "node") == {"ok": True} + assert method(api, user, pipeline, "node") == {"ok": True} class TestPipelineRunApis: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_draft_run_success(self, app: Flask): + def test_draft_run_success(self, app: Flask) -> None: api = DraftRagPipelineRunApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock() + pipeline = make_pipeline() + user = make_account() payload = { - "inputs": {}, + "inputs": empty_mapping(), "datasource_type": "x", - "datasource_info_list": [], + "datasource_info_list": empty_list(), "start_node_id": "n", } with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", return_value=MagicMock(), @@ -446,27 +435,27 @@ class TestPipelineRunApis: return_value={"ok": True}, ), ): - assert method(api, pipeline) == {"ok": True} + assert method(api, user, pipeline) == {"ok": True} - def test_draft_run_rate_limit(self, app: Flask): + def test_draft_run_rate_limit(self, app: Flask) -> None: api = DraftRagPipelineRunApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock() + pipeline = make_pipeline() + user = make_account() + payload: dict[str, object] = { + "inputs": empty_mapping(), + "datasource_type": "x", + "datasource_info_list": empty_list(), + "start_node_id": "n", + } with ( - app.test_request_context( - "/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"} - ), + app.test_request_context("/", json=payload), patch.object( type(console_ns), "payload", - {"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}, - ), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), + payload, ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", @@ -474,48 +463,42 @@ class TestPipelineRunApis: ), ): with pytest.raises(InvokeRateLimitHttpError): - method(api, pipeline) + method(api, user, pipeline) class TestDraftNodeRun: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_execution_not_found(self, app: Flask): + def test_execution_not_found(self, app: Flask) -> None: api = RagPipelineDraftNodeRunApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock() + pipeline = make_pipeline() + user = make_account() service = MagicMock() service.run_draft_workflow_node.return_value = None with ( - app.test_request_context("/", json={"inputs": {}}), - patch.object(type(console_ns), "payload", {"inputs": {}}), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), + app.test_request_context("/", json={"inputs": empty_mapping()}), + patch.object(type(console_ns), "payload", {"inputs": empty_mapping()}), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, ), ): with pytest.raises(ValueError): - method(api, pipeline, "node") + method(api, user, pipeline, "node") class TestPublishedPipelineApis: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_publish_success(self, app: Flask, db_session_with_containers: Session): - from models.dataset import Pipeline - + def test_publish_success(self, app: Flask, db_session_with_containers: Session) -> None: api = PublishedRagPipelineApi() method = unwrap(api.post) @@ -530,7 +513,7 @@ class TestPublishedPipelineApis: db_session_with_containers.commit() db_session_with_containers.expire_all() - user = MagicMock(id="u1") + user = make_account(id="u1") workflow = MagicMock( id=str(uuid4()), @@ -542,16 +525,12 @@ class TestPublishedPipelineApis: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, ), ): - result = method(api, pipeline) + result = method(api, user, pipeline) assert result["result"] == "success" assert "created_at" in result @@ -559,47 +538,39 @@ class TestPublishedPipelineApis: class TestMiscApis: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_task_stop(self, app: Flask): + def test_task_stop(self, app: Flask) -> None: api = RagPipelineTaskStopApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock(id="u1") + pipeline = make_pipeline() + user = make_account(id="u1") with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag" ) as stop_mock, ): - result = method(api, pipeline, "task-1") + result = method(api, user, pipeline, "task-1") stop_mock.assert_called_once() assert result["result"] == "success" - def test_transform_forbidden(self, app: Flask): + def test_transform_forbidden(self, app: Flask) -> None: api = RagPipelineTransformApi() method = unwrap(api.post) - user = MagicMock(has_edit_permission=False, is_dataset_operator=False) + user = make_account(role=TenantAccountRole.NORMAL) with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), ): with pytest.raises(Forbidden): - method(api, "ds1") + method(api, user, "ds1") - def test_recommended_plugins(self, app: Flask): + def test_recommended_plugins(self, app: Flask) -> None: api = RagPipelineRecommendedPluginApi() method = unwrap(api.get) @@ -619,20 +590,20 @@ class TestMiscApis: class TestPublishedRagPipelineRunApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_published_run_success(self, app: Flask): + def test_published_run_success(self, app: Flask) -> None: api = PublishedRagPipelineRunApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock() + pipeline = make_pipeline() + user = make_account() payload = { - "inputs": {}, + "inputs": empty_mapping(), "datasource_type": "x", - "datasource_info_list": [], + "datasource_info_list": empty_list(), "start_node_id": "n", "response_mode": "blocking", } @@ -640,10 +611,6 @@ class TestPublishedRagPipelineRunApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", return_value=MagicMock(), @@ -653,49 +620,45 @@ class TestPublishedRagPipelineRunApi: return_value={"ok": True}, ), ): - result = method(api, pipeline) + result = method(api, user, pipeline) assert result == {"ok": True} - def test_published_run_rate_limit(self, app: Flask): + def test_published_run_rate_limit(self, app: Flask) -> None: api = PublishedRagPipelineRunApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock() + pipeline = make_pipeline() + user = make_account() payload = { - "inputs": {}, + "inputs": empty_mapping(), "datasource_type": "x", - "datasource_info_list": [], + "datasource_info_list": empty_list(), "start_node_id": "n", } with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", side_effect=InvokeRateLimitError("limit"), ), ): with pytest.raises(InvokeRateLimitHttpError): - method(api, pipeline) + method(api, user, pipeline) class TestDefaultBlockConfigApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_get_block_config_success(self, app: Flask): + def test_get_block_config_success(self, app: Flask) -> None: api = DefaultRagPipelineBlockConfigApi() method = unwrap(api.get) - pipeline = MagicMock() + pipeline = make_pipeline() service = MagicMock() service.get_default_block_config.return_value = {"k": "v"} @@ -710,11 +673,11 @@ class TestDefaultBlockConfigApi: result = method(api, pipeline, "llm") assert result == {"k": "v"} - def test_get_block_config_invalid_json(self, app: Flask): + def test_get_block_config_invalid_json(self, app: Flask) -> None: api = DefaultRagPipelineBlockConfigApi() method = unwrap(api.get) - pipeline = MagicMock() + pipeline = make_pipeline() with app.test_request_context("/?q=bad-json"): with pytest.raises(ValueError): @@ -723,65 +686,57 @@ class TestDefaultBlockConfigApi: class TestPublishedAllRagPipelineApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_get_published_workflows_success(self, app: Flask): + def test_get_published_workflows_success(self, app: Flask) -> None: api = PublishedAllRagPipelineApi() method = unwrap(api.get) - pipeline = MagicMock() - user = MagicMock(id="u1") + pipeline = make_pipeline() + user = make_account(id="u1") service = MagicMock() service.get_all_published_workflow.return_value = ([make_workflow(id="w1")], False) with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, ), ): - result = method(api, pipeline) + result = method(api, user, pipeline) assert result["items"][0]["id"] == "w1" assert result["items"][0]["graph"] == {"nodes": [], "edges": []} assert result["has_more"] is False - def test_get_published_workflows_forbidden(self, app: Flask): + def test_get_published_workflows_forbidden(self, app: Flask) -> None: api = PublishedAllRagPipelineApi() method = unwrap(api.get) - pipeline = MagicMock() - user = MagicMock(id="u1") + pipeline = make_pipeline() + user = make_account(id="u1") with ( app.test_request_context("/?user_id=u2"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), ): with pytest.raises(Forbidden): - method(api, pipeline) + method(api, user, pipeline) class TestRagPipelineByIdApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_patch_success(self, app: Flask): + def test_patch_success(self, app: Flask) -> None: api = RagPipelineByIdApi() method = unwrap(api.patch) - pipeline = MagicMock(tenant_id="t1") - user = MagicMock(id="u1") + pipeline = make_pipeline(tenant_id="t1") + user = make_account(id="u1") workflow = make_workflow(id="w1", marked_name="test") @@ -793,44 +748,36 @@ class TestRagPipelineByIdApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, ), ): - result = method(api, pipeline, "w1") + result = method(api, user, pipeline, "w1") assert result["id"] == "w1" assert result["marked_name"] == "test" assert result["hash"] == workflow.unique_hash - def test_patch_no_fields(self, app: Flask): + def test_patch_no_fields(self, app: Flask) -> None: api = RagPipelineByIdApi() method = unwrap(api.patch) - pipeline = MagicMock() - user = MagicMock() + pipeline = make_pipeline() + user = make_account() with ( app.test_request_context("/", json={}), - patch.object(type(console_ns), "payload", {}), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), + patch.object(type(console_ns), "payload", empty_mapping()), ): - result, status = method(api, pipeline, "w1") + result, status = method(api, user, pipeline, "w1") assert status == 400 - def test_delete_success(self, app: Flask): + def test_delete_success(self, app: Flask) -> None: api = RagPipelineByIdApi() method = unwrap(api.delete) - pipeline = MagicMock(tenant_id="t1", workflow_id="active-workflow", id="pipeline-1") + pipeline = make_pipeline(tenant_id="t1", workflow_id="active-workflow") workflow_service = MagicMock() @@ -846,11 +793,11 @@ class TestRagPipelineByIdApi: workflow_service.delete_workflow.assert_called_once() assert result == (None, 204) - def test_delete_active_workflow_rejected(self, app: Flask): + def test_delete_active_workflow_rejected(self, app: Flask) -> None: api = RagPipelineByIdApi() method = unwrap(api.delete) - pipeline = MagicMock(tenant_id="t1", workflow_id="active-workflow", id="pipeline-1") + pipeline = make_pipeline(tenant_id="t1", workflow_id="active-workflow") with app.test_request_context("/", method="DELETE"): with pytest.raises(BadRequest, match="currently in use by pipeline"): @@ -859,14 +806,14 @@ class TestRagPipelineByIdApi: class TestRagPipelineWorkflowLastRunApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_last_run_success(self, app: Flask): + def test_last_run_success(self, app: Flask) -> None: api = RagPipelineWorkflowLastRunApi() method = unwrap(api.get) - pipeline = MagicMock() + pipeline = make_pipeline() workflow = MagicMock() node_exec = make_node_execution() @@ -886,11 +833,11 @@ class TestRagPipelineWorkflowLastRunApi: assert result["inputs"] == {"query": "hello"} assert result["outputs"] == {"answer": "world"} - def test_last_run_not_found(self, app: Flask): + def test_last_run_not_found(self, app: Flask) -> None: api = RagPipelineWorkflowLastRunApi() method = unwrap(api.get) - pipeline = MagicMock() + pipeline = make_pipeline() service = MagicMock() service.get_draft_workflow.return_value = None @@ -908,19 +855,19 @@ class TestRagPipelineWorkflowLastRunApi: class TestRagPipelineDatasourceVariableApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_set_datasource_variables_success(self, app: Flask): + def test_set_datasource_variables_success(self, app: Flask) -> None: api = RagPipelineDatasourceVariableApi() method = unwrap(api.post) - pipeline = MagicMock() - user = MagicMock() + pipeline = make_pipeline() + user = make_account() payload = { "datasource_type": "db", - "datasource_info": {}, + "datasource_info": empty_mapping(), "start_node_id": "n1", "start_node_title": "Node", } @@ -931,15 +878,11 @@ class TestRagPipelineDatasourceVariableApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", - return_value=(user, "t"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", return_value=service, ), ): - result = method(api, pipeline) + result = method(api, user, pipeline) assert result["node_id"] == "n1" assert result["process_data"] == {} diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py index d5eb53ad99..8739ca28bd 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py @@ -3,12 +3,13 @@ from __future__ import annotations import json -from types import SimpleNamespace +from inspect import unwrap from unittest.mock import MagicMock, patch import pytest from flask import Flask from flask.testing import FlaskClient +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from controllers.console.workspace.tool_providers import ( @@ -43,41 +44,56 @@ from controllers.console.workspace.tool_providers import ( ToolWorkflowProviderUpdateApi, is_valid_url, ) +from models.account import Account, TenantAccountRole from services.tools.mcp_tools_manage_service import ReconnectResult +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func +def empty_mapping() -> dict[str, object]: + return {} + + +def empty_list() -> list[object]: + return [] @pytest.fixture -def _mock_cache(): +def _mock_cache() -> None: return @pytest.fixture -def _mock_user_tenant(): +def _mock_user_tenant() -> None: return @pytest.fixture -def client(flask_app_with_containers: Flask): +def client(flask_app_with_containers: Flask) -> FlaskClient: return flask_app_with_containers.test_client() -@patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u1"), "t1"), - autospec=True, -) +def make_account(*, id: str = "u", role: TenantAccountRole = TenantAccountRole.EDITOR) -> Account: + account = Account(name="Alice", email=f"{id}@example.com") + account.id = id + account.role = role + return account + + @patch("controllers.console.workspace.tool_providers.sessionmaker", autospec=True) @patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url", autospec=True) @pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant") def test_create_mcp_provider_populates_tools( - mock_reconnect, mock_session, mock_current_account_with_tenant, client: FlaskClient -): + mock_reconnect: MagicMock, + mock_session: MagicMock, + client: FlaskClient, + db_session_with_containers: Session, +) -> None: + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + headers = authenticate_console_client(client, account) + # Arrange: reconnect returns tools immediately mock_reconnect.return_value = ReconnectResult( authed=True, @@ -104,21 +120,11 @@ def test_create_mcp_provider_populates_tools( "icon_background": "#000", "server_identifier": "demo-sid", "configuration": {"timeout": 5, "sse_read_timeout": 30}, - "headers": {}, - "authentication": {}, + "headers": empty_mapping(), + "authentication": empty_mapping(), } # Act with ( - patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check - patch( - "controllers.console.wraps.current_account_with_tenant", - return_value=(MagicMock(id="u1"), "t1"), - autospec=True, - ), - patch("libs.login.check_csrf_token", return_value=None, autospec=True), # bypass CSRF in login_required - patch( - "libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True), autospec=True - ), # login patch( "services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider", return_value={"id": "provider-1", "tools": [{"name": "ping"}]}, @@ -128,6 +134,7 @@ def test_create_mcp_provider_populates_tools( resp = client.post( "/console/api/workspaces/current/tool-provider/mcp", data=json.dumps(payload), + headers=headers, content_type="application/json", ) @@ -141,7 +148,7 @@ def test_create_mcp_provider_populates_tools( class TestUtils: - def test_is_valid_url(self): + def test_is_valid_url(self) -> None: assert is_valid_url("https://example.com") assert is_valid_url("http://example.com") assert not is_valid_url("") @@ -152,154 +159,121 @@ class TestUtils: class TestToolProviderListApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_get_success(self, app: Flask): + def test_get_success(self, app: Flask) -> None: api = ToolProviderListApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u1"), "t1"), - ), patch( "controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers", return_value=["p1"], ), ): - assert method(api) == ["p1"] + assert method(api, "t1", make_account(id="u1")) == ["p1"] class TestBuiltinProviderApis: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_list_tools(self, app: Flask): + def test_list_tools(self, app: Flask) -> None: api = ToolBuiltinProviderListToolsApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(None, "t1"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools", return_value=[{"a": 1}], ), ): - assert method(api, "provider") == [{"a": 1}] + assert method(api, "t1", "provider") == [{"a": 1}] - def test_info(self, app: Flask): + def test_info(self, app: Flask) -> None: api = ToolBuiltinProviderInfoApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(None, "t1"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info", return_value={"x": 1}, ), ): - assert method(api, "provider") == {"x": 1} + assert method(api, "t1", "provider") == {"x": 1} - def test_delete(self, app: Flask): + def test_delete(self, app: Flask) -> None: api = ToolBuiltinProviderDeleteApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"credential_id": "cid"}), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(None, "t1"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_builtin_tool_provider", return_value={"result": "success"}, ), ): - assert method(api, "provider")["result"] == "success" + assert method(api, "t1", "provider")["result"] == "success" - def test_add_invalid_type(self, app: Flask): + def test_add_invalid_type(self, app: Flask) -> None: api = ToolBuiltinProviderAddApi() method = unwrap(api.post) with ( - app.test_request_context("/", json={"credentials": {}, "type": "invalid"}), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), + app.test_request_context("/", json={"credentials": empty_mapping(), "type": "invalid"}), ): with pytest.raises(ValueError): - method(api, "provider") + method(api, "t", make_account(), "provider") - def test_add_success(self, app: Flask): + def test_add_success(self, app: Flask) -> None: api = ToolBuiltinProviderAddApi() method = unwrap(api.post) - payload = {"credentials": {}, "type": "oauth2", "name": "n"} + payload = {"credentials": empty_mapping(), "type": "oauth2", "name": "n"} with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider", return_value={"id": 1}, ), ): - assert method(api, "provider")["id"] == 1 + assert method(api, "t", make_account(), "provider")["id"] == 1 - def test_update(self, app: Flask): + def test_update(self, app: Flask) -> None: api = ToolBuiltinProviderUpdateApi() method = unwrap(api.post) - payload = {"credential_id": "c1", "credentials": {}, "name": "n"} + payload = {"credential_id": "c1", "credentials": empty_mapping(), "name": "n"} with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider", return_value={"ok": True}, ), ): - assert method(api, "provider")["ok"] + assert method(api, "t", make_account(), "provider")["ok"] - def test_get_credentials(self, app: Flask): + def test_get_credentials(self, app: Flask) -> None: api = ToolBuiltinProviderGetCredentialsApi() method = unwrap(api.get) - mock_user = SimpleNamespace(id="user-1", is_admin_or_owner=False) with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(mock_user, "t"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials", return_value={"k": "v"}, ), ): - assert method(api, "provider") == {"k": "v"} + assert method(api, "t", make_account(id="user-1"), "provider") == {"k": "v"} - def test_icon(self, app: Flask): + def test_icon(self, app: Flask) -> None: api = ToolBuiltinProviderIconApi() method = unwrap(api.get) @@ -313,208 +287,168 @@ class TestBuiltinProviderApis: response = method(api, "provider") assert response.mimetype == "image/png" - def test_credentials_schema(self, app: Flask): + def test_credentials_schema(self, app: Flask) -> None: api = ToolBuiltinProviderCredentialsSchemaApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(), "t"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema", return_value={"schema": {}}, ), ): - assert method(api, "provider", "oauth2") == {"schema": {}} + assert method(api, "t", "provider", "oauth2") == {"schema": {}} - def test_set_default_credential(self, app: Flask): + def test_set_default_credential(self, app: Flask) -> None: api = ToolBuiltinProviderSetDefaultApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"id": "c1"}), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider", return_value={"ok": True}, ), ): - assert method(api, "provider")["ok"] + assert method(api, "t", "provider")["ok"] - def test_get_credential_info(self, app: Flask): + def test_get_credential_info(self, app: Flask) -> None: api = ToolBuiltinProviderGetCredentialInfoApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(), "t"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info", return_value={"info": "x"}, ), ): - assert method(api, "provider") == {"info": "x"} + assert method(api, "t", make_account(), "provider") == {"info": "x"} - def test_get_oauth_client_schema(self, app: Flask): + def test_get_oauth_client_schema(self, app: Flask) -> None: api = ToolBuiltinProviderGetOauthClientSchemaApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(), "t"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema", return_value={"schema": {}}, ), ): - assert method(api, "provider") == {"schema": {}} + assert method(api, "t", "provider") == {"schema": {}} class TestApiProviderApis: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_add(self, app: Flask): + def test_add(self, app: Flask) -> None: api = ToolApiProviderAddApi() method = unwrap(api.post) payload = { - "credentials": {}, + "credentials": empty_mapping(), "schema_type": "openapi", "schema": "{}", "provider": "p", - "icon": {}, + "icon": empty_mapping(), } with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider", return_value={"id": 1}, ), ): - assert method(api)["id"] == 1 + assert method(api, "t", make_account())["id"] == 1 - def test_remote_schema(self, app: Flask): + def test_remote_schema(self, app: Flask) -> None: api = ToolApiProviderGetRemoteSchemaApi() method = unwrap(api.get) with ( app.test_request_context("/?url=http://x.com"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema", return_value={"schema": "x"}, ), ): - assert method(api)["schema"] == "x" + assert method(api, "t", make_account())["schema"] == "x" - def test_list_tools(self, app: Flask): + def test_list_tools(self, app: Flask) -> None: api = ToolApiProviderListToolsApi() method = unwrap(api.get) with ( app.test_request_context("/?provider=p"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools", return_value=[{"tool": 1}], ), ): - assert method(api) == [{"tool": 1}] + assert method(api, "t", make_account()) == [{"tool": 1}] - def test_update(self, app: Flask): + def test_update(self, app: Flask) -> None: api = ToolApiProviderUpdateApi() method = unwrap(api.post) payload = { - "credentials": {}, + "credentials": empty_mapping(), "schema_type": "openapi", "schema": "{}", "provider": "p", "original_provider": "o", - "icon": {}, + "icon": empty_mapping(), "privacy_policy": "", "custom_disclaimer": "", } with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider", return_value={"ok": True}, ), ): - assert method(api)["ok"] + assert method(api, "t", make_account())["ok"] - def test_delete(self, app: Flask): + def test_delete(self, app: Flask) -> None: api = ToolApiProviderDeleteApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"provider": "p"}), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.ApiToolManageService.delete_api_tool_provider", return_value={"result": "success"}, ), ): - assert method(api)["result"] == "success" + assert method(api, "t", make_account())["result"] == "success" - def test_get(self, app: Flask): + def test_get(self, app: Flask) -> None: api = ToolApiProviderGetApi() method = unwrap(api.get) with ( app.test_request_context("/?provider=p"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider", return_value={"x": 1}, ), ): - assert method(api) == {"x": 1} + assert method(api, "t", make_account()) == {"x": 1} class TestWorkflowApis: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_create(self, app: Flask): + def test_create(self, app: Flask) -> None: api = ToolWorkflowProviderCreateApi() method = unwrap(api.post) @@ -523,24 +457,20 @@ class TestWorkflowApis: "name": "n", "label": "l", "description": "d", - "icon": {}, - "parameters": [], + "icon": empty_mapping(), + "parameters": empty_list(), } with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool", return_value={"id": 1}, ), ): - assert method(api)["id"] == 1 + assert method(api, "t", make_account())["id"] == 1 - def test_update_invalid(self, app: Flask): + def test_update_invalid(self, app: Flask) -> None: api = ToolWorkflowProviderUpdateApi() method = unwrap(api.post) @@ -549,61 +479,49 @@ class TestWorkflowApis: "name": "Tool", "label": "Tool Label", "description": "A tool", - "icon": {}, + "icon": empty_mapping(), } with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool", return_value={"ok": True}, ), ): - result = method(api) + result = method(api, "t", make_account()) assert result["ok"] - def test_delete(self, app: Flask): + def test_delete(self, app: Flask) -> None: api = ToolWorkflowProviderDeleteApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool", return_value={"ok": True}, ), ): - assert method(api)["ok"] + assert method(api, "t", make_account())["ok"] - def test_get_error(self, app: Flask): + def test_get_error(self, app: Flask) -> None: api = ToolWorkflowProviderGetApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), ): with pytest.raises(ValueError): - method(api) + method(api, "t", make_account()) class TestLists: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_builtin_list(self, app: Flask): + def test_builtin_list(self, app: Flask) -> None: api = ToolBuiltinListApi() method = unwrap(api.get) @@ -612,18 +530,14 @@ class TestLists: with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools", return_value=[m], ), ): - assert method(api) == [{"x": 1}] + assert method(api, "t", make_account()) == [{"x": 1}] - def test_api_list(self, app: Flask): + def test_api_list(self, app: Flask) -> None: api = ToolApiListApi() method = unwrap(api.get) @@ -632,18 +546,14 @@ class TestLists: with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(None, "t"), - ), patch( "controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools", return_value=[m], ), ): - assert method(api) == [{"x": 1}] + assert method(api, "t") == [{"x": 1}] - def test_workflow_list(self, app: Flask): + def test_workflow_list(self, app: Flask) -> None: api = ToolWorkflowListApi() method = unwrap(api.get) @@ -652,24 +562,20 @@ class TestLists: with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools", return_value=[m], ), ): - assert method(api) == [{"x": 1}] + assert method(api, "t", make_account()) == [{"x": 1}] class TestLabels: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_labels(self, app: Flask): + def test_labels(self, app: Flask) -> None: api = ToolLabelsApi() method = unwrap(api.get) @@ -685,28 +591,24 @@ class TestLabels: class TestOAuth: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_oauth_no_client(self, app: Flask): + def test_oauth_no_client(self, app: Flask) -> None: api = ToolPluginOAuthApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(id="u"), "t"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_oauth_client", return_value=None, ), ): with pytest.raises(Forbidden): - method(api, "provider") + method(api, "t", make_account(), "provider") - def test_oauth_callback_no_cookie(self, app: Flask): + def test_oauth_callback_no_cookie(self, app: Flask) -> None: api = ToolOAuthCallback() method = unwrap(api.get) @@ -717,56 +619,44 @@ class TestOAuth: class TestOAuthCustomClient: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_save_custom_client(self, app: Flask): + def test_save_custom_client(self, app: Flask) -> None: api = ToolOAuthCustomClient() method = unwrap(api.post) with ( app.test_request_context("/", json={"client_params": {"a": 1}}), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(), "t"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params", return_value={"ok": True}, ), ): - assert method(api, "provider")["ok"] + assert method(api, "t", "provider")["ok"] - def test_get_custom_client(self, app: Flask): + def test_get_custom_client(self, app: Flask) -> None: api = ToolOAuthCustomClient() method = unwrap(api.get) with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(), "t"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_custom_oauth_client_params", return_value={"client_id": "x"}, ), ): - assert method(api, "provider") == {"client_id": "x"} + assert method(api, "t", "provider") == {"client_id": "x"} - def test_delete_custom_client(self, app: Flask): + def test_delete_custom_client(self, app: Flask) -> None: api = ToolOAuthCustomClient() method = unwrap(api.delete) with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.tool_providers.current_account_with_tenant", - return_value=(MagicMock(), "t"), - ), patch( "controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params", return_value={"ok": True}, ), ): - assert method(api, "provider")["ok"] + assert method(api, "t", "provider")["ok"] diff --git a/api/tests/unit_tests/controllers/console/app/test_agent_app_workspace.py b/api/tests/unit_tests/controllers/console/app/test_agent_app_workspace.py index 352001710b..e31beb100b 100644 --- a/api/tests/unit_tests/controllers/console/app/test_agent_app_workspace.py +++ b/api/tests/unit_tests/controllers/console/app/test_agent_app_workspace.py @@ -1,8 +1,11 @@ from __future__ import annotations +from inspect import unwrap from types import SimpleNamespace +from typing import cast import pytest +from flask import Response from clients.agent_backend.errors import AgentBackendHTTPError, AgentBackendTransportError from clients.agent_backend.workspace_files_client import ( @@ -12,16 +15,10 @@ from clients.agent_backend.workspace_files_client import ( WorkspacePreviewResult, ) from controllers.console import agent_app_workspace as module +from models.model import App, AppMode, IconType from services.agent_app_workspace_service import AgentWorkspaceInspectorError -def _unwrapped_get(resource_cls): - func = resource_cls.get - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - class _AgentAppService: def __init__(self) -> None: self.calls: list[tuple[str, str, str, str, str]] = [] @@ -87,6 +84,20 @@ class _WorkflowService: return WorkspaceDownloadResult(path=path, size=3, truncated=False, content=b"abc") +def _app_model(app_id: str = "app-1") -> App: + return App( + id=app_id, + tenant_id="tenant-1", + name="App", + mode=AppMode.AGENT, + icon_type=IconType.EMOJI, + icon="bot", + icon_background="#fff", + enable_site=False, + enable_api=False, + ) + + def test_handle_maps_workspace_and_agent_backend_errors() -> None: assert module._handle(AgentWorkspaceInspectorError("no_sandbox", "no sandbox", status_code=404)) == ( {"code": "no_sandbox", "message": "no sandbox"}, @@ -108,8 +119,11 @@ def test_handle_maps_workspace_and_agent_backend_errors() -> None: def test_download_response_returns_binary_or_too_large_error() -> None: - response = module._download_response( - WorkspaceDownloadResult(path="dir/report.txt", size=3, truncated=False, content=b"abc") + response = cast( + Response, + module._download_response( + WorkspaceDownloadResult(path="dir/report.txt", size=3, truncated=False, content=b"abc") + ), ) assert response.status_code == 200 @@ -133,17 +147,16 @@ def test_download_response_returns_binary_or_too_large_error() -> None: def test_agent_app_workspace_resources_proxy_service(monkeypatch: pytest.MonkeyPatch) -> None: service = _AgentAppService() monkeypatch.setattr(module, "AgentAppWorkspaceService", lambda: service) - monkeypatch.setattr(module, "current_account_with_tenant", lambda: (None, "tenant-1")) monkeypatch.setattr( module, "query_params_from_request", lambda model: SimpleNamespace(conversation_id="conv-1", path="sub/report.txt"), ) - app_model = SimpleNamespace(id="app-1") + app_model = _app_model() - listing = _unwrapped_get(module.AgentAppWorkspaceListResource)(object(), app_model) - preview = _unwrapped_get(module.AgentAppWorkspacePreviewResource)(object(), app_model) - download = _unwrapped_get(module.AgentAppWorkspaceDownloadResource)(object(), app_model) + listing = unwrap(module.AgentAppWorkspaceListResource.get)(object(), "tenant-1", app_model) + preview = unwrap(module.AgentAppWorkspacePreviewResource.get)(object(), "tenant-1", app_model) + download = unwrap(module.AgentAppWorkspaceDownloadResource.get)(object(), "tenant-1", app_model) assert listing["entries"][0]["name"] == "a.txt" assert preview["text"] == "hello" @@ -161,14 +174,13 @@ def test_agent_app_workspace_resource_returns_normalized_errors(monkeypatch: pyt raise AgentWorkspaceInspectorError("no_active_session", "no active session", status_code=404) monkeypatch.setattr(module, "AgentAppWorkspaceService", FailingService) - monkeypatch.setattr(module, "current_account_with_tenant", lambda: (None, "tenant-1")) monkeypatch.setattr( module, "query_params_from_request", lambda model: SimpleNamespace(conversation_id="conv-1", path="."), ) - assert _unwrapped_get(module.AgentAppWorkspaceListResource)(object(), SimpleNamespace(id="app-1")) == ( + assert unwrap(module.AgentAppWorkspaceListResource.get)(object(), "tenant-1", _app_model()) == ( {"code": "no_active_session", "message": "no active session"}, 404, ) @@ -177,17 +189,22 @@ def test_agent_app_workspace_resource_returns_normalized_errors(monkeypatch: pyt def test_workflow_agent_workspace_resources_proxy_service(monkeypatch: pytest.MonkeyPatch) -> None: service = _WorkflowService() monkeypatch.setattr(module, "WorkflowAgentWorkspaceService", lambda: service) - monkeypatch.setattr(module, "current_account_with_tenant", lambda: (None, "tenant-1")) monkeypatch.setattr( module, "query_params_from_request", lambda model: SimpleNamespace(node_execution_id="exec-1", path="out.txt"), ) - app_model = SimpleNamespace(id="app-1") + app_model = _app_model() - listing = _unwrapped_get(module.WorkflowAgentWorkspaceListResource)(object(), app_model, "run-1", "agent-node") - preview = _unwrapped_get(module.WorkflowAgentWorkspacePreviewResource)(object(), app_model, "run-1", "agent-node") - download = _unwrapped_get(module.WorkflowAgentWorkspaceDownloadResource)(object(), app_model, "run-1", "agent-node") + listing = unwrap(module.WorkflowAgentWorkspaceListResource.get)( + object(), "tenant-1", app_model, "run-1", "agent-node" + ) + preview = unwrap(module.WorkflowAgentWorkspacePreviewResource.get)( + object(), "tenant-1", app_model, "run-1", "agent-node" + ) + download = unwrap(module.WorkflowAgentWorkspaceDownloadResource.get)( + object(), "tenant-1", app_model, "run-1", "agent-node" + ) assert listing["path"] == "out.txt" assert preview["text"] == "hello" diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index a03f09e91e..272337d0c5 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -542,7 +542,6 @@ def test_workflow_online_users_filters_inaccessible_workflow(app: Flask, monkeyp app_id_2 = "22222222-2222-2222-2222-222222222222" signed_avatar_url = "https://files.example.com/signed/avatar-1" sign_avatar = Mock(return_value=signed_avatar_url) - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) monkeypatch.setattr( workflow_module, "WorkflowService", @@ -596,7 +595,7 @@ def test_workflow_online_users_filters_inaccessible_workflow(app: Flask, monkeyp method="POST", json={"app_ids": [app_id_1, app_id_2]}, ): - response = handler(api) + response = handler(api, "tenant-1") assert response == { "data": [ @@ -625,7 +624,6 @@ def test_workflow_online_users_filters_inaccessible_workflow(app: Flask, monkeyp def test_workflow_online_users_batches_redis_reads(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: app_ids = [f"wf-{index}" for index in range(workflow_module.WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE + 1)] - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) monkeypatch.setattr( workflow_module, "WorkflowService", @@ -647,7 +645,7 @@ def test_workflow_online_users_batches_redis_reads(app: Flask, monkeypatch: pyte method="POST", json={"app_ids": app_ids}, ): - response = handler(api) + response = handler(api, "tenant-1") assert len(response["data"]) == len(app_ids) assert redis_pipeline_factory.call_count == 2 @@ -656,7 +654,6 @@ def test_workflow_online_users_batches_redis_reads(app: Flask, monkeypatch: pyte def test_workflow_online_users_rejects_excessive_workflow_ids(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) accessible_app_ids = Mock(return_value=set()) monkeypatch.setattr( workflow_module, @@ -675,7 +672,7 @@ def test_workflow_online_users_rejects_excessive_workflow_ids(app: Flask, monkey json={"app_ids": excessive_ids}, ): with pytest.raises(HTTPException) as exc: - handler(api) + handler(api, "tenant-1") assert exc.value.code == 400 assert exc.value.description is not None diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py index 564682b1b3..f04ab6d6e7 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py @@ -39,7 +39,6 @@ def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") monkeypatch.delenv("INIT_PASSWORD", raising=False) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index 52e36fd521..a2748ad323 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -1,13 +1,15 @@ from __future__ import annotations from datetime import datetime -from inspect import unwrap +from inspect import unwrap as unwrap_all from types import SimpleNamespace from unittest.mock import PropertyMock, patch import pytest from controllers.console.datasets.rag_pipeline import rag_pipeline_workflow as module +from models.account import Account, TenantAccountRole +from models.dataset import Pipeline def _make_workflow(**overrides): @@ -33,6 +35,19 @@ def _make_workflow(**overrides): return workflow +def _account() -> Account: + account = Account(name="Alice", email="alice@example.com") + account.id = "user-1" + account.role = TenantAccountRole.EDITOR + return account + + +def _pipeline() -> Pipeline: + pipeline = Pipeline(tenant_id="tenant-1", name="Pipeline", description="desc") + pipeline.id = "pipeline-1" + return pipeline + + def test_draft_rag_pipeline_workflow_get_serializes_response_model(monkeypatch: pytest.MonkeyPatch) -> None: workflow = _make_workflow() monkeypatch.setattr( @@ -40,9 +55,9 @@ def test_draft_rag_pipeline_workflow_get_serializes_response_model(monkeypatch: ) api = module.DraftRagPipelineApi() - handler = unwrap(api.get) + handler = unwrap_all(api.get) - response = handler(api, pipeline=SimpleNamespace(id="pipeline-1")) + response = handler(api, _pipeline()) assert response["id"] == "workflow-1" assert response["graph"] == {"nodes": [], "edges": []} @@ -58,7 +73,7 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes( app, monkeypatch: pytest.MonkeyPatch ) -> None: api = module.PublishedAllRagPipelineApi() - handler = unwrap(api.get) + handler = unwrap_all(api.get) session_state = {"open": False} class _SessionContext: @@ -83,7 +98,6 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes( monkeypatch.setattr(module, "db", SimpleNamespace(engine=object())) monkeypatch.setattr(module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker()) - monkeypatch.setattr(module, "current_account_with_tenant", lambda: (SimpleNamespace(id="user-1"), "tenant-1")) monkeypatch.setattr( module, "RagPipelineService", @@ -95,7 +109,7 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes( method="GET", query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"}, ): - response = handler(api, pipeline=SimpleNamespace(id="pipeline-1")) + response = handler(api, _account(), pipeline=_pipeline()) assert response["items"][0]["id"] == "workflow-1" assert response["page"] == 1 @@ -105,7 +119,6 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes( def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch: pytest.MonkeyPatch) -> None: workflow = _make_workflow(marked_name="Updated release") - monkeypatch.setattr(module, "current_account_with_tenant", lambda: (SimpleNamespace(id="user-1"), "tenant-1")) class _SessionContext: def __enter__(self): @@ -128,7 +141,7 @@ def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch: payload: dict[str, object] = {"marked_name": "Updated release"} api = module.RagPipelineByIdApi() - handler = unwrap(api.patch) + handler = unwrap_all(api.patch) with ( app.test_request_context("/rag/pipelines/pipeline-1/workflows/workflow-1", method="PATCH", json=payload), @@ -136,7 +149,8 @@ def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch: ): response = handler( api, - pipeline=SimpleNamespace(id="pipeline-1", tenant_id="tenant-1"), + _account(), + pipeline=_pipeline(), workflow_id="workflow-1", ) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index ff1a664515..101a640699 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -1,6 +1,7 @@ import datetime import json from contextlib import ExitStack +from inspect import unwrap from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -33,18 +34,13 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.provider_manager import ProviderManager from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType +from models.account import Account, TenantAccountRole from models.dataset import Dataset, DatasetQuery, Document from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus from models.model import ApiToken, App, AppMode, IconType, UploadFile from services.dataset_service import DatasetPermissionService, DatasetService -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - @pytest.fixture(autouse=True) def dataset_model_property_defaults(): properties: dict[str, object] = { @@ -98,6 +94,13 @@ def make_dataset(**overrides) -> Dataset: return Dataset(**base) +def make_account(role: TenantAccountRole = TenantAccountRole.EDITOR) -> Account: + account = Account(name="Test User", email="user@example.com") + account.id = "account-1" + account.role = role + return account + + def make_related_app(**overrides) -> App: base = { "id": "app-1", @@ -147,8 +150,7 @@ def make_document_status(**overrides) -> Document: class TestDatasetList: def _mock_user(self): - user = MagicMock() - user.is_dataset_editor = True + user = make_account() return user def test_get_success_basic(self, app: Flask): @@ -160,10 +162,6 @@ class TestDatasetList: with app.test_request_context("/datasets"): with ( - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_datasets", @@ -175,7 +173,7 @@ class TestDatasetList: return_value=MagicMock(get_models=lambda **_: []), ), ): - resp, status = method(api) + resp, status = method(api, "tenant-1", current_user) assert status == 200 assert resp["total"] == 1 @@ -196,10 +194,6 @@ class TestDatasetList: with app.test_request_context("/datasets?ids=1&ids=2"): with ( - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_datasets_by_ids", @@ -211,7 +205,7 @@ class TestDatasetList: return_value=MagicMock(get_models=lambda **_: []), ), ): - resp, status = method(api) + resp, status = method(api, "tenant-1", current_user) by_ids_mock.assert_called_once() assert status == 200 @@ -226,10 +220,6 @@ class TestDatasetList: with app.test_request_context("/datasets?tag_ids=tag1"): with ( - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_datasets", @@ -241,7 +231,7 @@ class TestDatasetList: return_value=MagicMock(get_models=lambda **_: []), ), ): - resp, status = method(api) + resp, status = method(api, "tenant-1", current_user) assert status == 200 @@ -274,10 +264,6 @@ class TestDatasetList: with app.test_request_context("/datasets"): with ( - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_datasets", @@ -289,7 +275,7 @@ class TestDatasetList: return_value=MagicMock(get_models=lambda **_: []), ), ): - resp, status = method(api) + resp, status = method(api, "tenant-1", current_user) assert status == 200 assert resp["data"][0]["retrieval_model_dict"]["weights"]["weight_type"] is None @@ -303,10 +289,6 @@ class TestDatasetList: with app.test_request_context("/datasets"): with ( - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_datasets", @@ -318,7 +300,7 @@ class TestDatasetList: return_value=MagicMock(get_models=lambda **_: []), ), ): - resp, status = method(api) + resp, status = method(api, "tenant-1", current_user) assert status == 200 retrieval_model = resp["data"][0]["retrieval_model_dict"] @@ -345,10 +327,6 @@ class TestDatasetList: with app.test_request_context("/datasets"): with ( - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_datasets", @@ -360,7 +338,7 @@ class TestDatasetList: return_value=config, ), ): - resp, status = method(api) + resp, status = method(api, "tenant-1", current_user) assert resp["data"][0]["embedding_available"] is False @@ -373,10 +351,6 @@ class TestDatasetList: with app.test_request_context("/datasets"): with ( - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_datasets", @@ -392,7 +366,7 @@ class TestDatasetList: return_value=MagicMock(get_models=lambda **_: []), ), ): - resp, status = method(api) + resp, status = method(api, "tenant-1", current_user) assert resp["data"][0]["partial_member_list"] == ["u1"] @@ -409,25 +383,20 @@ class TestDatasetListApiPost: "provider": "vendor", } - user = MagicMock() - user.is_dataset_editor = True + user = make_account() dataset = make_dataset(name=payload["name"], description=payload["description"]) with ( app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch.object( DatasetService, "create_empty_dataset", return_value=dataset, ), ): - _, status = method(api) + _, status = method(api, "tenant-1", user) assert status == 201 @@ -437,19 +406,14 @@ class TestDatasetListApiPost: payload = {"name": "test"} - user = MagicMock() - user.is_dataset_editor = False + user = make_account(TenantAccountRole.NORMAL) with ( app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), ): with pytest.raises(Forbidden): - method(api) + method(api, "tenant-1", user) def test_post_duplicate_name(self, app: Flask): api = DatasetListApi() @@ -457,16 +421,11 @@ class TestDatasetListApiPost: payload = {"name": "duplicate"} - user = MagicMock() - user.is_dataset_editor = True + user = make_account() with ( app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch.object( DatasetService, "create_empty_dataset", @@ -474,7 +433,7 @@ class TestDatasetListApiPost: ), ): with pytest.raises(DatasetNameDuplicateError): - method(api) + method(api, "tenant-1", user) def test_post_invalid_payload_missing_name(self, app: Flask): api = DatasetListApi() @@ -482,7 +441,7 @@ class TestDatasetListApiPost: with app.test_request_context("/datasets", json={}), patch.object(type(console_ns), "payload", {}): with pytest.raises(ValueError): - method(api) + method(api, "tenant-1", make_account()) def test_post_invalid_indexing_technique(self, app: Flask): api = DatasetListApi() @@ -495,7 +454,7 @@ class TestDatasetListApiPost: with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload): with pytest.raises(ValueError, match="Invalid indexing technique"): - method(api) + method(api, "tenant-1", make_account()) def test_post_invalid_provider(self, app: Flask): api = DatasetListApi() @@ -508,7 +467,7 @@ class TestDatasetListApiPost: with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload): with pytest.raises(ValueError, match="Invalid provider"): - method(api) + method(api, "tenant-1", make_account()) class TestDatasetApiGet: @@ -518,17 +477,13 @@ class TestDatasetApiGet: dataset_id = "123e4567-e89b-12d3-a456-426614174000" - user = MagicMock() + user = make_account() tenant_id = "tenant-1" dataset = make_dataset(id=dataset_id) with ( app.test_request_context(f"/datasets/{dataset_id}"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(user, tenant_id), - ), patch.object( DatasetService, "get_dataset", @@ -544,7 +499,7 @@ class TestDatasetApiGet: # embedding models exist → embedding_available stays True provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] - data, status = method(api, dataset_id) + data, status = method(api, tenant_id, user, dataset_id) assert status == 200 assert data["embedding_available"] is True @@ -558,10 +513,6 @@ class TestDatasetApiGet: with ( app.test_request_context(f"/datasets/{dataset_id}"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant"), - ), patch.object( DatasetService, "get_dataset", @@ -576,7 +527,7 @@ class TestDatasetApiGet: ): provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] - data, status = method(api, dataset_id) + data, status = method(api, "tenant", make_account(), dataset_id) assert status == 200 assert data["external_retrieval_model"] == { @@ -593,10 +544,6 @@ class TestDatasetApiGet: with ( app.test_request_context(f"/datasets/{dataset_id}"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant"), - ), patch.object( DatasetService, "get_dataset", @@ -604,21 +551,17 @@ class TestDatasetApiGet: ), ): with pytest.raises(NotFound, match="Dataset not found"): - method(api, dataset_id) + method(api, "tenant", make_account(), dataset_id) def test_get_permission_denied(self, app: Flask): api = DatasetApi() method = unwrap(api.get) dataset_id = "dataset-id" - dataset = MagicMock() + dataset = make_dataset(id=dataset_id) with ( app.test_request_context(f"/datasets/{dataset_id}"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant"), - ), patch.object( DatasetService, "get_dataset", @@ -631,14 +574,14 @@ class TestDatasetApiGet: ), ): with pytest.raises(Forbidden, match="no access"): - method(api, dataset_id) + method(api, "tenant", make_account(), dataset_id) def test_get_high_quality_embedding_unavailable(self, app: Flask): api = DatasetApi() method = unwrap(api.get) dataset_id = "dataset-id" - user = MagicMock() + user = make_account() tenant_id = "tenant-1" dataset = make_dataset( @@ -650,10 +593,6 @@ class TestDatasetApiGet: with ( app.test_request_context(f"/datasets/{dataset_id}"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(user, tenant_id), - ), patch.object( DatasetService, "get_dataset", @@ -669,7 +608,7 @@ class TestDatasetApiGet: # embedding model NOT configured provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] - data, _ = method(api, dataset_id) + data, _ = method(api, tenant_id, user, dataset_id) assert data["embedding_available"] is False @@ -685,10 +624,6 @@ class TestDatasetApiGet: with ( app.test_request_context(f"/datasets/{dataset_id}"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant"), - ), patch.object( DatasetService, "get_dataset", @@ -708,7 +643,7 @@ class TestDatasetApiGet: ): provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] - data, _ = method(api, dataset_id) + data, _ = method(api, "tenant", make_account(), dataset_id) assert data["partial_member_list"] == partial_members @@ -725,7 +660,7 @@ class TestDatasetApiPatch: "description": "updated description", } - user = MagicMock() + user = make_account() tenant_id = "tenant-1" dataset = make_dataset(id=dataset_id, tenant_id=tenant_id) @@ -733,10 +668,6 @@ class TestDatasetApiPatch: with ( app.test_request_context(f"/datasets/{dataset_id}"), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(user, tenant_id), - ), patch.object( DatasetService, "get_dataset", @@ -758,7 +689,7 @@ class TestDatasetApiPatch: return_value=[], ), ): - result, status = method(api, dataset_id) + result, status = method(api, tenant_id, user, dataset_id) assert status == 200 assert result["partial_member_list"] == [] @@ -776,14 +707,14 @@ class TestDatasetApiPatch: ), ): with pytest.raises(NotFound, match="Dataset not found"): - method(api, "missing") + method(api, "tenant-1", make_account(), "missing") def test_patch_permission_denied(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) dataset_id = "dataset-id" - dataset = MagicMock() + dataset = make_dataset(id=dataset_id) payload = {"name": "x"} @@ -795,10 +726,6 @@ class TestDatasetApiPatch: "get_dataset", return_value=dataset, ), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant"), - ), patch.object( DatasetPermissionService, "check_permission", @@ -806,7 +733,7 @@ class TestDatasetApiPatch: ), ): with pytest.raises(Forbidden): - method(api, dataset_id) + method(api, "tenant", make_account(), dataset_id) def test_patch_partial_members_update(self, app: Flask): api = DatasetApi() @@ -824,10 +751,6 @@ class TestDatasetApiPatch: with ( app.test_request_context(f"/datasets/{dataset_id}"), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant"), - ), patch.object( DatasetService, "get_dataset", @@ -854,7 +777,7 @@ class TestDatasetApiPatch: return_value=["u1", "u2"], ), ): - result, _ = method(api, dataset_id) + result, _ = method(api, "tenant", make_account(), dataset_id) assert result["partial_member_list"] == ["u1", "u2"] @@ -873,10 +796,6 @@ class TestDatasetApiPatch: with ( app.test_request_context(f"/datasets/{dataset_id}"), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant"), - ), patch.object( DatasetService, "get_dataset", @@ -903,7 +822,7 @@ class TestDatasetApiPatch: return_value=[], ), ): - result, _ = method(api, dataset_id) + result, _ = method(api, "tenant", make_account(), dataset_id) assert result["partial_member_list"] == [] @@ -914,16 +833,10 @@ class TestDatasetApiDelete: method = unwrap(api.delete) dataset_id = "dataset-id" - user = MagicMock() - user.has_edit_permission = True - user.is_dataset_operator = False + user = make_account() with ( app.test_request_context(f"/datasets/{dataset_id}"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(user, "tenant"), - ), patch.object( DatasetService, "delete_dataset", @@ -935,7 +848,7 @@ class TestDatasetApiDelete: return_value=None, ), ): - result, status = method(api, dataset_id) + result, status = method(api, user, dataset_id) assert status == 204 assert result == "" @@ -945,35 +858,21 @@ class TestDatasetApiDelete: method = unwrap(api.delete) dataset_id = "dataset-id" - user = MagicMock() - user.has_edit_permission = False - user.is_dataset_operator = False + user = make_account(TenantAccountRole.NORMAL) - with ( - app.test_request_context(f"/datasets/{dataset_id}"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(user, "tenant"), - ), - ): + with app.test_request_context(f"/datasets/{dataset_id}"): with pytest.raises(Forbidden): - method(api, dataset_id) + method(api, user, dataset_id) def test_delete_dataset_not_found(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) dataset_id = "missing-dataset" - user = MagicMock() - user.has_edit_permission = True - user.is_dataset_operator = False + user = make_account() with ( app.test_request_context(f"/datasets/{dataset_id}"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(user, "tenant"), - ), patch.object( DatasetService, "delete_dataset", @@ -981,23 +880,17 @@ class TestDatasetApiDelete: ), ): with pytest.raises(NotFound, match="Dataset not found"): - method(api, dataset_id) + method(api, user, dataset_id) def test_delete_dataset_in_use(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) dataset_id = "dataset-id" - user = MagicMock() - user.has_edit_permission = True - user.is_dataset_operator = False + user = make_account() with ( app.test_request_context(f"/datasets/{dataset_id}"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(user, "tenant"), - ), patch.object( DatasetService, "delete_dataset", @@ -1005,7 +898,7 @@ class TestDatasetApiDelete: ), ): with pytest.raises(DatasetInUseError): - method(api, dataset_id) + method(api, user, dataset_id) class TestDatasetUseCheckApi: @@ -1076,19 +969,14 @@ class TestDatasetQueryApi: dataset_id = "dataset-id" - current_user = MagicMock() + current_user = make_account() - dataset = MagicMock() - dataset.id = dataset_id + dataset = make_dataset(id=dataset_id) queries = [self._query_record(1), self._query_record(2)] with ( app.test_request_context("/datasets/queries?page=1&limit=20"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_dataset", @@ -1105,7 +993,7 @@ class TestDatasetQueryApi: return_value=(queries, 2), ), ): - response, status = method(api, dataset_id) + response, status = method(api, current_user, dataset_id) assert status == 200 assert response["total"] == 2 @@ -1134,14 +1022,10 @@ class TestDatasetQueryApi: method = unwrap(api.get) dataset_id = "dataset-id" - current_user = MagicMock() + current_user = make_account() with ( app.test_request_context("/datasets/queries"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_dataset", @@ -1149,23 +1033,19 @@ class TestDatasetQueryApi: ), ): with pytest.raises(NotFound, match="Dataset not found"): - method(api, dataset_id) + method(api, current_user, dataset_id) def test_get_queries_permission_denied(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) dataset_id = "dataset-id" - current_user = MagicMock() + current_user = make_account() - dataset = MagicMock() + dataset = make_dataset(id=dataset_id) with ( app.test_request_context("/datasets/queries"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_dataset", @@ -1178,26 +1058,21 @@ class TestDatasetQueryApi: ), ): with pytest.raises(Forbidden): - method(api, dataset_id) + method(api, current_user, dataset_id) def test_get_queries_pagination_has_more(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) dataset_id = "dataset-id" - current_user = MagicMock() + current_user = make_account() - dataset = MagicMock() - dataset.id = dataset_id + dataset = make_dataset(id=dataset_id) queries = [self._query_record(index) for index in range(1, 21)] with ( app.test_request_context("/datasets/queries?page=1&limit=20"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(current_user, "tenant-1"), - ), patch.object( DatasetService, "get_dataset", @@ -1214,7 +1089,7 @@ class TestDatasetQueryApi: return_value=(queries, 40), ), ): - response, status = method(api, dataset_id) + response, status = method(api, current_user, dataset_id) assert status == 200 assert response["has_more"] is True @@ -1267,10 +1142,6 @@ class TestDatasetIndexingEstimateApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( type(console_ns), "payload", @@ -1290,7 +1161,7 @@ class TestDatasetIndexingEstimateApi: return_value=mock_response, ), ): - response, status = method(api) + response, status = method(api, "tenant-1") assert status == 200 assert response == {"tokens": 100} @@ -1303,10 +1174,6 @@ class TestDatasetIndexingEstimateApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( type(console_ns), "payload", @@ -1323,7 +1190,7 @@ class TestDatasetIndexingEstimateApi: ), ): with pytest.raises(NotFound): - method(api) + method(api, "tenant-1") def test_post_llm_bad_request_error(self, app: Flask): api = DatasetIndexingEstimateApi() @@ -1334,10 +1201,6 @@ class TestDatasetIndexingEstimateApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( type(console_ns), "payload", @@ -1358,7 +1221,7 @@ class TestDatasetIndexingEstimateApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(api) + method(api, "tenant-1") def test_post_provider_token_not_init(self, app: Flask): api = DatasetIndexingEstimateApi() @@ -1369,10 +1232,6 @@ class TestDatasetIndexingEstimateApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( type(console_ns), "payload", @@ -1393,7 +1252,7 @@ class TestDatasetIndexingEstimateApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(api) + method(api, "tenant-1") def test_post_generic_exception(self, app: Flask): api = DatasetIndexingEstimateApi() @@ -1404,10 +1263,6 @@ class TestDatasetIndexingEstimateApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( type(console_ns), "payload", @@ -1428,7 +1283,7 @@ class TestDatasetIndexingEstimateApi: ), ): with pytest.raises(IndexingEstimateError): - method(api) + method(api, "tenant-1") class TestDatasetRelatedAppListApi: @@ -1436,8 +1291,7 @@ class TestDatasetRelatedAppListApi: api = DatasetRelatedAppListApi() method = unwrap(api.get) - dataset = MagicMock() - dataset.id = "dataset-1" + dataset = make_dataset(id="dataset-1") app1 = make_related_app(id="app-1", name="App 1") app2 = make_related_app(id="app-2", name="App 2") @@ -1447,10 +1301,6 @@ class TestDatasetRelatedAppListApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=dataset, @@ -1464,7 +1314,7 @@ class TestDatasetRelatedAppListApi: return_value=[join1, join2], ), ): - response, status = method(api, "dataset-1") + response, status = method(api, make_account(), "dataset-1") assert status == 200 assert response["total"] == 2 @@ -1497,30 +1347,22 @@ class TestDatasetRelatedAppListApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=None, ), ): with pytest.raises(NotFound): - method(api, "dataset-1") + method(api, make_account(), "dataset-1") def test_get_permission_denied(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) - dataset = MagicMock() + dataset = make_dataset(id="dataset-1") with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=dataset, @@ -1531,14 +1373,13 @@ class TestDatasetRelatedAppListApi: ), ): with pytest.raises(Forbidden): - method(api, "dataset-1") + method(api, make_account(), "dataset-1") def test_get_filters_none_apps(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) - dataset = MagicMock() - dataset.id = "dataset-1" + dataset = make_dataset(id="dataset-1") app1 = make_related_app() @@ -1547,10 +1388,6 @@ class TestDatasetRelatedAppListApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=dataset, @@ -1564,7 +1401,7 @@ class TestDatasetRelatedAppListApi: return_value=[join1, join2], ), ): - response, status = method(api, "dataset-1") + response, status = method(api, make_account(), "dataset-1") assert status == 200 assert response["total"] == 1 @@ -1601,10 +1438,6 @@ class TestDatasetIndexingStatusApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.db.session.scalars", return_value=MagicMock(all=lambda: [document]), @@ -1614,7 +1447,7 @@ class TestDatasetIndexingStatusApi: return_value=3, ), ): - response, status = method(api, "dataset-1") + response, status = method(api, "tenant-1", "dataset-1") assert status == 200 assert "data" in response @@ -1630,16 +1463,12 @@ class TestDatasetIndexingStatusApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.db.session.scalars", return_value=MagicMock(all=lambda: []), ), ): - response, status = method(api, "dataset-1") + response, status = method(api, "tenant-1", "dataset-1") assert status == 200 assert response == {"data": []} @@ -1662,10 +1491,6 @@ class TestDatasetIndexingStatusApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.db.session.scalars", return_value=MagicMock(all=lambda: [document]), @@ -1675,7 +1500,7 @@ class TestDatasetIndexingStatusApi: side_effect=[2, 5], ), ): - response, status = method(api, "dataset-1") + response, status = method(api, "tenant-1", "dataset-1") assert status == 200 item = response["data"][0] @@ -1703,16 +1528,12 @@ class TestDatasetApiKeyApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.db.session.scalars", return_value=MagicMock(all=lambda: [mock_key_1, mock_key_2]), ), ): - response = method(api) + response = method(api, "tenant-1") assert "data" in response assert len(response["data"]) == 2 @@ -1736,10 +1557,6 @@ class TestDatasetApiKeyApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.db.session.scalar", return_value=3, @@ -1757,7 +1574,7 @@ class TestDatasetApiKeyApi: return_value=None, ), ): - response, status = method(api) + response, status = method(api, "tenant-1") assert status == 200 assert isinstance(response, dict) @@ -1772,17 +1589,13 @@ class TestDatasetApiKeyApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.db.session.scalar", return_value=10, ), ): with pytest.raises(BadRequest) as exc_info: - method(api) + method(api, "tenant-1") assert exc_info.value.code == 400 assert vars(exc_info.value)["data"] == { @@ -1800,10 +1613,6 @@ class TestDatasetApiDeleteApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.db.session.scalar", return_value=mock_key, @@ -1817,7 +1626,7 @@ class TestDatasetApiDeleteApi: return_value=None, ), ): - response, status = method(api, "api-key-id") + response, status = method(api, "tenant-1", "api-key-id") assert status == 204 assert response == "" @@ -1828,17 +1637,13 @@ class TestDatasetApiDeleteApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.db.session.scalar", return_value=None, ), ): with pytest.raises(NotFound): - method(api, "api-key-id") + method(api, "tenant-1", "api-key-id") class TestDatasetEnableApiApi: @@ -1965,7 +1770,7 @@ class TestDatasetErrorDocs: api = DatasetErrorDocs() method = unwrap(api.get) - dataset = MagicMock() + dataset = make_dataset(id="dataset-1") error_doc = make_document_status(id="error-doc", indexing_status=IndexingStatus.ERROR, error="failed") with ( @@ -2004,15 +1809,11 @@ class TestDatasetPermissionUserListApi: api = DatasetPermissionUserListApi() method = unwrap(api.get) - dataset = MagicMock() + dataset = make_dataset(id="dataset-1") users = ["u1", "u2"] with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=dataset, @@ -2026,7 +1827,7 @@ class TestDatasetPermissionUserListApi: return_value=users, ), ): - response, status = method(api, "dataset-1") + response, status = method(api, make_account(), "dataset-1") assert status == 200 assert response["data"] == users @@ -2035,14 +1836,10 @@ class TestDatasetPermissionUserListApi: api = DatasetPermissionUserListApi() method = unwrap(api.get) - dataset = MagicMock() + dataset = make_dataset(id="dataset-1") with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=dataset, @@ -2053,7 +1850,7 @@ class TestDatasetPermissionUserListApi: ), ): with pytest.raises(Forbidden): - method(api, "dataset-1") + method(api, make_account(), "dataset-1") class TestDatasetAutoDisableLogApi: @@ -2061,7 +1858,7 @@ class TestDatasetAutoDisableLogApi: api = DatasetAutoDisableLogApi() method = unwrap(api.get) - dataset = MagicMock() + dataset = make_dataset(id="dataset-1") logs = {"document_ids": ["doc-1"], "count": 1} with ( diff --git a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py index e43ee84bb0..cfc0299cc2 100644 --- a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py +++ b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import datetime +from inspect import unwrap from types import SimpleNamespace from unittest.mock import Mock @@ -8,21 +9,34 @@ import pytest from werkzeug.exceptions import HTTPException, NotFound from controllers.console.snippets import snippet_workflow as snippet_workflow_module +from models.account import Account, TenantAccountRole +from models.snippet import CustomizedSnippet -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func +def _account(account_id: str = "account-1") -> Account: + account = Account(name="Test User", email=f"{account_id}@example.com") + account.id = account_id + account.role = TenantAccountRole.EDITOR + return account + + +def _snippet(**overrides) -> CustomizedSnippet: + data = { + "id": "snippet-1", + "tenant_id": "tenant-1", + "name": "Snippet", + "description": "Description", + "type": "node", + "created_by": "account-1", + } + data.update(overrides) + return CustomizedSnippet(**data) @pytest.fixture(autouse=True) def _patch_snippet_service_factory(monkeypatch: pytest.MonkeyPatch) -> None: def factory(): - service_factory = snippet_workflow_module.SnippetService - if isinstance(service_factory, type): - return service_factory.__new__(service_factory) - return service_factory() + return snippet_workflow_module.SnippetService() monkeypatch.setattr(snippet_workflow_module, "_snippet_service", factory) monkeypatch.setattr(snippet_workflow_module, "_snippet_session_maker", Mock(return_value=Mock())) @@ -39,7 +53,7 @@ def test_get_snippet_requires_snippet_id(app): def test_get_snippet_injects_resolved_snippet(app, monkeypatch: pytest.MonkeyPatch) -> None: - snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + snippet = _snippet() @snippet_workflow_module.get_snippet def view(**kwargs): @@ -48,7 +62,7 @@ def test_get_snippet_injects_resolved_snippet(app, monkeypatch: pytest.MonkeyPat monkeypatch.setattr( snippet_workflow_module, "current_account_with_tenant", - lambda: (SimpleNamespace(id="account-1"), "tenant-1"), + lambda: (_account("account-1"), "tenant-1"), ) monkeypatch.setattr(snippet_workflow_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet)) @@ -66,7 +80,7 @@ def test_get_snippet_raises_not_found_when_snippet_missing(app, monkeypatch: pyt monkeypatch.setattr( snippet_workflow_module, "current_account_with_tenant", - lambda: (SimpleNamespace(id="account-1"), "tenant-1"), + lambda: (_account("account-1"), "tenant-1"), ) monkeypatch.setattr(snippet_workflow_module.SnippetService, "get_snippet_by_id", Mock(return_value=None)) @@ -76,7 +90,7 @@ def test_get_snippet_raises_not_found_when_snippet_missing(app, monkeypatch: pyt def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyPatch) -> None: - snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + snippet = _snippet() monkeypatch.setattr( snippet_workflow_module, "SnippetService", @@ -84,7 +98,7 @@ def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyP ) api = snippet_workflow_module.SnippetDraftWorkflowApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/snippets/snippet-1/workflows/draft"): with pytest.raises(snippet_workflow_module.DraftWorkflowNotExist): @@ -92,10 +106,9 @@ def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyP def test_draft_workflow_post_returns_400_for_invalid_graph(app, monkeypatch: pytest.MonkeyPatch) -> None: - user = SimpleNamespace(id="account-1") - snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + user = _account("account-1") + snippet = _snippet() sync_draft_workflow = Mock(side_effect=ValueError("invalid graph")) - monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1")) monkeypatch.setattr( snippet_workflow_module, "SnippetService", @@ -103,14 +116,14 @@ def test_draft_workflow_post_returns_400_for_invalid_graph(app, monkeypatch: pyt ) api = snippet_workflow_module.SnippetDraftWorkflowApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/snippets/snippet-1/workflows/draft", method="POST", json={"graph": {"nodes": [], "edges": []}, "hash": "hash-1"}, ): - response, status_code = handler(api, snippet=snippet) + response, status_code = handler(api, user, snippet) assert status_code == 400 assert response == {"message": "invalid graph"} @@ -118,7 +131,7 @@ def test_draft_workflow_post_returns_400_for_invalid_graph(app, monkeypatch: pyt def test_draft_config_returns_parallel_depth_limit(app) -> None: api = snippet_workflow_module.SnippetDraftConfigApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/snippets/snippet-1/workflows/draft/config"): assert handler(api, snippet=SimpleNamespace(id="snippet-1")) == {"parallel_depth_limit": 3} @@ -126,16 +139,16 @@ def test_draft_config_returns_parallel_depth_limit(app) -> None: def test_published_workflow_get_returns_none_when_not_published(app) -> None: api = snippet_workflow_module.SnippetPublishedWorkflowApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/snippets/snippet-1/workflows/publish"): assert handler(api, snippet=SimpleNamespace(id="snippet-1", is_published=False)) is None def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch: pytest.MonkeyPatch) -> None: - user = SimpleNamespace(id="account-1") - snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") - merged_snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + user = _account("account-1") + snippet = _snippet() + merged_snippet = _snippet() session = SimpleNamespace(merge=Mock(return_value=merged_snippet), commit=Mock()) class SessionContext: @@ -148,7 +161,6 @@ def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch def __exit__(self, exc_type, exc, tb): return False - monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1")) monkeypatch.setattr(snippet_workflow_module, "Session", SessionContext) monkeypatch.setattr(snippet_workflow_module, "db", SimpleNamespace(engine=object())) monkeypatch.setattr( @@ -158,10 +170,10 @@ def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch ) api = snippet_workflow_module.SnippetPublishedWorkflowApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context("/snippets/snippet-1/workflows/publish", method="POST", json={}): - response, status_code = handler(api, snippet=snippet) + response, status_code = handler(api, user, snippet) assert status_code == 400 assert response == {"message": "No valid workflow found."} @@ -177,7 +189,7 @@ def test_default_block_configs_delegates_to_service(app, monkeypatch: pytest.Mon ) api = snippet_workflow_module.SnippetDefaultBlockConfigsApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/snippets/snippet-1/workflows/default-workflow-block-configs"): result = handler(api, snippet=SimpleNamespace(id="snippet-1")) @@ -192,10 +204,9 @@ def test_restore_published_snippet_workflow_to_draft_success(app, monkeypatch: p updated_at=None, created_at=datetime(2024, 1, 1), ) - user = SimpleNamespace(id="account-1") - snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + user = _account("account-1") + snippet = _snippet() - monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1")) monkeypatch.setattr( snippet_workflow_module, "SnippetService", @@ -203,23 +214,22 @@ def test_restore_published_snippet_workflow_to_draft_success(app, monkeypatch: p ) api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/snippets/snippet-1/workflows/published-workflow/restore", method="POST", ): - response = handler(api, snippet=snippet, workflow_id="published-workflow") + response = handler(api, user, snippet, workflow_id="published-workflow") assert response["result"] == "success" assert response["hash"] == "restored-hash" def test_restore_published_snippet_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: - user = SimpleNamespace(id="account-1") - snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + user = _account("account-1") + snippet = _snippet() - monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1")) monkeypatch.setattr( snippet_workflow_module, "SnippetService", @@ -231,23 +241,22 @@ def test_restore_published_snippet_workflow_to_draft_not_found(app, monkeypatch: ) api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/snippets/snippet-1/workflows/published-workflow/restore", method="POST", ): with pytest.raises(NotFound): - handler(api, snippet=snippet, workflow_id="published-workflow") + handler(api, user, snippet, workflow_id="published-workflow") def test_restore_published_snippet_workflow_to_draft_returns_400_for_draft_source( app, monkeypatch: pytest.MonkeyPatch ) -> None: - user = SimpleNamespace(id="account-1") - snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + user = _account("account-1") + snippet = _snippet() - monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1")) monkeypatch.setattr( snippet_workflow_module, "SnippetService", @@ -259,14 +268,14 @@ def test_restore_published_snippet_workflow_to_draft_returns_400_for_draft_sourc ) api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/snippets/snippet-1/workflows/draft-workflow/restore", method="POST", ): with pytest.raises(HTTPException) as exc: - handler(api, snippet=snippet, workflow_id="draft-workflow") + handler(api, user, snippet, workflow_id="draft-workflow") assert exc.value.code == 400 assert exc.value.description == snippet_workflow_module.RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE @@ -275,10 +284,9 @@ def test_restore_published_snippet_workflow_to_draft_returns_400_for_draft_sourc def test_restore_published_snippet_workflow_to_draft_returns_400_for_invalid_graph( app, monkeypatch: pytest.MonkeyPatch ) -> None: - user = SimpleNamespace(id="account-1") - snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + user = _account("account-1") + snippet = _snippet() - monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1")) monkeypatch.setattr( snippet_workflow_module, "SnippetService", @@ -290,21 +298,21 @@ def test_restore_published_snippet_workflow_to_draft_returns_400_for_invalid_gra ) api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/snippets/snippet-1/workflows/published-workflow/restore", method="POST", ): with pytest.raises(HTTPException) as exc: - handler(api, snippet=snippet, workflow_id="published-workflow") + handler(api, user, snippet, workflow_id="published-workflow") assert exc.value.code == 400 assert exc.value.description == "invalid snippet workflow graph" def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch: pytest.MonkeyPatch) -> None: - snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + snippet = _snippet() monkeypatch.setattr( snippet_workflow_module, "SnippetService", @@ -312,7 +320,7 @@ def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch: ) api = snippet_workflow_module.SnippetWorkflowRunDetailApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/snippets/snippet-1/workflow-runs/run-1"): with pytest.raises(NotFound, match="Workflow run not found"): @@ -320,7 +328,7 @@ def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch: def test_draft_node_last_run_raises_not_found_when_execution_missing(app, monkeypatch: pytest.MonkeyPatch) -> None: - snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1") + snippet = _snippet() draft_workflow = SimpleNamespace(id="workflow-1") monkeypatch.setattr( snippet_workflow_module, @@ -332,7 +340,7 @@ def test_draft_node_last_run_raises_not_found_when_execution_missing(app, monkey ) api = snippet_workflow_module.SnippetDraftNodeLastRunApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/snippets/snippet-1/workflows/draft/nodes/llm-1/last-run"): with pytest.raises(NotFound, match="Node last run not found"): @@ -354,7 +362,7 @@ def test_workflow_task_stop_uses_queue_flag_and_graph_command(app, monkeypatch: ) api = snippet_workflow_module.SnippetWorkflowTaskStopApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context("/snippets/snippet-1/workflow-runs/tasks/task-1/stop", method="POST"): result = handler(api, snippet=SimpleNamespace(id="snippet-1"), task_id="task-1") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py index 83915a0b74..da010558bc 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py @@ -1,4 +1,5 @@ import io +from inspect import unwrap from unittest.mock import MagicMock, patch import pytest @@ -39,21 +40,19 @@ from controllers.console.workspace.plugin import ( PluginUploadFromPkgApi, ) from core.plugin.impl.exc import PluginDaemonClientSideError -from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission +from models.account import Account, TenantAccountRole, TenantPluginAutoUpgradeStrategy, TenantPluginPermission -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func +def _account(role: TenantAccountRole = TenantAccountRole.OWNER) -> Account: + account = Account(name="Test User", email="u1@example.com") + account.id = "u1" + account.role = role + return account @pytest.fixture def user(): - u = MagicMock() - u.id = "u1" - u.is_admin_or_owner = True - return u + return _account() @pytest.fixture @@ -102,10 +101,9 @@ class TestPluginDebuggingKeyApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.get_debugging_key", return_value="k"), ): - result = method(api) + result = method(api, "t1") assert result["key"] == "k" @@ -115,13 +113,12 @@ class TestPluginDebuggingKeyApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.get_debugging_key", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api) + result = method(api, "t1") assert result == ({"code": "plugin_error", "message": "error"}, 400) @@ -134,10 +131,9 @@ class TestPluginListApi: with ( app.test_request_context("/?page=1&page_size=10"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.list_with_total", return_value=mock_list), ): - result = method(api) + result = method(api, "t1") assert result["total"] == 1 @@ -163,10 +159,9 @@ class TestPluginAssetApi: with ( app.test_request_context("/?plugin_unique_identifier=p&file_name=a.bin"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.extract_asset", return_value=b"x"), ): - response = method(api) + response = method(api, "t1") assert response.mimetype == "application/octet-stream" @@ -182,10 +177,9 @@ class TestPluginUploadFromPkgApi: with ( app.test_request_context("/", data=data, content_type="multipart/form-data"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.upload_pkg", return_value={"ok": True}), ): - result = method(api) + result = method(api, "t1") assert result["ok"] is True @@ -199,12 +193,11 @@ class TestPluginUploadFromPkgApi: with ( app.test_request_context("/", data=data, content_type="multipart/form-data"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 0), patch("controllers.console.workspace.plugin.PluginService.upload_pkg") as upload_pkg_mock, ): with pytest.raises(ValueError) as exc_info: - method(api) + method(api, "t1") assert "File size exceeds the maximum allowed size" in str(exc_info.value) upload_pkg_mock.assert_not_called() @@ -219,12 +212,11 @@ class TestPluginInstallFromPkgApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.install_from_local_pkg", return_value={"ok": True} ), ): - result = method(api) + result = method(api, "t1") assert result["ok"] is True @@ -238,10 +230,9 @@ class TestPluginUninstallApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.uninstall", return_value=True), ): - result = method(api) + result = method(api, "t1") assert result["success"] is True @@ -251,7 +242,7 @@ class TestPluginChangePermissionApi: api = PluginChangePermissionApi() method = unwrap(api.post) - user = MagicMock(is_admin_or_owner=False) + user = _account(TenantAccountRole.NORMAL) payload = { "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, @@ -260,16 +251,15 @@ class TestPluginChangePermissionApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), ): with pytest.raises(Forbidden): - method(api) + method(api, "t1", user) def test_change_permission_success(self, app: Flask): api = PluginChangePermissionApi() method = unwrap(api.post) - user = MagicMock(is_admin_or_owner=True) + user = _account() payload = { "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, @@ -278,10 +268,9 @@ class TestPluginChangePermissionApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=True), ): - result = method(api) + result = method(api, "t1", user) assert result["success"] is True @@ -293,10 +282,9 @@ class TestPluginFetchPermissionApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginPermissionService.get_permission", return_value=None), ): - result = method(api) + result = method(api, "t1") assert result["install_permission"] is not None @@ -308,13 +296,12 @@ class TestPluginFetchDynamicSelectOptionsApi: with ( app.test_request_context("/?plugin_id=p&provider=x&action=y¶meter=z&provider_type=tool"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), patch( "controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options", return_value=[1, 2], ), ): - result = method(api) + result = method(api, "t1", user) assert result["options"] == [1, 2] @@ -326,16 +313,15 @@ class TestPluginReadmeApi: with ( app.test_request_context("/?plugin_unique_identifier=p"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.fetch_plugin_readme", return_value="readme"), ): - result = method(api) + result = method(api, "t1") assert result["readme"] == "readme" class TestPluginListInstallationsFromIdsApi: - def test_success(self, app: Flask): + def test_success(self, app: Flask, user): api = PluginListInstallationsFromIdsApi() method = unwrap(api.post) @@ -343,13 +329,12 @@ class TestPluginListInstallationsFromIdsApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.list_installations_from_ids", return_value=[{"id": "p1"}], ), ): - result = method(api) + result = method(api, "t1") assert "plugins" in result @@ -361,18 +346,17 @@ class TestPluginListInstallationsFromIdsApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.list_installations_from_ids", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api) + result = method(api, "t1") assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUploadFromGithubApi: - def test_success(self, app: Flask): + def test_success(self, app: Flask, user): api = PluginUploadFromGithubApi() method = unwrap(api.post) @@ -380,12 +364,11 @@ class TestPluginUploadFromGithubApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.upload_pkg_from_github", return_value={"ok": True} ), ): - result = method(api) + result = method(api, "t1") assert result["ok"] is True @@ -397,13 +380,12 @@ class TestPluginUploadFromGithubApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.upload_pkg_from_github", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api) + result = method(api, "t1") assert result == ({"code": "plugin_error", "message": "error"}, 400) @@ -424,10 +406,9 @@ class TestPluginUploadFromBundleApi: data={"bundle": file}, content_type="multipart/form-data", ), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.upload_bundle", return_value={"ok": True}), ): - result = method(api) + result = method(api, "t1") assert result["ok"] is True @@ -447,12 +428,11 @@ class TestPluginUploadFromBundleApi: data={"bundle": file}, content_type="multipart/form-data", ), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_BUNDLE_SIZE", 0), patch("controllers.console.workspace.plugin.PluginService.upload_bundle") as upload_bundle_mock, ): with pytest.raises(ValueError) as exc_info: - method(api) + method(api, "t1") assert "File size exceeds the maximum allowed size" in str(exc_info.value) upload_bundle_mock.assert_not_called() @@ -472,10 +452,9 @@ class TestPluginInstallFromGithubApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.install_from_github", return_value={"ok": True}), ): - result = method(api) + result = method(api, "t1") assert result["ok"] is True @@ -492,13 +471,12 @@ class TestPluginInstallFromGithubApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.install_from_github", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api) + result = method(api, "t1") assert result == ({"code": "plugin_error", "message": "error"}, 400) @@ -511,13 +489,12 @@ class TestPluginInstallFromMarketplaceApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.install_from_marketplace_pkg", return_value={"ok": True}, ), ): - result = method(api) + result = method(api, "t1") assert result["ok"] is True @@ -529,13 +506,12 @@ class TestPluginInstallFromMarketplaceApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.install_from_marketplace_pkg", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api) + result = method(api, "t1") assert result == ({"code": "plugin_error", "message": "error"}, 400) @@ -546,10 +522,9 @@ class TestPluginFetchMarketplacePkgApi: with ( app.test_request_context("/?plugin_unique_identifier=p"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.fetch_marketplace_pkg", return_value={"m": 1}), ): - result = method(api) + result = method(api, "t1") assert "manifest" in result @@ -559,13 +534,12 @@ class TestPluginFetchMarketplacePkgApi: with ( app.test_request_context("/?plugin_unique_identifier=p"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.fetch_marketplace_pkg", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api) + result = method(api, "t1") assert result == ({"code": "plugin_error", "message": "error"}, 400) @@ -579,10 +553,9 @@ class TestPluginFetchManifestApi: with ( app.test_request_context("/?plugin_unique_identifier=p"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.fetch_plugin_manifest", return_value=manifest), ): - result = method(api) + result = method(api, "t1") assert "manifest" in result @@ -592,13 +565,12 @@ class TestPluginFetchManifestApi: with ( app.test_request_context("/?plugin_unique_identifier=p"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.fetch_plugin_manifest", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api) + result = method(api, "t1") assert result == ({"code": "plugin_error", "message": "error"}, 400) @@ -609,10 +581,9 @@ class TestPluginFetchInstallTasksApi: with ( app.test_request_context("/?page=1&page_size=10"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.fetch_install_tasks", return_value=[{"id": 1}]), ): - result = method(api) + result = method(api, "t1") assert "tasks" in result @@ -622,13 +593,12 @@ class TestPluginFetchInstallTasksApi: with ( app.test_request_context("/?page=1&page_size=10"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.fetch_install_tasks", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api) + result = method(api, "t1") assert result == ({"code": "plugin_error", "message": "error"}, 400) @@ -639,10 +609,9 @@ class TestPluginFetchInstallTaskApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.fetch_install_task", return_value={"id": "x"}), ): - result = method(api, "x") + result = method(api, "t1", "x") assert "task" in result @@ -652,13 +621,12 @@ class TestPluginFetchInstallTaskApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.fetch_install_task", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api, "t") + result = method(api, "t1", "t") assert result == ({"code": "plugin_error", "message": "error"}, 400) @@ -669,10 +637,9 @@ class TestPluginDeleteInstallTaskApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.delete_install_task", return_value=True), ): - result = method(api, "x") + result = method(api, "t1", "x") assert result["success"] is True @@ -682,13 +649,12 @@ class TestPluginDeleteInstallTaskApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.delete_install_task", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api, "t") + result = method(api, "t1", "t") assert result == ({"code": "plugin_error", "message": "error"}, 400) @@ -699,12 +665,11 @@ class TestPluginDeleteAllInstallTaskItemsApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.delete_all_install_task_items", return_value=True ), ): - result = method(api) + result = method(api, "t1") assert result["success"] is True @@ -714,13 +679,12 @@ class TestPluginDeleteAllInstallTaskItemsApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.delete_all_install_task_items", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api) + result = method(api, "t1") assert result == ({"code": "plugin_error", "message": "error"}, 400) @@ -731,10 +695,9 @@ class TestPluginDeleteInstallTaskItemApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginService.delete_install_task_item", return_value=True), ): - result = method(api, "task1", "item1") + result = method(api, "t1", "task1", "item1") assert result["success"] is True @@ -744,18 +707,17 @@ class TestPluginDeleteInstallTaskItemApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.delete_install_task_item", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api, "task1", "item1") + result = method(api, "t1", "task1", "item1") assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUpgradeFromMarketplaceApi: - def test_success(self, app: Flask): + def test_success(self, app: Flask, user): api = PluginUpgradeFromMarketplaceApi() method = unwrap(api.post) @@ -766,13 +728,12 @@ class TestPluginUpgradeFromMarketplaceApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_marketplace", return_value={"ok": True}, ), ): - result = method(api) + result = method(api, "t1") assert result["ok"] is True @@ -787,13 +748,12 @@ class TestPluginUpgradeFromMarketplaceApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_marketplace", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api) + result = method(api, "t1") assert result == ({"code": "plugin_error", "message": "error"}, 400) @@ -812,13 +772,12 @@ class TestPluginUpgradeFromGithubApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_github", return_value={"ok": True}, ), ): - result = method(api) + result = method(api, "t1") assert result["ok"] is True @@ -836,23 +795,20 @@ class TestPluginUpgradeFromGithubApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_github", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api) + result = method(api, "t1") assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: - def test_success(self, app: Flask): + def test_success(self, app: Flask, user): api = PluginFetchDynamicSelectOptionsWithCredentialsApi() method = unwrap(api.post) - user = MagicMock(id="u1", is_admin_or_owner=True) - payload = { "plugin_id": "p", "provider": "x", @@ -864,22 +820,19 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), patch( "controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options_with_credentials", return_value=[1], ), ): - result = method(api) + result = method(api, "t1", user) assert result["options"] == [1] - def test_daemon_error(self, app: Flask): + def test_daemon_error(self, app: Flask, user): api = PluginFetchDynamicSelectOptionsWithCredentialsApi() method = unwrap(api.post) - user = MagicMock(id="u1", is_admin_or_owner=True) - payload = { "plugin_id": "p", "provider": "x", @@ -891,13 +844,12 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), patch( "controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options_with_credentials", side_effect=PluginDaemonClientSideError("error"), ), ): - result = method(api) + result = method(api, "t1", user) assert result == ({"code": "plugin_error", "message": "error"}, 400) @@ -906,7 +858,7 @@ class TestPluginChangePreferencesApi: api = PluginChangePreferencesApi() method = unwrap(api.post) - user = MagicMock(is_admin_or_owner=True) + user = _account() payload = { "permission": { @@ -924,11 +876,10 @@ class TestPluginChangePreferencesApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=True), patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.change_strategy", return_value=True), ): - result = method(api) + result = method(api, "t1", user) assert result["success"] is True @@ -936,7 +887,7 @@ class TestPluginChangePreferencesApi: api = PluginChangePreferencesApi() method = unwrap(api.post) - user = MagicMock(is_admin_or_owner=True) + user = _account() payload = { "permission": { @@ -954,10 +905,9 @@ class TestPluginChangePreferencesApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=False), ): - result = method(api) + result = method(api, "t1", user) assert result["success"] is False @@ -982,7 +932,6 @@ class TestPluginFetchPreferencesApi: with ( app.test_request_context("/"), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch( "controllers.console.workspace.plugin.PluginPermissionService.get_permission", return_value=permission ), @@ -990,7 +939,7 @@ class TestPluginFetchPreferencesApi: "controllers.console.workspace.plugin.PluginAutoUpgradeService.get_strategy", return_value=auto_upgrade ), ): - result = method(api) + result = method(api, "t1") assert "permission" in result assert "auto_upgrade" in result @@ -1005,10 +954,9 @@ class TestPluginAutoUpgradeExcludePluginApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.exclude_plugin", return_value=True), ): - result = method(api) + result = method(api, "t1") assert result["success"] is True @@ -1020,9 +968,8 @@ class TestPluginAutoUpgradeExcludePluginApi: with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.exclude_plugin", return_value=False), ): - result = method(api) + result = method(api, "t1") assert result["success"] is False diff --git a/api/tests/unit_tests/controllers/console/workspace/test_snippets.py b/api/tests/unit_tests/controllers/console/workspace/test_snippets.py index a712a46404..b8914fc26c 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_snippets.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_snippets.py @@ -1,3 +1,4 @@ +from inspect import unwrap from types import SimpleNamespace from unittest.mock import Mock @@ -5,15 +6,11 @@ import pytest from werkzeug.exceptions import NotFound from controllers.console.workspace import snippets as snippets_module +from models.account import Account, TenantAccountRole +from models.snippet import CustomizedSnippet from services.snippet_dsl_service import ImportStatus, SnippetImportInfo -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - @pytest.fixture(autouse=True) def _patch_snippet_service_factory(monkeypatch): def factory(): @@ -34,6 +31,26 @@ class _SessionContext: return False +def _account(account_id: str = "account-1") -> Account: + account = Account(name="Test User", email=f"{account_id}@example.com") + account.id = account_id + account.role = TenantAccountRole.EDITOR + return account + + +def _snippet(**overrides) -> CustomizedSnippet: + data = { + "id": "snippet-1", + "tenant_id": "tenant-1", + "name": "Snippet", + "description": "Description", + "type": snippets_module.SnippetType.NODE, + "created_by": "account-1", + } + data.update(overrides) + return CustomizedSnippet(**data) + + def test_normalize_snippet_list_query_args_sorts_indexed_values(): query_args = snippets_module.MultiDict( [ @@ -53,21 +70,19 @@ def test_normalize_snippet_list_query_args_sorts_indexed_values(): def test_list_snippets_returns_pagination(app, monkeypatch): - user = SimpleNamespace(id="account-1") - snippets = [SimpleNamespace(id="snippet-1")] + snippets = [_snippet()] tag_id = "11111111-1111-1111-1111-111111111111" get_snippets = Mock(return_value=(snippets, 1, False)) - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1")) monkeypatch.setattr(snippets_module.SnippetService, "get_snippets", get_snippets) monkeypatch.setattr(snippets_module, "marshal", Mock(return_value=[{"id": "snippet-1"}])) api = snippets_module.CustomizedSnippetsApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context( f"/workspaces/current/customized-snippets?page=2&limit=10&tag_ids[0]={tag_id}&creator_ids[0]=account-2" ): - response, status_code = handler(api) + response, status_code = handler(api, "tenant-1") assert status_code == 200 assert response == { @@ -89,10 +104,9 @@ def test_list_snippets_returns_pagination(app, monkeypatch): def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypatch): - user = SimpleNamespace(id="account-1") - snippet = SimpleNamespace(id="snippet-1") + user = _account("account-1") + snippet = _snippet() create_snippet = Mock(return_value=snippet) - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1")) monkeypatch.setattr(snippets_module.SnippetService, "create_snippet", create_snippet) monkeypatch.setattr( snippets_module.CreateSnippetPayload, @@ -111,14 +125,14 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypat monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"})) api = snippets_module.CustomizedSnippetsApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/workspaces/current/customized-snippets", method="POST", json={"name": "Snippet", "type": "node", "description": "Description"}, ): - response, status_code = handler(api) + response, status_code = handler(api, "tenant-1", user) assert status_code == 201 assert response == {"id": "snippet-1"} @@ -126,13 +140,12 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypat def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch): - user = SimpleNamespace(id="account-1") + user = _account("account-1") create_snippet = Mock() - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1")) monkeypatch.setattr(snippets_module.SnippetService, "create_snippet", create_snippet) api = snippets_module.CustomizedSnippetsApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/workspaces/current/customized-snippets", @@ -148,7 +161,7 @@ def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch): }, }, ): - response, status_code = handler(api) + response, status_code = handler(api, "tenant-1", user) assert status_code == 400 assert "knowledge-retrieval" in response["message"] @@ -156,60 +169,54 @@ def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch): def test_get_snippet_detail_raises_when_missing(app, monkeypatch): - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None)) api = snippets_module.CustomizedSnippetDetailApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/workspaces/current/customized-snippets/snippet-1"): with pytest.raises(NotFound, match="Snippet not found"): - handler(api, snippet_id="snippet-1") + handler(api, "tenant-1", snippet_id="snippet-1") def test_get_snippet_detail_returns_snippet(app, monkeypatch): - snippet = SimpleNamespace(id="snippet-1") - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) + snippet = _snippet() monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet)) monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"})) api = snippets_module.CustomizedSnippetDetailApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/workspaces/current/customized-snippets/snippet-1"): - response, status_code = handler(api, snippet_id="snippet-1") + response, status_code = handler(api, "tenant-1", snippet_id="snippet-1") assert status_code == 200 assert response == {"id": "snippet-1"} def test_patch_snippet_returns_400_for_empty_payload(app, monkeypatch): - snippet = SimpleNamespace(id="snippet-1") - monkeypatch.setattr( - snippets_module, - "current_account_with_tenant", - lambda: (SimpleNamespace(id="user-1"), "tenant-1"), - ) + snippet = _snippet() + user = _account("user-1") monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet)) api = snippets_module.CustomizedSnippetDetailApi() - handler = _unwrap(api.patch) + handler = unwrap(api.patch) with app.test_request_context( "/workspaces/current/customized-snippets/snippet-1", method="PATCH", json={}, ): - response, status_code = handler(api, snippet_id="snippet-1") + response, status_code = handler(api, "tenant-1", user, snippet_id="snippet-1") assert status_code == 400 assert response == {"message": "No valid fields to update"} def test_patch_snippet_updates_and_commits(app, monkeypatch): - user = SimpleNamespace(id="account-1") - snippet = SimpleNamespace(id="snippet-1") - updated_snippet = SimpleNamespace(id="snippet-1", name="New") + user = _account("account-1") + snippet = _snippet() + updated_snippet = _snippet(name="New") session = SimpleNamespace(merge=Mock(return_value=snippet), commit=Mock()) update_snippet = Mock(return_value=updated_snippet) @@ -217,7 +224,6 @@ def test_patch_snippet_updates_and_commits(app, monkeypatch): def __init__(self, engine, *args, **kwargs): super().__init__(engine, *args, session=session, **kwargs) - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1")) monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet)) monkeypatch.setattr(snippets_module.SnippetService, "update_snippet", update_snippet) monkeypatch.setattr(snippets_module, "Session", SessionContext) @@ -225,14 +231,14 @@ def test_patch_snippet_updates_and_commits(app, monkeypatch): monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1", "name": "New"})) api = snippets_module.CustomizedSnippetDetailApi() - handler = _unwrap(api.patch) + handler = unwrap(api.patch) with app.test_request_context( "/workspaces/current/customized-snippets/snippet-1", method="PATCH", json={"name": "New", "icon_info": {"icon": "star"}}, ): - response, status_code = handler(api, snippet_id="snippet-1") + response, status_code = handler(api, "tenant-1", user, snippet_id="snippet-1") assert status_code == 200 assert response == {"id": "snippet-1", "name": "New"} @@ -245,7 +251,7 @@ def test_patch_snippet_updates_and_commits(app, monkeypatch): def test_delete_snippet_deletes_and_commits(app, monkeypatch): - snippet = SimpleNamespace(id="snippet-1") + snippet = _snippet() session = SimpleNamespace(merge=Mock(return_value=snippet), commit=Mock()) delete_snippet = Mock() @@ -253,17 +259,16 @@ def test_delete_snippet_deletes_and_commits(app, monkeypatch): def __init__(self, engine, *args, **kwargs): super().__init__(engine, *args, session=session, **kwargs) - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet)) monkeypatch.setattr(snippets_module.SnippetService, "delete_snippet", delete_snippet) monkeypatch.setattr(snippets_module, "Session", SessionContext) monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object())) api = snippets_module.CustomizedSnippetDetailApi() - handler = _unwrap(api.delete) + handler = unwrap(api.delete) with app.test_request_context("/workspaces/current/customized-snippets/snippet-1", method="DELETE"): - response, status_code = handler(api, snippet_id="snippet-1") + response, status_code = handler(api, "tenant-1", snippet_id="snippet-1") assert status_code == 204 assert response == "" @@ -272,7 +277,7 @@ def test_delete_snippet_deletes_and_commits(app, monkeypatch): def test_export_snippet_returns_yaml_attachment(app, monkeypatch): - snippet = SimpleNamespace(id="snippet-1", name="Snippet One") + snippet = _snippet(name="Snippet One") export_snippet_dsl = Mock(return_value="version: 0.1.0\nkind: snippet\n") session = SimpleNamespace() @@ -280,7 +285,6 @@ def test_export_snippet_returns_yaml_attachment(app, monkeypatch): def __init__(self, engine, *args, **kwargs): super().__init__(engine, *args, session=session, **kwargs) - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet)) monkeypatch.setattr( snippets_module, @@ -291,10 +295,10 @@ def test_export_snippet_returns_yaml_attachment(app, monkeypatch): monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object())) api = snippets_module.CustomizedSnippetExportApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/workspaces/current/customized-snippets/snippet-1/export?include_secret=true"): - response = handler(api, snippet_id="snippet-1") + response = handler(api, "tenant-1", snippet_id="snippet-1") assert response.status_code == 200 assert response.get_data(as_text=True) == "version: 0.1.0\nkind: snippet\n" @@ -304,7 +308,7 @@ def test_export_snippet_returns_yaml_attachment(app, monkeypatch): def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch): - user = SimpleNamespace(id="account-1") + user = _account("account-1") result = SnippetImportInfo(id="import-1", status=ImportStatus.PENDING, imported_dsl_version="999.0.0") import_snippet = Mock(return_value=result) session = SimpleNamespace(commit=Mock()) @@ -319,7 +323,6 @@ def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch): def __exit__(self, exc_type, exc, tb): return False - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1")) monkeypatch.setattr(snippets_module, "Session", _SessionContext) monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object())) monkeypatch.setattr( @@ -329,14 +332,14 @@ def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch): ) api = snippets_module.CustomizedSnippetImportApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/workspaces/current/customized-snippets/imports", method="POST", json={"mode": "yaml-content", "yaml_content": "kind: snippet"}, ): - response, status_code = handler(api) + response, status_code = handler(api, user) assert status_code == 202 assert response["status"] == ImportStatus.PENDING.value @@ -345,7 +348,7 @@ def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch): def test_import_snippet_returns_400_for_failed_import(app, monkeypatch): - user = SimpleNamespace(id="account-1") + user = _account("account-1") result = SnippetImportInfo(id="import-1", status=ImportStatus.FAILED, error="Invalid DSL") import_snippet = Mock(return_value=result) session = SimpleNamespace(commit=Mock()) @@ -354,7 +357,6 @@ def test_import_snippet_returns_400_for_failed_import(app, monkeypatch): def __init__(self, engine, *args, **kwargs): super().__init__(engine, *args, session=session, **kwargs) - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1")) monkeypatch.setattr(snippets_module, "Session", SessionContext) monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object())) monkeypatch.setattr( @@ -364,14 +366,14 @@ def test_import_snippet_returns_400_for_failed_import(app, monkeypatch): ) api = snippets_module.CustomizedSnippetImportApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/workspaces/current/customized-snippets/imports", method="POST", json={"mode": "yaml-content", "yaml_content": "kind: snippet"}, ): - response, status_code = handler(api) + response, status_code = handler(api, user) assert status_code == 400 assert response["error"] == "Invalid DSL" @@ -379,7 +381,7 @@ def test_import_snippet_returns_400_for_failed_import(app, monkeypatch): def test_import_confirm_returns_200_for_completed_import(app, monkeypatch): - user = SimpleNamespace(id="account-1") + user = _account("account-1") result = SnippetImportInfo(id="import-1", status=ImportStatus.COMPLETED, snippet_id="snippet-1") confirm_import = Mock(return_value=result) session = SimpleNamespace(commit=Mock()) @@ -388,7 +390,6 @@ def test_import_confirm_returns_200_for_completed_import(app, monkeypatch): def __init__(self, engine, *args, **kwargs): super().__init__(engine, *args, session=session, **kwargs) - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1")) monkeypatch.setattr(snippets_module, "Session", SessionContext) monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object())) monkeypatch.setattr( @@ -398,13 +399,13 @@ def test_import_confirm_returns_200_for_completed_import(app, monkeypatch): ) api = snippets_module.CustomizedSnippetImportConfirmApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/workspaces/current/customized-snippets/imports/import-1/confirm", method="POST", ): - response, status_code = handler(api, import_id="import-1") + response, status_code = handler(api, user, import_id="import-1") assert status_code == 200 assert response["snippet_id"] == "snippet-1" @@ -413,19 +414,18 @@ def test_import_confirm_returns_200_for_completed_import(app, monkeypatch): def test_check_dependencies_raises_when_snippet_missing(app, monkeypatch): - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None)) api = snippets_module.CustomizedSnippetCheckDependenciesApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/workspaces/current/customized-snippets/snippet-1/check-dependencies"): with pytest.raises(NotFound, match="Snippet not found"): - handler(api, snippet_id="snippet-1") + handler(api, "tenant-1", snippet_id="snippet-1") def test_check_dependencies_returns_dependency_result(app, monkeypatch): - snippet = SimpleNamespace(id="snippet-1") + snippet = _snippet() check_dependencies = Mock( return_value=SimpleNamespace(model_dump=Mock(return_value={"dependencies": [], "missing_dependencies": []})) ) @@ -435,7 +435,6 @@ def test_check_dependencies_returns_dependency_result(app, monkeypatch): def __init__(self, engine, *args, **kwargs): super().__init__(engine, *args, session=session, **kwargs) - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet)) monkeypatch.setattr(snippets_module, "Session", SessionContext) monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object())) @@ -446,10 +445,10 @@ def test_check_dependencies_returns_dependency_result(app, monkeypatch): ) api = snippets_module.CustomizedSnippetCheckDependenciesApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/workspaces/current/customized-snippets/snippet-1/check-dependencies"): - response, status_code = handler(api, snippet_id="snippet-1") + response, status_code = handler(api, "tenant-1", snippet_id="snippet-1") assert status_code == 200 assert response == {"dependencies": [], "missing_dependencies": []} @@ -457,18 +456,17 @@ def test_check_dependencies_returns_dependency_result(app, monkeypatch): def test_increment_use_count_raises_when_snippet_missing(app, monkeypatch): - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None)) api = snippets_module.CustomizedSnippetUseCountIncrementApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/workspaces/current/customized-snippets/snippet-1/use-count/increment", method="POST", ): with pytest.raises(NotFound, match="Snippet not found"): - handler(api, snippet_id="snippet-1") + handler(api, "tenant-1", snippet_id="snippet-1") def test_increment_use_count_returns_refreshed_count(app, monkeypatch): @@ -487,20 +485,19 @@ def test_increment_use_count_returns_refreshed_count(app, monkeypatch): return False increment_use_count = Mock() - monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet)) monkeypatch.setattr(snippets_module.SnippetService, "increment_use_count", increment_use_count) monkeypatch.setattr(snippets_module, "Session", _SessionContext) monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object())) api = snippets_module.CustomizedSnippetUseCountIncrementApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/workspaces/current/customized-snippets/snippet-1/use-count/increment", method="POST", ): - response, status_code = handler(api, snippet_id="snippet-1") + response, status_code = handler(api, "tenant-1", snippet_id="snippet-1") assert status_code == 200 assert response == {"result": "success", "use_count": 3} diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py index 2410b912bd..e9c23da428 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -5,6 +5,7 @@ from __future__ import annotations import builtins import importlib from contextlib import ExitStack, contextmanager +from inspect import unwrap from types import ModuleType, SimpleNamespace from unittest.mock import MagicMock, patch @@ -12,12 +13,14 @@ import pytest from flask import Flask from flask.views import MethodView +from models import Account +from models.account import TenantAccountRole + if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] _CONTROLLER_MODULE: ModuleType | None = None -_WRAPS_MODULE: ModuleType | None = None @contextmanager @@ -69,9 +72,7 @@ def controller_module(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload) # Ensure decorators that consult deployment edition do not reach the database. - global _WRAPS_MODULE wraps_module = importlib.import_module("controllers.console.wraps") - _WRAPS_MODULE = wraps_module monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD") monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD") @@ -80,46 +81,24 @@ def controller_module(monkeypatch: pytest.MonkeyPatch): return module -def _mock_account(user_id: str = "user-123") -> SimpleNamespace: - return SimpleNamespace( - id=user_id, - status="active", - is_authenticated=True, - current_tenant_id=None, - is_admin_or_owner=False, - ) - - -def _set_current_account( - monkeypatch: pytest.MonkeyPatch, - controller_module: ModuleType, - user: SimpleNamespace, - tenant_id: str, -) -> None: - def _getter(): - return user, tenant_id - - user.current_tenant_id = tenant_id - - monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter) - if _WRAPS_MODULE is not None: - monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter) - - login_module = importlib.import_module("libs.login") - monkeypatch.setattr(login_module, "_get_user", lambda: user) +def _mock_account(user_id: str = "user-123") -> Account: + user = Account(name="Test User", email=f"{user_id}@example.com") + user.id = user_id + user.role = TenantAccountRole.NORMAL + return user def test_tool_provider_list_calls_service_with_query( app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch ): user = _mock_account() - _set_current_account(monkeypatch, controller_module, user, "tenant-456") service_mock = MagicMock(return_value=[{"provider": "builtin"}]) monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock) with app.test_request_context("/workspaces/current/tool-providers?type=builtin"): - response = controller_module.ToolProviderListApi().get() + api = controller_module.ToolProviderListApi() + response = unwrap(api.get)(api, "tenant-456", user) assert response == [{"provider": "builtin"}] service_mock.assert_called_once_with(user.id, "tenant-456", "builtin") @@ -129,7 +108,6 @@ def test_builtin_provider_add_passes_payload( app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch ): user = _mock_account() - _set_current_account(monkeypatch, controller_module, user, "tenant-456") service_mock = MagicMock(return_value={"status": "ok"}) monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock) @@ -145,7 +123,8 @@ def test_builtin_provider_add_passes_payload( method="POST", json=payload, ): - response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai") + api = controller_module.ToolBuiltinProviderAddApi() + response = unwrap(api.post)(api, "tenant-456", user, provider="openai") assert response == {"status": "ok"} service_mock.assert_called_once_with( @@ -161,7 +140,6 @@ def test_builtin_provider_add_passes_payload( def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): user = _mock_account("user-tenant-789") - _set_current_account(monkeypatch, controller_module, user, "tenant-789") service_mock = MagicMock(return_value=[{"name": "tool-a"}]) monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock) @@ -171,7 +149,8 @@ def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: "/workspaces/current/tool-provider/builtin/my-provider/tools", method="GET", ): - response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider") + api = controller_module.ToolBuiltinProviderListToolsApi() + response = unwrap(api.get)(api, "tenant-789", provider="my-provider") assert response == [{"name": "tool-a"}] service_mock.assert_called_once_with("tenant-789", "my-provider") @@ -179,12 +158,12 @@ def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): user = _mock_account("user-tenant-9") - _set_current_account(monkeypatch, controller_module, user, "tenant-9") service_mock = MagicMock(return_value={"info": True}) monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock) with app.test_request_context("/info", method="GET"): - resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo") + api = controller_module.ToolBuiltinProviderInfoApi() + resp = unwrap(api.get)(api, "tenant-9", provider="demo") assert resp == {"info": True} service_mock.assert_called_once_with("tenant-9", "demo") @@ -192,7 +171,6 @@ def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: p def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): user = _mock_account("user-tenant-cred") - _set_current_account(monkeypatch, controller_module, user, "tenant-cred") service_mock = MagicMock(return_value=[{"cred": 1}]) monkeypatch.setattr( controller_module.BuiltinToolManageService, @@ -201,7 +179,8 @@ def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeyp ) with app.test_request_context("/creds", method="GET"): - resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo") + api = controller_module.ToolBuiltinProviderGetCredentialsApi() + resp = unwrap(api.get)(api, "tenant-cred", user, provider="demo") assert resp == [{"cred": 1}] service_mock.assert_called_once_with( @@ -214,12 +193,12 @@ def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeyp def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): user = _mock_account() - _set_current_account(monkeypatch, controller_module, user, "tenant-10") service_mock = MagicMock(return_value={"schema": "ok"}) monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock) with app.test_request_context("/remote?url=https://example.com/"): - resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get() + api = controller_module.ToolApiProviderGetRemoteSchemaApi() + resp = unwrap(api.get)(api, "tenant-10", user) assert resp == {"schema": "ok"} service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/") @@ -227,12 +206,12 @@ def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypat def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): user = _mock_account() - _set_current_account(monkeypatch, controller_module, user, "tenant-11") service_mock = MagicMock(return_value=[{"tool": "t"}]) monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock) with app.test_request_context("/tools?provider=foo"): - resp = controller_module.ToolApiProviderListToolsApi().get() + api = controller_module.ToolApiProviderListToolsApi() + resp = unwrap(api.get)(api, "tenant-11", user) assert resp == [{"tool": "t"}] service_mock.assert_called_once_with(user.id, "tenant-11", "foo") @@ -240,12 +219,12 @@ def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): user = _mock_account() - _set_current_account(monkeypatch, controller_module, user, "tenant-12") service_mock = MagicMock(return_value={"provider": "foo"}) monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock) with app.test_request_context("/get?provider=foo"): - resp = controller_module.ToolApiProviderGetApi().get() + api = controller_module.ToolApiProviderGetApi() + resp = unwrap(api.get)(api, "tenant-12", user) assert resp == {"provider": "foo"} service_mock.assert_called_once_with(user.id, "tenant-12", "foo") @@ -253,7 +232,6 @@ def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.Mon def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): user = _mock_account("user-tenant-13") - _set_current_account(monkeypatch, controller_module, user, "tenant-13") service_mock = MagicMock(return_value={"schema": True}) monkeypatch.setattr( controller_module.BuiltinToolManageService, @@ -262,9 +240,8 @@ def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, ) with app.test_request_context("/schema", method="GET"): - resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get( - provider="demo", credential_type="api-key" - ) + api = controller_module.ToolBuiltinProviderCredentialsSchemaApi() + resp = unwrap(api.get)(api, "tenant-13", provider="demo", credential_type="api-key") assert resp == {"schema": True} service_mock.assert_called_once() @@ -272,7 +249,6 @@ def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): user = _mock_account() - _set_current_account(monkeypatch, controller_module, user, "tenant-wf") tool_service = MagicMock(return_value={"wf": 1}) monkeypatch.setattr( controller_module.WorkflowToolManageService, @@ -282,7 +258,8 @@ def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatc tool_id = "00000000-0000-0000-0000-000000000001" with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"): - resp = controller_module.ToolWorkflowProviderGetApi().get() + api = controller_module.ToolWorkflowProviderGetApi() + resp = unwrap(api.get)(api, "tenant-wf", user) assert resp == {"wf": 1} tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id) @@ -290,7 +267,6 @@ def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatc def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): user = _mock_account() - _set_current_account(monkeypatch, controller_module, user, "tenant-wf2") service_mock = MagicMock(return_value={"app": 1}) monkeypatch.setattr( controller_module.WorkflowToolManageService, @@ -300,7 +276,8 @@ def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch app_id = "00000000-0000-0000-0000-000000000002" with app.test_request_context(f"/workflow?workflow_app_id={app_id}"): - resp = controller_module.ToolWorkflowProviderGetApi().get() + api = controller_module.ToolWorkflowProviderGetApi() + resp = unwrap(api.get)(api, "tenant-wf2", user) assert resp == {"app": 1} service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id) @@ -308,13 +285,13 @@ def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): user = _mock_account() - _set_current_account(monkeypatch, controller_module, user, "tenant-wf3") service_mock = MagicMock(return_value=[{"id": 1}]) monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock) tool_id = "00000000-0000-0000-0000-000000000003" with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"): - resp = controller_module.ToolWorkflowProviderListToolApi().get() + api = controller_module.ToolWorkflowProviderListToolApi() + resp = unwrap(api.get)(api, "tenant-wf3", user) assert resp == [{"id": 1}] service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id) @@ -322,7 +299,6 @@ def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): user = _mock_account() - _set_current_account(monkeypatch, controller_module, user, "tenant-bt") provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"}) monkeypatch.setattr( @@ -332,14 +308,14 @@ def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.M ) with app.test_request_context("/tools/builtin"): - resp = controller_module.ToolBuiltinListApi().get() + api = controller_module.ToolBuiltinListApi() + resp = unwrap(api.get)(api, "tenant-bt", user) assert resp == [{"name": "builtin"}] def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): user = _mock_account("user-tenant-api") - _set_current_account(monkeypatch, controller_module, user, "tenant-api") provider = SimpleNamespace(to_dict=lambda: {"name": "api"}) monkeypatch.setattr( @@ -349,14 +325,14 @@ def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.Monke ) with app.test_request_context("/tools/api"): - resp = controller_module.ToolApiListApi().get() + api = controller_module.ToolApiListApi() + resp = unwrap(api.get)(api, "tenant-api") assert resp == [{"name": "api"}] def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): user = _mock_account() - _set_current_account(monkeypatch, controller_module, user, "tenant-wf4") provider = SimpleNamespace(to_dict=lambda: {"name": "wf"}) monkeypatch.setattr( @@ -366,18 +342,18 @@ def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest. ) with app.test_request_context("/tools/workflow"): - resp = controller_module.ToolWorkflowListApi().get() + api = controller_module.ToolWorkflowListApi() + resp = unwrap(api.get)(api, "tenant-wf4", user) assert resp == [{"name": "wf"}] def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch): - user = _mock_account("user-label") - _set_current_account(monkeypatch, controller_module, user, "tenant-labels") monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"]) with app.test_request_context("/tool-labels"): - resp = controller_module.ToolLabelsApi().get() + api = controller_module.ToolLabelsApi() + resp = unwrap(api.get)(api) assert resp == ["a", "b"]