mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:13:59 +08:00
refactor(api): migrate tenant/user via DI for several endpoints (#37114)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1231c2f976
commit
6b12152ce8
@ -23,6 +23,7 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
with_current_user_id,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
@ -36,7 +37,7 @@ from core.helper.trace_id_helper import get_external_trace_id
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -104,7 +105,8 @@ class CompletionMessageApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
@ -112,8 +114,6 @@ class CompletionMessageApi(Resource):
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account or EndUser instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
)
|
||||
@ -178,7 +178,8 @@ class ChatMessageApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
raw_payload = console_ns.payload or {}
|
||||
args_model = ChatMessagePayload.model_validate(raw_payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
@ -197,8 +198,6 @@ class ChatMessageApi(Resource):
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account or EndUser instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
)
|
||||
|
||||
@ -7,12 +7,18 @@ from pydantic import BaseModel, Field, TypeAdapter, computed_field, field_valida
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import AccountWithRole
|
||||
from libs.helper import build_avatar_url, dump_response, to_timestamp
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
@ -213,9 +219,10 @@ class WorkflowCommentListApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return WorkflowCommentBasicList.model_validate({"data": comments}).model_dump(mode="json")
|
||||
|
||||
@ -229,12 +236,14 @@ class WorkflowCommentListApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
@ -258,10 +267,11 @@ class WorkflowCommentDetailApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
tenant_id=current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return dump_response(WorkflowCommentDetail, comment)
|
||||
@ -276,12 +286,14 @@ class WorkflowCommentDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
@ -302,10 +314,12 @@ class WorkflowCommentDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
@ -327,10 +341,12 @@ class WorkflowCommentResolveApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
@ -353,11 +369,13 @@ class WorkflowCommentReplyApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
comment_id=comment_id, tenant_id=current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {})
|
||||
@ -386,17 +404,19 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
comment_id=comment_id, tenant_id=current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
reply_id=reply_id,
|
||||
@ -416,15 +436,17 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
comment_id=comment_id, tenant_id=current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
reply_id=reply_id,
|
||||
@ -448,9 +470,13 @@ class WorkflowCommentMentionUsersApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
current_tenant = current_user.current_tenant # need the tenant object here
|
||||
if current_tenant is None:
|
||||
raise ValueError("current tenant is required")
|
||||
members = TenantService.get_tenant_members(current_tenant)
|
||||
users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersPayload(users=users)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
@ -15,7 +15,12 @@ from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
@ -27,8 +32,8 @@ from graphon.file import helpers as file_helpers
|
||||
from graphon.variables.segment_group import SegmentGroup
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode
|
||||
from libs.login import login_required
|
||||
from models import Account, App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -123,14 +128,15 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_variable_access(
|
||||
def ensure_variable_access(
|
||||
variable: WorkflowDraftVariable | None,
|
||||
app_id: str,
|
||||
variable_id: str,
|
||||
current_user_id: str,
|
||||
) -> WorkflowDraftVariable:
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_id or variable.user_id != current_user.id:
|
||||
if variable.app_id != app_id or variable.user_id != current_user_id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@ -215,7 +221,7 @@ workflow_draft_variable_list_model = console_ns.model(
|
||||
|
||||
|
||||
def _api_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, P], R],
|
||||
f: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
@ -232,9 +238,10 @@ def _api_prerequisite[T, **P, R](
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@wraps(f)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(self, *args, **kwargs)
|
||||
def wrapper(self: T, current_user: Account, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(self, current_user, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -251,7 +258,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
)
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, app_model: App):
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
@ -281,7 +288,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.doc(description="Delete all draft workflow variables")
|
||||
@console_ns.response(204, "Workflow variables deleted successfully")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App):
|
||||
def delete(self, current_user: Account, app_model: App):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -315,7 +322,7 @@ class NodeVariableCollectionApi(Resource):
|
||||
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
def get(self, current_user: Account, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
@ -329,7 +336,7 @@ class NodeVariableCollectionApi(Resource):
|
||||
@console_ns.doc(description="Delete all variables for a specific node")
|
||||
@console_ns.response(204, "Node variables deleted successfully")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
def delete(self, current_user: Account, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(app_model.id, node_id, user_id=current_user.id)
|
||||
@ -349,15 +356,16 @@ class VariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, app_model: App, variable_id: UUID):
|
||||
def get(self, current_user: Account, app_model: App, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
return variable
|
||||
|
||||
@ -368,7 +376,7 @@ class VariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, app_model: App, variable_id: UUID):
|
||||
def patch(self, current_user: Account, app_model: App, variable_id: UUID):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
@ -396,10 +404,11 @@ class VariableApi(Resource):
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
|
||||
new_name = args_model.name
|
||||
@ -440,15 +449,16 @@ class VariableApi(Resource):
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: UUID):
|
||||
def delete(self, current_user: Account, app_model: App, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
@ -464,7 +474,7 @@ class VariableResetApi(Resource):
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def put(self, app_model: App, variable_id: UUID):
|
||||
def put(self, current_user: Account, app_model: App, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -476,10 +486,11 @@ class VariableResetApi(Resource):
|
||||
f"Draft workflow not found, app_id={app_model.id}",
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
@ -490,20 +501,20 @@ class VariableResetApi(Resource):
|
||||
return marshal(resetted, workflow_draft_variable_model)
|
||||
|
||||
|
||||
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
def _get_variable_list(app_model: App, node_id: str, current_user_id: str) -> WorkflowDraftVariableList:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user_id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user_id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
return draft_vars
|
||||
|
||||
@ -517,7 +528,7 @@ class ConversationVariableCollectionApi(Resource):
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App):
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
|
||||
# so their IDs can be returned to the caller.
|
||||
workflow_srv = WorkflowService()
|
||||
@ -527,7 +538,7 @@ class ConversationVariableCollectionApi(Resource):
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID, current_user.id)
|
||||
|
||||
@console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_conversation_variables")
|
||||
@ -539,7 +550,8 @@ class ConversationVariableCollectionApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
@ -566,8 +578,8 @@ class SystemVariableCollectionApi(Resource):
|
||||
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID, current_user.id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||
@ -578,7 +590,7 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
@console_ns.response(200, "Environment variables retrieved successfully")
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_api_prerequisite
|
||||
def get(self, app_model: App):
|
||||
def get(self, _current_user: Account, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
@ -619,7 +631,8 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal, cast
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
@ -12,7 +12,12 @@ from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id
|
||||
from extensions.ext_database import db
|
||||
@ -30,8 +35,8 @@ from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
|
||||
from libs.custom_inputs import time_duration
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
||||
from libs.login import login_required
|
||||
from models import Account, App, AppMode, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
@ -190,8 +195,8 @@ class WorkflowRunExportApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
tenant_id = str(app_model.tenant_id)
|
||||
app_id = str(app_model.id)
|
||||
tenant_id = app_model.tenant_id
|
||||
app_id = app_model.id
|
||||
run_id_str = str(run_id)
|
||||
|
||||
run_created_at = db.session.scalar(
|
||||
@ -397,18 +402,18 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, run_id: UUID):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
"""
|
||||
run_id_str = str(run_id)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
user = cast("Account | EndUser", current_user)
|
||||
node_executions = workflow_run_service.get_workflow_run_node_executions(
|
||||
app_model=app_model,
|
||||
run_id=run_id_str,
|
||||
user=user,
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
return WorkflowRunNodeExecutionListResponse.model_validate(
|
||||
@ -432,7 +437,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workflow_run_id: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow pause details.
|
||||
|
||||
@ -449,7 +455,7 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
if not workflow_run:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
if workflow_run.tenant_id != current_user.current_tenant_id:
|
||||
if workflow_run.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
# Check if workflow is suspended
|
||||
|
||||
@ -12,14 +12,14 @@ from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import Account, App, AppMode
|
||||
from models.model import App, AppMode
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
|
||||
from .. import console_ns
|
||||
from ..app.wraps import get_app_model
|
||||
from ..wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from ..wraps import account_initialization_required, edit_permission_required, setup_required, with_current_tenant_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -124,18 +124,16 @@ class AppTriggersApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__])
|
||||
def get(self, app_model: App):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, app_model: App):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.tenant_id == current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||
@ -166,19 +164,18 @@ class AppTriggerEnableApi(Resource):
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__])
|
||||
def post(self, app_model: App):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
args = ParserEnable.model_validate(console_ns.payload)
|
||||
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
trigger_id = args.trigger_id
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
AppTrigger.id == trigger_id,
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.tenant_id == current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
@ -2,13 +2,12 @@ from flask_restx import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
@ -30,13 +29,11 @@ class DataSourceContentPreviewApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run datasource content preview
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
args = Parser.model_validate(console_ns.payload)
|
||||
|
||||
inputs = args.inputs
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, Concatenate, NoReturn
|
||||
from uuid import UUID
|
||||
|
||||
@ -21,7 +22,7 @@ from controllers.console.app.workflow_draft_variable import (
|
||||
workflow_draft_variable_model,
|
||||
)
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
@ -29,7 +30,7 @@ from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
@ -58,7 +59,7 @@ register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
|
||||
|
||||
def _api_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, P], R],
|
||||
f: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
@ -74,10 +75,12 @@ def _api_prerequisite[T, **P, R](
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
@with_current_user
|
||||
@wraps(f)
|
||||
def wrapper(self: T, current_user: Account, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(self, *args, **kwargs)
|
||||
return f(self, current_user, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -86,7 +89,7 @@ def _api_prerequisite[T, **P, R](
|
||||
class RagPipelineVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
@ -114,7 +117,7 @@ class RagPipelineVariableCollectionApi(Resource):
|
||||
return workflow_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline):
|
||||
def delete(self, current_user: Account, pipeline: Pipeline):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -145,7 +148,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, pipeline: Pipeline, node_id: str):
|
||||
def get(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
@ -156,7 +159,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
return node_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, node_id: str):
|
||||
def delete(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(pipeline.id, node_id, user_id=current_user.id)
|
||||
@ -171,7 +174,7 @@ class RagPipelineVariableApi(Resource):
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def get(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -186,7 +189,7 @@ class RagPipelineVariableApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
|
||||
def patch(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def patch(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
@ -255,7 +258,7 @@ class RagPipelineVariableApi(Resource):
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def delete(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -273,7 +276,7 @@ class RagPipelineVariableApi(Resource):
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class RagPipelineVariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def put(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -299,17 +302,17 @@ class RagPipelineVariableResetApi(Resource):
|
||||
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
|
||||
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
|
||||
def _get_variable_list(pipeline: Pipeline, node_id: str, current_user_id: str) -> WorkflowDraftVariableList:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id, user_id=current_user_id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(pipeline.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_system_variables(pipeline.id, user_id=current_user_id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id, user_id=current_user_id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
@ -317,14 +320,14 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
|
||||
class RagPipelineSystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
|
||||
def get(self, current_user: Account, pipeline: Pipeline):
|
||||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID, current_user.id)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables")
|
||||
class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, _current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
|
||||
@ -8,10 +8,11 @@ from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from constants.languages import languages
|
||||
from controllers.common.schema import query_params_from_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, with_current_user
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import build_icon_url
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
|
||||
@ -79,13 +80,14 @@ class RecommendedAppListApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
# language args
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
language = args.language
|
||||
if language and language in languages:
|
||||
language_prefix = language
|
||||
elif current_user and current_user.interface_language:
|
||||
elif current_user.interface_language:
|
||||
language_prefix = current_user.interface_language
|
||||
else:
|
||||
language_prefix = languages[0]
|
||||
|
||||
@ -12,7 +12,7 @@ Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`.
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import Any, Concatenate
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
@ -23,22 +23,28 @@ from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
WorkflowDraftVariableListQuery,
|
||||
WorkflowDraftVariableUpdatePayload,
|
||||
_ensure_variable_access,
|
||||
_file_access_controller,
|
||||
ensure_variable_access,
|
||||
validate_node_id,
|
||||
workflow_draft_variable_list_model,
|
||||
workflow_draft_variable_list_without_value_model,
|
||||
workflow_draft_variable_model,
|
||||
)
|
||||
from controllers.console.snippets.snippet_workflow import get_snippet
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.snippet import CustomizedSnippet
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.snippet_service import SnippetService
|
||||
@ -47,6 +53,7 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
|
||||
_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset(
|
||||
{SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID}
|
||||
)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
def _snippet_service() -> SnippetService:
|
||||
@ -63,7 +70,9 @@ def _ensure_snippet_draft_variable_row_allowed(
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
|
||||
def _snippet_draft_var_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
def _snippet_draft_var_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs)."""
|
||||
|
||||
@setup_required
|
||||
@ -71,9 +80,10 @@ def _snippet_draft_var_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(*args, **kwargs)
|
||||
def wrapper(self: T, current_user: Account, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(self, current_user, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -90,7 +100,7 @@ class SnippetWorkflowVariableCollectionApi(Resource):
|
||||
)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
def get(self, current_user: Account, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
@ -113,7 +123,7 @@ class SnippetWorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)")
|
||||
@console_ns.response(204, "Workflow variables deleted successfully")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, snippet: CustomizedSnippet) -> Response:
|
||||
def delete(self, current_user: Account, snippet: CustomizedSnippet) -> Response:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
@ -127,7 +137,7 @@ class SnippetNodeVariableCollectionApi(Resource):
|
||||
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList:
|
||||
def get(self, current_user: Account, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList:
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
@ -139,7 +149,7 @@ class SnippetNodeVariableCollectionApi(Resource):
|
||||
@console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)")
|
||||
@console_ns.response(204, "Node variables deleted successfully")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, snippet: CustomizedSnippet, node_id: str) -> Response:
|
||||
def delete(self, current_user: Account, snippet: CustomizedSnippet, node_id: str) -> Response:
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
@ -155,12 +165,13 @@ class SnippetVariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
def get(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
return variable
|
||||
@ -172,14 +183,15 @@ class SnippetVariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
def patch(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
@ -193,21 +205,21 @@ class SnippetVariableApi(Resource):
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
@ -218,12 +230,13 @@ class SnippetVariableApi(Resource):
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, snippet: CustomizedSnippet, variable_id: str) -> Response:
|
||||
def delete(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> Response:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
@ -239,7 +252,7 @@ class SnippetVariableResetApi(Resource):
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def put(self, snippet: CustomizedSnippet, variable_id: str) -> Response | Any:
|
||||
def put(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> Response | Any:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
snippet_service = _snippet_service()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
@ -247,10 +260,11 @@ class SnippetVariableResetApi(Resource):
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, snippet_id={snippet.id}",
|
||||
)
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
@ -270,7 +284,7 @@ class SnippetConversationVariableCollectionApi(Resource):
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
def get(self, _current_user: Account, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
@ -283,7 +297,7 @@ class SnippetSystemVariableCollectionApi(Resource):
|
||||
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
def get(self, _current_user: Account, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
@ -294,7 +308,7 @@ class SnippetEnvironmentVariableCollectionApi(Resource):
|
||||
@console_ns.response(200, "Environment variables retrieved successfully")
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def get(self, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]:
|
||||
def get(self, _current_user: Account, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]:
|
||||
snippet_service = _snippet_service()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if workflow is None:
|
||||
|
||||
@ -7476,6 +7476,10 @@ Set datasource variables
|
||||
### /rag/pipelines/{pipeline_id}/workflows/draft/environment-variables
|
||||
|
||||
#### GET
|
||||
##### Summary
|
||||
|
||||
Get draft workflow
|
||||
|
||||
##### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
@ -7686,6 +7690,10 @@ Run draft workflow
|
||||
| 200 | Success |
|
||||
|
||||
#### GET
|
||||
##### Summary
|
||||
|
||||
Get draft workflow
|
||||
|
||||
##### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
@ -7735,6 +7743,7 @@ Run draft workflow
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| pipeline_id | path | | Yes | string |
|
||||
| variable_id | path | | Yes | string |
|
||||
| payload | body | | Yes | [WorkflowDraftVariablePatchPayload](#workflowdraftvariablepatchpayload) |
|
||||
|
||||
##### Responses
|
||||
|
||||
|
||||
@ -59,6 +59,7 @@ from controllers.console.app.workflow_app_log import WorkflowAppLogQuery
|
||||
from controllers.console.app.workflow_draft_variable import WorkflowDraftVariableUpdatePayload
|
||||
from controllers.console.app.workflow_statistic import WorkflowStatisticQuery
|
||||
from controllers.console.app.workflow_trigger import Parser, ParserEnable
|
||||
from models.account import Account, AccountStatus
|
||||
from models.model import AppMode
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
@ -76,6 +77,16 @@ def _unwrap(func):
|
||||
return func
|
||||
|
||||
|
||||
def _make_account() -> Account:
|
||||
account = Account(
|
||||
name="tester",
|
||||
email="tester@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = "user-1" # type: ignore[assignment]
|
||||
return account
|
||||
|
||||
|
||||
class TestCompletionEndpoints:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
@ -99,13 +110,6 @@ class TestCompletionEndpoints:
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -121,7 +125,7 @@ class TestCompletionEndpoints:
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
resp = method(app_model=MagicMock(id="app-1"))
|
||||
resp = method(_make_account(), app_model=MagicMock(id="app-1"))
|
||||
|
||||
assert resp == {"result": {"text": "ok"}}
|
||||
|
||||
@ -129,13 +133,6 @@ class TestCompletionEndpoints:
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -149,19 +146,12 @@ class TestCompletionEndpoints:
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
method(_make_account(), app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_provider_not_initialized(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -173,19 +163,12 @@ class TestCompletionEndpoints:
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderNotInitializeError):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
method(_make_account(), app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_quota_exceeded(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
class DummyAccount:
|
||||
pass
|
||||
|
||||
dummy_account = DummyAccount()
|
||||
|
||||
monkeypatch.setattr(completion_module, "current_user", dummy_account)
|
||||
monkeypatch.setattr(completion_module, "Account", DummyAccount)
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -197,7 +180,7 @@ class TestCompletionEndpoints:
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderQuotaExceededError):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
method(_make_account(), app_model=MagicMock(id="app-1"))
|
||||
|
||||
|
||||
class TestAppEndpoints:
|
||||
@ -517,7 +500,6 @@ class TestWorkflowDraftVariableEndpoints:
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1"))
|
||||
|
||||
class DummySessionCtx:
|
||||
def __enter__(self):
|
||||
@ -550,7 +532,7 @@ class TestWorkflowDraftVariableEndpoints:
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService)
|
||||
|
||||
with app.test_request_context("/?page=1&limit=20"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
result = method(_make_account(), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
|
||||
@ -50,7 +50,6 @@ def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app
|
||||
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(app_wraps, "_load_app_model_from_scoped_session", lambda _app_id: app_model)
|
||||
monkeypatch.setattr(workflow_comment_module, "current_user", account)
|
||||
|
||||
|
||||
def _patch_write_services(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
@ -7,40 +8,14 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console import wraps as console_wraps
|
||||
from controllers.console.app import workflow_run as workflow_run_module
|
||||
from controllers.web.error import NotFoundError
|
||||
from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.nodes.human_input.entities import ParagraphInputConfig, UserActionConfig
|
||||
from libs import login as login_lib
|
||||
from models.account import Account, AccountStatus, TenantAccountRole
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
|
||||
def _make_account() -> Account:
|
||||
account = Account(
|
||||
name="tester",
|
||||
email="tester@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.role = TenantAccountRole.OWNER
|
||||
account.id = "account-123" # type: ignore[assignment]
|
||||
account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined]
|
||||
account._get_current_object = lambda: account # type: ignore[attr-defined]
|
||||
return account
|
||||
|
||||
|
||||
def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account) -> None:
|
||||
monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True)
|
||||
monkeypatch.setattr(login_lib, "current_user", account)
|
||||
monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None)
|
||||
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(workflow_run_module, "current_user", account)
|
||||
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
|
||||
|
||||
|
||||
class _PauseEntity:
|
||||
def __init__(self, paused_at: datetime, reasons: list[HumanInputRequired]):
|
||||
self.paused_at = paused_at
|
||||
@ -51,8 +26,6 @@ class _PauseEntity:
|
||||
|
||||
|
||||
def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = _make_account()
|
||||
_patch_console_guards(monkeypatch, account)
|
||||
monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com")
|
||||
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
@ -86,7 +59,12 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"):
|
||||
response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1")
|
||||
handler = inspect.unwrap(workflow_run_module.ConsoleWorkflowPauseDetailsApi.get)
|
||||
response, status = handler(
|
||||
workflow_run_module.ConsoleWorkflowPauseDetailsApi(),
|
||||
"tenant-123",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
assert response["paused_at"] == "2024-01-01T12:00:00Z"
|
||||
@ -100,8 +78,6 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte
|
||||
|
||||
|
||||
def test_pause_details_tenant_isolation(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = _make_account()
|
||||
_patch_console_guards(monkeypatch, account)
|
||||
monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com")
|
||||
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
@ -111,15 +87,17 @@ def test_pause_details_tenant_isolation(app: Flask, monkeypatch: pytest.MonkeyPa
|
||||
fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run))
|
||||
monkeypatch.setattr(workflow_run_module, "db", fake_db)
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"):
|
||||
response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1")
|
||||
handler = inspect.unwrap(workflow_run_module.ConsoleWorkflowPauseDetailsApi.get)
|
||||
with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(
|
||||
workflow_run_module.ConsoleWorkflowPauseDetailsApi(),
|
||||
"tenant-123",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
|
||||
def test_pause_details_returns_empty_response_for_non_paused_run(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = _make_account()
|
||||
_patch_console_guards(monkeypatch, account)
|
||||
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
workflow_run.tenant_id = "tenant-123"
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
@ -127,7 +105,12 @@ def test_pause_details_returns_empty_response_for_non_paused_run(app: Flask, mon
|
||||
monkeypatch.setattr(workflow_run_module, "db", fake_db)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"):
|
||||
response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1")
|
||||
handler = inspect.unwrap(workflow_run_module.ConsoleWorkflowPauseDetailsApi.get)
|
||||
response, status = handler(
|
||||
workflow_run_module.ConsoleWorkflowPauseDetailsApi(),
|
||||
"tenant-123",
|
||||
workflow_run_id="run-1",
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"paused_at": None, "paused_nodes": []}
|
||||
|
||||
@ -9,6 +9,7 @@ from flask import Flask
|
||||
from flask_restx import marshal
|
||||
|
||||
from controllers.console.app import workflow_run as workflow_run_module
|
||||
from models import Account
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
@ -32,6 +33,12 @@ def _account() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="account-1", name="Alice", email="alice@example.com")
|
||||
|
||||
|
||||
def _current_account() -> Account:
|
||||
account = Account(name="Alice", email="alice@example.com")
|
||||
account.id = "account-1"
|
||||
return account
|
||||
|
||||
|
||||
def _workflow_run_summary(**overrides) -> SimpleNamespace:
|
||||
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
|
||||
payload = {
|
||||
@ -208,13 +215,14 @@ def test_workflow_run_node_executions_return_frontend_trace_contract(
|
||||
return [_workflow_run_node_execution()]
|
||||
|
||||
monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService)
|
||||
monkeypatch.setattr(workflow_run_module, "current_user", SimpleNamespace(id="account-1"))
|
||||
|
||||
api = workflow_run_module.WorkflowRunNodeExecutionListApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/apps/app-1/workflow-runs/run-1/node-executions", method="GET"):
|
||||
payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"), run_id="run-1")
|
||||
payload = handler(
|
||||
api, _current_account(), app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"), run_id="run-1"
|
||||
)
|
||||
|
||||
response = _serialize_200_response(api.get, payload)
|
||||
|
||||
|
||||
@ -1,8 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app import workflow_trigger as workflow_trigger_module
|
||||
|
||||
|
||||
@ -52,3 +57,65 @@ def test_webhook_trigger_response_serializes_datetime():
|
||||
payload = workflow_trigger_module.WebhookTriggerResponse.model_validate(webhook).model_dump(mode="json")
|
||||
assert payload["webhook_id"] == "whk-1"
|
||||
assert payload["created_at"] == "2026-01-02T03:04:05Z"
|
||||
|
||||
|
||||
def test_app_triggers_get_uses_injected_tenant_id(app: Flask) -> None:
|
||||
trigger = SimpleNamespace(
|
||||
id="trigger-1",
|
||||
trigger_type="trigger-plugin",
|
||||
title="Trigger",
|
||||
node_id="node-1",
|
||||
provider_name="provider",
|
||||
icon="",
|
||||
status="enabled",
|
||||
created_at=None,
|
||||
updated_at=None,
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalars.return_value.all.return_value = [trigger]
|
||||
|
||||
api = workflow_trigger_module.AppTriggersApi()
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(workflow_trigger_module.db), "engine", new_callable=PropertyMock, return_value=MagicMock()),
|
||||
patch("controllers.console.app.workflow_trigger.sessionmaker") as sessionmaker_mock,
|
||||
):
|
||||
sessionmaker_mock.return_value.begin.return_value.__enter__.return_value = session
|
||||
response = method(api, "tenant-1", SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response["data"][0]["id"] == "trigger-1"
|
||||
assert response["data"][0]["icon"].endswith("/provider/icon")
|
||||
|
||||
|
||||
def test_app_trigger_enable_uses_injected_tenant_id(app: Flask) -> None:
|
||||
trigger = SimpleNamespace(
|
||||
id="trigger-1",
|
||||
trigger_type="trigger-plugin",
|
||||
title="Trigger",
|
||||
node_id="node-1",
|
||||
provider_name="provider",
|
||||
icon="",
|
||||
status="disabled",
|
||||
created_at=None,
|
||||
updated_at=None,
|
||||
)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = trigger
|
||||
payload = {"trigger_id": "trigger-1", "enable_trigger": True}
|
||||
|
||||
api = workflow_trigger_module.AppTriggerEnableApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(type(workflow_trigger_module.db), "engine", new_callable=PropertyMock, return_value=MagicMock()),
|
||||
patch("controllers.console.app.workflow_trigger.sessionmaker") as sessionmaker_mock,
|
||||
):
|
||||
sessionmaker_mock.return_value.begin.return_value.__enter__.return_value = session
|
||||
response = method(api, "tenant-1", SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response["id"] == "trigger-1"
|
||||
assert response["status"] == "enabled"
|
||||
|
||||
@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.datasource_content_preview import (
|
||||
@ -18,6 +17,12 @@ def unwrap(func):
|
||||
return func
|
||||
|
||||
|
||||
def make_account() -> Account:
|
||||
account = Account(name="Test User", email="user@example.com")
|
||||
account.id = "account-1"
|
||||
return account
|
||||
|
||||
|
||||
class TestDataSourceContentPreviewApi:
|
||||
def _valid_payload(self):
|
||||
return {
|
||||
@ -34,7 +39,7 @@ class TestDataSourceContentPreviewApi:
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
node_id = "node-1"
|
||||
account = MagicMock(spec=Account)
|
||||
account = make_account()
|
||||
|
||||
preview_result = {"content": "preview data"}
|
||||
|
||||
@ -44,16 +49,12 @@ class TestDataSourceContentPreviewApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
|
||||
return_value=service_instance,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline, node_id)
|
||||
response, status = method(api, account, pipeline, node_id)
|
||||
|
||||
service_instance.run_datasource_node_preview.assert_called_once_with(
|
||||
pipeline=pipeline,
|
||||
@ -67,25 +68,6 @@ class TestDataSourceContentPreviewApi:
|
||||
assert status == 200
|
||||
assert response == preview_result
|
||||
|
||||
def test_post_forbidden_non_account_user(self, app: Flask):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
MagicMock(), # NOT Account
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, pipeline, "node-1")
|
||||
|
||||
def test_post_invalid_payload(self, app: Flask):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
@ -96,18 +78,14 @@ class TestDataSourceContentPreviewApi:
|
||||
}
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
account = MagicMock(spec=Account)
|
||||
account = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "node-1")
|
||||
method(api, account, pipeline, "node-1")
|
||||
|
||||
def test_post_without_credential_id(self, app: Flask):
|
||||
api = DataSourceContentPreviewApi()
|
||||
@ -120,7 +98,7 @@ class TestDataSourceContentPreviewApi:
|
||||
}
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
account = MagicMock(spec=Account)
|
||||
account = make_account()
|
||||
|
||||
service_instance = MagicMock()
|
||||
service_instance.run_datasource_node_preview.return_value = {"ok": True}
|
||||
@ -128,16 +106,12 @@ class TestDataSourceContentPreviewApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
|
||||
return_value=service_instance,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline, "node-1")
|
||||
response, status = method(api, account, pipeline, "node-1")
|
||||
|
||||
service_instance.run_datasource_node_preview.assert_called_once()
|
||||
assert status == 200
|
||||
|
||||
@ -16,7 +16,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID
|
||||
from graphon.variables.types import SegmentType
|
||||
from models.account import Account
|
||||
from models.account import Account, TenantAccountRole
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
@ -34,9 +34,10 @@ def fake_db():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def editor_user():
|
||||
user = MagicMock(spec=Account)
|
||||
user.has_edit_permission = True
|
||||
def editor_user() -> Account:
|
||||
user = Account(name="Test User", email="user@example.com")
|
||||
user.id = "account-1"
|
||||
user.role = TenantAccountRole.EDITOR
|
||||
return user
|
||||
|
||||
|
||||
@ -65,7 +66,6 @@ class TestRagPipelineVariableCollectionApi:
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=10"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
@ -76,9 +76,15 @@ class TestRagPipelineVariableCollectionApi:
|
||||
return_value=draft_srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
result = method(api, editor_user, pipeline)
|
||||
|
||||
assert result["items"] == []
|
||||
assert result is var_list
|
||||
draft_srv.list_variables_without_values.assert_called_once_with(
|
||||
app_id="p1",
|
||||
page=1,
|
||||
limit=10,
|
||||
user_id="account-1",
|
||||
)
|
||||
|
||||
def test_get_variables_workflow_not_exist(self, app: Flask, fake_db, editor_user):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
@ -91,7 +97,6 @@ class TestRagPipelineVariableCollectionApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
@ -99,7 +104,7 @@ class TestRagPipelineVariableCollectionApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
method(api, pipeline)
|
||||
method(api, editor_user, pipeline)
|
||||
|
||||
def test_delete_variables_success(self, app: Flask, fake_db, editor_user):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
@ -109,11 +114,10 @@ class TestRagPipelineVariableCollectionApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService"),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
result = method(api, editor_user, pipeline)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 204
|
||||
@ -135,16 +139,16 @@ class TestRagPipelineNodeVariableCollectionApi:
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node1")
|
||||
result = method(api, editor_user, pipeline, "node1")
|
||||
|
||||
assert result["items"] == []
|
||||
assert result is var_list
|
||||
srv.list_node_variables.assert_called_once_with("p1", "node1", user_id="account-1")
|
||||
|
||||
def test_get_node_variables_invalid_node(self, app: Flask, editor_user):
|
||||
api = RagPipelineNodeVariableCollectionApi()
|
||||
@ -152,10 +156,9 @@ class TestRagPipelineNodeVariableCollectionApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
method(api, MagicMock(), SYSTEM_VARIABLE_NODE_ID)
|
||||
method(api, editor_user, MagicMock(), SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class TestRagPipelineVariableApi:
|
||||
@ -168,7 +171,6 @@ class TestRagPipelineVariableApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
@ -176,7 +178,7 @@ class TestRagPipelineVariableApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, MagicMock(), "v1")
|
||||
method(api, editor_user, MagicMock(), "v1")
|
||||
|
||||
def test_patch_variable_invalid_file_payload(self, app: Flask, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
@ -193,7 +195,6 @@ class TestRagPipelineVariableApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
@ -201,7 +202,7 @@ class TestRagPipelineVariableApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
method(api, pipeline, "v1")
|
||||
method(api, editor_user, pipeline, "v1")
|
||||
|
||||
def test_delete_variable_success(self, app: Flask, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
@ -215,14 +216,13 @@ class TestRagPipelineVariableApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "v1")
|
||||
result = method(api, editor_user, pipeline, "v1")
|
||||
|
||||
assert result.status_code == 204
|
||||
|
||||
@ -245,7 +245,6 @@ class TestRagPipelineVariableResetApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
@ -260,7 +259,7 @@ class TestRagPipelineVariableResetApi:
|
||||
return_value={"id": "v1"},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "v1")
|
||||
result = method(api, editor_user, pipeline, "v1")
|
||||
|
||||
assert result == {"id": "v1"}
|
||||
|
||||
@ -281,16 +280,16 @@ class TestSystemAndEnvironmentVariablesApi:
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
result = method(api, editor_user, pipeline)
|
||||
|
||||
assert result["items"] == []
|
||||
assert result is var_list
|
||||
srv.list_system_variables.assert_called_once_with("p1", user_id="account-1")
|
||||
|
||||
def test_environment_variables_success(self, app: Flask, editor_user):
|
||||
api = RagPipelineEnvironmentVariableCollectionApi()
|
||||
@ -313,12 +312,11 @@ class TestSystemAndEnvironmentVariablesApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
result = method(api, editor_user, pipeline)
|
||||
|
||||
assert len(result["items"]) == 1
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask import Flask
|
||||
|
||||
import controllers.console.explore.recommended_app as module
|
||||
from models import Account
|
||||
from models.model import AppMode, IconType
|
||||
|
||||
|
||||
@ -12,6 +13,13 @@ def unwrap(func):
|
||||
return func
|
||||
|
||||
|
||||
def make_account(interface_language: str | None) -> Account:
|
||||
account = Account(name="Test User", email="user@example.com")
|
||||
account.id = "account-1"
|
||||
account.interface_language = interface_language
|
||||
return account
|
||||
|
||||
|
||||
class TestRecommendedAppListApi:
|
||||
def test_get_with_language_param(self, app: Flask):
|
||||
api = module.RecommendedAppListApi()
|
||||
@ -21,14 +29,13 @@ class TestRecommendedAppListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"language": "en-US"}),
|
||||
patch.object(module, "current_user", MagicMock(interface_language="fr-FR")),
|
||||
patch.object(
|
||||
module.RecommendedAppService,
|
||||
"get_recommended_apps_and_categories",
|
||||
return_value=result_data,
|
||||
) as service_mock,
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, make_account("fr-FR"))
|
||||
|
||||
service_mock.assert_called_once_with("en-US")
|
||||
assert result == result_data
|
||||
@ -41,14 +48,13 @@ class TestRecommendedAppListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"language": "invalid"}),
|
||||
patch.object(module, "current_user", MagicMock(interface_language="fr-FR")),
|
||||
patch.object(
|
||||
module.RecommendedAppService,
|
||||
"get_recommended_apps_and_categories",
|
||||
return_value=result_data,
|
||||
) as service_mock,
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, make_account("fr-FR"))
|
||||
|
||||
service_mock.assert_called_once_with("fr-FR")
|
||||
assert result == result_data
|
||||
@ -61,14 +67,13 @@ class TestRecommendedAppListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(module, "current_user", MagicMock(interface_language=None)),
|
||||
patch.object(
|
||||
module.RecommendedAppService,
|
||||
"get_recommended_apps_and_categories",
|
||||
return_value=result_data,
|
||||
) as service_mock,
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, make_account(None))
|
||||
|
||||
service_mock.assert_called_once_with(module.languages[0])
|
||||
assert result == result_data
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import importlib
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
@ -7,10 +6,9 @@ from flask import Flask
|
||||
|
||||
from controllers.console.snippets import snippet_workflow_draft_variable as module
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from models.account import Account, AccountStatus
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
app_workflow_draft_variable_module = importlib.import_module("controllers.console.app.workflow_draft_variable")
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
@ -18,6 +16,16 @@ def _unwrap(func):
|
||||
return func
|
||||
|
||||
|
||||
def _make_account() -> Account:
|
||||
account = Account(
|
||||
name="tester",
|
||||
email="tester@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = "user-1" # type: ignore[assignment]
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_snippet_service_factory(monkeypatch):
|
||||
def factory():
|
||||
@ -61,7 +69,7 @@ def test_conversation_variables_returns_empty_list(app):
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
result = handler(api, snippet=SimpleNamespace(id="snippet-1"))
|
||||
result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"))
|
||||
|
||||
assert result == WorkflowDraftVariableList(variables=[])
|
||||
|
||||
@ -71,7 +79,7 @@ def test_system_variables_returns_empty_list(app):
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
result = handler(api, snippet=SimpleNamespace(id="snippet-1"))
|
||||
result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"))
|
||||
|
||||
assert result == WorkflowDraftVariableList(variables=[])
|
||||
|
||||
@ -79,7 +87,6 @@ def test_system_variables_returns_empty_list(app):
|
||||
def test_delete_variable_collection_deletes_current_user_variables(app, monkeypatch):
|
||||
draft_var_service = SimpleNamespace(delete_user_workflow_variables=Mock())
|
||||
monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service))
|
||||
monkeypatch.setattr(module, "current_user", SimpleNamespace(id="user-1"))
|
||||
db_session = Mock()
|
||||
db_session.return_value = SimpleNamespace()
|
||||
monkeypatch.setattr(module.db, "session", db_session)
|
||||
@ -87,7 +94,7 @@ def test_delete_variable_collection_deletes_current_user_variables(app, monkeypa
|
||||
handler = _unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/", method="DELETE"):
|
||||
response = handler(api, snippet=SimpleNamespace(id="snippet-1"))
|
||||
response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"))
|
||||
|
||||
assert response.status_code == 204
|
||||
draft_var_service.delete_user_workflow_variables.assert_called_once_with("snippet-1", user_id="user-1")
|
||||
@ -106,7 +113,7 @@ def test_variable_collection_get_raises_when_draft_workflow_missing(app, monkeyp
|
||||
|
||||
with app.test_request_context("/?page=1&limit=20"):
|
||||
with pytest.raises(module.DraftWorkflowNotExist):
|
||||
handler(api, snippet=SimpleNamespace(id="snippet-1"))
|
||||
handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"))
|
||||
|
||||
|
||||
def test_node_variable_collection_get_lists_node_variables(app, monkeypatch):
|
||||
@ -126,7 +133,6 @@ def test_node_variable_collection_get_lists_node_variables(app, monkeypatch):
|
||||
|
||||
monkeypatch.setattr(module, "Session", SessionContext)
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(module, "current_user", SimpleNamespace(id="user-1"))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"WorkflowDraftVariableService",
|
||||
@ -137,7 +143,7 @@ def test_node_variable_collection_get_lists_node_variables(app, monkeypatch):
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
result = handler(api, snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1")
|
||||
result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1")
|
||||
|
||||
assert result is variables
|
||||
list_node_variables.assert_called_once_with("snippet-1", "llm-1", user_id="user-1")
|
||||
@ -147,7 +153,6 @@ def test_node_variable_collection_delete_deletes_node_variables(app, monkeypatch
|
||||
delete_node_variables = Mock()
|
||||
draft_var_service = SimpleNamespace(delete_node_variables=delete_node_variables)
|
||||
monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service))
|
||||
monkeypatch.setattr(module, "current_user", SimpleNamespace(id="user-1"))
|
||||
db_session = Mock()
|
||||
db_session.return_value = SimpleNamespace()
|
||||
monkeypatch.setattr(module.db, "session", db_session)
|
||||
@ -156,7 +161,7 @@ def test_node_variable_collection_delete_deletes_node_variables(app, monkeypatch
|
||||
handler = _unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/", method="DELETE"):
|
||||
response = handler(api, snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1")
|
||||
response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1")
|
||||
|
||||
assert response.status_code == 204
|
||||
delete_node_variables.assert_called_once_with("snippet-1", "llm-1", user_id="user-1")
|
||||
@ -169,15 +174,18 @@ def test_variable_patch_returns_variable_when_no_changes(app, monkeypatch):
|
||||
db_session = Mock()
|
||||
db_session.return_value = SimpleNamespace()
|
||||
monkeypatch.setattr(module.db, "session", db_session)
|
||||
monkeypatch.setattr(module, "current_user", SimpleNamespace(id="user-1"))
|
||||
monkeypatch.setattr(app_workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1"))
|
||||
monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service))
|
||||
|
||||
api = module.SnippetVariableApi()
|
||||
handler = _unwrap(api.patch)
|
||||
|
||||
with app.test_request_context("/", method="PATCH", json={}):
|
||||
result = handler(api, snippet=SimpleNamespace(id="snippet-1", tenant_id="tenant-1"), variable_id="var-1")
|
||||
result = handler(
|
||||
api,
|
||||
_make_account(),
|
||||
snippet=SimpleNamespace(id="snippet-1", tenant_id="tenant-1"),
|
||||
variable_id="var-1",
|
||||
)
|
||||
|
||||
assert result is variable
|
||||
draft_var_service.update_variable.assert_not_called()
|
||||
@ -191,15 +199,13 @@ def test_variable_delete_deletes_variable(app, monkeypatch):
|
||||
db_session = Mock()
|
||||
db_session.return_value = SimpleNamespace()
|
||||
monkeypatch.setattr(module.db, "session", db_session)
|
||||
monkeypatch.setattr(module, "current_user", SimpleNamespace(id="user-1"))
|
||||
monkeypatch.setattr(app_workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1"))
|
||||
monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service))
|
||||
|
||||
api = module.SnippetVariableApi()
|
||||
handler = _unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/", method="DELETE"):
|
||||
response = handler(api, snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1")
|
||||
response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1")
|
||||
|
||||
assert response.status_code == 204
|
||||
delete_variable.assert_called_once_with(variable)
|
||||
@ -216,8 +222,6 @@ def test_variable_reset_returns_no_content_when_reset_result_is_none(app, monkey
|
||||
db_session = Mock()
|
||||
db_session.return_value = SimpleNamespace()
|
||||
monkeypatch.setattr(module.db, "session", db_session)
|
||||
monkeypatch.setattr(module, "current_user", SimpleNamespace(id="user-1"))
|
||||
monkeypatch.setattr(app_workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1"))
|
||||
monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service))
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
@ -229,7 +233,7 @@ def test_variable_reset_returns_no_content_when_reset_result_is_none(app, monkey
|
||||
handler = _unwrap(api.put)
|
||||
|
||||
with app.test_request_context("/", method="PUT"):
|
||||
response = handler(api, snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1")
|
||||
response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1")
|
||||
|
||||
assert response.status_code == 204
|
||||
draft_var_service.reset_variable.assert_called_once_with(draft_workflow, variable)
|
||||
@ -259,7 +263,7 @@ def test_environment_variables_returns_workflow_environment_variables(app, monke
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
result = handler(api, snippet=SimpleNamespace(id="snippet-1"))
|
||||
result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"))
|
||||
|
||||
assert result == {
|
||||
"items": [
|
||||
|
||||
@ -67,6 +67,7 @@ import {
|
||||
zPatchRagPipelineCustomizedTemplatesByTemplateIdResponse,
|
||||
zPatchRagPipelinesByPipelineIdWorkflowsByWorkflowIdPath,
|
||||
zPatchRagPipelinesByPipelineIdWorkflowsByWorkflowIdResponse,
|
||||
zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdBody,
|
||||
zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdPath,
|
||||
zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdResponse,
|
||||
zPostRagPipelineCustomizedTemplatesByTemplateIdPath,
|
||||
@ -677,6 +678,8 @@ export const datasource = {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get draft workflow
|
||||
*
|
||||
* Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.
|
||||
*
|
||||
* @deprecated
|
||||
@ -690,6 +693,7 @@ export const get12 = oc
|
||||
method: 'GET',
|
||||
operationId: 'getRagPipelinesByPipelineIdWorkflowsDraftEnvironmentVariables',
|
||||
path: '/rag/pipelines/{pipeline_id}/workflows/draft/environment-variables',
|
||||
summary: 'Get draft workflow',
|
||||
tags: ['console'],
|
||||
})
|
||||
.input(z.object({ params: zGetRagPipelinesByPipelineIdWorkflowsDraftEnvironmentVariablesPath }))
|
||||
@ -1084,7 +1088,10 @@ export const patch2 = oc
|
||||
tags: ['console'],
|
||||
})
|
||||
.input(
|
||||
z.object({ params: zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdPath }),
|
||||
z.object({
|
||||
body: zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdBody,
|
||||
params: zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdPath,
|
||||
}),
|
||||
)
|
||||
.output(zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdResponse)
|
||||
|
||||
@ -1115,6 +1122,8 @@ export const delete4 = oc
|
||||
.output(zDeleteRagPipelinesByPipelineIdWorkflowsDraftVariablesResponse)
|
||||
|
||||
/**
|
||||
* Get draft workflow
|
||||
*
|
||||
* Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.
|
||||
*
|
||||
* @deprecated
|
||||
@ -1128,6 +1137,7 @@ export const get19 = oc
|
||||
method: 'GET',
|
||||
operationId: 'getRagPipelinesByPipelineIdWorkflowsDraftVariables',
|
||||
path: '/rag/pipelines/{pipeline_id}/workflows/draft/variables',
|
||||
summary: 'Get draft workflow',
|
||||
tags: ['console'],
|
||||
})
|
||||
.input(z.object({ params: zGetRagPipelinesByPipelineIdWorkflowsDraftVariablesPath }))
|
||||
|
||||
@ -236,6 +236,11 @@ export type DraftWorkflowRunPayload = {
|
||||
start_node_id: string
|
||||
}
|
||||
|
||||
export type WorkflowDraftVariablePatchPayload = {
|
||||
name?: string | null
|
||||
value?: unknown
|
||||
}
|
||||
|
||||
export type RagPipelineWorkflowPublishResponse = {
|
||||
created_at: number
|
||||
result: string
|
||||
@ -1182,7 +1187,7 @@ export type GetRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdRespon
|
||||
= GetRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdResponses[keyof GetRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdResponses]
|
||||
|
||||
export type PatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdData = {
|
||||
body?: never
|
||||
body: WorkflowDraftVariablePatchPayload
|
||||
path: {
|
||||
pipeline_id: string
|
||||
variable_id: string
|
||||
|
||||
@ -113,6 +113,14 @@ export const zDraftWorkflowRunPayload = z.object({
|
||||
start_node_id: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* WorkflowDraftVariablePatchPayload
|
||||
*/
|
||||
export const zWorkflowDraftVariablePatchPayload = z.object({
|
||||
name: z.string().nullish(),
|
||||
value: z.unknown().optional(),
|
||||
})
|
||||
|
||||
/**
|
||||
* RagPipelineWorkflowPublishResponse
|
||||
*/
|
||||
@ -1006,6 +1014,9 @@ export const zGetRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdResp
|
||||
z.unknown(),
|
||||
)
|
||||
|
||||
export const zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdBody
|
||||
= zWorkflowDraftVariablePatchPayload
|
||||
|
||||
export const zPatchRagPipelinesByPipelineIdWorkflowsDraftVariablesByVariableIdPath = z.object({
|
||||
pipeline_id: z.string(),
|
||||
variable_id: z.string(),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user