mirror of
https://github.com/langgenius/dify.git
synced 2026-06-11 02:31:13 +08:00
refactor(api): migrate tenant/user via DI for several endpoints (#37240)
This commit is contained in:
parent
dad2e64a62
commit
d849d60822
@ -24,9 +24,9 @@ from controllers.common.schema import (
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent_app_workspace_service import (
|
||||
AgentAppWorkspaceService,
|
||||
@ -142,8 +142,8 @@ class AgentAppWorkspaceListResource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
query = query_params_from_request(AgentWorkspaceListQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().list_files(
|
||||
@ -167,8 +167,8 @@ class AgentAppWorkspacePreviewResource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
query = query_params_from_request(AgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().preview(
|
||||
@ -194,8 +194,8 @@ class AgentAppWorkspaceDownloadResource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
query = query_params_from_request(AgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().download(
|
||||
@ -228,8 +228,8 @@ class WorkflowAgentWorkspaceListResource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceListQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().list_files(
|
||||
@ -264,8 +264,8 @@ class WorkflowAgentWorkspacePreviewResource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().preview(
|
||||
@ -302,8 +302,8 @@ class WorkflowAgentWorkspaceDownloadResource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().download(
|
||||
|
||||
@ -25,6 +25,7 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
@ -58,7 +59,7 @@ from graphon.variables import SecretVariable, SegmentType, VariableBase
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import TimestampField, dump_response, to_timestamp, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
@ -1568,7 +1569,8 @@ class WorkflowOnlineUsersApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = WorkflowOnlineUsersPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_ids = args.app_ids
|
||||
@ -1578,7 +1580,6 @@ class WorkflowOnlineUsersApi(Resource):
|
||||
if not app_ids:
|
||||
return {"data": []}
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id)
|
||||
ordered_accessible_app_ids = [app_id for app_id in app_ids if app_id in accessible_app_ids]
|
||||
|
||||
@ -22,6 +22,8 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
@ -36,9 +38,9 @@ from fields.base import ResponseModel
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import build_icon_url, dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from libs.url_utils import normalize_api_base_url
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models import Account, ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import ApiTokenType, SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
@ -389,8 +391,9 @@ class DatasetListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
# Convert query parameters to dict, handling list parameters correctly
|
||||
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
|
||||
# Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value)
|
||||
@ -471,9 +474,10 @@ class DatasetListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
payload = DatasetCreatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -512,8 +516,9 @@ class DatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -566,14 +571,15 @@ class DatasetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# check embedding model setting
|
||||
if (
|
||||
payload.indexing_technique == IndexTechniqueType.HIGH_QUALITY
|
||||
@ -614,9 +620,9 @@ class DatasetApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Dataset deleted successfully")
|
||||
def delete(self, dataset_id: UUID):
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
@ -664,8 +670,8 @@ class DatasetQueryApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -704,10 +710,10 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__])
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = IndexingEstimatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
extract_settings = []
|
||||
@ -804,8 +810,8 @@ class DatasetRelatedAppListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -840,8 +846,8 @@ class DatasetIndexingStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id)
|
||||
@ -898,8 +904,8 @@ class DatasetApiKeyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
).all()
|
||||
@ -911,9 +917,8 @@ class DatasetApiKeyApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
current_key_count = (
|
||||
db.session.scalar(
|
||||
select(func.count(ApiToken.id)).where(
|
||||
@ -952,8 +957,8 @@ class DatasetApiDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, api_key_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, api_key_id: UUID):
|
||||
api_key_id_str = str(api_key_id)
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
@ -1079,8 +1084,8 @@ class DatasetPermissionUserListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
|
||||
@ -29,6 +29,8 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -46,7 +48,7 @@ from fields.workflow_run_fields import (
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty, dump_response
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.model import EndUser
|
||||
@ -187,16 +189,14 @@ class DraftRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
@edit_permission_required
|
||||
@console_ns.response(200, "Success", console_ns.models[RagPipelineWorkflowSyncResponse.__name__])
|
||||
def post(self, pipeline: Pipeline):
|
||||
def post(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Sync draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
@ -247,15 +247,13 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
@edit_permission_required
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = NodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
@ -283,14 +281,12 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = NodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
@ -318,14 +314,12 @@ class DraftRagPipelineRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
def post(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Run draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
|
||||
@ -350,14 +344,12 @@ class PublishedRagPipelineRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
def post(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Run published workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
streaming = payload.response_mode == "streaming"
|
||||
@ -383,14 +375,12 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
@ -416,14 +406,12 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run rag pipeline datasource
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
@ -454,14 +442,12 @@ class RagPipelineDraftNodeRunApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {})
|
||||
inputs = payload.inputs
|
||||
|
||||
@ -485,14 +471,12 @@ class RagPipelineTaskStopApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, task_id: str):
|
||||
def post(self, current_user: Account, pipeline: Pipeline, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
return {"result": "success"}
|
||||
@ -532,13 +516,12 @@ class PublishedRagPipelineApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline):
|
||||
def post(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Publish workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.publish_workflow(
|
||||
session=db.session, # type: ignore[reportArgumentType,arg-type]
|
||||
@ -609,13 +592,12 @@ class PublishedAllRagPipelineApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Get published workflows
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
query = WorkflowListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
page = query.page
|
||||
@ -655,9 +637,9 @@ class RagPipelineDraftWorkflowRestoreApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, workflow_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def post(self, current_user: Account, pipeline: Pipeline, workflow_id: str):
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
try:
|
||||
@ -689,14 +671,12 @@ class RagPipelineByIdApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
def patch(self, pipeline: Pipeline, workflow_id: str):
|
||||
def patch(self, current_user: Account, pipeline: Pipeline, workflow_id: str):
|
||||
"""
|
||||
Update workflow attributes
|
||||
"""
|
||||
# Check permission
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
@ -925,8 +905,8 @@ class DatasourceListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(current_tenant_id))
|
||||
|
||||
|
||||
@ -961,9 +941,8 @@ class RagPipelineTransformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
|
||||
@ -984,13 +963,13 @@ class RagPipelineDatasourceVariableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_rag_pipeline
|
||||
@edit_permission_required
|
||||
def post(self, pipeline: Pipeline):
|
||||
def post(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Set datasource variables
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
@ -30,6 +30,7 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -45,6 +46,7 @@ from graphon.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
@ -158,12 +160,11 @@ class SnippetDraftWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet):
|
||||
"""Sync draft workflow for snippet."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
@ -239,11 +240,11 @@ class SnippetPublishedWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet):
|
||||
"""Publish snippet workflow."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
snippet_service = _snippet_service()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
@ -331,11 +332,11 @@ class SnippetDraftWorkflowRestoreApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, workflow_id: str):
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet, workflow_id: str):
|
||||
"""Restore a published snippet workflow version into the draft workflow."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
snippet_service = _snippet_service()
|
||||
|
||||
try:
|
||||
@ -455,16 +456,16 @@ class SnippetDraftNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a single node in snippet draft workflow.
|
||||
|
||||
Executes a specific node with provided inputs for single-step debugging.
|
||||
Returns the node execution result including status, outputs, and timing.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
user_inputs = payload.inputs
|
||||
@ -539,16 +540,16 @@ class SnippetDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow iteration node for snippet.
|
||||
|
||||
Iteration nodes execute their internal sub-graph multiple times over an input list.
|
||||
Returns an SSE event stream with iteration progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
@ -580,16 +581,16 @@ class SnippetDraftRunLoopNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow loop node for snippet.
|
||||
|
||||
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
|
||||
Returns an SSE event stream with loop progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
@ -619,17 +620,16 @@ class SnippetDraftWorkflowRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
def post(self, current_user: Account, snippet: CustomizedSnippet):
|
||||
"""
|
||||
Run draft workflow for snippet.
|
||||
|
||||
Executes the snippet's draft workflow with the provided inputs
|
||||
and returns an SSE event stream with execution progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
|
||||
@ -13,13 +13,19 @@ from controllers.common.fields import SuccessResponse
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from fields.base import ResponseModel
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||
from libs.login import login_required
|
||||
from models.account import Account, TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_parameter_service import PluginParameterService
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
@ -200,9 +206,8 @@ class PluginDebuggingKeyApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
try:
|
||||
return {
|
||||
"key": PluginService.get_debugging_key(tenant_id),
|
||||
@ -219,8 +224,8 @@ class PluginListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserList.model_validate(request.args.to_dict(flat=True))
|
||||
try:
|
||||
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size)
|
||||
@ -253,9 +258,8 @@ class PluginListInstallationsFromIdsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserLatest.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -288,10 +292,10 @@ class PluginAssetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserAsset.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
try:
|
||||
binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name)
|
||||
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
|
||||
@ -305,9 +309,8 @@ class PluginUploadFromPkgApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
file = request.files["pkg"]
|
||||
content = _read_upload_content(file, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
try:
|
||||
@ -325,9 +328,8 @@ class PluginUploadFromGithubApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserGithubUpload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -344,9 +346,8 @@ class PluginUploadFromBundleApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
file = request.files["bundle"]
|
||||
content = _read_upload_content(file, dify_config.PLUGIN_MAX_BUNDLE_SIZE)
|
||||
try:
|
||||
@ -364,8 +365,8 @@ class PluginInstallFromPkgApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -383,9 +384,8 @@ class PluginInstallFromGithubApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserGithubInstall.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -409,9 +409,8 @@ class PluginInstallFromMarketplaceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserPluginIdentifiers.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -429,8 +428,8 @@ class PluginFetchMarketplacePkgApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
@ -453,9 +452,8 @@ class PluginFetchManifestApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
@ -473,9 +471,8 @@ class PluginFetchInstallTasksApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserTasks.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
@ -490,9 +487,8 @@ class PluginFetchInstallTaskApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def get(self, task_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, task_id: str):
|
||||
try:
|
||||
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
|
||||
except PluginDaemonClientSideError as e:
|
||||
@ -506,9 +502,8 @@ class PluginDeleteInstallTaskApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self, task_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, task_id: str):
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
@ -522,9 +517,8 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
try:
|
||||
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
@ -538,9 +532,8 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self, task_id: str, identifier: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, task_id: str, identifier: str):
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
@ -554,9 +547,8 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserMarketplaceUpgrade.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -576,9 +568,8 @@ class PluginUpgradeFromGithubApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserGithubUpgrade.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
@ -604,11 +595,10 @@ class PluginUninstallApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserUninstall.model_validate(console_ns.payload)
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
@ -622,16 +612,14 @@ class PluginChangePermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = ParserPermissionChange.model_validate(console_ns.payload)
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
return {
|
||||
"success": PluginPermissionService.change_permission(
|
||||
tenant_id, args.install_permission, args.debug_permission
|
||||
@ -644,9 +632,8 @@ class PluginFetchPermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
permission = PluginPermissionService.get_permission(tenant_id)
|
||||
if not permission:
|
||||
return jsonable_encoder(
|
||||
@ -671,16 +658,15 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
user_id = current_user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, current_user: Account):
|
||||
args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
user_id=current_user.id,
|
||||
plugin_id=args.plugin_id,
|
||||
provider=args.provider,
|
||||
action=args.action,
|
||||
@ -701,17 +687,16 @@ class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, current_user: Account):
|
||||
"""Fetch dynamic options using credentials directly (for edit mode)."""
|
||||
current_user, tenant_id = current_account_with_tenant()
|
||||
user_id = current_user.id
|
||||
|
||||
args = ParserDynamicOptionsWithCredentials.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options_with_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
user_id=current_user.id,
|
||||
plugin_id=args.plugin_id,
|
||||
provider=args.provider,
|
||||
action=args.action,
|
||||
@ -731,8 +716,9 @@ class PluginChangePreferencesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
@ -780,9 +766,8 @@ class PluginFetchPreferencesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
permission = PluginPermissionService.get_permission(tenant_id)
|
||||
permission_dict = {
|
||||
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
@ -820,10 +805,9 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
# exclude one single plugin
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserExcludePlugin.model_validate(console_ns.payload)
|
||||
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)})
|
||||
@ -835,8 +819,8 @@ class PluginReadmeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserReadme.model_validate(request.args.to_dict(flat=True))
|
||||
return jsonable_encoder(
|
||||
{"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)}
|
||||
|
||||
@ -21,10 +21,13 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.snippet import SnippetType
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.snippet_dsl_service import SnippetDslService
|
||||
@ -91,10 +94,9 @@ class CustomizedSnippetsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
"""List customized snippets with pagination and search."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query = SnippetListQuery.model_validate(_normalize_snippet_list_query_args(request.args))
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
@ -124,10 +126,10 @@ class CustomizedSnippetsApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
"""Create a new customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
@ -163,10 +165,9 @@ class CustomizedSnippetDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, snippet_id: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, snippet_id: str):
|
||||
"""Get customized snippet details."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
@ -187,10 +188,10 @@ class CustomizedSnippetDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, snippet_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(self, current_tenant_id: str, current_user: Account, snippet_id: str):
|
||||
"""Update customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
@ -231,10 +232,9 @@ class CustomizedSnippetDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, snippet_id: str):
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, snippet_id: str):
|
||||
"""Delete customized snippet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
@ -266,10 +266,9 @@ class CustomizedSnippetExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, snippet_id: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, snippet_id: str):
|
||||
"""Export snippet as DSL."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
@ -312,9 +311,9 @@ class CustomizedSnippetImportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
"""Import snippet from DSL."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = SnippetImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine) as session:
|
||||
@ -350,10 +349,9 @@ class CustomizedSnippetImportConfirmApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, import_id: str):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, import_id: str):
|
||||
"""Confirm a pending snippet import."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = SnippetDslService(session)
|
||||
result = import_service.confirm_import(import_id=import_id, account=current_user)
|
||||
@ -375,10 +373,9 @@ class CustomizedSnippetCheckDependenciesApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, snippet_id: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, snippet_id: str):
|
||||
"""Check dependencies for a snippet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
@ -406,10 +403,9 @@ class CustomizedSnippetUseCountIncrementApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, snippet_id: str):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, snippet_id: str):
|
||||
"""Increment snippet use count when it is inserted into a workflow."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
|
||||
@ -18,6 +18,8 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.mcp_provider import IdentityMode, MCPAuthentication, MCPConfiguration
|
||||
@ -30,7 +32,8 @@ from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToo
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.provider_ids import ToolProviderID
|
||||
|
||||
# from models.provider_ids import ToolProviderID
|
||||
@ -286,15 +289,13 @@ class ToolProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
raw_args = request.args.to_dict()
|
||||
query = ToolProviderListQuery.model_validate(raw_args)
|
||||
|
||||
return ToolCommonService.list_tool_providers(user_id, tenant_id, query.type) # type: ignore
|
||||
return ToolCommonService.list_tool_providers(user.id, tenant_id, query.type) # type: ignore
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
||||
@ -302,9 +303,8 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.list_builtin_tool_provider_tools(
|
||||
tenant_id,
|
||||
@ -318,9 +318,8 @@ class ToolBuiltinProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||
|
||||
|
||||
@ -331,9 +330,8 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return BuiltinToolManageService.delete_builtin_tool_provider(
|
||||
@ -349,15 +347,13 @@ class ToolBuiltinProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account, provider: str):
|
||||
payload = BuiltinToolAddPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return BuiltinToolManageService.add_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credentials=payload.credentials,
|
||||
@ -374,14 +370,13 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account, provider: str):
|
||||
payload = BuiltinToolUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = BuiltinToolManageService.update_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credential_id=payload.credential_id,
|
||||
@ -396,8 +391,9 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
# Optional list of credential IDs to include even if visibility would hide them
|
||||
# (used when a workflow/agent node still references another member's only_me credential).
|
||||
include_credential_ids = request.args.getlist("include_credential_ids") or [
|
||||
@ -430,15 +426,13 @@ class ToolApiProviderAddApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = ApiToolProviderAddPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return ApiToolManageService.create_api_tool_provider(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
payload.provider,
|
||||
payload.icon,
|
||||
@ -456,16 +450,14 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
raw_args = request.args.to_dict()
|
||||
query = UrlQuery.model_validate(raw_args)
|
||||
|
||||
return ApiToolManageService.get_api_tool_provider_remote_schema(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
str(query.url),
|
||||
)
|
||||
@ -476,17 +468,15 @@ class ToolApiProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
raw_args = request.args.to_dict()
|
||||
query = ProviderQuery.model_validate(raw_args)
|
||||
|
||||
return jsonable_encoder(
|
||||
ApiToolManageService.list_api_tool_provider_tools(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
query.provider,
|
||||
)
|
||||
@ -500,15 +490,13 @@ class ToolApiProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = ApiToolProviderUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return ApiToolManageService.update_api_tool_provider(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
payload.provider,
|
||||
payload.original_provider,
|
||||
@ -529,15 +517,13 @@ class ToolApiProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = ApiToolProviderDeletePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return ApiToolManageService.delete_api_tool_provider(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
payload.provider,
|
||||
)
|
||||
@ -548,16 +534,14 @@ class ToolApiProviderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
raw_args = request.args.to_dict()
|
||||
query = ProviderQuery.model_validate(raw_args)
|
||||
|
||||
return ApiToolManageService.get_api_tool_provider(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
query.provider,
|
||||
)
|
||||
@ -568,9 +552,8 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, credential_type):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider, credential_type):
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.list_builtin_provider_credentials_schema(
|
||||
provider, CredentialType.of(credential_type), tenant_id
|
||||
@ -598,9 +581,9 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = ApiToolTestPayload.model_validate(console_ns.payload or {})
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return ApiToolManageService.test_api_tool_preview(
|
||||
current_tenant_id,
|
||||
payload.provider_name or "",
|
||||
@ -619,15 +602,13 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = WorkflowToolCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return WorkflowToolManageService.create_workflow_tool(
|
||||
user_id=user_id,
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_app_id=payload.workflow_app_id,
|
||||
name=payload.name,
|
||||
@ -647,14 +628,13 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = WorkflowToolUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return WorkflowToolManageService.update_workflow_tool(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
payload.workflow_tool_id,
|
||||
payload.name,
|
||||
@ -674,15 +654,13 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = WorkflowToolDeletePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
return WorkflowToolManageService.delete_workflow_tool(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
payload.workflow_tool_id,
|
||||
)
|
||||
@ -693,23 +671,21 @@ class ToolWorkflowProviderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
raw_args = request.args.to_dict()
|
||||
query = WorkflowToolGetQuery.model_validate(raw_args)
|
||||
|
||||
if query.workflow_tool_id:
|
||||
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
query.workflow_tool_id,
|
||||
)
|
||||
elif query.workflow_app_id:
|
||||
tool = WorkflowToolManageService.get_workflow_tool_by_app_id(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
query.workflow_app_id,
|
||||
)
|
||||
@ -724,17 +700,15 @@ class ToolWorkflowProviderListToolApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
raw_args = request.args.to_dict()
|
||||
query = WorkflowToolListQuery.model_validate(raw_args)
|
||||
|
||||
return jsonable_encoder(
|
||||
WorkflowToolManageService.list_single_workflow_tools(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
query.workflow_tool_id,
|
||||
)
|
||||
@ -746,16 +720,14 @@ class ToolBuiltinListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
return jsonable_encoder(
|
||||
[
|
||||
provider.to_dict()
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
)
|
||||
]
|
||||
@ -767,9 +739,8 @@ class ToolApiListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
return jsonable_encoder(
|
||||
[
|
||||
provider.to_dict()
|
||||
@ -785,16 +756,14 @@ class ToolWorkflowListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account):
|
||||
return jsonable_encoder(
|
||||
[
|
||||
provider.to_dict()
|
||||
for provider in WorkflowToolManageService.list_tenant_workflow_tools(
|
||||
user_id,
|
||||
user.id,
|
||||
tenant_id,
|
||||
)
|
||||
]
|
||||
@ -817,13 +786,13 @@ class ToolPluginOAuthApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("no oauth available client config found for this tool provider")
|
||||
@ -912,8 +881,8 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
tenant_id=current_tenant_id, provider=provider, id=payload.id
|
||||
@ -927,11 +896,10 @@ class ToolOAuthCustomClient(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
payload = ToolOAuthCustomClientPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return BuiltinToolManageService.save_custom_oauth_client_params(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
@ -944,8 +912,8 @@ class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, provider: str):
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
)
|
||||
@ -953,8 +921,8 @@ class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider: str):
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
)
|
||||
@ -965,8 +933,8 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, provider: str):
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
|
||||
tenant_id=current_tenant_id, provider_name=provider
|
||||
@ -979,8 +947,9 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
include_credential_ids = request.args.getlist("include_credential_ids") or [
|
||||
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
|
||||
]
|
||||
@ -1001,9 +970,10 @@ class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
payload = MCPProviderCreatePayload.model_validate(console_ns.payload or {})
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
# Parse and validate models
|
||||
configuration = MCPConfiguration.model_validate(payload.configuration or {})
|
||||
@ -1054,11 +1024,11 @@ class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self):
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str):
|
||||
payload = MCPProviderUpdatePayload.model_validate(console_ns.payload or {})
|
||||
configuration = MCPConfiguration.model_validate(payload.configuration or {})
|
||||
authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
|
||||
validation_data = None
|
||||
@ -1107,9 +1077,9 @@ class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self):
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str):
|
||||
payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {})
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
@ -1124,10 +1094,10 @@ class ToolMCPAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
payload = MCPAuthPayload.model_validate(console_ns.payload or {})
|
||||
provider_id = payload.provider_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
@ -1197,8 +1167,8 @@ class ToolMCPDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider_id: str):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
@ -1210,9 +1180,8 @@ class ToolMCPListAllApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
# Skip sensitive data decryption for list view to improve performance
|
||||
@ -1226,8 +1195,8 @@ class ToolMCPUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider_id: str):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_provider_tools(
|
||||
|
||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from typing import TypedDict, Unpack
|
||||
from unittest.mock import MagicMock, patch
|
||||
@ -37,7 +38,8 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import Workflow
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -48,6 +50,14 @@ DEFAULT_WORKFLOW_CREATED_BY = "00000000-0000-0000-0000-000000000003"
|
||||
type WorkflowVariablePayload = dict[str, object]
|
||||
|
||||
|
||||
def empty_mapping() -> dict[str, object]:
|
||||
return {}
|
||||
|
||||
|
||||
def empty_list() -> list[object]:
|
||||
return []
|
||||
|
||||
|
||||
class WorkflowFactoryPayload(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
@ -86,14 +96,8 @@ class WorkflowFactoryOverrides(TypedDict, total=False):
|
||||
rag_pipeline_variables: list[WorkflowVariablePayload]
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def make_node_execution(**overrides):
|
||||
payload = {
|
||||
def make_node_execution(**overrides: object) -> SimpleNamespace:
|
||||
payload: dict[str, object] = {
|
||||
"id": "node-exec-1",
|
||||
"index": 1,
|
||||
"predecessor_node_id": None,
|
||||
@ -148,6 +152,27 @@ def make_workflow(**overrides: Unpack[WorkflowFactoryOverrides]) -> Workflow:
|
||||
return Workflow(**payload)
|
||||
|
||||
|
||||
def make_account(*, id: str = "account-1", role: TenantAccountRole = TenantAccountRole.EDITOR) -> Account:
|
||||
account = Account(name="Alice", email=f"{id}@example.com")
|
||||
account.id = id
|
||||
account.role = role
|
||||
return account
|
||||
|
||||
|
||||
def make_pipeline(
|
||||
*,
|
||||
id: str = "pipeline-1",
|
||||
tenant_id: str = "tenant-1",
|
||||
workflow_id: str | None = None,
|
||||
is_published: bool = False,
|
||||
) -> Pipeline:
|
||||
pipeline = Pipeline(tenant_id=tenant_id, name="test-pipeline", description="test")
|
||||
pipeline.id = id
|
||||
pipeline.workflow_id = workflow_id
|
||||
pipeline.is_published = is_published
|
||||
return pipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workflow_author(db_session_with_containers: Session) -> Account:
|
||||
account = Account(name="Alice", email=f"alice-{uuid4()}@example.com")
|
||||
@ -158,14 +183,14 @@ def workflow_author(db_session_with_containers: Session) -> Account:
|
||||
|
||||
class TestDraftWorkflowApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_draft_success(self, app: Flask, workflow_author: Account):
|
||||
def test_get_draft_success(self, app: Flask, workflow_author: Account) -> None:
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
workflow = make_workflow(created_by=workflow_author.id)
|
||||
|
||||
service = MagicMock()
|
||||
@ -191,11 +216,11 @@ class TestDraftWorkflowApi:
|
||||
}
|
||||
assert result["updated_by"] is None
|
||||
|
||||
def test_get_draft_not_exist(self, app: Flask):
|
||||
def test_get_draft_not_exist(self, app: Flask) -> None:
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = None
|
||||
|
||||
@ -209,54 +234,46 @@ class TestDraftWorkflowApi:
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_sync_hash_not_match(self, app: Flask):
|
||||
def test_sync_hash_not_match(self, app: Flask) -> None:
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
user = make_account()
|
||||
|
||||
service = MagicMock()
|
||||
service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"graph": {}, "features": {}}),
|
||||
patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
app.test_request_context("/", json={"graph": empty_mapping(), "features": empty_mapping()}),
|
||||
patch.object(type(console_ns), "payload", {"graph": empty_mapping(), "features": empty_mapping()}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotSync):
|
||||
method(api, pipeline)
|
||||
method(api, user, pipeline)
|
||||
|
||||
def test_sync_invalid_text_plain(self, app: Flask):
|
||||
def test_sync_invalid_text_plain(self, app: Flask) -> None:
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
response, status = method(api, user, pipeline)
|
||||
assert status == 400
|
||||
|
||||
def test_restore_published_workflow_to_draft_success(self, app: Flask):
|
||||
def test_restore_published_workflow_to_draft_success(self, app: Flask) -> None:
|
||||
api = RagPipelineDraftWorkflowRestoreApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="account-1")
|
||||
pipeline = make_pipeline()
|
||||
user = make_account(id="account-1")
|
||||
workflow = MagicMock(unique_hash="restored-hash", updated_at=None, created_at=datetime(2024, 1, 1))
|
||||
|
||||
service = MagicMock()
|
||||
@ -264,50 +281,42 @@ class TestDraftWorkflowApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="POST"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "published-workflow")
|
||||
result = method(api, user, pipeline, "published-workflow")
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert result["hash"] == "restored-hash"
|
||||
|
||||
def test_restore_published_workflow_to_draft_not_found(self, app: Flask):
|
||||
def test_restore_published_workflow_to_draft_not_found(self, app: Flask) -> None:
|
||||
api = RagPipelineDraftWorkflowRestoreApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="account-1")
|
||||
pipeline = make_pipeline()
|
||||
user = make_account(id="account-1")
|
||||
|
||||
service = MagicMock()
|
||||
service.restore_published_workflow_to_draft.side_effect = WorkflowNotFoundError("Workflow not found")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="POST"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "published-workflow")
|
||||
method(api, user, pipeline, "published-workflow")
|
||||
|
||||
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app: Flask):
|
||||
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app: Flask) -> None:
|
||||
api = RagPipelineDraftWorkflowRestoreApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="account-1")
|
||||
pipeline = make_pipeline()
|
||||
user = make_account(id="account-1")
|
||||
|
||||
service = MagicMock()
|
||||
service.restore_published_workflow_to_draft.side_effect = IsDraftWorkflowError(
|
||||
@ -316,17 +325,13 @@ class TestDraftWorkflowApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="POST"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
method(api, pipeline, "draft-workflow")
|
||||
method(api, user, pipeline, "draft-workflow")
|
||||
|
||||
assert exc.value.code == 400
|
||||
assert exc.value.description == "source workflow must be published"
|
||||
@ -334,23 +339,19 @@ class TestDraftWorkflowApi:
|
||||
|
||||
class TestDraftRunNodes:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_iteration_node_success(self, app: Flask):
|
||||
def test_iteration_node_success(self, app: Flask) -> None:
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
app.test_request_context("/", json={"inputs": empty_mapping()}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": empty_mapping()}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
|
||||
return_value=MagicMock(),
|
||||
@ -360,45 +361,37 @@ class TestDraftRunNodes:
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node")
|
||||
result = method(api, user, pipeline, "node")
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_iteration_node_conversation_not_exists(self, app: Flask):
|
||||
def test_iteration_node_conversation_not_exists(self, app: Flask) -> None:
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
app.test_request_context("/", json={"inputs": empty_mapping()}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": empty_mapping()}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
|
||||
side_effect=services.errors.conversation.ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "node")
|
||||
method(api, user, pipeline, "node")
|
||||
|
||||
def test_loop_node_success(self, app: Flask):
|
||||
def test_loop_node_success(self, app: Flask) -> None:
|
||||
api = RagPipelineDraftRunLoopNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
app.test_request_context("/", json={"inputs": empty_mapping()}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": empty_mapping()}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop",
|
||||
return_value=MagicMock(),
|
||||
@ -408,35 +401,31 @@ class TestDraftRunNodes:
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, pipeline, "node") == {"ok": True}
|
||||
assert method(api, user, pipeline, "node") == {"ok": True}
|
||||
|
||||
|
||||
class TestPipelineRunApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_draft_run_success(self, app: Flask):
|
||||
def test_draft_run_success(self, app: Flask) -> None:
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
user = make_account()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"inputs": empty_mapping(),
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"datasource_info_list": empty_list(),
|
||||
"start_node_id": "n",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
@ -446,27 +435,27 @@ class TestPipelineRunApis:
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, pipeline) == {"ok": True}
|
||||
assert method(api, user, pipeline) == {"ok": True}
|
||||
|
||||
def test_draft_run_rate_limit(self, app: Flask):
|
||||
def test_draft_run_rate_limit(self, app: Flask) -> None:
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
user = make_account()
|
||||
payload: dict[str, object] = {
|
||||
"inputs": empty_mapping(),
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": empty_list(),
|
||||
"start_node_id": "n",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}
|
||||
),
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
{"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
@ -474,48 +463,42 @@ class TestPipelineRunApis:
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, pipeline)
|
||||
method(api, user, pipeline)
|
||||
|
||||
|
||||
class TestDraftNodeRun:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_execution_not_found(self, app: Flask):
|
||||
def test_execution_not_found(self, app: Flask) -> None:
|
||||
api = RagPipelineDraftNodeRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
user = make_account()
|
||||
|
||||
service = MagicMock()
|
||||
service.run_draft_workflow_node.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
app.test_request_context("/", json={"inputs": empty_mapping()}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": empty_mapping()}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "node")
|
||||
method(api, user, pipeline, "node")
|
||||
|
||||
|
||||
class TestPublishedPipelineApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_publish_success(self, app: Flask, db_session_with_containers: Session):
|
||||
from models.dataset import Pipeline
|
||||
|
||||
def test_publish_success(self, app: Flask, db_session_with_containers: Session) -> None:
|
||||
api = PublishedRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -530,7 +513,7 @@ class TestPublishedPipelineApis:
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
user = MagicMock(id="u1")
|
||||
user = make_account(id="u1")
|
||||
|
||||
workflow = MagicMock(
|
||||
id=str(uuid4()),
|
||||
@ -542,16 +525,12 @@ class TestPublishedPipelineApis:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
result = method(api, user, pipeline)
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert "created_at" in result
|
||||
@ -559,47 +538,39 @@ class TestPublishedPipelineApis:
|
||||
|
||||
class TestMiscApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_task_stop(self, app: Flask):
|
||||
def test_task_stop(self, app: Flask) -> None:
|
||||
api = RagPipelineTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
pipeline = make_pipeline()
|
||||
user = make_account(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag"
|
||||
) as stop_mock,
|
||||
):
|
||||
result = method(api, pipeline, "task-1")
|
||||
result = method(api, user, pipeline, "task-1")
|
||||
stop_mock.assert_called_once()
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_transform_forbidden(self, app: Flask):
|
||||
def test_transform_forbidden(self, app: Flask) -> None:
|
||||
api = RagPipelineTransformApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(has_edit_permission=False, is_dataset_operator=False)
|
||||
user = make_account(role=TenantAccountRole.NORMAL)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "ds1")
|
||||
method(api, user, "ds1")
|
||||
|
||||
def test_recommended_plugins(self, app: Flask):
|
||||
def test_recommended_plugins(self, app: Flask) -> None:
|
||||
api = RagPipelineRecommendedPluginApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -619,20 +590,20 @@ class TestMiscApis:
|
||||
|
||||
class TestPublishedRagPipelineRunApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_published_run_success(self, app: Flask):
|
||||
def test_published_run_success(self, app: Flask) -> None:
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
user = make_account()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"inputs": empty_mapping(),
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"datasource_info_list": empty_list(),
|
||||
"start_node_id": "n",
|
||||
"response_mode": "blocking",
|
||||
}
|
||||
@ -640,10 +611,6 @@ class TestPublishedRagPipelineRunApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
@ -653,49 +620,45 @@ class TestPublishedRagPipelineRunApi:
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
result = method(api, user, pipeline)
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_published_run_rate_limit(self, app: Flask):
|
||||
def test_published_run_rate_limit(self, app: Flask) -> None:
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
user = make_account()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"inputs": empty_mapping(),
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"datasource_info_list": empty_list(),
|
||||
"start_node_id": "n",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, pipeline)
|
||||
method(api, user, pipeline)
|
||||
|
||||
|
||||
class TestDefaultBlockConfigApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_block_config_success(self, app: Flask):
|
||||
def test_get_block_config_success(self, app: Flask) -> None:
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_default_block_config.return_value = {"k": "v"}
|
||||
@ -710,11 +673,11 @@ class TestDefaultBlockConfigApi:
|
||||
result = method(api, pipeline, "llm")
|
||||
assert result == {"k": "v"}
|
||||
|
||||
def test_get_block_config_invalid_json(self, app: Flask):
|
||||
def test_get_block_config_invalid_json(self, app: Flask) -> None:
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
|
||||
with app.test_request_context("/?q=bad-json"):
|
||||
with pytest.raises(ValueError):
|
||||
@ -723,65 +686,57 @@ class TestDefaultBlockConfigApi:
|
||||
|
||||
class TestPublishedAllRagPipelineApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_published_workflows_success(self, app: Flask):
|
||||
def test_get_published_workflows_success(self, app: Flask) -> None:
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
pipeline = make_pipeline()
|
||||
user = make_account(id="u1")
|
||||
|
||||
service = MagicMock()
|
||||
service.get_all_published_workflow.return_value = ([make_workflow(id="w1")], False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
result = method(api, user, pipeline)
|
||||
|
||||
assert result["items"][0]["id"] == "w1"
|
||||
assert result["items"][0]["graph"] == {"nodes": [], "edges": []}
|
||||
assert result["has_more"] is False
|
||||
|
||||
def test_get_published_workflows_forbidden(self, app: Flask):
|
||||
def test_get_published_workflows_forbidden(self, app: Flask) -> None:
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
pipeline = make_pipeline()
|
||||
user = make_account(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?user_id=u2"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, pipeline)
|
||||
method(api, user, pipeline)
|
||||
|
||||
|
||||
class TestRagPipelineByIdApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_patch_success(self, app: Flask):
|
||||
def test_patch_success(self, app: Flask) -> None:
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock(tenant_id="t1")
|
||||
user = MagicMock(id="u1")
|
||||
pipeline = make_pipeline(tenant_id="t1")
|
||||
user = make_account(id="u1")
|
||||
|
||||
workflow = make_workflow(id="w1", marked_name="test")
|
||||
|
||||
@ -793,44 +748,36 @@ class TestRagPipelineByIdApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "w1")
|
||||
result = method(api, user, pipeline, "w1")
|
||||
|
||||
assert result["id"] == "w1"
|
||||
assert result["marked_name"] == "test"
|
||||
assert result["hash"] == workflow.unique_hash
|
||||
|
||||
def test_patch_no_fields(self, app: Flask):
|
||||
def test_patch_no_fields(self, app: Flask) -> None:
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
patch.object(type(console_ns), "payload", {}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch.object(type(console_ns), "payload", empty_mapping()),
|
||||
):
|
||||
result, status = method(api, pipeline, "w1")
|
||||
result, status = method(api, user, pipeline, "w1")
|
||||
assert status == 400
|
||||
|
||||
def test_delete_success(self, app: Flask):
|
||||
def test_delete_success(self, app: Flask) -> None:
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
pipeline = MagicMock(tenant_id="t1", workflow_id="active-workflow", id="pipeline-1")
|
||||
pipeline = make_pipeline(tenant_id="t1", workflow_id="active-workflow")
|
||||
|
||||
workflow_service = MagicMock()
|
||||
|
||||
@ -846,11 +793,11 @@ class TestRagPipelineByIdApi:
|
||||
workflow_service.delete_workflow.assert_called_once()
|
||||
assert result == (None, 204)
|
||||
|
||||
def test_delete_active_workflow_rejected(self, app: Flask):
|
||||
def test_delete_active_workflow_rejected(self, app: Flask) -> None:
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
pipeline = MagicMock(tenant_id="t1", workflow_id="active-workflow", id="pipeline-1")
|
||||
pipeline = make_pipeline(tenant_id="t1", workflow_id="active-workflow")
|
||||
|
||||
with app.test_request_context("/", method="DELETE"):
|
||||
with pytest.raises(BadRequest, match="currently in use by pipeline"):
|
||||
@ -859,14 +806,14 @@ class TestRagPipelineByIdApi:
|
||||
|
||||
class TestRagPipelineWorkflowLastRunApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_last_run_success(self, app: Flask):
|
||||
def test_last_run_success(self, app: Flask) -> None:
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
workflow = MagicMock()
|
||||
node_exec = make_node_execution()
|
||||
|
||||
@ -886,11 +833,11 @@ class TestRagPipelineWorkflowLastRunApi:
|
||||
assert result["inputs"] == {"query": "hello"}
|
||||
assert result["outputs"] == {"answer": "world"}
|
||||
|
||||
def test_last_run_not_found(self, app: Flask):
|
||||
def test_last_run_not_found(self, app: Flask) -> None:
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = None
|
||||
@ -908,19 +855,19 @@ class TestRagPipelineWorkflowLastRunApi:
|
||||
|
||||
class TestRagPipelineDatasourceVariableApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_set_datasource_variables_success(self, app: Flask):
|
||||
def test_set_datasource_variables_success(self, app: Flask) -> None:
|
||||
api = RagPipelineDatasourceVariableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
pipeline = make_pipeline()
|
||||
user = make_account()
|
||||
|
||||
payload = {
|
||||
"datasource_type": "db",
|
||||
"datasource_info": {},
|
||||
"datasource_info": empty_mapping(),
|
||||
"start_node_id": "n1",
|
||||
"start_node_title": "Node",
|
||||
}
|
||||
@ -931,15 +878,11 @@ class TestRagPipelineDatasourceVariableApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
result = method(api, user, pipeline)
|
||||
assert result["node_id"] == "n1"
|
||||
assert result["process_data"] == {}
|
||||
|
||||
@ -3,12 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from inspect import unwrap
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace.tool_providers import (
|
||||
@ -43,41 +44,56 @@ from controllers.console.workspace.tool_providers import (
|
||||
ToolWorkflowProviderUpdateApi,
|
||||
is_valid_url,
|
||||
)
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.tools.mcp_tools_manage_service import ReconnectResult
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
def empty_mapping() -> dict[str, object]:
|
||||
return {}
|
||||
|
||||
|
||||
def empty_list() -> list[object]:
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_cache():
|
||||
def _mock_cache() -> None:
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_user_tenant():
|
||||
def _mock_user_tenant() -> None:
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(flask_app_with_containers: Flask):
|
||||
def client(flask_app_with_containers: Flask) -> FlaskClient:
|
||||
return flask_app_with_containers.test_client()
|
||||
|
||||
|
||||
@patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "t1"),
|
||||
autospec=True,
|
||||
)
|
||||
def make_account(*, id: str = "u", role: TenantAccountRole = TenantAccountRole.EDITOR) -> Account:
|
||||
account = Account(name="Alice", email=f"{id}@example.com")
|
||||
account.id = id
|
||||
account.role = role
|
||||
return account
|
||||
|
||||
|
||||
@patch("controllers.console.workspace.tool_providers.sessionmaker", autospec=True)
|
||||
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url", autospec=True)
|
||||
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
|
||||
def test_create_mcp_provider_populates_tools(
|
||||
mock_reconnect, mock_session, mock_current_account_with_tenant, client: FlaskClient
|
||||
):
|
||||
mock_reconnect: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
client: FlaskClient,
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
headers = authenticate_console_client(client, account)
|
||||
|
||||
# Arrange: reconnect returns tools immediately
|
||||
mock_reconnect.return_value = ReconnectResult(
|
||||
authed=True,
|
||||
@ -104,21 +120,11 @@ def test_create_mcp_provider_populates_tools(
|
||||
"icon_background": "#000",
|
||||
"server_identifier": "demo-sid",
|
||||
"configuration": {"timeout": 5, "sse_read_timeout": 30},
|
||||
"headers": {},
|
||||
"authentication": {},
|
||||
"headers": empty_mapping(),
|
||||
"authentication": empty_mapping(),
|
||||
}
|
||||
# Act
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
|
||||
patch(
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "t1"),
|
||||
autospec=True,
|
||||
),
|
||||
patch("libs.login.check_csrf_token", return_value=None, autospec=True), # bypass CSRF in login_required
|
||||
patch(
|
||||
"libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True), autospec=True
|
||||
), # login
|
||||
patch(
|
||||
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
|
||||
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
|
||||
@ -128,6 +134,7 @@ def test_create_mcp_provider_populates_tools(
|
||||
resp = client.post(
|
||||
"/console/api/workspaces/current/tool-provider/mcp",
|
||||
data=json.dumps(payload),
|
||||
headers=headers,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
@ -141,7 +148,7 @@ def test_create_mcp_provider_populates_tools(
|
||||
|
||||
|
||||
class TestUtils:
|
||||
def test_is_valid_url(self):
|
||||
def test_is_valid_url(self) -> None:
|
||||
assert is_valid_url("https://example.com")
|
||||
assert is_valid_url("http://example.com")
|
||||
assert not is_valid_url("")
|
||||
@ -152,154 +159,121 @@ class TestUtils:
|
||||
|
||||
class TestToolProviderListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app: Flask):
|
||||
def test_get_success(self, app: Flask) -> None:
|
||||
api = ToolProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers",
|
||||
return_value=["p1"],
|
||||
),
|
||||
):
|
||||
assert method(api) == ["p1"]
|
||||
assert method(api, "t1", make_account(id="u1")) == ["p1"]
|
||||
|
||||
|
||||
class TestBuiltinProviderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_list_tools(self, app: Flask):
|
||||
def test_list_tools(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderListToolsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools",
|
||||
return_value=[{"a": 1}],
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == [{"a": 1}]
|
||||
assert method(api, "t1", "provider") == [{"a": 1}]
|
||||
|
||||
def test_info(self, app: Flask):
|
||||
def test_info(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info",
|
||||
return_value={"x": 1},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"x": 1}
|
||||
assert method(api, "t1", "provider") == {"x": 1}
|
||||
|
||||
def test_delete(self, app: Flask):
|
||||
def test_delete(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credential_id": "cid"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_builtin_tool_provider",
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["result"] == "success"
|
||||
assert method(api, "t1", "provider")["result"] == "success"
|
||||
|
||||
def test_add_invalid_type(self, app: Flask):
|
||||
def test_add_invalid_type(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}, "type": "invalid"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
app.test_request_context("/", json={"credentials": empty_mapping(), "type": "invalid"}),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "provider")
|
||||
method(api, "t", make_account(), "provider")
|
||||
|
||||
def test_add_success(self, app: Flask):
|
||||
def test_add_success(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {}, "type": "oauth2", "name": "n"}
|
||||
payload = {"credentials": empty_mapping(), "type": "oauth2", "name": "n"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["id"] == 1
|
||||
assert method(api, "t", make_account(), "provider")["id"] == 1
|
||||
|
||||
def test_update(self, app: Flask):
|
||||
def test_update(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "c1", "credentials": {}, "name": "n"}
|
||||
payload = {"credential_id": "c1", "credentials": empty_mapping(), "name": "n"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
assert method(api, "t", make_account(), "provider")["ok"]
|
||||
|
||||
def test_get_credentials(self, app: Flask):
|
||||
def test_get_credentials(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderGetCredentialsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
mock_user = SimpleNamespace(id="user-1", is_admin_or_owner=False)
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(mock_user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials",
|
||||
return_value={"k": "v"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"k": "v"}
|
||||
assert method(api, "t", make_account(id="user-1"), "provider") == {"k": "v"}
|
||||
|
||||
def test_icon(self, app: Flask):
|
||||
def test_icon(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderIconApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -313,208 +287,168 @@ class TestBuiltinProviderApis:
|
||||
response = method(api, "provider")
|
||||
assert response.mimetype == "image/png"
|
||||
|
||||
def test_credentials_schema(self, app: Flask):
|
||||
def test_credentials_schema(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderCredentialsSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema",
|
||||
return_value={"schema": {}},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider", "oauth2") == {"schema": {}}
|
||||
assert method(api, "t", "provider", "oauth2") == {"schema": {}}
|
||||
|
||||
def test_set_default_credential(self, app: Flask):
|
||||
def test_set_default_credential(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderSetDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"id": "c1"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
assert method(api, "t", "provider")["ok"]
|
||||
|
||||
def test_get_credential_info(self, app: Flask):
|
||||
def test_get_credential_info(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderGetCredentialInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info",
|
||||
return_value={"info": "x"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"info": "x"}
|
||||
assert method(api, "t", make_account(), "provider") == {"info": "x"}
|
||||
|
||||
def test_get_oauth_client_schema(self, app: Flask):
|
||||
def test_get_oauth_client_schema(self, app: Flask) -> None:
|
||||
api = ToolBuiltinProviderGetOauthClientSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema",
|
||||
return_value={"schema": {}},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"schema": {}}
|
||||
assert method(api, "t", "provider") == {"schema": {}}
|
||||
|
||||
|
||||
class TestApiProviderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_add(self, app: Flask):
|
||||
def test_add(self, app: Flask) -> None:
|
||||
api = ToolApiProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credentials": {},
|
||||
"credentials": empty_mapping(),
|
||||
"schema_type": "openapi",
|
||||
"schema": "{}",
|
||||
"provider": "p",
|
||||
"icon": {},
|
||||
"icon": empty_mapping(),
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
):
|
||||
assert method(api)["id"] == 1
|
||||
assert method(api, "t", make_account())["id"] == 1
|
||||
|
||||
def test_remote_schema(self, app: Flask):
|
||||
def test_remote_schema(self, app: Flask) -> None:
|
||||
api = ToolApiProviderGetRemoteSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?url=http://x.com"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema",
|
||||
return_value={"schema": "x"},
|
||||
),
|
||||
):
|
||||
assert method(api)["schema"] == "x"
|
||||
assert method(api, "t", make_account())["schema"] == "x"
|
||||
|
||||
def test_list_tools(self, app: Flask):
|
||||
def test_list_tools(self, app: Flask) -> None:
|
||||
api = ToolApiProviderListToolsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?provider=p"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools",
|
||||
return_value=[{"tool": 1}],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"tool": 1}]
|
||||
assert method(api, "t", make_account()) == [{"tool": 1}]
|
||||
|
||||
def test_update(self, app: Flask):
|
||||
def test_update(self, app: Flask) -> None:
|
||||
api = ToolApiProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credentials": {},
|
||||
"credentials": empty_mapping(),
|
||||
"schema_type": "openapi",
|
||||
"schema": "{}",
|
||||
"provider": "p",
|
||||
"original_provider": "o",
|
||||
"icon": {},
|
||||
"icon": empty_mapping(),
|
||||
"privacy_policy": "",
|
||||
"custom_disclaimer": "",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api)["ok"]
|
||||
assert method(api, "t", make_account())["ok"]
|
||||
|
||||
def test_delete(self, app: Flask):
|
||||
def test_delete(self, app: Flask) -> None:
|
||||
api = ToolApiProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"provider": "p"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.delete_api_tool_provider",
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api)["result"] == "success"
|
||||
assert method(api, "t", make_account())["result"] == "success"
|
||||
|
||||
def test_get(self, app: Flask):
|
||||
def test_get(self, app: Flask) -> None:
|
||||
api = ToolApiProviderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?provider=p"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider",
|
||||
return_value={"x": 1},
|
||||
),
|
||||
):
|
||||
assert method(api) == {"x": 1}
|
||||
assert method(api, "t", make_account()) == {"x": 1}
|
||||
|
||||
|
||||
class TestWorkflowApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_create(self, app: Flask):
|
||||
def test_create(self, app: Flask) -> None:
|
||||
api = ToolWorkflowProviderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -523,24 +457,20 @@ class TestWorkflowApis:
|
||||
"name": "n",
|
||||
"label": "l",
|
||||
"description": "d",
|
||||
"icon": {},
|
||||
"parameters": [],
|
||||
"icon": empty_mapping(),
|
||||
"parameters": empty_list(),
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
):
|
||||
assert method(api)["id"] == 1
|
||||
assert method(api, "t", make_account())["id"] == 1
|
||||
|
||||
def test_update_invalid(self, app: Flask):
|
||||
def test_update_invalid(self, app: Flask) -> None:
|
||||
api = ToolWorkflowProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -549,61 +479,49 @@ class TestWorkflowApis:
|
||||
"name": "Tool",
|
||||
"label": "Tool Label",
|
||||
"description": "A tool",
|
||||
"icon": {},
|
||||
"icon": empty_mapping(),
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t", make_account())
|
||||
assert result["ok"]
|
||||
|
||||
def test_delete(self, app: Flask):
|
||||
def test_delete(self, app: Flask) -> None:
|
||||
api = ToolWorkflowProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api)["ok"]
|
||||
assert method(api, "t", make_account())["ok"]
|
||||
|
||||
def test_get_error(self, app: Flask):
|
||||
def test_get_error(self, app: Flask) -> None:
|
||||
api = ToolWorkflowProviderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, "t", make_account())
|
||||
|
||||
|
||||
class TestLists:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_builtin_list(self, app: Flask):
|
||||
def test_builtin_list(self, app: Flask) -> None:
|
||||
api = ToolBuiltinListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -612,18 +530,14 @@ class TestLists:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools",
|
||||
return_value=[m],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
assert method(api, "t", make_account()) == [{"x": 1}]
|
||||
|
||||
def test_api_list(self, app: Flask):
|
||||
def test_api_list(self, app: Flask) -> None:
|
||||
api = ToolApiListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -632,18 +546,14 @@ class TestLists:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools",
|
||||
return_value=[m],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
assert method(api, "t") == [{"x": 1}]
|
||||
|
||||
def test_workflow_list(self, app: Flask):
|
||||
def test_workflow_list(self, app: Flask) -> None:
|
||||
api = ToolWorkflowListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -652,24 +562,20 @@ class TestLists:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools",
|
||||
return_value=[m],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
assert method(api, "t", make_account()) == [{"x": 1}]
|
||||
|
||||
|
||||
class TestLabels:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_labels(self, app: Flask):
|
||||
def test_labels(self, app: Flask) -> None:
|
||||
api = ToolLabelsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -685,28 +591,24 @@ class TestLabels:
|
||||
|
||||
class TestOAuth:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_oauth_no_client(self, app: Flask):
|
||||
def test_oauth_no_client(self, app: Flask) -> None:
|
||||
api = ToolPluginOAuthApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "provider")
|
||||
method(api, "t", make_account(), "provider")
|
||||
|
||||
def test_oauth_callback_no_cookie(self, app: Flask):
|
||||
def test_oauth_callback_no_cookie(self, app: Flask) -> None:
|
||||
api = ToolOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@ -717,56 +619,44 @@ class TestOAuth:
|
||||
|
||||
class TestOAuthCustomClient:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_save_custom_client(self, app: Flask):
|
||||
def test_save_custom_client(self, app: Flask) -> None:
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"client_params": {"a": 1}}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
assert method(api, "t", "provider")["ok"]
|
||||
|
||||
def test_get_custom_client(self, app: Flask):
|
||||
def test_get_custom_client(self, app: Flask) -> None:
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_custom_oauth_client_params",
|
||||
return_value={"client_id": "x"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"client_id": "x"}
|
||||
assert method(api, "t", "provider") == {"client_id": "x"}
|
||||
|
||||
def test_delete_custom_client(self, app: Flask):
|
||||
def test_delete_custom_client(self, app: Flask) -> None:
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
assert method(api, "t", "provider")["ok"]
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from flask import Response
|
||||
|
||||
from clients.agent_backend.errors import AgentBackendHTTPError, AgentBackendTransportError
|
||||
from clients.agent_backend.workspace_files_client import (
|
||||
@ -12,16 +15,10 @@ from clients.agent_backend.workspace_files_client import (
|
||||
WorkspacePreviewResult,
|
||||
)
|
||||
from controllers.console import agent_app_workspace as module
|
||||
from models.model import App, AppMode, IconType
|
||||
from services.agent_app_workspace_service import AgentWorkspaceInspectorError
|
||||
|
||||
|
||||
def _unwrapped_get(resource_cls):
|
||||
func = resource_cls.get
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class _AgentAppService:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[tuple[str, str, str, str, str]] = []
|
||||
@ -87,6 +84,20 @@ class _WorkflowService:
|
||||
return WorkspaceDownloadResult(path=path, size=3, truncated=False, content=b"abc")
|
||||
|
||||
|
||||
def _app_model(app_id: str = "app-1") -> App:
|
||||
return App(
|
||||
id=app_id,
|
||||
tenant_id="tenant-1",
|
||||
name="App",
|
||||
mode=AppMode.AGENT,
|
||||
icon_type=IconType.EMOJI,
|
||||
icon="bot",
|
||||
icon_background="#fff",
|
||||
enable_site=False,
|
||||
enable_api=False,
|
||||
)
|
||||
|
||||
|
||||
def test_handle_maps_workspace_and_agent_backend_errors() -> None:
|
||||
assert module._handle(AgentWorkspaceInspectorError("no_sandbox", "no sandbox", status_code=404)) == (
|
||||
{"code": "no_sandbox", "message": "no sandbox"},
|
||||
@ -108,8 +119,11 @@ def test_handle_maps_workspace_and_agent_backend_errors() -> None:
|
||||
|
||||
|
||||
def test_download_response_returns_binary_or_too_large_error() -> None:
|
||||
response = module._download_response(
|
||||
WorkspaceDownloadResult(path="dir/report.txt", size=3, truncated=False, content=b"abc")
|
||||
response = cast(
|
||||
Response,
|
||||
module._download_response(
|
||||
WorkspaceDownloadResult(path="dir/report.txt", size=3, truncated=False, content=b"abc")
|
||||
),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@ -133,17 +147,16 @@ def test_download_response_returns_binary_or_too_large_error() -> None:
|
||||
def test_agent_app_workspace_resources_proxy_service(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = _AgentAppService()
|
||||
monkeypatch.setattr(module, "AgentAppWorkspaceService", lambda: service)
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (None, "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"query_params_from_request",
|
||||
lambda model: SimpleNamespace(conversation_id="conv-1", path="sub/report.txt"),
|
||||
)
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
app_model = _app_model()
|
||||
|
||||
listing = _unwrapped_get(module.AgentAppWorkspaceListResource)(object(), app_model)
|
||||
preview = _unwrapped_get(module.AgentAppWorkspacePreviewResource)(object(), app_model)
|
||||
download = _unwrapped_get(module.AgentAppWorkspaceDownloadResource)(object(), app_model)
|
||||
listing = unwrap(module.AgentAppWorkspaceListResource.get)(object(), "tenant-1", app_model)
|
||||
preview = unwrap(module.AgentAppWorkspacePreviewResource.get)(object(), "tenant-1", app_model)
|
||||
download = unwrap(module.AgentAppWorkspaceDownloadResource.get)(object(), "tenant-1", app_model)
|
||||
|
||||
assert listing["entries"][0]["name"] == "a.txt"
|
||||
assert preview["text"] == "hello"
|
||||
@ -161,14 +174,13 @@ def test_agent_app_workspace_resource_returns_normalized_errors(monkeypatch: pyt
|
||||
raise AgentWorkspaceInspectorError("no_active_session", "no active session", status_code=404)
|
||||
|
||||
monkeypatch.setattr(module, "AgentAppWorkspaceService", FailingService)
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (None, "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"query_params_from_request",
|
||||
lambda model: SimpleNamespace(conversation_id="conv-1", path="."),
|
||||
)
|
||||
|
||||
assert _unwrapped_get(module.AgentAppWorkspaceListResource)(object(), SimpleNamespace(id="app-1")) == (
|
||||
assert unwrap(module.AgentAppWorkspaceListResource.get)(object(), "tenant-1", _app_model()) == (
|
||||
{"code": "no_active_session", "message": "no active session"},
|
||||
404,
|
||||
)
|
||||
@ -177,17 +189,22 @@ def test_agent_app_workspace_resource_returns_normalized_errors(monkeypatch: pyt
|
||||
def test_workflow_agent_workspace_resources_proxy_service(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
service = _WorkflowService()
|
||||
monkeypatch.setattr(module, "WorkflowAgentWorkspaceService", lambda: service)
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (None, "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"query_params_from_request",
|
||||
lambda model: SimpleNamespace(node_execution_id="exec-1", path="out.txt"),
|
||||
)
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
app_model = _app_model()
|
||||
|
||||
listing = _unwrapped_get(module.WorkflowAgentWorkspaceListResource)(object(), app_model, "run-1", "agent-node")
|
||||
preview = _unwrapped_get(module.WorkflowAgentWorkspacePreviewResource)(object(), app_model, "run-1", "agent-node")
|
||||
download = _unwrapped_get(module.WorkflowAgentWorkspaceDownloadResource)(object(), app_model, "run-1", "agent-node")
|
||||
listing = unwrap(module.WorkflowAgentWorkspaceListResource.get)(
|
||||
object(), "tenant-1", app_model, "run-1", "agent-node"
|
||||
)
|
||||
preview = unwrap(module.WorkflowAgentWorkspacePreviewResource.get)(
|
||||
object(), "tenant-1", app_model, "run-1", "agent-node"
|
||||
)
|
||||
download = unwrap(module.WorkflowAgentWorkspaceDownloadResource.get)(
|
||||
object(), "tenant-1", app_model, "run-1", "agent-node"
|
||||
)
|
||||
|
||||
assert listing["path"] == "out.txt"
|
||||
assert preview["text"] == "hello"
|
||||
|
||||
@ -542,7 +542,6 @@ def test_workflow_online_users_filters_inaccessible_workflow(app: Flask, monkeyp
|
||||
app_id_2 = "22222222-2222-2222-2222-222222222222"
|
||||
signed_avatar_url = "https://files.example.com/signed/avatar-1"
|
||||
sign_avatar = Mock(return_value=signed_avatar_url)
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module,
|
||||
"WorkflowService",
|
||||
@ -596,7 +595,7 @@ def test_workflow_online_users_filters_inaccessible_workflow(app: Flask, monkeyp
|
||||
method="POST",
|
||||
json={"app_ids": [app_id_1, app_id_2]},
|
||||
):
|
||||
response = handler(api)
|
||||
response = handler(api, "tenant-1")
|
||||
|
||||
assert response == {
|
||||
"data": [
|
||||
@ -625,7 +624,6 @@ def test_workflow_online_users_filters_inaccessible_workflow(app: Flask, monkeyp
|
||||
|
||||
def test_workflow_online_users_batches_redis_reads(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_ids = [f"wf-{index}" for index in range(workflow_module.WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE + 1)]
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module,
|
||||
"WorkflowService",
|
||||
@ -647,7 +645,7 @@ def test_workflow_online_users_batches_redis_reads(app: Flask, monkeypatch: pyte
|
||||
method="POST",
|
||||
json={"app_ids": app_ids},
|
||||
):
|
||||
response = handler(api)
|
||||
response = handler(api, "tenant-1")
|
||||
|
||||
assert len(response["data"]) == len(app_ids)
|
||||
assert redis_pipeline_factory.call_count == 2
|
||||
@ -656,7 +654,6 @@ def test_workflow_online_users_batches_redis_reads(app: Flask, monkeypatch: pyte
|
||||
|
||||
|
||||
def test_workflow_online_users_rejects_excessive_workflow_ids(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
|
||||
accessible_app_ids = Mock(return_value=set())
|
||||
monkeypatch.setattr(
|
||||
workflow_module,
|
||||
@ -675,7 +672,7 @@ def test_workflow_online_users_rejects_excessive_workflow_ids(app: Flask, monkey
|
||||
json={"app_ids": excessive_ids},
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
handler(api)
|
||||
handler(api, "tenant-1")
|
||||
|
||||
assert exc.value.code == 400
|
||||
assert exc.value.description is not None
|
||||
|
||||
@ -39,7 +39,6 @@ def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app
|
||||
monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None)
|
||||
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from inspect import unwrap
|
||||
from inspect import unwrap as unwrap_all
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.datasets.rag_pipeline import rag_pipeline_workflow as module
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
def _make_workflow(**overrides):
|
||||
@ -33,6 +35,19 @@ def _make_workflow(**overrides):
|
||||
return workflow
|
||||
|
||||
|
||||
def _account() -> Account:
|
||||
account = Account(name="Alice", email="alice@example.com")
|
||||
account.id = "user-1"
|
||||
account.role = TenantAccountRole.EDITOR
|
||||
return account
|
||||
|
||||
|
||||
def _pipeline() -> Pipeline:
|
||||
pipeline = Pipeline(tenant_id="tenant-1", name="Pipeline", description="desc")
|
||||
pipeline.id = "pipeline-1"
|
||||
return pipeline
|
||||
|
||||
|
||||
def test_draft_rag_pipeline_workflow_get_serializes_response_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = _make_workflow()
|
||||
monkeypatch.setattr(
|
||||
@ -40,9 +55,9 @@ def test_draft_rag_pipeline_workflow_get_serializes_response_model(monkeypatch:
|
||||
)
|
||||
|
||||
api = module.DraftRagPipelineApi()
|
||||
handler = unwrap(api.get)
|
||||
handler = unwrap_all(api.get)
|
||||
|
||||
response = handler(api, pipeline=SimpleNamespace(id="pipeline-1"))
|
||||
response = handler(api, _pipeline())
|
||||
|
||||
assert response["id"] == "workflow-1"
|
||||
assert response["graph"] == {"nodes": [], "edges": []}
|
||||
@ -58,7 +73,7 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes(
|
||||
app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
api = module.PublishedAllRagPipelineApi()
|
||||
handler = unwrap(api.get)
|
||||
handler = unwrap_all(api.get)
|
||||
session_state = {"open": False}
|
||||
|
||||
class _SessionContext:
|
||||
@ -83,7 +98,6 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes(
|
||||
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker())
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (SimpleNamespace(id="user-1"), "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"RagPipelineService",
|
||||
@ -95,7 +109,7 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes(
|
||||
method="GET",
|
||||
query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"},
|
||||
):
|
||||
response = handler(api, pipeline=SimpleNamespace(id="pipeline-1"))
|
||||
response = handler(api, _account(), pipeline=_pipeline())
|
||||
|
||||
assert response["items"][0]["id"] == "workflow-1"
|
||||
assert response["page"] == 1
|
||||
@ -105,7 +119,6 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes(
|
||||
|
||||
def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = _make_workflow(marked_name="Updated release")
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (SimpleNamespace(id="user-1"), "tenant-1"))
|
||||
|
||||
class _SessionContext:
|
||||
def __enter__(self):
|
||||
@ -128,7 +141,7 @@ def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch:
|
||||
payload: dict[str, object] = {"marked_name": "Updated release"}
|
||||
|
||||
api = module.RagPipelineByIdApi()
|
||||
handler = unwrap(api.patch)
|
||||
handler = unwrap_all(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/rag/pipelines/pipeline-1/workflows/workflow-1", method="PATCH", json=payload),
|
||||
@ -136,7 +149,8 @@ def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch:
|
||||
):
|
||||
response = handler(
|
||||
api,
|
||||
pipeline=SimpleNamespace(id="pipeline-1", tenant_id="tenant-1"),
|
||||
_account(),
|
||||
pipeline=_pipeline(),
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
@ -8,21 +9,34 @@ import pytest
|
||||
from werkzeug.exceptions import HTTPException, NotFound
|
||||
|
||||
from controllers.console.snippets import snippet_workflow as snippet_workflow_module
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.snippet import CustomizedSnippet
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
def _account(account_id: str = "account-1") -> Account:
|
||||
account = Account(name="Test User", email=f"{account_id}@example.com")
|
||||
account.id = account_id
|
||||
account.role = TenantAccountRole.EDITOR
|
||||
return account
|
||||
|
||||
|
||||
def _snippet(**overrides) -> CustomizedSnippet:
|
||||
data = {
|
||||
"id": "snippet-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"name": "Snippet",
|
||||
"description": "Description",
|
||||
"type": "node",
|
||||
"created_by": "account-1",
|
||||
}
|
||||
data.update(overrides)
|
||||
return CustomizedSnippet(**data)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_snippet_service_factory(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def factory():
|
||||
service_factory = snippet_workflow_module.SnippetService
|
||||
if isinstance(service_factory, type):
|
||||
return service_factory.__new__(service_factory)
|
||||
return service_factory()
|
||||
return snippet_workflow_module.SnippetService()
|
||||
|
||||
monkeypatch.setattr(snippet_workflow_module, "_snippet_service", factory)
|
||||
monkeypatch.setattr(snippet_workflow_module, "_snippet_session_maker", Mock(return_value=Mock()))
|
||||
@ -39,7 +53,7 @@ def test_get_snippet_requires_snippet_id(app):
|
||||
|
||||
|
||||
def test_get_snippet_injects_resolved_snippet(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
|
||||
snippet = _snippet()
|
||||
|
||||
@snippet_workflow_module.get_snippet
|
||||
def view(**kwargs):
|
||||
@ -48,7 +62,7 @@ def test_get_snippet_injects_resolved_snippet(app, monkeypatch: pytest.MonkeyPat
|
||||
monkeypatch.setattr(
|
||||
snippet_workflow_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="account-1"), "tenant-1"),
|
||||
lambda: (_account("account-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr(snippet_workflow_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
|
||||
|
||||
@ -66,7 +80,7 @@ def test_get_snippet_raises_not_found_when_snippet_missing(app, monkeypatch: pyt
|
||||
monkeypatch.setattr(
|
||||
snippet_workflow_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="account-1"), "tenant-1"),
|
||||
lambda: (_account("account-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr(snippet_workflow_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
|
||||
|
||||
@ -76,7 +90,7 @@ def test_get_snippet_raises_not_found_when_snippet_missing(app, monkeypatch: pyt
|
||||
|
||||
|
||||
def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
|
||||
snippet = _snippet()
|
||||
monkeypatch.setattr(
|
||||
snippet_workflow_module,
|
||||
"SnippetService",
|
||||
@ -84,7 +98,7 @@ def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyP
|
||||
)
|
||||
|
||||
api = snippet_workflow_module.SnippetDraftWorkflowApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/snippets/snippet-1/workflows/draft"):
|
||||
with pytest.raises(snippet_workflow_module.DraftWorkflowNotExist):
|
||||
@ -92,10 +106,9 @@ def test_draft_workflow_get_raises_when_missing(app, monkeypatch: pytest.MonkeyP
|
||||
|
||||
|
||||
def test_draft_workflow_post_returns_400_for_invalid_graph(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(id="account-1")
|
||||
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
|
||||
user = _account("account-1")
|
||||
snippet = _snippet()
|
||||
sync_draft_workflow = Mock(side_effect=ValueError("invalid graph"))
|
||||
monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
snippet_workflow_module,
|
||||
"SnippetService",
|
||||
@ -103,14 +116,14 @@ def test_draft_workflow_post_returns_400_for_invalid_graph(app, monkeypatch: pyt
|
||||
)
|
||||
|
||||
api = snippet_workflow_module.SnippetDraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/snippets/snippet-1/workflows/draft",
|
||||
method="POST",
|
||||
json={"graph": {"nodes": [], "edges": []}, "hash": "hash-1"},
|
||||
):
|
||||
response, status_code = handler(api, snippet=snippet)
|
||||
response, status_code = handler(api, user, snippet)
|
||||
|
||||
assert status_code == 400
|
||||
assert response == {"message": "invalid graph"}
|
||||
@ -118,7 +131,7 @@ def test_draft_workflow_post_returns_400_for_invalid_graph(app, monkeypatch: pyt
|
||||
|
||||
def test_draft_config_returns_parallel_depth_limit(app) -> None:
|
||||
api = snippet_workflow_module.SnippetDraftConfigApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/snippets/snippet-1/workflows/draft/config"):
|
||||
assert handler(api, snippet=SimpleNamespace(id="snippet-1")) == {"parallel_depth_limit": 3}
|
||||
@ -126,16 +139,16 @@ def test_draft_config_returns_parallel_depth_limit(app) -> None:
|
||||
|
||||
def test_published_workflow_get_returns_none_when_not_published(app) -> None:
|
||||
api = snippet_workflow_module.SnippetPublishedWorkflowApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/snippets/snippet-1/workflows/publish"):
|
||||
assert handler(api, snippet=SimpleNamespace(id="snippet-1", is_published=False)) is None
|
||||
|
||||
|
||||
def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(id="account-1")
|
||||
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
|
||||
merged_snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
|
||||
user = _account("account-1")
|
||||
snippet = _snippet()
|
||||
merged_snippet = _snippet()
|
||||
session = SimpleNamespace(merge=Mock(return_value=merged_snippet), commit=Mock())
|
||||
|
||||
class SessionContext:
|
||||
@ -148,7 +161,6 @@ def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
|
||||
monkeypatch.setattr(snippet_workflow_module, "Session", SessionContext)
|
||||
monkeypatch.setattr(snippet_workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
@ -158,10 +170,10 @@ def test_published_workflow_post_returns_400_when_publish_fails(app, monkeypatch
|
||||
)
|
||||
|
||||
api = snippet_workflow_module.SnippetPublishedWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/snippets/snippet-1/workflows/publish", method="POST", json={}):
|
||||
response, status_code = handler(api, snippet=snippet)
|
||||
response, status_code = handler(api, user, snippet)
|
||||
|
||||
assert status_code == 400
|
||||
assert response == {"message": "No valid workflow found."}
|
||||
@ -177,7 +189,7 @@ def test_default_block_configs_delegates_to_service(app, monkeypatch: pytest.Mon
|
||||
)
|
||||
|
||||
api = snippet_workflow_module.SnippetDefaultBlockConfigsApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/snippets/snippet-1/workflows/default-workflow-block-configs"):
|
||||
result = handler(api, snippet=SimpleNamespace(id="snippet-1"))
|
||||
@ -192,10 +204,9 @@ def test_restore_published_snippet_workflow_to_draft_success(app, monkeypatch: p
|
||||
updated_at=None,
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
user = SimpleNamespace(id="account-1")
|
||||
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
|
||||
user = _account("account-1")
|
||||
snippet = _snippet()
|
||||
|
||||
monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
snippet_workflow_module,
|
||||
"SnippetService",
|
||||
@ -203,23 +214,22 @@ def test_restore_published_snippet_workflow_to_draft_success(app, monkeypatch: p
|
||||
)
|
||||
|
||||
api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/snippets/snippet-1/workflows/published-workflow/restore",
|
||||
method="POST",
|
||||
):
|
||||
response = handler(api, snippet=snippet, workflow_id="published-workflow")
|
||||
response = handler(api, user, snippet, workflow_id="published-workflow")
|
||||
|
||||
assert response["result"] == "success"
|
||||
assert response["hash"] == "restored-hash"
|
||||
|
||||
|
||||
def test_restore_published_snippet_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(id="account-1")
|
||||
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
|
||||
user = _account("account-1")
|
||||
snippet = _snippet()
|
||||
|
||||
monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
snippet_workflow_module,
|
||||
"SnippetService",
|
||||
@ -231,23 +241,22 @@ def test_restore_published_snippet_workflow_to_draft_not_found(app, monkeypatch:
|
||||
)
|
||||
|
||||
api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/snippets/snippet-1/workflows/published-workflow/restore",
|
||||
method="POST",
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, snippet=snippet, workflow_id="published-workflow")
|
||||
handler(api, user, snippet, workflow_id="published-workflow")
|
||||
|
||||
|
||||
def test_restore_published_snippet_workflow_to_draft_returns_400_for_draft_source(
|
||||
app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
user = SimpleNamespace(id="account-1")
|
||||
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
|
||||
user = _account("account-1")
|
||||
snippet = _snippet()
|
||||
|
||||
monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
snippet_workflow_module,
|
||||
"SnippetService",
|
||||
@ -259,14 +268,14 @@ def test_restore_published_snippet_workflow_to_draft_returns_400_for_draft_sourc
|
||||
)
|
||||
|
||||
api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/snippets/snippet-1/workflows/draft-workflow/restore",
|
||||
method="POST",
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
handler(api, snippet=snippet, workflow_id="draft-workflow")
|
||||
handler(api, user, snippet, workflow_id="draft-workflow")
|
||||
|
||||
assert exc.value.code == 400
|
||||
assert exc.value.description == snippet_workflow_module.RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE
|
||||
@ -275,10 +284,9 @@ def test_restore_published_snippet_workflow_to_draft_returns_400_for_draft_sourc
|
||||
def test_restore_published_snippet_workflow_to_draft_returns_400_for_invalid_graph(
|
||||
app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
user = SimpleNamespace(id="account-1")
|
||||
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
|
||||
user = _account("account-1")
|
||||
snippet = _snippet()
|
||||
|
||||
monkeypatch.setattr(snippet_workflow_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
snippet_workflow_module,
|
||||
"SnippetService",
|
||||
@ -290,21 +298,21 @@ def test_restore_published_snippet_workflow_to_draft_returns_400_for_invalid_gra
|
||||
)
|
||||
|
||||
api = snippet_workflow_module.SnippetDraftWorkflowRestoreApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/snippets/snippet-1/workflows/published-workflow/restore",
|
||||
method="POST",
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
handler(api, snippet=snippet, workflow_id="published-workflow")
|
||||
handler(api, user, snippet, workflow_id="published-workflow")
|
||||
|
||||
assert exc.value.code == 400
|
||||
assert exc.value.description == "invalid snippet workflow graph"
|
||||
|
||||
|
||||
def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
|
||||
snippet = _snippet()
|
||||
monkeypatch.setattr(
|
||||
snippet_workflow_module,
|
||||
"SnippetService",
|
||||
@ -312,7 +320,7 @@ def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch:
|
||||
)
|
||||
|
||||
api = snippet_workflow_module.SnippetWorkflowRunDetailApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/snippets/snippet-1/workflow-runs/run-1"):
|
||||
with pytest.raises(NotFound, match="Workflow run not found"):
|
||||
@ -320,7 +328,7 @@ def test_workflow_run_detail_raises_not_found_when_run_missing(app, monkeypatch:
|
||||
|
||||
|
||||
def test_draft_node_last_run_raises_not_found_when_execution_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
snippet = SimpleNamespace(id="snippet-1", tenant_id="tenant-1")
|
||||
snippet = _snippet()
|
||||
draft_workflow = SimpleNamespace(id="workflow-1")
|
||||
monkeypatch.setattr(
|
||||
snippet_workflow_module,
|
||||
@ -332,7 +340,7 @@ def test_draft_node_last_run_raises_not_found_when_execution_missing(app, monkey
|
||||
)
|
||||
|
||||
api = snippet_workflow_module.SnippetDraftNodeLastRunApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/snippets/snippet-1/workflows/draft/nodes/llm-1/last-run"):
|
||||
with pytest.raises(NotFound, match="Node last run not found"):
|
||||
@ -354,7 +362,7 @@ def test_workflow_task_stop_uses_queue_flag_and_graph_command(app, monkeypatch:
|
||||
)
|
||||
|
||||
api = snippet_workflow_module.SnippetWorkflowTaskStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/snippets/snippet-1/workflow-runs/tasks/task-1/stop", method="POST"):
|
||||
result = handler(api, snippet=SimpleNamespace(id="snippet-1"), task_id="task-1")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import io
|
||||
from inspect import unwrap
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -39,21 +40,19 @@ from controllers.console.workspace.plugin import (
|
||||
PluginUploadFromPkgApi,
|
||||
)
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||
from models.account import Account, TenantAccountRole, TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
def _account(role: TenantAccountRole = TenantAccountRole.OWNER) -> Account:
|
||||
account = Account(name="Test User", email="u1@example.com")
|
||||
account.id = "u1"
|
||||
account.role = role
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user():
|
||||
u = MagicMock()
|
||||
u.id = "u1"
|
||||
u.is_admin_or_owner = True
|
||||
return u
|
||||
return _account()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -102,10 +101,9 @@ class TestPluginDebuggingKeyApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.get_debugging_key", return_value="k"),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["key"] == "k"
|
||||
|
||||
@ -115,13 +113,12 @@ class TestPluginDebuggingKeyApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.get_debugging_key",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
@ -134,10 +131,9 @@ class TestPluginListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.list_with_total", return_value=mock_list),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["total"] == 1
|
||||
|
||||
@ -163,10 +159,9 @@ class TestPluginAssetApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?plugin_unique_identifier=p&file_name=a.bin"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.extract_asset", return_value=b"x"),
|
||||
):
|
||||
response = method(api)
|
||||
response = method(api, "t1")
|
||||
|
||||
assert response.mimetype == "application/octet-stream"
|
||||
|
||||
@ -182,10 +177,9 @@ class TestPluginUploadFromPkgApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", data=data, content_type="multipart/form-data"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.upload_pkg", return_value={"ok": True}),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["ok"] is True
|
||||
|
||||
@ -199,12 +193,11 @@ class TestPluginUploadFromPkgApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", data=data, content_type="multipart/form-data"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 0),
|
||||
patch("controllers.console.workspace.plugin.PluginService.upload_pkg") as upload_pkg_mock,
|
||||
):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
method(api)
|
||||
method(api, "t1")
|
||||
assert "File size exceeds the maximum allowed size" in str(exc_info.value)
|
||||
|
||||
upload_pkg_mock.assert_not_called()
|
||||
@ -219,12 +212,11 @@ class TestPluginInstallFromPkgApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.install_from_local_pkg", return_value={"ok": True}
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["ok"] is True
|
||||
|
||||
@ -238,10 +230,9 @@ class TestPluginUninstallApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.uninstall", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@ -251,7 +242,7 @@ class TestPluginChangePermissionApi:
|
||||
api = PluginChangePermissionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_admin_or_owner=False)
|
||||
user = _account(TenantAccountRole.NORMAL)
|
||||
|
||||
payload = {
|
||||
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
@ -260,16 +251,15 @@ class TestPluginChangePermissionApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, "t1", user)
|
||||
|
||||
def test_change_permission_success(self, app: Flask):
|
||||
api = PluginChangePermissionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_admin_or_owner=True)
|
||||
user = _account()
|
||||
|
||||
payload = {
|
||||
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
@ -278,10 +268,9 @@ class TestPluginChangePermissionApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", user)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@ -293,10 +282,9 @@ class TestPluginFetchPermissionApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginPermissionService.get_permission", return_value=None),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["install_permission"] is not None
|
||||
|
||||
@ -308,13 +296,12 @@ class TestPluginFetchDynamicSelectOptionsApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?plugin_id=p&provider=x&action=y¶meter=z&provider_type=tool"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options",
|
||||
return_value=[1, 2],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", user)
|
||||
|
||||
assert result["options"] == [1, 2]
|
||||
|
||||
@ -326,16 +313,15 @@ class TestPluginReadmeApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?plugin_unique_identifier=p"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.fetch_plugin_readme", return_value="readme"),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["readme"] == "readme"
|
||||
|
||||
|
||||
class TestPluginListInstallationsFromIdsApi:
|
||||
def test_success(self, app: Flask):
|
||||
def test_success(self, app: Flask, user):
|
||||
api = PluginListInstallationsFromIdsApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -343,13 +329,12 @@ class TestPluginListInstallationsFromIdsApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.list_installations_from_ids",
|
||||
return_value=[{"id": "p1"}],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert "plugins" in result
|
||||
|
||||
@ -361,18 +346,17 @@ class TestPluginListInstallationsFromIdsApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.list_installations_from_ids",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
class TestPluginUploadFromGithubApi:
|
||||
def test_success(self, app: Flask):
|
||||
def test_success(self, app: Flask, user):
|
||||
api = PluginUploadFromGithubApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -380,12 +364,11 @@ class TestPluginUploadFromGithubApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.upload_pkg_from_github", return_value={"ok": True}
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["ok"] is True
|
||||
|
||||
@ -397,13 +380,12 @@ class TestPluginUploadFromGithubApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.upload_pkg_from_github",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
@ -424,10 +406,9 @@ class TestPluginUploadFromBundleApi:
|
||||
data={"bundle": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.upload_bundle", return_value={"ok": True}),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["ok"] is True
|
||||
|
||||
@ -447,12 +428,11 @@ class TestPluginUploadFromBundleApi:
|
||||
data={"bundle": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_BUNDLE_SIZE", 0),
|
||||
patch("controllers.console.workspace.plugin.PluginService.upload_bundle") as upload_bundle_mock,
|
||||
):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
method(api)
|
||||
method(api, "t1")
|
||||
assert "File size exceeds the maximum allowed size" in str(exc_info.value)
|
||||
|
||||
upload_bundle_mock.assert_not_called()
|
||||
@ -472,10 +452,9 @@ class TestPluginInstallFromGithubApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.install_from_github", return_value={"ok": True}),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["ok"] is True
|
||||
|
||||
@ -492,13 +471,12 @@ class TestPluginInstallFromGithubApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.install_from_github",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
@ -511,13 +489,12 @@ class TestPluginInstallFromMarketplaceApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.install_from_marketplace_pkg",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["ok"] is True
|
||||
|
||||
@ -529,13 +506,12 @@ class TestPluginInstallFromMarketplaceApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.install_from_marketplace_pkg",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
@ -546,10 +522,9 @@ class TestPluginFetchMarketplacePkgApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?plugin_unique_identifier=p"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.fetch_marketplace_pkg", return_value={"m": 1}),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert "manifest" in result
|
||||
|
||||
@ -559,13 +534,12 @@ class TestPluginFetchMarketplacePkgApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?plugin_unique_identifier=p"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.fetch_marketplace_pkg",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
@ -579,10 +553,9 @@ class TestPluginFetchManifestApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?plugin_unique_identifier=p"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.fetch_plugin_manifest", return_value=manifest),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert "manifest" in result
|
||||
|
||||
@ -592,13 +565,12 @@ class TestPluginFetchManifestApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?plugin_unique_identifier=p"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.fetch_plugin_manifest",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
@ -609,10 +581,9 @@ class TestPluginFetchInstallTasksApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.fetch_install_tasks", return_value=[{"id": 1}]),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert "tasks" in result
|
||||
|
||||
@ -622,13 +593,12 @@ class TestPluginFetchInstallTasksApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.fetch_install_tasks",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
@ -639,10 +609,9 @@ class TestPluginFetchInstallTaskApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.fetch_install_task", return_value={"id": "x"}),
|
||||
):
|
||||
result = method(api, "x")
|
||||
result = method(api, "t1", "x")
|
||||
|
||||
assert "task" in result
|
||||
|
||||
@ -652,13 +621,12 @@ class TestPluginFetchInstallTaskApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.fetch_install_task",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api, "t")
|
||||
result = method(api, "t1", "t")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
@ -669,10 +637,9 @@ class TestPluginDeleteInstallTaskApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.delete_install_task", return_value=True),
|
||||
):
|
||||
result = method(api, "x")
|
||||
result = method(api, "t1", "x")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@ -682,13 +649,12 @@ class TestPluginDeleteInstallTaskApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.delete_install_task",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api, "t")
|
||||
result = method(api, "t1", "t")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
@ -699,12 +665,11 @@ class TestPluginDeleteAllInstallTaskItemsApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.delete_all_install_task_items", return_value=True
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@ -714,13 +679,12 @@ class TestPluginDeleteAllInstallTaskItemsApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.delete_all_install_task_items",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
@ -731,10 +695,9 @@ class TestPluginDeleteInstallTaskItemApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginService.delete_install_task_item", return_value=True),
|
||||
):
|
||||
result = method(api, "task1", "item1")
|
||||
result = method(api, "t1", "task1", "item1")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@ -744,18 +707,17 @@ class TestPluginDeleteInstallTaskItemApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.delete_install_task_item",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api, "task1", "item1")
|
||||
result = method(api, "t1", "task1", "item1")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
class TestPluginUpgradeFromMarketplaceApi:
|
||||
def test_success(self, app: Flask):
|
||||
def test_success(self, app: Flask, user):
|
||||
api = PluginUpgradeFromMarketplaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@ -766,13 +728,12 @@ class TestPluginUpgradeFromMarketplaceApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_marketplace",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["ok"] is True
|
||||
|
||||
@ -787,13 +748,12 @@ class TestPluginUpgradeFromMarketplaceApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_marketplace",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
@ -812,13 +772,12 @@ class TestPluginUpgradeFromGithubApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_github",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["ok"] is True
|
||||
|
||||
@ -836,23 +795,20 @@ class TestPluginUpgradeFromGithubApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_github",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
|
||||
def test_success(self, app: Flask):
|
||||
def test_success(self, app: Flask, user):
|
||||
api = PluginFetchDynamicSelectOptionsWithCredentialsApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(id="u1", is_admin_or_owner=True)
|
||||
|
||||
payload = {
|
||||
"plugin_id": "p",
|
||||
"provider": "x",
|
||||
@ -864,22 +820,19 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options_with_credentials",
|
||||
return_value=[1],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", user)
|
||||
|
||||
assert result["options"] == [1]
|
||||
|
||||
def test_daemon_error(self, app: Flask):
|
||||
def test_daemon_error(self, app: Flask, user):
|
||||
api = PluginFetchDynamicSelectOptionsWithCredentialsApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(id="u1", is_admin_or_owner=True)
|
||||
|
||||
payload = {
|
||||
"plugin_id": "p",
|
||||
"provider": "x",
|
||||
@ -891,13 +844,12 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options_with_credentials",
|
||||
side_effect=PluginDaemonClientSideError("error"),
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", user)
|
||||
assert result == ({"code": "plugin_error", "message": "error"}, 400)
|
||||
|
||||
|
||||
@ -906,7 +858,7 @@ class TestPluginChangePreferencesApi:
|
||||
api = PluginChangePreferencesApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_admin_or_owner=True)
|
||||
user = _account()
|
||||
|
||||
payload = {
|
||||
"permission": {
|
||||
@ -924,11 +876,10 @@ class TestPluginChangePreferencesApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=True),
|
||||
patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.change_strategy", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", user)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@ -936,7 +887,7 @@ class TestPluginChangePreferencesApi:
|
||||
api = PluginChangePreferencesApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_admin_or_owner=True)
|
||||
user = _account()
|
||||
|
||||
payload = {
|
||||
"permission": {
|
||||
@ -954,10 +905,9 @@ class TestPluginChangePreferencesApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", user)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@ -982,7 +932,6 @@ class TestPluginFetchPreferencesApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.plugin.PluginPermissionService.get_permission", return_value=permission
|
||||
),
|
||||
@ -990,7 +939,7 @@ class TestPluginFetchPreferencesApi:
|
||||
"controllers.console.workspace.plugin.PluginAutoUpgradeService.get_strategy", return_value=auto_upgrade
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert "permission" in result
|
||||
assert "auto_upgrade" in result
|
||||
@ -1005,10 +954,9 @@ class TestPluginAutoUpgradeExcludePluginApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.exclude_plugin", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
@ -1020,9 +968,8 @@ class TestPluginAutoUpgradeExcludePluginApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")),
|
||||
patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.exclude_plugin", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
@ -5,15 +6,11 @@ import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.workspace import snippets as snippets_module
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.snippet_dsl_service import ImportStatus, SnippetImportInfo
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_snippet_service_factory(monkeypatch):
|
||||
def factory():
|
||||
@ -34,6 +31,26 @@ class _SessionContext:
|
||||
return False
|
||||
|
||||
|
||||
def _account(account_id: str = "account-1") -> Account:
|
||||
account = Account(name="Test User", email=f"{account_id}@example.com")
|
||||
account.id = account_id
|
||||
account.role = TenantAccountRole.EDITOR
|
||||
return account
|
||||
|
||||
|
||||
def _snippet(**overrides) -> CustomizedSnippet:
|
||||
data = {
|
||||
"id": "snippet-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"name": "Snippet",
|
||||
"description": "Description",
|
||||
"type": snippets_module.SnippetType.NODE,
|
||||
"created_by": "account-1",
|
||||
}
|
||||
data.update(overrides)
|
||||
return CustomizedSnippet(**data)
|
||||
|
||||
|
||||
def test_normalize_snippet_list_query_args_sorts_indexed_values():
|
||||
query_args = snippets_module.MultiDict(
|
||||
[
|
||||
@ -53,21 +70,19 @@ def test_normalize_snippet_list_query_args_sorts_indexed_values():
|
||||
|
||||
|
||||
def test_list_snippets_returns_pagination(app, monkeypatch):
|
||||
user = SimpleNamespace(id="account-1")
|
||||
snippets = [SimpleNamespace(id="snippet-1")]
|
||||
snippets = [_snippet()]
|
||||
tag_id = "11111111-1111-1111-1111-111111111111"
|
||||
get_snippets = Mock(return_value=(snippets, 1, False))
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippets", get_snippets)
|
||||
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value=[{"id": "snippet-1"}]))
|
||||
|
||||
api = snippets_module.CustomizedSnippetsApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context(
|
||||
f"/workspaces/current/customized-snippets?page=2&limit=10&tag_ids[0]={tag_id}&creator_ids[0]=account-2"
|
||||
):
|
||||
response, status_code = handler(api)
|
||||
response, status_code = handler(api, "tenant-1")
|
||||
|
||||
assert status_code == 200
|
||||
assert response == {
|
||||
@ -89,10 +104,9 @@ def test_list_snippets_returns_pagination(app, monkeypatch):
|
||||
|
||||
|
||||
def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypatch):
|
||||
user = SimpleNamespace(id="account-1")
|
||||
snippet = SimpleNamespace(id="snippet-1")
|
||||
user = _account("account-1")
|
||||
snippet = _snippet()
|
||||
create_snippet = Mock(return_value=snippet)
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "create_snippet", create_snippet)
|
||||
monkeypatch.setattr(
|
||||
snippets_module.CreateSnippetPayload,
|
||||
@ -111,14 +125,14 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypat
|
||||
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"}))
|
||||
|
||||
api = snippets_module.CustomizedSnippetsApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/customized-snippets",
|
||||
method="POST",
|
||||
json={"name": "Snippet", "type": "node", "description": "Description"},
|
||||
):
|
||||
response, status_code = handler(api)
|
||||
response, status_code = handler(api, "tenant-1", user)
|
||||
|
||||
assert status_code == 201
|
||||
assert response == {"id": "snippet-1"}
|
||||
@ -126,13 +140,12 @@ def test_create_snippet_defaults_unknown_type_and_returns_created(app, monkeypat
|
||||
|
||||
|
||||
def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch):
|
||||
user = SimpleNamespace(id="account-1")
|
||||
user = _account("account-1")
|
||||
create_snippet = Mock()
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "create_snippet", create_snippet)
|
||||
|
||||
api = snippets_module.CustomizedSnippetsApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/customized-snippets",
|
||||
@ -148,7 +161,7 @@ def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch):
|
||||
},
|
||||
},
|
||||
):
|
||||
response, status_code = handler(api)
|
||||
response, status_code = handler(api, "tenant-1", user)
|
||||
|
||||
assert status_code == 400
|
||||
assert "knowledge-retrieval" in response["message"]
|
||||
@ -156,60 +169,54 @@ def test_create_snippet_rejects_forbidden_nodes(app, monkeypatch):
|
||||
|
||||
|
||||
def test_get_snippet_detail_raises_when_missing(app, monkeypatch):
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
|
||||
|
||||
api = snippets_module.CustomizedSnippetDetailApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/workspaces/current/customized-snippets/snippet-1"):
|
||||
with pytest.raises(NotFound, match="Snippet not found"):
|
||||
handler(api, snippet_id="snippet-1")
|
||||
handler(api, "tenant-1", snippet_id="snippet-1")
|
||||
|
||||
|
||||
def test_get_snippet_detail_returns_snippet(app, monkeypatch):
|
||||
snippet = SimpleNamespace(id="snippet-1")
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
|
||||
snippet = _snippet()
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
|
||||
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1"}))
|
||||
|
||||
api = snippets_module.CustomizedSnippetDetailApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/workspaces/current/customized-snippets/snippet-1"):
|
||||
response, status_code = handler(api, snippet_id="snippet-1")
|
||||
response, status_code = handler(api, "tenant-1", snippet_id="snippet-1")
|
||||
|
||||
assert status_code == 200
|
||||
assert response == {"id": "snippet-1"}
|
||||
|
||||
|
||||
def test_patch_snippet_returns_400_for_empty_payload(app, monkeypatch):
|
||||
snippet = SimpleNamespace(id="snippet-1")
|
||||
monkeypatch.setattr(
|
||||
snippets_module,
|
||||
"current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
|
||||
)
|
||||
snippet = _snippet()
|
||||
user = _account("user-1")
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
|
||||
|
||||
api = snippets_module.CustomizedSnippetDetailApi()
|
||||
handler = _unwrap(api.patch)
|
||||
handler = unwrap(api.patch)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/customized-snippets/snippet-1",
|
||||
method="PATCH",
|
||||
json={},
|
||||
):
|
||||
response, status_code = handler(api, snippet_id="snippet-1")
|
||||
response, status_code = handler(api, "tenant-1", user, snippet_id="snippet-1")
|
||||
|
||||
assert status_code == 400
|
||||
assert response == {"message": "No valid fields to update"}
|
||||
|
||||
|
||||
def test_patch_snippet_updates_and_commits(app, monkeypatch):
|
||||
user = SimpleNamespace(id="account-1")
|
||||
snippet = SimpleNamespace(id="snippet-1")
|
||||
updated_snippet = SimpleNamespace(id="snippet-1", name="New")
|
||||
user = _account("account-1")
|
||||
snippet = _snippet()
|
||||
updated_snippet = _snippet(name="New")
|
||||
session = SimpleNamespace(merge=Mock(return_value=snippet), commit=Mock())
|
||||
update_snippet = Mock(return_value=updated_snippet)
|
||||
|
||||
@ -217,7 +224,6 @@ def test_patch_snippet_updates_and_commits(app, monkeypatch):
|
||||
def __init__(self, engine, *args, **kwargs):
|
||||
super().__init__(engine, *args, session=session, **kwargs)
|
||||
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "update_snippet", update_snippet)
|
||||
monkeypatch.setattr(snippets_module, "Session", SessionContext)
|
||||
@ -225,14 +231,14 @@ def test_patch_snippet_updates_and_commits(app, monkeypatch):
|
||||
monkeypatch.setattr(snippets_module, "marshal", Mock(return_value={"id": "snippet-1", "name": "New"}))
|
||||
|
||||
api = snippets_module.CustomizedSnippetDetailApi()
|
||||
handler = _unwrap(api.patch)
|
||||
handler = unwrap(api.patch)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/customized-snippets/snippet-1",
|
||||
method="PATCH",
|
||||
json={"name": "New", "icon_info": {"icon": "star"}},
|
||||
):
|
||||
response, status_code = handler(api, snippet_id="snippet-1")
|
||||
response, status_code = handler(api, "tenant-1", user, snippet_id="snippet-1")
|
||||
|
||||
assert status_code == 200
|
||||
assert response == {"id": "snippet-1", "name": "New"}
|
||||
@ -245,7 +251,7 @@ def test_patch_snippet_updates_and_commits(app, monkeypatch):
|
||||
|
||||
|
||||
def test_delete_snippet_deletes_and_commits(app, monkeypatch):
|
||||
snippet = SimpleNamespace(id="snippet-1")
|
||||
snippet = _snippet()
|
||||
session = SimpleNamespace(merge=Mock(return_value=snippet), commit=Mock())
|
||||
delete_snippet = Mock()
|
||||
|
||||
@ -253,17 +259,16 @@ def test_delete_snippet_deletes_and_commits(app, monkeypatch):
|
||||
def __init__(self, engine, *args, **kwargs):
|
||||
super().__init__(engine, *args, session=session, **kwargs)
|
||||
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "delete_snippet", delete_snippet)
|
||||
monkeypatch.setattr(snippets_module, "Session", SessionContext)
|
||||
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = snippets_module.CustomizedSnippetDetailApi()
|
||||
handler = _unwrap(api.delete)
|
||||
handler = unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/workspaces/current/customized-snippets/snippet-1", method="DELETE"):
|
||||
response, status_code = handler(api, snippet_id="snippet-1")
|
||||
response, status_code = handler(api, "tenant-1", snippet_id="snippet-1")
|
||||
|
||||
assert status_code == 204
|
||||
assert response == ""
|
||||
@ -272,7 +277,7 @@ def test_delete_snippet_deletes_and_commits(app, monkeypatch):
|
||||
|
||||
|
||||
def test_export_snippet_returns_yaml_attachment(app, monkeypatch):
|
||||
snippet = SimpleNamespace(id="snippet-1", name="Snippet One")
|
||||
snippet = _snippet(name="Snippet One")
|
||||
export_snippet_dsl = Mock(return_value="version: 0.1.0\nkind: snippet\n")
|
||||
session = SimpleNamespace()
|
||||
|
||||
@ -280,7 +285,6 @@ def test_export_snippet_returns_yaml_attachment(app, monkeypatch):
|
||||
def __init__(self, engine, *args, **kwargs):
|
||||
super().__init__(engine, *args, session=session, **kwargs)
|
||||
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
|
||||
monkeypatch.setattr(
|
||||
snippets_module,
|
||||
@ -291,10 +295,10 @@ def test_export_snippet_returns_yaml_attachment(app, monkeypatch):
|
||||
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = snippets_module.CustomizedSnippetExportApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/workspaces/current/customized-snippets/snippet-1/export?include_secret=true"):
|
||||
response = handler(api, snippet_id="snippet-1")
|
||||
response = handler(api, "tenant-1", snippet_id="snippet-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data(as_text=True) == "version: 0.1.0\nkind: snippet\n"
|
||||
@ -304,7 +308,7 @@ def test_export_snippet_returns_yaml_attachment(app, monkeypatch):
|
||||
|
||||
|
||||
def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch):
|
||||
user = SimpleNamespace(id="account-1")
|
||||
user = _account("account-1")
|
||||
result = SnippetImportInfo(id="import-1", status=ImportStatus.PENDING, imported_dsl_version="999.0.0")
|
||||
import_snippet = Mock(return_value=result)
|
||||
session = SimpleNamespace(commit=Mock())
|
||||
@ -319,7 +323,6 @@ def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch):
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module, "Session", _SessionContext)
|
||||
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
@ -329,14 +332,14 @@ def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch):
|
||||
)
|
||||
|
||||
api = snippets_module.CustomizedSnippetImportApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/customized-snippets/imports",
|
||||
method="POST",
|
||||
json={"mode": "yaml-content", "yaml_content": "kind: snippet"},
|
||||
):
|
||||
response, status_code = handler(api)
|
||||
response, status_code = handler(api, user)
|
||||
|
||||
assert status_code == 202
|
||||
assert response["status"] == ImportStatus.PENDING.value
|
||||
@ -345,7 +348,7 @@ def test_import_snippet_returns_202_for_pending_confirmation(app, monkeypatch):
|
||||
|
||||
|
||||
def test_import_snippet_returns_400_for_failed_import(app, monkeypatch):
|
||||
user = SimpleNamespace(id="account-1")
|
||||
user = _account("account-1")
|
||||
result = SnippetImportInfo(id="import-1", status=ImportStatus.FAILED, error="Invalid DSL")
|
||||
import_snippet = Mock(return_value=result)
|
||||
session = SimpleNamespace(commit=Mock())
|
||||
@ -354,7 +357,6 @@ def test_import_snippet_returns_400_for_failed_import(app, monkeypatch):
|
||||
def __init__(self, engine, *args, **kwargs):
|
||||
super().__init__(engine, *args, session=session, **kwargs)
|
||||
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module, "Session", SessionContext)
|
||||
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
@ -364,14 +366,14 @@ def test_import_snippet_returns_400_for_failed_import(app, monkeypatch):
|
||||
)
|
||||
|
||||
api = snippets_module.CustomizedSnippetImportApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/customized-snippets/imports",
|
||||
method="POST",
|
||||
json={"mode": "yaml-content", "yaml_content": "kind: snippet"},
|
||||
):
|
||||
response, status_code = handler(api)
|
||||
response, status_code = handler(api, user)
|
||||
|
||||
assert status_code == 400
|
||||
assert response["error"] == "Invalid DSL"
|
||||
@ -379,7 +381,7 @@ def test_import_snippet_returns_400_for_failed_import(app, monkeypatch):
|
||||
|
||||
|
||||
def test_import_confirm_returns_200_for_completed_import(app, monkeypatch):
|
||||
user = SimpleNamespace(id="account-1")
|
||||
user = _account("account-1")
|
||||
result = SnippetImportInfo(id="import-1", status=ImportStatus.COMPLETED, snippet_id="snippet-1")
|
||||
confirm_import = Mock(return_value=result)
|
||||
session = SimpleNamespace(commit=Mock())
|
||||
@ -388,7 +390,6 @@ def test_import_confirm_returns_200_for_completed_import(app, monkeypatch):
|
||||
def __init__(self, engine, *args, **kwargs):
|
||||
super().__init__(engine, *args, session=session, **kwargs)
|
||||
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (user, "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module, "Session", SessionContext)
|
||||
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
@ -398,13 +399,13 @@ def test_import_confirm_returns_200_for_completed_import(app, monkeypatch):
|
||||
)
|
||||
|
||||
api = snippets_module.CustomizedSnippetImportConfirmApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/customized-snippets/imports/import-1/confirm",
|
||||
method="POST",
|
||||
):
|
||||
response, status_code = handler(api, import_id="import-1")
|
||||
response, status_code = handler(api, user, import_id="import-1")
|
||||
|
||||
assert status_code == 200
|
||||
assert response["snippet_id"] == "snippet-1"
|
||||
@ -413,19 +414,18 @@ def test_import_confirm_returns_200_for_completed_import(app, monkeypatch):
|
||||
|
||||
|
||||
def test_check_dependencies_raises_when_snippet_missing(app, monkeypatch):
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
|
||||
|
||||
api = snippets_module.CustomizedSnippetCheckDependenciesApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/workspaces/current/customized-snippets/snippet-1/check-dependencies"):
|
||||
with pytest.raises(NotFound, match="Snippet not found"):
|
||||
handler(api, snippet_id="snippet-1")
|
||||
handler(api, "tenant-1", snippet_id="snippet-1")
|
||||
|
||||
|
||||
def test_check_dependencies_returns_dependency_result(app, monkeypatch):
|
||||
snippet = SimpleNamespace(id="snippet-1")
|
||||
snippet = _snippet()
|
||||
check_dependencies = Mock(
|
||||
return_value=SimpleNamespace(model_dump=Mock(return_value={"dependencies": [], "missing_dependencies": []}))
|
||||
)
|
||||
@ -435,7 +435,6 @@ def test_check_dependencies_returns_dependency_result(app, monkeypatch):
|
||||
def __init__(self, engine, *args, **kwargs):
|
||||
super().__init__(engine, *args, session=session, **kwargs)
|
||||
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
|
||||
monkeypatch.setattr(snippets_module, "Session", SessionContext)
|
||||
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
|
||||
@ -446,10 +445,10 @@ def test_check_dependencies_returns_dependency_result(app, monkeypatch):
|
||||
)
|
||||
|
||||
api = snippets_module.CustomizedSnippetCheckDependenciesApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/workspaces/current/customized-snippets/snippet-1/check-dependencies"):
|
||||
response, status_code = handler(api, snippet_id="snippet-1")
|
||||
response, status_code = handler(api, "tenant-1", snippet_id="snippet-1")
|
||||
|
||||
assert status_code == 200
|
||||
assert response == {"dependencies": [], "missing_dependencies": []}
|
||||
@ -457,18 +456,17 @@ def test_check_dependencies_returns_dependency_result(app, monkeypatch):
|
||||
|
||||
|
||||
def test_increment_use_count_raises_when_snippet_missing(app, monkeypatch):
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=None))
|
||||
|
||||
api = snippets_module.CustomizedSnippetUseCountIncrementApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/customized-snippets/snippet-1/use-count/increment",
|
||||
method="POST",
|
||||
):
|
||||
with pytest.raises(NotFound, match="Snippet not found"):
|
||||
handler(api, snippet_id="snippet-1")
|
||||
handler(api, "tenant-1", snippet_id="snippet-1")
|
||||
|
||||
|
||||
def test_increment_use_count_returns_refreshed_count(app, monkeypatch):
|
||||
@ -487,20 +485,19 @@ def test_increment_use_count_returns_refreshed_count(app, monkeypatch):
|
||||
return False
|
||||
|
||||
increment_use_count = Mock()
|
||||
monkeypatch.setattr(snippets_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "get_snippet_by_id", Mock(return_value=snippet))
|
||||
monkeypatch.setattr(snippets_module.SnippetService, "increment_use_count", increment_use_count)
|
||||
monkeypatch.setattr(snippets_module, "Session", _SessionContext)
|
||||
monkeypatch.setattr(snippets_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = snippets_module.CustomizedSnippetUseCountIncrementApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/customized-snippets/snippet-1/use-count/increment",
|
||||
method="POST",
|
||||
):
|
||||
response, status_code = handler(api, snippet_id="snippet-1")
|
||||
response, status_code = handler(api, "tenant-1", snippet_id="snippet-1")
|
||||
|
||||
assert status_code == 200
|
||||
assert response == {"result": "success", "use_count": 3}
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import builtins
|
||||
import importlib
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from inspect import unwrap
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -12,12 +13,14 @@ import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from models import Account
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
_CONTROLLER_MODULE: ModuleType | None = None
|
||||
_WRAPS_MODULE: ModuleType | None = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -69,9 +72,7 @@ def controller_module(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload)
|
||||
|
||||
# Ensure decorators that consult deployment edition do not reach the database.
|
||||
global _WRAPS_MODULE
|
||||
wraps_module = importlib.import_module("controllers.console.wraps")
|
||||
_WRAPS_MODULE = wraps_module
|
||||
monkeypatch.setattr(module.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.setattr(wraps_module.dify_config, "EDITION", "CLOUD")
|
||||
|
||||
@ -80,46 +81,24 @@ def controller_module(monkeypatch: pytest.MonkeyPatch):
|
||||
return module
|
||||
|
||||
|
||||
def _mock_account(user_id: str = "user-123") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=user_id,
|
||||
status="active",
|
||||
is_authenticated=True,
|
||||
current_tenant_id=None,
|
||||
is_admin_or_owner=False,
|
||||
)
|
||||
|
||||
|
||||
def _set_current_account(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
controller_module: ModuleType,
|
||||
user: SimpleNamespace,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
def _getter():
|
||||
return user, tenant_id
|
||||
|
||||
user.current_tenant_id = tenant_id
|
||||
|
||||
monkeypatch.setattr(controller_module, "current_account_with_tenant", _getter)
|
||||
if _WRAPS_MODULE is not None:
|
||||
monkeypatch.setattr(_WRAPS_MODULE, "current_account_with_tenant", _getter)
|
||||
|
||||
login_module = importlib.import_module("libs.login")
|
||||
monkeypatch.setattr(login_module, "_get_user", lambda: user)
|
||||
def _mock_account(user_id: str = "user-123") -> Account:
|
||||
user = Account(name="Test User", email=f"{user_id}@example.com")
|
||||
user.id = user_id
|
||||
user.role = TenantAccountRole.NORMAL
|
||||
return user
|
||||
|
||||
|
||||
def test_tool_provider_list_calls_service_with_query(
|
||||
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
|
||||
|
||||
service_mock = MagicMock(return_value=[{"provider": "builtin"}])
|
||||
monkeypatch.setattr(controller_module.ToolCommonService, "list_tool_providers", service_mock)
|
||||
|
||||
with app.test_request_context("/workspaces/current/tool-providers?type=builtin"):
|
||||
response = controller_module.ToolProviderListApi().get()
|
||||
api = controller_module.ToolProviderListApi()
|
||||
response = unwrap(api.get)(api, "tenant-456", user)
|
||||
|
||||
assert response == [{"provider": "builtin"}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-456", "builtin")
|
||||
@ -129,7 +108,6 @@ def test_builtin_provider_add_passes_payload(
|
||||
app: Flask, controller_module: ModuleType, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-456")
|
||||
|
||||
service_mock = MagicMock(return_value={"status": "ok"})
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "add_builtin_tool_provider", service_mock)
|
||||
@ -145,7 +123,8 @@ def test_builtin_provider_add_passes_payload(
|
||||
method="POST",
|
||||
json=payload,
|
||||
):
|
||||
response = controller_module.ToolBuiltinProviderAddApi().post(provider="openai")
|
||||
api = controller_module.ToolBuiltinProviderAddApi()
|
||||
response = unwrap(api.post)(api, "tenant-456", user, provider="openai")
|
||||
|
||||
assert response == {"status": "ok"}
|
||||
service_mock.assert_called_once_with(
|
||||
@ -161,7 +140,6 @@ def test_builtin_provider_add_passes_payload(
|
||||
|
||||
def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-789")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-789")
|
||||
|
||||
service_mock = MagicMock(return_value=[{"name": "tool-a"}])
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "list_builtin_tool_provider_tools", service_mock)
|
||||
@ -171,7 +149,8 @@ def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch:
|
||||
"/workspaces/current/tool-provider/builtin/my-provider/tools",
|
||||
method="GET",
|
||||
):
|
||||
response = controller_module.ToolBuiltinProviderListToolsApi().get(provider="my-provider")
|
||||
api = controller_module.ToolBuiltinProviderListToolsApi()
|
||||
response = unwrap(api.get)(api, "tenant-789", provider="my-provider")
|
||||
|
||||
assert response == [{"name": "tool-a"}]
|
||||
service_mock.assert_called_once_with("tenant-789", "my-provider")
|
||||
@ -179,12 +158,12 @@ def test_builtin_provider_tools_get(app: Flask, controller_module, monkeypatch:
|
||||
|
||||
def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-9")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-9")
|
||||
service_mock = MagicMock(return_value={"info": True})
|
||||
monkeypatch.setattr(controller_module.BuiltinToolManageService, "get_builtin_tool_provider_info", service_mock)
|
||||
|
||||
with app.test_request_context("/info", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderInfoApi().get(provider="demo")
|
||||
api = controller_module.ToolBuiltinProviderInfoApi()
|
||||
resp = unwrap(api.get)(api, "tenant-9", provider="demo")
|
||||
|
||||
assert resp == {"info": True}
|
||||
service_mock.assert_called_once_with("tenant-9", "demo")
|
||||
@ -192,7 +171,6 @@ def test_builtin_provider_info_get(app: Flask, controller_module, monkeypatch: p
|
||||
|
||||
def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-cred")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-cred")
|
||||
service_mock = MagicMock(return_value=[{"cred": 1}])
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
@ -201,7 +179,8 @@ def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeyp
|
||||
)
|
||||
|
||||
with app.test_request_context("/creds", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderGetCredentialsApi().get(provider="demo")
|
||||
api = controller_module.ToolBuiltinProviderGetCredentialsApi()
|
||||
resp = unwrap(api.get)(api, "tenant-cred", user, provider="demo")
|
||||
|
||||
assert resp == [{"cred": 1}]
|
||||
service_mock.assert_called_once_with(
|
||||
@ -214,12 +193,12 @@ def test_builtin_provider_credentials_get(app: Flask, controller_module, monkeyp
|
||||
|
||||
def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-10")
|
||||
service_mock = MagicMock(return_value={"schema": "ok"})
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider_remote_schema", service_mock)
|
||||
|
||||
with app.test_request_context("/remote?url=https://example.com/"):
|
||||
resp = controller_module.ToolApiProviderGetRemoteSchemaApi().get()
|
||||
api = controller_module.ToolApiProviderGetRemoteSchemaApi()
|
||||
resp = unwrap(api.get)(api, "tenant-10", user)
|
||||
|
||||
assert resp == {"schema": "ok"}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-10", "https://example.com/")
|
||||
@ -227,12 +206,12 @@ def test_api_provider_remote_schema_get(app: Flask, controller_module, monkeypat
|
||||
|
||||
def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-11")
|
||||
service_mock = MagicMock(return_value=[{"tool": "t"}])
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "list_api_tool_provider_tools", service_mock)
|
||||
|
||||
with app.test_request_context("/tools?provider=foo"):
|
||||
resp = controller_module.ToolApiProviderListToolsApi().get()
|
||||
api = controller_module.ToolApiProviderListToolsApi()
|
||||
resp = unwrap(api.get)(api, "tenant-11", user)
|
||||
|
||||
assert resp == [{"tool": "t"}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-11", "foo")
|
||||
@ -240,12 +219,12 @@ def test_api_provider_list_tools_get(app: Flask, controller_module, monkeypatch:
|
||||
|
||||
def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-12")
|
||||
service_mock = MagicMock(return_value={"provider": "foo"})
|
||||
monkeypatch.setattr(controller_module.ApiToolManageService, "get_api_tool_provider", service_mock)
|
||||
|
||||
with app.test_request_context("/get?provider=foo"):
|
||||
resp = controller_module.ToolApiProviderGetApi().get()
|
||||
api = controller_module.ToolApiProviderGetApi()
|
||||
resp = unwrap(api.get)(api, "tenant-12", user)
|
||||
|
||||
assert resp == {"provider": "foo"}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-12", "foo")
|
||||
@ -253,7 +232,6 @@ def test_api_provider_get(app: Flask, controller_module, monkeypatch: pytest.Mon
|
||||
|
||||
def test_builtin_provider_credentials_schema_get(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-13")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-13")
|
||||
service_mock = MagicMock(return_value={"schema": True})
|
||||
monkeypatch.setattr(
|
||||
controller_module.BuiltinToolManageService,
|
||||
@ -262,9 +240,8 @@ def test_builtin_provider_credentials_schema_get(app: Flask, controller_module,
|
||||
)
|
||||
|
||||
with app.test_request_context("/schema", method="GET"):
|
||||
resp = controller_module.ToolBuiltinProviderCredentialsSchemaApi().get(
|
||||
provider="demo", credential_type="api-key"
|
||||
)
|
||||
api = controller_module.ToolBuiltinProviderCredentialsSchemaApi()
|
||||
resp = unwrap(api.get)(api, "tenant-13", provider="demo", credential_type="api-key")
|
||||
|
||||
assert resp == {"schema": True}
|
||||
service_mock.assert_called_once()
|
||||
@ -272,7 +249,6 @@ def test_builtin_provider_credentials_schema_get(app: Flask, controller_module,
|
||||
|
||||
def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf")
|
||||
tool_service = MagicMock(return_value={"wf": 1})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
@ -282,7 +258,8 @@ def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatc
|
||||
|
||||
tool_id = "00000000-0000-0000-0000-000000000001"
|
||||
with app.test_request_context(f"/workflow?workflow_tool_id={tool_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderGetApi().get()
|
||||
api = controller_module.ToolWorkflowProviderGetApi()
|
||||
resp = unwrap(api.get)(api, "tenant-wf", user)
|
||||
|
||||
assert resp == {"wf": 1}
|
||||
tool_service.assert_called_once_with(user.id, "tenant-wf", tool_id)
|
||||
@ -290,7 +267,6 @@ def test_workflow_provider_get_by_tool(app: Flask, controller_module, monkeypatc
|
||||
|
||||
def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf2")
|
||||
service_mock = MagicMock(return_value={"app": 1})
|
||||
monkeypatch.setattr(
|
||||
controller_module.WorkflowToolManageService,
|
||||
@ -300,7 +276,8 @@ def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch
|
||||
|
||||
app_id = "00000000-0000-0000-0000-000000000002"
|
||||
with app.test_request_context(f"/workflow?workflow_app_id={app_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderGetApi().get()
|
||||
api = controller_module.ToolWorkflowProviderGetApi()
|
||||
resp = unwrap(api.get)(api, "tenant-wf2", user)
|
||||
|
||||
assert resp == {"app": 1}
|
||||
service_mock.assert_called_once_with(user.id, "tenant-wf2", app_id)
|
||||
@ -308,13 +285,13 @@ def test_workflow_provider_get_by_app(app: Flask, controller_module, monkeypatch
|
||||
|
||||
def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf3")
|
||||
service_mock = MagicMock(return_value=[{"id": 1}])
|
||||
monkeypatch.setattr(controller_module.WorkflowToolManageService, "list_single_workflow_tools", service_mock)
|
||||
|
||||
tool_id = "00000000-0000-0000-0000-000000000003"
|
||||
with app.test_request_context(f"/workflow/tools?workflow_tool_id={tool_id}"):
|
||||
resp = controller_module.ToolWorkflowProviderListToolApi().get()
|
||||
api = controller_module.ToolWorkflowProviderListToolApi()
|
||||
resp = unwrap(api.get)(api, "tenant-wf3", user)
|
||||
|
||||
assert resp == [{"id": 1}]
|
||||
service_mock.assert_called_once_with(user.id, "tenant-wf3", tool_id)
|
||||
@ -322,7 +299,6 @@ def test_workflow_provider_list_tools(app: Flask, controller_module, monkeypatch
|
||||
|
||||
def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-bt")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "builtin"})
|
||||
monkeypatch.setattr(
|
||||
@ -332,14 +308,14 @@ def test_builtin_tools_list(app: Flask, controller_module, monkeypatch: pytest.M
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/builtin"):
|
||||
resp = controller_module.ToolBuiltinListApi().get()
|
||||
api = controller_module.ToolBuiltinListApi()
|
||||
resp = unwrap(api.get)(api, "tenant-bt", user)
|
||||
|
||||
assert resp == [{"name": "builtin"}]
|
||||
|
||||
|
||||
def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-tenant-api")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-api")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "api"})
|
||||
monkeypatch.setattr(
|
||||
@ -349,14 +325,14 @@ def test_api_tools_list(app: Flask, controller_module, monkeypatch: pytest.Monke
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/api"):
|
||||
resp = controller_module.ToolApiListApi().get()
|
||||
api = controller_module.ToolApiListApi()
|
||||
resp = unwrap(api.get)(api, "tenant-api")
|
||||
|
||||
assert resp == [{"name": "api"}]
|
||||
|
||||
|
||||
def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account()
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-wf4")
|
||||
|
||||
provider = SimpleNamespace(to_dict=lambda: {"name": "wf"})
|
||||
monkeypatch.setattr(
|
||||
@ -366,18 +342,18 @@ def test_workflow_tools_list(app: Flask, controller_module, monkeypatch: pytest.
|
||||
)
|
||||
|
||||
with app.test_request_context("/tools/workflow"):
|
||||
resp = controller_module.ToolWorkflowListApi().get()
|
||||
api = controller_module.ToolWorkflowListApi()
|
||||
resp = unwrap(api.get)(api, "tenant-wf4", user)
|
||||
|
||||
assert resp == [{"name": "wf"}]
|
||||
|
||||
|
||||
def test_tool_labels_list(app: Flask, controller_module, monkeypatch: pytest.MonkeyPatch):
|
||||
user = _mock_account("user-label")
|
||||
_set_current_account(monkeypatch, controller_module, user, "tenant-labels")
|
||||
monkeypatch.setattr(controller_module.ToolLabelsService, "list_tool_labels", lambda: ["a", "b"])
|
||||
|
||||
with app.test_request_context("/tool-labels"):
|
||||
resp = controller_module.ToolLabelsApi().get()
|
||||
api = controller_module.ToolLabelsApi()
|
||||
resp = unwrap(api.get)(api)
|
||||
|
||||
assert resp == ["a", "b"]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user