diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 78643ad076..e5eabf6c49 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -23,6 +23,7 @@ from controllers.console.wraps import ( account_initialization_required, edit_permission_required, setup_required, + with_current_user, with_current_user_id, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError @@ -36,7 +37,7 @@ from core.helper.trace_id_helper import get_external_trace_id from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value -from libs.login import current_user, login_required +from libs.login import login_required from models import Account from models.model import App, AppMode from services.app_generate_service import AppGenerateService @@ -104,7 +105,8 @@ class CompletionMessageApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) - def post(self, app_model: App): + @with_current_user + def post(self, current_user: Account, app_model: App): args_model = CompletionMessagePayload.model_validate(console_ns.payload) args = args_model.model_dump(exclude_none=True, by_alias=True) @@ -112,8 +114,6 @@ class CompletionMessageApi(Resource): args["auto_generate_name"] = False try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account or EndUser instance") response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) @@ -178,7 +178,8 @@ class ChatMessageApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT]) @edit_permission_required - def post(self, app_model: App): + @with_current_user + def post(self, current_user: Account, app_model: App): raw_payload = console_ns.payload or {} args_model = ChatMessagePayload.model_validate(raw_payload) args = args_model.model_dump(exclude_none=True, by_alias=True) @@ -197,8 +198,6 @@ class ChatMessageApi(Resource): args["external_trace_id"] = external_trace_id try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account or EndUser instance") response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py index 27009953b7..3cecd45054 100644 --- a/api/controllers/console/app/workflow_comment.py +++ b/api/controllers/console/app/workflow_comment.py @@ -7,12 +7,18 @@ from pydantic import BaseModel, Field, TypeAdapter, computed_field, field_valida from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_tenant_id, + with_current_user, +) from fields.base import ResponseModel from fields.member_fields import AccountWithRole from libs.helper import build_avatar_url, dump_response, to_timestamp -from libs.login import current_user, login_required -from models import App +from libs.login import login_required +from models import Account, App from services.account_service import TenantService from services.workflow_comment_service import WorkflowCommentService @@ -213,9 +219,10 @@ class WorkflowCommentListApi(Resource): @setup_required @account_initialization_required @get_app_model() - def get(self, app_model: App): + @with_current_tenant_id + def get(self, current_tenant_id: str, app_model: App): """Get all comments for a workflow.""" - comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id) + comments = WorkflowCommentService.get_comments(tenant_id=current_tenant_id, app_id=app_model.id) return WorkflowCommentBasicList.model_validate({"data": comments}).model_dump(mode="json") @@ -229,12 +236,14 @@ class WorkflowCommentListApi(Resource): @account_initialization_required @get_app_model() @edit_permission_required - def post(self, app_model: App): + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account, app_model: App): """Create a new workflow comment.""" payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {}) result = WorkflowCommentService.create_comment( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, app_id=app_model.id, created_by=current_user.id, content=payload.content, @@ -258,10 +267,11 @@ class WorkflowCommentDetailApi(Resource): @setup_required @account_initialization_required @get_app_model() - def get(self, app_model: App, comment_id: str): + @with_current_tenant_id + def get(self, current_tenant_id: str, app_model: App, comment_id: str): """Get a specific workflow comment.""" comment = WorkflowCommentService.get_comment( - tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id + tenant_id=current_tenant_id, app_id=app_model.id, comment_id=comment_id ) return dump_response(WorkflowCommentDetail, comment) @@ -276,12 +286,14 @@ class WorkflowCommentDetailApi(Resource): @account_initialization_required @get_app_model() @edit_permission_required - def put(self, app_model: App, comment_id: str): + @with_current_user + @with_current_tenant_id + def put(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str): """Update a workflow comment.""" payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {}) result = WorkflowCommentService.update_comment( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, app_id=app_model.id, comment_id=comment_id, user_id=current_user.id, @@ -302,10 +314,12 @@ class WorkflowCommentDetailApi(Resource): @account_initialization_required @get_app_model() @edit_permission_required - def delete(self, app_model: App, comment_id: str): + @with_current_user + @with_current_tenant_id + def delete(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str): """Delete a workflow comment.""" WorkflowCommentService.delete_comment( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, app_id=app_model.id, comment_id=comment_id, user_id=current_user.id, @@ -327,10 +341,12 @@ class WorkflowCommentResolveApi(Resource): @account_initialization_required @get_app_model() @edit_permission_required - def post(self, app_model: App, comment_id: str): + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str): """Resolve a workflow comment.""" comment = WorkflowCommentService.resolve_comment( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, app_id=app_model.id, comment_id=comment_id, user_id=current_user.id, @@ -353,11 +369,13 @@ class WorkflowCommentReplyApi(Resource): @account_initialization_required @get_app_model() @edit_permission_required - def post(self, app_model: App, comment_id: str): + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str): """Add a reply to a workflow comment.""" # Validate comment access first WorkflowCommentService.validate_comment_access( - comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + comment_id=comment_id, tenant_id=current_tenant_id, app_id=app_model.id ) payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {}) @@ -386,17 +404,19 @@ class WorkflowCommentReplyDetailApi(Resource): @account_initialization_required @get_app_model() @edit_permission_required - def put(self, app_model: App, comment_id: str, reply_id: str): + @with_current_user + @with_current_tenant_id + def put(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str, reply_id: str): """Update a comment reply.""" # Validate comment access first WorkflowCommentService.validate_comment_access( - comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + comment_id=comment_id, tenant_id=current_tenant_id, app_id=app_model.id ) payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {}) reply = WorkflowCommentService.update_reply( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, app_id=app_model.id, comment_id=comment_id, reply_id=reply_id, @@ -416,15 +436,17 @@ class WorkflowCommentReplyDetailApi(Resource): @account_initialization_required @get_app_model() @edit_permission_required - def delete(self, app_model: App, comment_id: str, reply_id: str): + @with_current_user + @with_current_tenant_id + def delete(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str, reply_id: str): """Delete a comment reply.""" # Validate comment access first WorkflowCommentService.validate_comment_access( - comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + comment_id=comment_id, tenant_id=current_tenant_id, app_id=app_model.id ) WorkflowCommentService.delete_reply( - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, app_id=app_model.id, comment_id=comment_id, reply_id=reply_id, @@ -448,9 +470,13 @@ class WorkflowCommentMentionUsersApi(Resource): @setup_required @account_initialization_required @get_app_model() - def get(self, app_model: App): + @with_current_user + def get(self, current_user: Account, app_model: App): """Get all users in current tenant for mentions.""" - members = TenantService.get_tenant_members(current_user.current_tenant) + current_tenant = current_user.current_tenant # need the tenant object here + if current_tenant is None: + raise ValueError("current tenant is required") + members = TenantService.get_tenant_members(current_tenant) users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) response = WorkflowCommentMentionUsersPayload(users=users) return response.model_dump(mode="json"), 200 diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 4c3cbce832..f6bb2aa008 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -15,7 +15,12 @@ from controllers.console.app.error import ( DraftWorkflowNotExist, ) from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_user, +) from controllers.web.error import InvalidArgumentError, NotFoundError from core.app.file_access import DatabaseFileAccessController from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID @@ -27,8 +32,8 @@ from graphon.file import helpers as file_helpers from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment from graphon.variables.types import SegmentType -from libs.login import current_user, login_required -from models import App, AppMode +from libs.login import login_required +from models import Account, App, AppMode from models.workflow import WorkflowDraftVariable from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService from services.workflow_service import WorkflowService @@ -123,14 +128,15 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict return result -def _ensure_variable_access( +def ensure_variable_access( variable: WorkflowDraftVariable | None, app_id: str, variable_id: str, + current_user_id: str, ) -> WorkflowDraftVariable: if variable is None: raise NotFoundError(description=f"variable not found, id={variable_id}") - if variable.app_id != app_id or variable.user_id != current_user.id: + if variable.app_id != app_id or variable.user_id != current_user_id: raise NotFoundError(description=f"variable not found, id={variable_id}") return variable @@ -215,7 +221,7 @@ workflow_draft_variable_list_model = console_ns.model( def _api_prerequisite[T, **P, R]( - f: Callable[Concatenate[T, P], R], + f: Callable[Concatenate[T, Account, P], R], ) -> Callable[Concatenate[T, P], R | Response]: """Common prerequisites for all draft workflow variable APIs. @@ -232,9 +238,10 @@ def _api_prerequisite[T, **P, R]( @account_initialization_required @edit_permission_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @with_current_user @wraps(f) - def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response: - return f(self, *args, **kwargs) + def wrapper(self: T, current_user: Account, *args: P.args, **kwargs: P.kwargs) -> R | Response: + return f(self, current_user, *args, **kwargs) return wrapper @@ -251,7 +258,7 @@ class WorkflowVariableCollectionApi(Resource): ) @_api_prerequisite @marshal_with(workflow_draft_variable_list_without_value_model) - def get(self, app_model: App): + def get(self, current_user: Account, app_model: App): """ Get draft workflow """ @@ -281,7 +288,7 @@ class WorkflowVariableCollectionApi(Resource): @console_ns.doc(description="Delete all draft workflow variables") @console_ns.response(204, "Workflow variables deleted successfully") @_api_prerequisite - def delete(self, app_model: App): + def delete(self, current_user: Account, app_model: App): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) @@ -315,7 +322,7 @@ class NodeVariableCollectionApi(Resource): @console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model) @_api_prerequisite @marshal_with(workflow_draft_variable_list_model) - def get(self, app_model: App, node_id: str): + def get(self, current_user: Account, app_model: App, node_id: str): validate_node_id(node_id) with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: draft_var_srv = WorkflowDraftVariableService( @@ -329,7 +336,7 @@ class NodeVariableCollectionApi(Resource): @console_ns.doc(description="Delete all variables for a specific node") @console_ns.response(204, "Node variables deleted successfully") @_api_prerequisite - def delete(self, app_model: App, node_id: str): + def delete(self, current_user: Account, app_model: App, node_id: str): validate_node_id(node_id) srv = WorkflowDraftVariableService(db.session()) srv.delete_node_variables(app_model.id, node_id, user_id=current_user.id) @@ -349,15 +356,16 @@ class VariableApi(Resource): @console_ns.response(404, "Variable not found") @_api_prerequisite @marshal_with(workflow_draft_variable_model) - def get(self, app_model: App, variable_id: UUID): + def get(self, current_user: Account, app_model: App, variable_id: UUID): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) variable_id_str = str(variable_id) - variable = _ensure_variable_access( + variable = ensure_variable_access( variable=draft_var_srv.get_variable(variable_id=variable_id_str), app_id=app_model.id, variable_id=variable_id_str, + current_user_id=current_user.id, ) return variable @@ -368,7 +376,7 @@ class VariableApi(Resource): @console_ns.response(404, "Variable not found") @_api_prerequisite @marshal_with(workflow_draft_variable_model) - def patch(self, app_model: App, variable_id: UUID): + def patch(self, current_user: Account, app_model: App, variable_id: UUID): # Request payload for file types: # # Local File: @@ -396,10 +404,11 @@ class VariableApi(Resource): args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {}) variable_id_str = str(variable_id) - variable = _ensure_variable_access( + variable = ensure_variable_access( variable=draft_var_srv.get_variable(variable_id=variable_id_str), app_id=app_model.id, variable_id=variable_id_str, + current_user_id=current_user.id, ) new_name = args_model.name @@ -440,15 +449,16 @@ class VariableApi(Resource): @console_ns.response(204, "Variable deleted successfully") @console_ns.response(404, "Variable not found") @_api_prerequisite - def delete(self, app_model: App, variable_id: UUID): + def delete(self, current_user: Account, app_model: App, variable_id: UUID): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) variable_id_str = str(variable_id) - variable = _ensure_variable_access( + variable = ensure_variable_access( variable=draft_var_srv.get_variable(variable_id=variable_id_str), app_id=app_model.id, variable_id=variable_id_str, + current_user_id=current_user.id, ) draft_var_srv.delete_variable(variable) db.session.commit() @@ -464,7 +474,7 @@ class VariableResetApi(Resource): @console_ns.response(204, "Variable reset (no content)") @console_ns.response(404, "Variable not found") @_api_prerequisite - def put(self, app_model: App, variable_id: UUID): + def put(self, current_user: Account, app_model: App, variable_id: UUID): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) @@ -476,10 +486,11 @@ class VariableResetApi(Resource): f"Draft workflow not found, app_id={app_model.id}", ) variable_id_str = str(variable_id) - variable = _ensure_variable_access( + variable = ensure_variable_access( variable=draft_var_srv.get_variable(variable_id=variable_id_str), app_id=app_model.id, variable_id=variable_id_str, + current_user_id=current_user.id, ) resetted = draft_var_srv.reset_variable(draft_workflow, variable) @@ -490,20 +501,20 @@ class VariableResetApi(Resource): return marshal(resetted, workflow_draft_variable_model) -def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList: +def _get_variable_list(app_model: App, node_id: str, current_user_id: str) -> WorkflowDraftVariableList: with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: draft_var_srv = WorkflowDraftVariableService( session=session, ) if node_id == CONVERSATION_VARIABLE_NODE_ID: - draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user.id) + draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user_id) elif node_id == SYSTEM_VARIABLE_NODE_ID: - draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user.id) + draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user_id) else: draft_vars = draft_var_srv.list_node_variables( app_id=app_model.id, node_id=node_id, - user_id=current_user.id, + user_id=current_user_id, ) return draft_vars @@ -517,7 +528,7 @@ class ConversationVariableCollectionApi(Resource): @console_ns.response(404, "Draft workflow not found") @_api_prerequisite @marshal_with(workflow_draft_variable_list_model) - def get(self, app_model: App): + def get(self, current_user: Account, app_model: App): # NOTE(QuantumGhost): Prefill conversation variables into the draft variables table # so their IDs can be returned to the caller. workflow_srv = WorkflowService() @@ -527,7 +538,7 @@ class ConversationVariableCollectionApi(Resource): draft_var_srv = WorkflowDraftVariableService(db.session()) draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=current_user.id) db.session.commit() - return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) + return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID, current_user.id) @console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__]) @console_ns.doc("update_conversation_variables") @@ -539,7 +550,8 @@ class ConversationVariableCollectionApi(Resource): @account_initialization_required @edit_permission_required @get_app_model(mode=AppMode.ADVANCED_CHAT) - def post(self, app_model: App): + @with_current_user + def post(self, current_user: Account, app_model: App): payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {}) workflow_service = WorkflowService() @@ -566,8 +578,8 @@ class SystemVariableCollectionApi(Resource): @console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model) @_api_prerequisite @marshal_with(workflow_draft_variable_list_model) - def get(self, app_model: App): - return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID) + def get(self, current_user: Account, app_model: App): + return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID, current_user.id) @console_ns.route("/apps//workflows/draft/environment-variables") @@ -578,7 +590,7 @@ class EnvironmentVariableCollectionApi(Resource): @console_ns.response(200, "Environment variables retrieved successfully") @console_ns.response(404, "Draft workflow not found") @_api_prerequisite - def get(self, app_model: App): + def get(self, _current_user: Account, app_model: App): """ Get draft workflow """ @@ -619,7 +631,8 @@ class EnvironmentVariableCollectionApi(Resource): @account_initialization_required @edit_permission_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - def post(self, app_model: App): + @with_current_user + def post(self, current_user: Account, app_model: App): payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {}) workflow_service = WorkflowService() diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 9b46aeacb8..0ae9fb6309 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,5 +1,5 @@ from datetime import UTC, datetime, timedelta -from typing import Literal, cast +from typing import Literal from uuid import UUID from flask import request @@ -12,7 +12,12 @@ from configs import dify_config from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + setup_required, + with_current_tenant_id, + with_current_user, +) from controllers.web.error import NotFoundError from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id from extensions.ext_database import db @@ -30,8 +35,8 @@ from graphon.enums import WorkflowExecutionStatus from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value -from libs.login import current_user, login_required -from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom +from libs.login import login_required +from models import Account, App, AppMode, WorkflowArchiveLog, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME @@ -190,8 +195,8 @@ class WorkflowRunExportApi(Resource): @account_initialization_required @get_app_model() def get(self, app_model: App, run_id: UUID): - tenant_id = str(app_model.tenant_id) - app_id = str(app_model.id) + tenant_id = app_model.tenant_id + app_id = app_model.id run_id_str = str(run_id) run_created_at = db.session.scalar( @@ -397,18 +402,18 @@ class WorkflowRunNodeExecutionListApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - def get(self, app_model: App, run_id: UUID): + @with_current_user + def get(self, current_user: Account, app_model: App, run_id: UUID): """ Get workflow run node execution list """ run_id_str = str(run_id) workflow_run_service = WorkflowRunService() - user = cast("Account | EndUser", current_user) node_executions = workflow_run_service.get_workflow_run_node_executions( app_model=app_model, run_id=run_id_str, - user=user, + user=current_user, ) return WorkflowRunNodeExecutionListResponse.model_validate( @@ -432,7 +437,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, workflow_run_id: str): + @with_current_tenant_id + def get(self, current_tenant_id: str, workflow_run_id: str): """ Get workflow pause details. @@ -449,7 +455,7 @@ class ConsoleWorkflowPauseDetailsApi(Resource): if not workflow_run: raise NotFoundError("Workflow run not found") - if workflow_run.tenant_id != current_user.current_tenant_id: + if workflow_run.tenant_id != current_tenant_id: raise NotFoundError("Workflow run not found") # Check if workflow is suspended diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index a80b4f5d0c..11c8b2ee55 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -12,14 +12,14 @@ from configs import dify_config from controllers.common.schema import register_schema_models from extensions.ext_database import db from fields.base import ResponseModel -from libs.login import current_user, login_required +from libs.login import login_required from models.enums import AppTriggerStatus -from models.model import Account, App, AppMode +from models.model import App, AppMode from models.trigger import AppTrigger, WorkflowWebhookTrigger from .. import console_ns from ..app.wraps import get_app_model -from ..wraps import account_initialization_required, edit_permission_required, setup_required +from ..wraps import account_initialization_required, edit_permission_required, setup_required, with_current_tenant_id logger = logging.getLogger(__name__) @@ -124,18 +124,16 @@ class AppTriggersApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.WORKFLOW) @console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__]) - def get(self, app_model: App): + @with_current_tenant_id + def get(self, current_tenant_id: str, app_model: App): """Get app triggers list""" - assert isinstance(current_user, Account) - assert current_user.current_tenant_id is not None - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: # Get all triggers for this app using select API triggers = ( session.execute( select(AppTrigger) .where( - AppTrigger.tenant_id == current_user.current_tenant_id, + AppTrigger.tenant_id == current_tenant_id, AppTrigger.app_id == app_model.id, ) .order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc()) @@ -166,19 +164,18 @@ class AppTriggerEnableApi(Resource): @edit_permission_required @get_app_model(mode=AppMode.WORKFLOW) @console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__]) - def post(self, app_model: App): + @with_current_tenant_id + def post(self, current_tenant_id: str, app_model: App): """Update app trigger (enable/disable)""" args = ParserEnable.model_validate(console_ns.payload) - assert current_user.current_tenant_id is not None - trigger_id = args.trigger_id with sessionmaker(db.engine, expire_on_commit=False).begin() as session: # Find the trigger using select trigger = session.execute( select(AppTrigger).where( AppTrigger.id == trigger_id, - AppTrigger.tenant_id == current_user.current_tenant_id, + AppTrigger.tenant_id == current_tenant_id, AppTrigger.app_id == app_model.id, ) ).scalar_one_or_none() diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index a43caa8f56..a872eb8861 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -2,13 +2,12 @@ from flask_restx import ( # type: ignore Resource, # type: ignore ) from pydantic import BaseModel -from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.datasets.wraps import get_rag_pipeline -from controllers.console.wraps import account_initialization_required, setup_required -from libs.login import current_user, login_required +from controllers.console.wraps import account_initialization_required, setup_required, with_current_user +from libs.login import login_required from models import Account from models.dataset import Pipeline from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -30,13 +29,11 @@ class DataSourceContentPreviewApi(Resource): @login_required @account_initialization_required @get_rag_pipeline - def post(self, pipeline: Pipeline, node_id: str): + @with_current_user + def post(self, current_user: Account, pipeline: Pipeline, node_id: str): """ Run datasource content preview """ - if not isinstance(current_user, Account): - raise Forbidden() - args = Parser.model_validate(console_ns.payload) inputs = args.inputs diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index f6fc5afc78..4643cfa15c 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -1,5 +1,6 @@ import logging from collections.abc import Callable +from functools import wraps from typing import Any, Concatenate, NoReturn from uuid import UUID @@ -21,7 +22,7 @@ from controllers.console.app.workflow_draft_variable import ( workflow_draft_variable_model, ) from controllers.console.datasets.wraps import get_rag_pipeline -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, setup_required, with_current_user from controllers.web.error import InvalidArgumentError, NotFoundError from core.app.file_access import DatabaseFileAccessController from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID @@ -29,7 +30,7 @@ from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type from graphon.variables.types import SegmentType -from libs.login import current_user, login_required +from libs.login import login_required from models import Account from models.dataset import Pipeline from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -58,7 +59,7 @@ register_schema_models(console_ns, WorkflowDraftVariablePatchPayload) def _api_prerequisite[T, **P, R]( - f: Callable[Concatenate[T, P], R], + f: Callable[Concatenate[T, Account, P], R], ) -> Callable[Concatenate[T, P], R | Response]: """Common prerequisites for all draft workflow variable APIs. @@ -74,10 +75,12 @@ def _api_prerequisite[T, **P, R]( @login_required @account_initialization_required @get_rag_pipeline - def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response: - if not isinstance(current_user, Account) or not current_user.has_edit_permission: + @with_current_user + @wraps(f) + def wrapper(self: T, current_user: Account, *args: P.args, **kwargs: P.kwargs) -> R | Response: + if not current_user.has_edit_permission: raise Forbidden() - return f(self, *args, **kwargs) + return f(self, current_user, *args, **kwargs) return wrapper @@ -86,7 +89,7 @@ def _api_prerequisite[T, **P, R]( class RagPipelineVariableCollectionApi(Resource): @_api_prerequisite @marshal_with(workflow_draft_variable_list_without_value_model) - def get(self, pipeline: Pipeline): + def get(self, current_user: Account, pipeline: Pipeline): """ Get draft workflow """ @@ -114,7 +117,7 @@ class RagPipelineVariableCollectionApi(Resource): return workflow_vars @_api_prerequisite - def delete(self, pipeline: Pipeline): + def delete(self, current_user: Account, pipeline: Pipeline): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) @@ -145,7 +148,7 @@ def validate_node_id(node_id: str) -> NoReturn | None: class RagPipelineNodeVariableCollectionApi(Resource): @_api_prerequisite @marshal_with(workflow_draft_variable_list_model) - def get(self, pipeline: Pipeline, node_id: str): + def get(self, current_user: Account, pipeline: Pipeline, node_id: str): validate_node_id(node_id) with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: draft_var_srv = WorkflowDraftVariableService( @@ -156,7 +159,7 @@ class RagPipelineNodeVariableCollectionApi(Resource): return node_vars @_api_prerequisite - def delete(self, pipeline: Pipeline, node_id: str): + def delete(self, current_user: Account, pipeline: Pipeline, node_id: str): validate_node_id(node_id) srv = WorkflowDraftVariableService(db.session()) srv.delete_node_variables(pipeline.id, node_id, user_id=current_user.id) @@ -171,7 +174,7 @@ class RagPipelineVariableApi(Resource): @_api_prerequisite @marshal_with(workflow_draft_variable_model) - def get(self, pipeline: Pipeline, variable_id: UUID): + def get(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) @@ -186,7 +189,7 @@ class RagPipelineVariableApi(Resource): @_api_prerequisite @marshal_with(workflow_draft_variable_model) @console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__]) - def patch(self, pipeline: Pipeline, variable_id: UUID): + def patch(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID): # Request payload for file types: # # Local File: @@ -255,7 +258,7 @@ class RagPipelineVariableApi(Resource): return variable @_api_prerequisite - def delete(self, pipeline: Pipeline, variable_id: UUID): + def delete(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) @@ -273,7 +276,7 @@ class RagPipelineVariableApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/variables//reset") class RagPipelineVariableResetApi(Resource): @_api_prerequisite - def put(self, pipeline: Pipeline, variable_id: UUID): + def put(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) @@ -299,17 +302,17 @@ class RagPipelineVariableResetApi(Resource): return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS) -def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList: +def _get_variable_list(pipeline: Pipeline, node_id: str, current_user_id: str) -> WorkflowDraftVariableList: with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: draft_var_srv = WorkflowDraftVariableService( session=session, ) if node_id == CONVERSATION_VARIABLE_NODE_ID: - draft_vars = draft_var_srv.list_conversation_variables(pipeline.id, user_id=current_user.id) + draft_vars = draft_var_srv.list_conversation_variables(pipeline.id, user_id=current_user_id) elif node_id == SYSTEM_VARIABLE_NODE_ID: - draft_vars = draft_var_srv.list_system_variables(pipeline.id, user_id=current_user.id) + draft_vars = draft_var_srv.list_system_variables(pipeline.id, user_id=current_user_id) else: - draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id, user_id=current_user.id) + draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id, user_id=current_user_id) return draft_vars @@ -317,14 +320,14 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList class RagPipelineSystemVariableCollectionApi(Resource): @_api_prerequisite @marshal_with(workflow_draft_variable_list_model) - def get(self, pipeline: Pipeline): - return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID) + def get(self, current_user: Account, pipeline: Pipeline): + return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID, current_user.id) @console_ns.route("/rag/pipelines//workflows/draft/environment-variables") class RagPipelineEnvironmentVariableCollectionApi(Resource): @_api_prerequisite - def get(self, pipeline: Pipeline): + def get(self, _current_user: Account, pipeline: Pipeline): """ Get draft workflow """ diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 5821b91489..44f9c1e7e3 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -8,10 +8,11 @@ from pydantic import BaseModel, Field, computed_field, field_validator from constants.languages import languages from controllers.common.schema import query_params_from_model, register_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, with_current_user from fields.base import ResponseModel from libs.helper import build_icon_url -from libs.login import current_user, login_required +from libs.login import login_required +from models import Account from services.recommended_app_service import RecommendedAppService @@ -79,13 +80,14 @@ class RecommendedAppListApi(Resource): @console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__]) @login_required @account_initialization_required - def get(self): + @with_current_user + def get(self, current_user: Account): # language args args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) language = args.language if language and language in languages: language_prefix = language - elif current_user and current_user.interface_language: + elif current_user.interface_language: language_prefix = current_user.interface_language else: language_prefix = languages[0] diff --git a/api/controllers/console/snippets/snippet_workflow_draft_variable.py b/api/controllers/console/snippets/snippet_workflow_draft_variable.py index 2ee69eeac7..323fb0b333 100644 --- a/api/controllers/console/snippets/snippet_workflow_draft_variable.py +++ b/api/controllers/console/snippets/snippet_workflow_draft_variable.py @@ -12,7 +12,7 @@ Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`. from collections.abc import Callable from functools import wraps -from typing import Any +from typing import Any, Concatenate from flask import Response, request from flask_restx import Resource, marshal, marshal_with @@ -23,22 +23,28 @@ from controllers.console.app.error import DraftWorkflowNotExist from controllers.console.app.workflow_draft_variable import ( WorkflowDraftVariableListQuery, WorkflowDraftVariableUpdatePayload, - _ensure_variable_access, - _file_access_controller, + ensure_variable_access, validate_node_id, workflow_draft_variable_list_model, workflow_draft_variable_list_without_value_model, workflow_draft_variable_model, ) from controllers.console.snippets.snippet_workflow import get_snippet -from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_user, +) from controllers.web.error import InvalidArgumentError, NotFoundError +from core.app.file_access import DatabaseFileAccessController from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type from graphon.variables.types import SegmentType -from libs.login import current_user, login_required +from libs.login import login_required +from models import Account from models.snippet import CustomizedSnippet from models.workflow import WorkflowDraftVariable from services.snippet_service import SnippetService @@ -47,6 +53,7 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList, _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset( {SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID} ) +_file_access_controller = DatabaseFileAccessController() def _snippet_service() -> SnippetService: @@ -63,7 +70,9 @@ def _ensure_snippet_draft_variable_row_allowed( raise NotFoundError(description=f"variable not found, id={variable_id}") -def _snippet_draft_var_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]: +def _snippet_draft_var_prerequisite[T, **P, R]( + f: Callable[Concatenate[T, Account, P], R], +) -> Callable[Concatenate[T, P], R | Response]: """Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs).""" @setup_required @@ -71,9 +80,10 @@ def _snippet_draft_var_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R @account_initialization_required @get_snippet @edit_permission_required + @with_current_user @wraps(f) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response: - return f(*args, **kwargs) + def wrapper(self: T, current_user: Account, *args: P.args, **kwargs: P.kwargs) -> R | Response: + return f(self, current_user, *args, **kwargs) return wrapper @@ -90,7 +100,7 @@ class SnippetWorkflowVariableCollectionApi(Resource): ) @_snippet_draft_var_prerequisite @marshal_with(workflow_draft_variable_list_without_value_model) - def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList: + def get(self, current_user: Account, snippet: CustomizedSnippet) -> WorkflowDraftVariableList: args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore snippet_service = _snippet_service() @@ -113,7 +123,7 @@ class SnippetWorkflowVariableCollectionApi(Resource): @console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)") @console_ns.response(204, "Workflow variables deleted successfully") @_snippet_draft_var_prerequisite - def delete(self, snippet: CustomizedSnippet) -> Response: + def delete(self, current_user: Account, snippet: CustomizedSnippet) -> Response: draft_var_srv = WorkflowDraftVariableService(session=db.session()) draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id) db.session.commit() @@ -127,7 +137,7 @@ class SnippetNodeVariableCollectionApi(Resource): @console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model) @_snippet_draft_var_prerequisite @marshal_with(workflow_draft_variable_list_model) - def get(self, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList: + def get(self, current_user: Account, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList: validate_node_id(node_id) with Session(bind=db.engine, expire_on_commit=False) as session: draft_var_srv = WorkflowDraftVariableService(session=session) @@ -139,7 +149,7 @@ class SnippetNodeVariableCollectionApi(Resource): @console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)") @console_ns.response(204, "Node variables deleted successfully") @_snippet_draft_var_prerequisite - def delete(self, snippet: CustomizedSnippet, node_id: str) -> Response: + def delete(self, current_user: Account, snippet: CustomizedSnippet, node_id: str) -> Response: validate_node_id(node_id) srv = WorkflowDraftVariableService(db.session()) srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id) @@ -155,12 +165,13 @@ class SnippetVariableApi(Resource): @console_ns.response(404, "Variable not found") @_snippet_draft_var_prerequisite @marshal_with(workflow_draft_variable_model) - def get(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable: + def get(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable: draft_var_srv = WorkflowDraftVariableService(session=db.session()) - variable = _ensure_variable_access( + variable = ensure_variable_access( variable=draft_var_srv.get_variable(variable_id=variable_id), app_id=snippet.id, variable_id=variable_id, + current_user_id=current_user.id, ) _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id) return variable @@ -172,14 +183,15 @@ class SnippetVariableApi(Resource): @console_ns.response(404, "Variable not found") @_snippet_draft_var_prerequisite @marshal_with(workflow_draft_variable_model) - def patch(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable: + def patch(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable: draft_var_srv = WorkflowDraftVariableService(session=db.session()) args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {}) - variable = _ensure_variable_access( + variable = ensure_variable_access( variable=draft_var_srv.get_variable(variable_id=variable_id), app_id=snippet.id, variable_id=variable_id, + current_user_id=current_user.id, ) _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id) @@ -193,21 +205,21 @@ class SnippetVariableApi(Resource): if variable.value_type == SegmentType.FILE: if not isinstance(raw_value, dict): raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping( - mapping=raw_value, - tenant_id=snippet.tenant_id, - access_controller=_file_access_controller, - ) + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=snippet.tenant_id, + access_controller=_file_access_controller, + ) elif variable.value_type == SegmentType.ARRAY_FILE: if not isinstance(raw_value, list): raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") if len(raw_value) > 0 and not isinstance(raw_value[0], dict): raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings( - mappings=raw_value, - tenant_id=snippet.tenant_id, - access_controller=_file_access_controller, - ) + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=snippet.tenant_id, + access_controller=_file_access_controller, + ) new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() @@ -218,12 +230,13 @@ class SnippetVariableApi(Resource): @console_ns.response(204, "Variable deleted successfully") @console_ns.response(404, "Variable not found") @_snippet_draft_var_prerequisite - def delete(self, snippet: CustomizedSnippet, variable_id: str) -> Response: + def delete(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> Response: draft_var_srv = WorkflowDraftVariableService(session=db.session()) - variable = _ensure_variable_access( + variable = ensure_variable_access( variable=draft_var_srv.get_variable(variable_id=variable_id), app_id=snippet.id, variable_id=variable_id, + current_user_id=current_user.id, ) _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id) draft_var_srv.delete_variable(variable) @@ -239,7 +252,7 @@ class SnippetVariableResetApi(Resource): @console_ns.response(204, "Variable reset (no content)") @console_ns.response(404, "Variable not found") @_snippet_draft_var_prerequisite - def put(self, snippet: CustomizedSnippet, variable_id: str) -> Response | Any: + def put(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> Response | Any: draft_var_srv = WorkflowDraftVariableService(session=db.session()) snippet_service = _snippet_service() draft_workflow = snippet_service.get_draft_workflow(snippet=snippet) @@ -247,10 +260,11 @@ class SnippetVariableResetApi(Resource): raise NotFoundError( f"Draft workflow not found, snippet_id={snippet.id}", ) - variable = _ensure_variable_access( + variable = ensure_variable_access( variable=draft_var_srv.get_variable(variable_id=variable_id), app_id=snippet.id, variable_id=variable_id, + current_user_id=current_user.id, ) _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id) @@ -270,7 +284,7 @@ class SnippetConversationVariableCollectionApi(Resource): @console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model) @_snippet_draft_var_prerequisite @marshal_with(workflow_draft_variable_list_model) - def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList: + def get(self, _current_user: Account, snippet: CustomizedSnippet) -> WorkflowDraftVariableList: return WorkflowDraftVariableList(variables=[]) @@ -283,7 +297,7 @@ class SnippetSystemVariableCollectionApi(Resource): @console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model) @_snippet_draft_var_prerequisite @marshal_with(workflow_draft_variable_list_model) - def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList: + def get(self, _current_user: Account, snippet: CustomizedSnippet) -> WorkflowDraftVariableList: return WorkflowDraftVariableList(variables=[]) @@ -294,7 +308,7 @@ class SnippetEnvironmentVariableCollectionApi(Resource): @console_ns.response(200, "Environment variables retrieved successfully") @console_ns.response(404, "Draft workflow not found") @_snippet_draft_var_prerequisite - def get(self, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]: + def get(self, _current_user: Account, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]: snippet_service = _snippet_service() workflow = snippet_service.get_draft_workflow(snippet=snippet) if workflow is None: diff --git a/api/openapi/markdown/console-swagger.md b/api/openapi/markdown/console-swagger.md index 7ea7d92f49..b3c199d558 100644 --- a/api/openapi/markdown/console-swagger.md +++ b/api/openapi/markdown/console-swagger.md @@ -7476,6 +7476,10 @@ Set datasource variables ### /rag/pipelines/{pipeline_id}/workflows/draft/environment-variables #### GET +##### Summary + +Get draft workflow + ##### Parameters | Name | Located in | Description | Required | Schema | @@ -7686,6 +7690,10 @@ Run draft workflow | 200 | Success | #### GET +##### Summary + +Get draft workflow + ##### Parameters | Name | Located in | Description | Required | Schema | @@ -7735,6 +7743,7 @@ Run draft workflow | ---- | ---------- | ----------- | -------- | ------ | | pipeline_id | path | | Yes | string | | variable_id | path | | Yes | string | +| payload | body | | Yes | [WorkflowDraftVariablePatchPayload](#workflowdraftvariablepatchpayload) | ##### Responses diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index 5edae75f52..1baac42368 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -59,6 +59,7 @@ from controllers.console.app.workflow_app_log import WorkflowAppLogQuery from controllers.console.app.workflow_draft_variable import WorkflowDraftVariableUpdatePayload from controllers.console.app.workflow_statistic import WorkflowStatisticQuery from controllers.console.app.workflow_trigger import Parser, ParserEnable +from models.account import Account, AccountStatus from models.model import AppMode from tests.test_containers_integration_tests.controllers.console.helpers import ( authenticate_console_client, @@ -76,6 +77,16 @@ def _unwrap(func): return func +def _make_account() -> Account: + account = Account( + name="tester", + email="tester@example.com", + status=AccountStatus.ACTIVE, + ) + account.id = "user-1" # type: ignore[assignment] + return account + + class TestCompletionEndpoints: @pytest.fixture def app(self, flask_app_with_containers: Flask): @@ -99,13 +110,6 @@ class TestCompletionEndpoints: api = completion_module.CompletionMessageApi() method = _unwrap(api.post) - class DummyAccount: - pass - - dummy_account = DummyAccount() - - monkeypatch.setattr(completion_module, "current_user", dummy_account) - monkeypatch.setattr(completion_module, "Account", DummyAccount) monkeypatch.setattr( completion_module.AppGenerateService, "generate", @@ -121,7 +125,7 @@ class TestCompletionEndpoints: "/", json={"inputs": {}, "model_config": {}, "query": "hi"}, ): - resp = method(app_model=MagicMock(id="app-1")) + resp = method(_make_account(), app_model=MagicMock(id="app-1")) assert resp == {"result": {"text": "ok"}} @@ -129,13 +133,6 @@ class TestCompletionEndpoints: api = completion_module.CompletionMessageApi() method = _unwrap(api.post) - class DummyAccount: - pass - - dummy_account = DummyAccount() - - monkeypatch.setattr(completion_module, "current_user", dummy_account) - monkeypatch.setattr(completion_module, "Account", DummyAccount) monkeypatch.setattr( completion_module.AppGenerateService, "generate", @@ -149,19 +146,12 @@ class TestCompletionEndpoints: json={"inputs": {}, "model_config": {}, "query": "hi"}, ): with pytest.raises(NotFound): - method(app_model=MagicMock(id="app-1")) + method(_make_account(), app_model=MagicMock(id="app-1")) def test_completion_api_provider_not_initialized(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) - class DummyAccount: - pass - - dummy_account = DummyAccount() - - monkeypatch.setattr(completion_module, "current_user", dummy_account) - monkeypatch.setattr(completion_module, "Account", DummyAccount) monkeypatch.setattr( completion_module.AppGenerateService, "generate", @@ -173,19 +163,12 @@ class TestCompletionEndpoints: json={"inputs": {}, "model_config": {}, "query": "hi"}, ): with pytest.raises(completion_module.ProviderNotInitializeError): - method(app_model=MagicMock(id="app-1")) + method(_make_account(), app_model=MagicMock(id="app-1")) def test_completion_api_quota_exceeded(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) - class DummyAccount: - pass - - dummy_account = DummyAccount() - - monkeypatch.setattr(completion_module, "current_user", dummy_account) - monkeypatch.setattr(completion_module, "Account", DummyAccount) monkeypatch.setattr( completion_module.AppGenerateService, "generate", @@ -197,7 +180,7 @@ class TestCompletionEndpoints: json={"inputs": {}, "model_config": {}, "query": "hi"}, ): with pytest.raises(completion_module.ProviderQuotaExceededError): - method(app_model=MagicMock(id="app-1")) + method(_make_account(), app_model=MagicMock(id="app-1")) class TestAppEndpoints: @@ -517,7 +500,6 @@ class TestWorkflowDraftVariableEndpoints: method = _unwrap(api.get) monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock())) - monkeypatch.setattr(workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1")) class DummySessionCtx: def __enter__(self): @@ -550,7 +532,7 @@ class TestWorkflowDraftVariableEndpoints: monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService) with app.test_request_context("/?page=1&limit=20"): - result = method(app_model=SimpleNamespace(id="app-1")) + result = method(_make_account(), app_model=SimpleNamespace(id="app-1")) assert result == {"items": [], "total": 0} diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py index d29946b65e..cd0ceee2b1 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py @@ -50,7 +50,6 @@ def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) monkeypatch.setattr(app_wraps, "_load_app_model_from_scoped_session", lambda _app_id: app_model) - monkeypatch.setattr(workflow_comment_module, "current_user", account) def _patch_write_services(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index 78e1b0c46f..aa06aeabc8 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from datetime import datetime from types import SimpleNamespace from unittest.mock import Mock @@ -7,40 +8,14 @@ from unittest.mock import Mock import pytest from flask import Flask -from controllers.console import wraps as console_wraps from controllers.console.app import workflow_run as workflow_run_module from controllers.web.error import NotFoundError from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import ParagraphInputConfig, UserActionConfig -from libs import login as login_lib -from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun -def _make_account() -> Account: - account = Account( - name="tester", - email="tester@example.com", - status=AccountStatus.ACTIVE, - ) - account.role = TenantAccountRole.OWNER - account.id = "account-123" # type: ignore[assignment] - account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] - account._get_current_object = lambda: account # type: ignore[attr-defined] - return account - - -def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account) -> None: - monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) - monkeypatch.setattr(login_lib, "current_user", account) - monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) - monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(workflow_run_module, "current_user", account) - monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") - - class _PauseEntity: def __init__(self, paused_at: datetime, reasons: list[HumanInputRequired]): self.paused_at = paused_at @@ -51,8 +26,6 @@ class _PauseEntity: def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: - account = _make_account() - _patch_console_guards(monkeypatch, account) monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com") workflow_run = Mock(spec=WorkflowRun) @@ -86,7 +59,12 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte ) with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): - response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") + handler = inspect.unwrap(workflow_run_module.ConsoleWorkflowPauseDetailsApi.get) + response, status = handler( + workflow_run_module.ConsoleWorkflowPauseDetailsApi(), + "tenant-123", + workflow_run_id="run-1", + ) assert status == 200 assert response["paused_at"] == "2024-01-01T12:00:00Z" @@ -100,8 +78,6 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte def test_pause_details_tenant_isolation(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: - account = _make_account() - _patch_console_guards(monkeypatch, account) monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com") workflow_run = Mock(spec=WorkflowRun) @@ -111,15 +87,17 @@ def test_pause_details_tenant_isolation(app: Flask, monkeypatch: pytest.MonkeyPa fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run)) monkeypatch.setattr(workflow_run_module, "db", fake_db) - with pytest.raises(NotFoundError): - with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): - response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") + handler = inspect.unwrap(workflow_run_module.ConsoleWorkflowPauseDetailsApi.get) + with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): + with pytest.raises(NotFoundError): + handler( + workflow_run_module.ConsoleWorkflowPauseDetailsApi(), + "tenant-123", + workflow_run_id="run-1", + ) def test_pause_details_returns_empty_response_for_non_paused_run(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: - account = _make_account() - _patch_console_guards(monkeypatch, account) - workflow_run = Mock(spec=WorkflowRun) workflow_run.tenant_id = "tenant-123" workflow_run.status = WorkflowExecutionStatus.RUNNING @@ -127,7 +105,12 @@ def test_pause_details_returns_empty_response_for_non_paused_run(app: Flask, mon monkeypatch.setattr(workflow_run_module, "db", fake_db) with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): - response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") + handler = inspect.unwrap(workflow_run_module.ConsoleWorkflowPauseDetailsApi.get) + response, status = handler( + workflow_run_module.ConsoleWorkflowPauseDetailsApi(), + "tenant-123", + workflow_run_id="run-1", + ) assert status == 200 assert response == {"paused_at": None, "paused_nodes": []} diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py index e225e31563..c76cb8d5d7 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py @@ -9,6 +9,7 @@ from flask import Flask from flask_restx import marshal from controllers.console.app import workflow_run as workflow_run_module +from models import Account def _unwrap(func): @@ -32,6 +33,12 @@ def _account() -> SimpleNamespace: return SimpleNamespace(id="account-1", name="Alice", email="alice@example.com") +def _current_account() -> Account: + account = Account(name="Alice", email="alice@example.com") + account.id = "account-1" + return account + + def _workflow_run_summary(**overrides) -> SimpleNamespace: created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) payload = { @@ -208,13 +215,14 @@ def test_workflow_run_node_executions_return_frontend_trace_contract( return [_workflow_run_node_execution()] monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService) - monkeypatch.setattr(workflow_run_module, "current_user", SimpleNamespace(id="account-1")) api = workflow_run_module.WorkflowRunNodeExecutionListApi() handler = _unwrap(api.get) with app.test_request_context("/apps/app-1/workflow-runs/run-1/node-executions", method="GET"): - payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"), run_id="run-1") + payload = handler( + api, _current_account(), app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"), run_id="run-1" + ) response = _serialize_200_response(api.get, payload) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py index 5363aa154f..14386efda3 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py @@ -1,8 +1,13 @@ from __future__ import annotations +import inspect from datetime import UTC, datetime from types import SimpleNamespace +from unittest.mock import MagicMock, PropertyMock, patch +from flask import Flask + +from controllers.console import console_ns from controllers.console.app import workflow_trigger as workflow_trigger_module @@ -52,3 +57,65 @@ def test_webhook_trigger_response_serializes_datetime(): payload = workflow_trigger_module.WebhookTriggerResponse.model_validate(webhook).model_dump(mode="json") assert payload["webhook_id"] == "whk-1" assert payload["created_at"] == "2026-01-02T03:04:05Z" + + +def test_app_triggers_get_uses_injected_tenant_id(app: Flask) -> None: + trigger = SimpleNamespace( + id="trigger-1", + trigger_type="trigger-plugin", + title="Trigger", + node_id="node-1", + provider_name="provider", + icon="", + status="enabled", + created_at=None, + updated_at=None, + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.all.return_value = [trigger] + + api = workflow_trigger_module.AppTriggersApi() + method = inspect.unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object(type(workflow_trigger_module.db), "engine", new_callable=PropertyMock, return_value=MagicMock()), + patch("controllers.console.app.workflow_trigger.sessionmaker") as sessionmaker_mock, + ): + sessionmaker_mock.return_value.begin.return_value.__enter__.return_value = session + response = method(api, "tenant-1", SimpleNamespace(id="app-1")) + + assert response["data"][0]["id"] == "trigger-1" + assert response["data"][0]["icon"].endswith("/provider/icon") + + +def test_app_trigger_enable_uses_injected_tenant_id(app: Flask) -> None: + trigger = SimpleNamespace( + id="trigger-1", + trigger_type="trigger-plugin", + title="Trigger", + node_id="node-1", + provider_name="provider", + icon="", + status="disabled", + created_at=None, + updated_at=None, + ) + session = MagicMock() + session.execute.return_value.scalar_one_or_none.return_value = trigger + payload = {"trigger_id": "trigger-1", "enable_trigger": True} + + api = workflow_trigger_module.AppTriggerEnableApi() + method = inspect.unwrap(api.post) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object(type(workflow_trigger_module.db), "engine", new_callable=PropertyMock, return_value=MagicMock()), + patch("controllers.console.app.workflow_trigger.sessionmaker") as sessionmaker_mock, + ): + sessionmaker_mock.return_value.begin.return_value.__enter__.return_value = session + response = method(api, "tenant-1", SimpleNamespace(id="app-1")) + + assert response["id"] == "trigger-1" + assert response["status"] == "enabled" diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py index d4c6a775ec..2cf8947014 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py @@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from werkzeug.exceptions import Forbidden from controllers.console import console_ns from controllers.console.datasets.rag_pipeline.datasource_content_preview import ( @@ -18,6 +17,12 @@ def unwrap(func): return func +def make_account() -> Account: + account = Account(name="Test User", email="user@example.com") + account.id = "account-1" + return account + + class TestDataSourceContentPreviewApi: def _valid_payload(self): return { @@ -34,7 +39,7 @@ class TestDataSourceContentPreviewApi: pipeline = MagicMock(spec=Pipeline) node_id = "node-1" - account = MagicMock(spec=Account) + account = make_account() preview_result = {"content": "preview data"} @@ -44,16 +49,12 @@ class TestDataSourceContentPreviewApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", - account, - ), patch( "controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService", return_value=service_instance, ), ): - response, status = method(api, pipeline, node_id) + response, status = method(api, account, pipeline, node_id) service_instance.run_datasource_node_preview.assert_called_once_with( pipeline=pipeline, @@ -67,25 +68,6 @@ class TestDataSourceContentPreviewApi: assert status == 200 assert response == preview_result - def test_post_forbidden_non_account_user(self, app: Flask): - api = DataSourceContentPreviewApi() - method = unwrap(api.post) - - payload = self._valid_payload() - - pipeline = MagicMock(spec=Pipeline) - - with ( - app.test_request_context("/", json=payload), - patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", - MagicMock(), # NOT Account - ), - ): - with pytest.raises(Forbidden): - method(api, pipeline, "node-1") - def test_post_invalid_payload(self, app: Flask): api = DataSourceContentPreviewApi() method = unwrap(api.post) @@ -96,18 +78,14 @@ class TestDataSourceContentPreviewApi: } pipeline = MagicMock(spec=Pipeline) - account = MagicMock(spec=Account) + account = make_account() with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", - account, - ), ): with pytest.raises(ValueError): - method(api, pipeline, "node-1") + method(api, account, pipeline, "node-1") def test_post_without_credential_id(self, app: Flask): api = DataSourceContentPreviewApi() @@ -120,7 +98,7 @@ class TestDataSourceContentPreviewApi: } pipeline = MagicMock(spec=Pipeline) - account = MagicMock(spec=Account) + account = make_account() service_instance = MagicMock() service_instance.run_datasource_node_preview.return_value = {"ok": True} @@ -128,16 +106,12 @@ class TestDataSourceContentPreviewApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", - account, - ), patch( "controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService", return_value=service_instance, ), ): - response, status = method(api, pipeline, "node-1") + response, status = method(api, account, pipeline, "node-1") service_instance.run_datasource_node_preview.assert_called_once() assert status == 200 diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index 8a65c4bbe5..9b491d63aa 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -16,7 +16,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor from controllers.web.error import InvalidArgumentError, NotFoundError from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from graphon.variables.types import SegmentType -from models.account import Account +from models.account import Account, TenantAccountRole def unwrap(func): @@ -34,9 +34,10 @@ def fake_db(): @pytest.fixture -def editor_user(): - user = MagicMock(spec=Account) - user.has_edit_permission = True +def editor_user() -> Account: + user = Account(name="Test User", email="user@example.com") + user.id = "account-1" + user.role = TenantAccountRole.EDITOR return user @@ -65,7 +66,6 @@ class TestRagPipelineVariableCollectionApi: with ( app.test_request_context("/?page=1&limit=10"), restx_config, - patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", @@ -76,9 +76,15 @@ class TestRagPipelineVariableCollectionApi: return_value=draft_srv, ), ): - result = method(api, pipeline) + result = method(api, editor_user, pipeline) - assert result["items"] == [] + assert result is var_list + draft_srv.list_variables_without_values.assert_called_once_with( + app_id="p1", + page=1, + limit=10, + user_id="account-1", + ) def test_get_variables_workflow_not_exist(self, app: Flask, fake_db, editor_user): api = RagPipelineVariableCollectionApi() @@ -91,7 +97,6 @@ class TestRagPipelineVariableCollectionApi: with ( app.test_request_context("/"), - patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", @@ -99,7 +104,7 @@ class TestRagPipelineVariableCollectionApi: ), ): with pytest.raises(DraftWorkflowNotExist): - method(api, pipeline) + method(api, editor_user, pipeline) def test_delete_variables_success(self, app: Flask, fake_db, editor_user): api = RagPipelineVariableCollectionApi() @@ -109,11 +114,10 @@ class TestRagPipelineVariableCollectionApi: with ( app.test_request_context("/"), - patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService"), ): - result = method(api, pipeline) + result = method(api, editor_user, pipeline) assert isinstance(result, Response) assert result.status_code == 204 @@ -135,16 +139,16 @@ class TestRagPipelineNodeVariableCollectionApi: with ( app.test_request_context("/"), restx_config, - patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", return_value=srv, ), ): - result = method(api, pipeline, "node1") + result = method(api, editor_user, pipeline, "node1") - assert result["items"] == [] + assert result is var_list + srv.list_node_variables.assert_called_once_with("p1", "node1", user_id="account-1") def test_get_node_variables_invalid_node(self, app: Flask, editor_user): api = RagPipelineNodeVariableCollectionApi() @@ -152,10 +156,9 @@ class TestRagPipelineNodeVariableCollectionApi: with ( app.test_request_context("/"), - patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), ): with pytest.raises(InvalidArgumentError): - method(api, MagicMock(), SYSTEM_VARIABLE_NODE_ID) + method(api, editor_user, MagicMock(), SYSTEM_VARIABLE_NODE_ID) class TestRagPipelineVariableApi: @@ -168,7 +171,6 @@ class TestRagPipelineVariableApi: with ( app.test_request_context("/"), - patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", @@ -176,7 +178,7 @@ class TestRagPipelineVariableApi: ), ): with pytest.raises(NotFoundError): - method(api, MagicMock(), "v1") + method(api, editor_user, MagicMock(), "v1") def test_patch_variable_invalid_file_payload(self, app: Flask, fake_db, editor_user): api = RagPipelineVariableApi() @@ -193,7 +195,6 @@ class TestRagPipelineVariableApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", @@ -201,7 +202,7 @@ class TestRagPipelineVariableApi: ), ): with pytest.raises(InvalidArgumentError): - method(api, pipeline, "v1") + method(api, editor_user, pipeline, "v1") def test_delete_variable_success(self, app: Flask, fake_db, editor_user): api = RagPipelineVariableApi() @@ -215,14 +216,13 @@ class TestRagPipelineVariableApi: with ( app.test_request_context("/"), - patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", return_value=srv, ), ): - result = method(api, pipeline, "v1") + result = method(api, editor_user, pipeline, "v1") assert result.status_code == 204 @@ -245,7 +245,6 @@ class TestRagPipelineVariableResetApi: with ( app.test_request_context("/"), - patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", @@ -260,7 +259,7 @@ class TestRagPipelineVariableResetApi: return_value={"id": "v1"}, ), ): - result = method(api, pipeline, "v1") + result = method(api, editor_user, pipeline, "v1") assert result == {"id": "v1"} @@ -281,16 +280,16 @@ class TestSystemAndEnvironmentVariablesApi: with ( app.test_request_context("/"), restx_config, - patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", return_value=srv, ), ): - result = method(api, pipeline) + result = method(api, editor_user, pipeline) - assert result["items"] == [] + assert result is var_list + srv.list_system_variables.assert_called_once_with("p1", user_id="account-1") def test_environment_variables_success(self, app: Flask, editor_user): api = RagPipelineEnvironmentVariableCollectionApi() @@ -313,12 +312,11 @@ class TestSystemAndEnvironmentVariablesApi: with ( app.test_request_context("/"), - patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", return_value=rag_srv, ), ): - result = method(api, pipeline) + result = method(api, editor_user, pipeline) assert len(result["items"]) == 1 diff --git a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py index 89cbea5ddc..0121d5c424 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py @@ -1,8 +1,9 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import patch from flask import Flask import controllers.console.explore.recommended_app as module +from models import Account from models.model import AppMode, IconType @@ -12,6 +13,13 @@ def unwrap(func): return func +def make_account(interface_language: str | None) -> Account: + account = Account(name="Test User", email="user@example.com") + account.id = "account-1" + account.interface_language = interface_language + return account + + class TestRecommendedAppListApi: def test_get_with_language_param(self, app: Flask): api = module.RecommendedAppListApi() @@ -21,14 +29,13 @@ class TestRecommendedAppListApi: with ( app.test_request_context("/", query_string={"language": "en-US"}), - patch.object(module, "current_user", MagicMock(interface_language="fr-FR")), patch.object( module.RecommendedAppService, "get_recommended_apps_and_categories", return_value=result_data, ) as service_mock, ): - result = method(api) + result = method(api, make_account("fr-FR")) service_mock.assert_called_once_with("en-US") assert result == result_data @@ -41,14 +48,13 @@ class TestRecommendedAppListApi: with ( app.test_request_context("/", query_string={"language": "invalid"}), - patch.object(module, "current_user", MagicMock(interface_language="fr-FR")), patch.object( module.RecommendedAppService, "get_recommended_apps_and_categories", return_value=result_data, ) as service_mock, ): - result = method(api) + result = method(api, make_account("fr-FR")) service_mock.assert_called_once_with("fr-FR") assert result == result_data @@ -61,14 +67,13 @@ class TestRecommendedAppListApi: with ( app.test_request_context("/"), - patch.object(module, "current_user", MagicMock(interface_language=None)), patch.object( module.RecommendedAppService, "get_recommended_apps_and_categories", return_value=result_data, ) as service_mock, ): - result = method(api) + result = method(api, make_account(None)) service_mock.assert_called_once_with(module.languages[0]) assert result == result_data diff --git a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py index fbbeab0eb8..b885b9d601 100644 --- a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py @@ -1,4 +1,3 @@ -import importlib from types import SimpleNamespace from unittest.mock import Mock @@ -7,10 +6,9 @@ from flask import Flask from controllers.console.snippets import snippet_workflow_draft_variable as module from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from models.account import Account, AccountStatus from services.workflow_draft_variable_service import WorkflowDraftVariableList -app_workflow_draft_variable_module = importlib.import_module("controllers.console.app.workflow_draft_variable") - def _unwrap(func): while hasattr(func, "__wrapped__"): @@ -18,6 +16,16 @@ def _unwrap(func): return func +def _make_account() -> Account: + account = Account( + name="tester", + email="tester@example.com", + status=AccountStatus.ACTIVE, + ) + account.id = "user-1" # type: ignore[assignment] + return account + + @pytest.fixture(autouse=True) def _patch_snippet_service_factory(monkeypatch): def factory(): @@ -61,7 +69,7 @@ def test_conversation_variables_returns_empty_list(app): handler = _unwrap(api.get) with app.test_request_context("/"): - result = handler(api, snippet=SimpleNamespace(id="snippet-1")) + result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1")) assert result == WorkflowDraftVariableList(variables=[]) @@ -71,7 +79,7 @@ def test_system_variables_returns_empty_list(app): handler = _unwrap(api.get) with app.test_request_context("/"): - result = handler(api, snippet=SimpleNamespace(id="snippet-1")) + result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1")) assert result == WorkflowDraftVariableList(variables=[]) @@ -79,7 +87,6 @@ def test_system_variables_returns_empty_list(app): def test_delete_variable_collection_deletes_current_user_variables(app, monkeypatch): draft_var_service = SimpleNamespace(delete_user_workflow_variables=Mock()) monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service)) - monkeypatch.setattr(module, "current_user", SimpleNamespace(id="user-1")) db_session = Mock() db_session.return_value = SimpleNamespace() monkeypatch.setattr(module.db, "session", db_session) @@ -87,7 +94,7 @@ def test_delete_variable_collection_deletes_current_user_variables(app, monkeypa handler = _unwrap(api.delete) with app.test_request_context("/", method="DELETE"): - response = handler(api, snippet=SimpleNamespace(id="snippet-1")) + response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1")) assert response.status_code == 204 draft_var_service.delete_user_workflow_variables.assert_called_once_with("snippet-1", user_id="user-1") @@ -106,7 +113,7 @@ def test_variable_collection_get_raises_when_draft_workflow_missing(app, monkeyp with app.test_request_context("/?page=1&limit=20"): with pytest.raises(module.DraftWorkflowNotExist): - handler(api, snippet=SimpleNamespace(id="snippet-1")) + handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1")) def test_node_variable_collection_get_lists_node_variables(app, monkeypatch): @@ -126,7 +133,6 @@ def test_node_variable_collection_get_lists_node_variables(app, monkeypatch): monkeypatch.setattr(module, "Session", SessionContext) monkeypatch.setattr(module, "db", SimpleNamespace(engine=object())) - monkeypatch.setattr(module, "current_user", SimpleNamespace(id="user-1")) monkeypatch.setattr( module, "WorkflowDraftVariableService", @@ -137,7 +143,7 @@ def test_node_variable_collection_get_lists_node_variables(app, monkeypatch): handler = _unwrap(api.get) with app.test_request_context("/"): - result = handler(api, snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1") + result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1") assert result is variables list_node_variables.assert_called_once_with("snippet-1", "llm-1", user_id="user-1") @@ -147,7 +153,6 @@ def test_node_variable_collection_delete_deletes_node_variables(app, monkeypatch delete_node_variables = Mock() draft_var_service = SimpleNamespace(delete_node_variables=delete_node_variables) monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service)) - monkeypatch.setattr(module, "current_user", SimpleNamespace(id="user-1")) db_session = Mock() db_session.return_value = SimpleNamespace() monkeypatch.setattr(module.db, "session", db_session) @@ -156,7 +161,7 @@ def test_node_variable_collection_delete_deletes_node_variables(app, monkeypatch handler = _unwrap(api.delete) with app.test_request_context("/", method="DELETE"): - response = handler(api, snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1") + response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1") assert response.status_code == 204 delete_node_variables.assert_called_once_with("snippet-1", "llm-1", user_id="user-1") @@ -169,15 +174,18 @@ def test_variable_patch_returns_variable_when_no_changes(app, monkeypatch): db_session = Mock() db_session.return_value = SimpleNamespace() monkeypatch.setattr(module.db, "session", db_session) - monkeypatch.setattr(module, "current_user", SimpleNamespace(id="user-1")) - monkeypatch.setattr(app_workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1")) monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service)) api = module.SnippetVariableApi() handler = _unwrap(api.patch) with app.test_request_context("/", method="PATCH", json={}): - result = handler(api, snippet=SimpleNamespace(id="snippet-1", tenant_id="tenant-1"), variable_id="var-1") + result = handler( + api, + _make_account(), + snippet=SimpleNamespace(id="snippet-1", tenant_id="tenant-1"), + variable_id="var-1", + ) assert result is variable draft_var_service.update_variable.assert_not_called() @@ -191,15 +199,13 @@ def test_variable_delete_deletes_variable(app, monkeypatch): db_session = Mock() db_session.return_value = SimpleNamespace() monkeypatch.setattr(module.db, "session", db_session) - monkeypatch.setattr(module, "current_user", SimpleNamespace(id="user-1")) - monkeypatch.setattr(app_workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1")) monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service)) api = module.SnippetVariableApi() handler = _unwrap(api.delete) with app.test_request_context("/", method="DELETE"): - response = handler(api, snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1") + response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1") assert response.status_code == 204 delete_variable.assert_called_once_with(variable) @@ -216,8 +222,6 @@ def test_variable_reset_returns_no_content_when_reset_result_is_none(app, monkey db_session = Mock() db_session.return_value = SimpleNamespace() monkeypatch.setattr(module.db, "session", db_session) - monkeypatch.setattr(module, "current_user", SimpleNamespace(id="user-1")) - monkeypatch.setattr(app_workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1")) monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service)) monkeypatch.setattr( module, @@ -229,7 +233,7 @@ def test_variable_reset_returns_no_content_when_reset_result_is_none(app, monkey handler = _unwrap(api.put) with app.test_request_context("/", method="PUT"): - response = handler(api, snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1") + response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1") assert response.status_code == 204 draft_var_service.reset_variable.assert_called_once_with(draft_workflow, variable) @@ -259,7 +263,7 @@ def test_environment_variables_returns_workflow_environment_variables(app, monke handler = _unwrap(api.get) with app.test_request_context("/"): - result = handler(api, snippet=SimpleNamespace(id="snippet-1")) + result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1")) assert result == { "items": [ diff --git a/packages/contracts/generated/api/console/rag/orpc.gen.ts b/packages/contracts/generated/api/console/rag/orpc.gen.ts index 6e627be69e..a4d22839fe 100644 --- a/packages/contracts/generated/api/console/rag/orpc.gen.ts +++ b/packages/contracts/generated/api/console/rag/orpc.gen.ts @@ -67,6 +67,7 @@ import { zPatchRagPipelineCustomizedTemplatesByTemplateIdResponse, zPatchRagPipelinesByPipelineIdWorkflowsByWorkflowIdPath, zPatchRagPipelinesByPipelineIdWorkflowsByWorkflowIdResponse, + zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdBody, zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdPath, zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdResponse, zPostRagPipelineCustomizedTemplatesByTemplateIdPath, @@ -677,6 +678,8 @@ export const datasource = { } /** + * Get draft workflow + * * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. * * @deprecated @@ -690,6 +693,7 @@ export const get12 = oc method: 'GET', operationId: 'getRagPipelinesByPipelineIdWorkflowsDraftEnvironmentVariables', path: '/rag/pipelines/{pipeline_id}/workflows/draft/environment-variables', + summary: 'Get draft workflow', tags: ['console'], }) .input(z.object({ params: zGetRagPipelinesByPipelineIdWorkflowsDraftEnvironmentVariablesPath })) @@ -1084,7 +1088,10 @@ export const patch2 = oc tags: ['console'], }) .input( - z.object({ params: zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdPath }), + z.object({ + body: zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdBody, + params: zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdPath, + }), ) .output(zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdResponse) @@ -1115,6 +1122,8 @@ export const delete4 = oc .output(zDeleteRagPipelinesByPipelineIdWorkflowsDraftVariablesResponse) /** + * Get draft workflow + * * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. * * @deprecated @@ -1128,6 +1137,7 @@ export const get19 = oc method: 'GET', operationId: 'getRagPipelinesByPipelineIdWorkflowsDraftVariables', path: '/rag/pipelines/{pipeline_id}/workflows/draft/variables', + summary: 'Get draft workflow', tags: ['console'], }) .input(z.object({ params: zGetRagPipelinesByPipelineIdWorkflowsDraftVariablesPath })) diff --git a/packages/contracts/generated/api/console/rag/types.gen.ts b/packages/contracts/generated/api/console/rag/types.gen.ts index 133431974d..ad46a4b276 100644 --- a/packages/contracts/generated/api/console/rag/types.gen.ts +++ b/packages/contracts/generated/api/console/rag/types.gen.ts @@ -236,6 +236,11 @@ export type DraftWorkflowRunPayload = { start_node_id: string } +export type WorkflowDraftVariablePatchPayload = { + name?: string | null + value?: unknown +} + export type RagPipelineWorkflowPublishResponse = { created_at: number result: string @@ -1182,7 +1187,7 @@ export type GetRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdRespon = GetRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdResponses[keyof GetRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdResponses] export type PatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdData = { - body?: never + body: WorkflowDraftVariablePatchPayload path: { pipeline_id: string variable_id: string diff --git a/packages/contracts/generated/api/console/rag/zod.gen.ts b/packages/contracts/generated/api/console/rag/zod.gen.ts index ac7bea3ef7..9c0c586bc6 100644 --- a/packages/contracts/generated/api/console/rag/zod.gen.ts +++ b/packages/contracts/generated/api/console/rag/zod.gen.ts @@ -113,6 +113,14 @@ export const zDraftWorkflowRunPayload = z.object({ start_node_id: z.string(), }) +/** + * WorkflowDraftVariablePatchPayload + */ +export const zWorkflowDraftVariablePatchPayload = z.object({ + name: z.string().nullish(), + value: z.unknown().optional(), +}) + /** * RagPipelineWorkflowPublishResponse */ @@ -1006,6 +1014,9 @@ export const zGetRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdResp z.unknown(), ) +export const zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdBody + = zWorkflowDraftVariablePatchPayload + export const zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdPath = z.object({ pipeline_id: z.string(), variable_id: z.string(),