From 44491e427c36250d5f7ac5256d57d15c35719db0 Mon Sep 17 00:00:00 2001 From: Yansong Zhang <916125788@qq.com> Date: Thu, 9 Apr 2026 09:36:16 +0800 Subject: [PATCH] feat(api): enable all sandbox/skill controller routes and resolve dependencies (P0) Resolve the full dependency chain to enable all previously disabled controllers: Enabled routes: - sandbox_files: sandbox file browser API - sandbox_providers: sandbox provider management API - app_asset: app asset management API - skills: skill extraction API - CLI API blueprint: DifyCli callback endpoints (/cli/api/*) Dependencies extracted (64 files, ~8000 lines): - models/sandbox.py, models/app_asset.py: DB models - core/zip_sandbox/: zip-based sandbox execution - core/session/: CLI API session management - core/memory/: base memory + node token buffer - core/helper/creators.py: helper utilities - core/llm_generator/: context models, output models, utils - core/workflow/nodes/command/: command node type - core/workflow/nodes/file_upload/: file upload node type - core/app/entities/: app_asset_entities, app_bundle_entities, llm_generation_entities - services/: asset_content, skill, workflow_collaboration, workflow_comment - controllers/console/app/error.py: AppAsset error classes - core/tools/utils/system_encryption.py Import fixes: - dify_graph.enums -> graphon.enums in skill_service.py - get_signed_file_url_for_plugin -> get_signed_file_url in cli_api.py All 5 controllers verified: import OK, Flask starts successfully. 46 existing tests still pass. Made-with: Cursor --- api/controllers/cli_api/dify_cli/cli_api.py | 8 +- api/controllers/console/__init__.py | 8 +- api/controllers/console/app/error.py | 18 + .../console/app/workflow_comment.py | 322 ++++++ api/controllers/console/socketio/__init__.py | 1 + api/controllers/console/socketio/workflow.py | 119 +++ api/controllers/console/workspace/dsl.py | 67 ++ api/core/agent/agent_app_runner.py | 380 +++++++ api/core/app/entities/app_asset_entities.py | 352 +++++++ api/core/app/entities/app_bundle_entities.py | 96 ++ .../app/entities/llm_generation_entities.py | 72 ++ api/core/helper/creators.py | 75 ++ api/core/llm_generator/context_models.py | 62 ++ api/core/llm_generator/output_models.py | 67 ++ .../llm_generator/output_parser/file_ref.py | 203 ++++ api/core/llm_generator/utils.py | 45 + api/core/memory/__init__.py | 11 + api/core/memory/base.py | 82 ++ api/core/memory/node_token_buffer_memory.py | 196 ++++ api/core/session/__init__.py | 11 + api/core/session/cli_api.py | 30 + api/core/session/session.py | 106 ++ api/core/tools/utils/system_encryption.py | 187 ++++ api/core/workflow/nodes/command/__init__.py | 3 + api/core/workflow/nodes/command/entities.py | 10 + api/core/workflow/nodes/command/exc.py | 16 + api/core/workflow/nodes/command/node.py | 152 +++ .../workflow/nodes/file_upload/__init__.py | 4 + .../workflow/nodes/file_upload/entities.py | 7 + api/core/workflow/nodes/file_upload/exc.py | 6 + api/core/workflow/nodes/file_upload/node.py | 244 +++++ api/core/zip_sandbox/__init__.py | 23 + api/core/zip_sandbox/cli_strategy.py | 81 ++ api/core/zip_sandbox/entities.py | 39 + api/core/zip_sandbox/node_strategy.py | 106 ++ api/core/zip_sandbox/python_strategy.py | 117 +++ api/core/zip_sandbox/strategy.py | 41 + api/core/zip_sandbox/zip_sandbox.py | 425 ++++++++ api/dify_graph/entities/tool_entities.py | 41 + api/dify_graph/nodes/agent/agent_node.py | 929 ++++++++++++++++++ api/extensions/ext_blueprints.py | 6 +- api/extensions/ext_socketio.py | 5 + api/fields/online_user_fields.py | 17 + api/fields/workflow_comment_fields.py | 96 ++ ...1031-aab323465866_agent_sandbox_support.py | 143 +++ ...27822d22895_add_workflow_comments_table.py | 109 ++ ...e0aa981887_add_app_asset_contents_table.py | 40 + api/models/app_asset.py | 89 ++ api/models/comment.py | 210 ++++ api/models/sandbox.py | 80 ++ api/models/workflow_comment.py | 0 api/models/workflow_features.py | 26 + .../workflow_collaboration_repository.py | 226 +++++ api/services/app_asset_package_service.py | 195 ++++ api/services/app_runtime_upgrade_service.py | 443 +++++++++ api/services/asset_content_service.py | 103 ++ api/services/errors/app_asset.py | 17 + api/services/llm_generation_service.py | 37 + api/services/skill_service.py | 204 ++++ api/services/storage_ticket_service.py | 153 +++ .../workflow/nested_node_graph_service.py | 157 +++ .../workflow_collaboration_service.py | 391 ++++++++ api/services/workflow_comment_service.py | 468 +++++++++ api/tasks/mail_workflow_comment_task.py | 65 ++ 64 files changed, 8030 insertions(+), 12 deletions(-) create mode 100644 api/controllers/console/app/workflow_comment.py create mode 100644 api/controllers/console/socketio/__init__.py create mode 100644 api/controllers/console/socketio/workflow.py create mode 100644 api/controllers/console/workspace/dsl.py create mode 100644 api/core/agent/agent_app_runner.py create mode 100644 api/core/app/entities/app_asset_entities.py create mode 100644 api/core/app/entities/app_bundle_entities.py create mode 100644 api/core/app/entities/llm_generation_entities.py create mode 100644 api/core/helper/creators.py create mode 100644 api/core/llm_generator/context_models.py create mode 100644 api/core/llm_generator/output_models.py create mode 100644 api/core/llm_generator/output_parser/file_ref.py create mode 100644 api/core/llm_generator/utils.py create mode 100644 api/core/memory/__init__.py create mode 100644 api/core/memory/base.py create mode 100644 api/core/memory/node_token_buffer_memory.py create mode 100644 api/core/session/__init__.py create mode 100644 api/core/session/cli_api.py create mode 100644 api/core/session/session.py create mode 100644 api/core/tools/utils/system_encryption.py create mode 100644 api/core/workflow/nodes/command/__init__.py create mode 100644 api/core/workflow/nodes/command/entities.py create mode 100644 api/core/workflow/nodes/command/exc.py create mode 100644 api/core/workflow/nodes/command/node.py create mode 100644 api/core/workflow/nodes/file_upload/__init__.py create mode 100644 api/core/workflow/nodes/file_upload/entities.py create mode 100644 api/core/workflow/nodes/file_upload/exc.py create mode 100644 api/core/workflow/nodes/file_upload/node.py create mode 100644 api/core/zip_sandbox/__init__.py create mode 100644 api/core/zip_sandbox/cli_strategy.py create mode 100644 api/core/zip_sandbox/entities.py create mode 100644 api/core/zip_sandbox/node_strategy.py create mode 100644 api/core/zip_sandbox/python_strategy.py create mode 100644 api/core/zip_sandbox/strategy.py create mode 100644 api/core/zip_sandbox/zip_sandbox.py create mode 100644 api/dify_graph/entities/tool_entities.py create mode 100644 api/dify_graph/nodes/agent/agent_node.py create mode 100644 api/extensions/ext_socketio.py create mode 100644 api/fields/online_user_fields.py create mode 100644 api/fields/workflow_comment_fields.py create mode 100644 api/migrations/versions/2026_02_09_1031-aab323465866_agent_sandbox_support.py create mode 100644 api/migrations/versions/2026_02_09_1726-227822d22895_add_workflow_comments_table.py create mode 100644 api/migrations/versions/2026_03_09_1200-5ee0aa981887_add_app_asset_contents_table.py create mode 100644 api/models/app_asset.py create mode 100644 api/models/comment.py create mode 100644 api/models/sandbox.py create mode 100644 api/models/workflow_comment.py create mode 100644 api/models/workflow_features.py create mode 100644 api/repositories/workflow_collaboration_repository.py create mode 100644 api/services/app_asset_package_service.py create mode 100644 api/services/app_runtime_upgrade_service.py create mode 100644 api/services/asset_content_service.py create mode 100644 api/services/errors/app_asset.py create mode 100644 api/services/llm_generation_service.py create mode 100644 api/services/skill_service.py create mode 100644 api/services/storage_ticket_service.py create mode 100644 api/services/workflow/nested_node_graph_service.py create mode 100644 api/services/workflow_collaboration_service.py create mode 100644 api/services/workflow_comment_service.py create mode 100644 api/tasks/mail_workflow_comment_task.py diff --git a/api/controllers/cli_api/dify_cli/cli_api.py b/api/controllers/cli_api/dify_cli/cli_api.py index efcaaf0bf6..e5c400f293 100644 --- a/api/controllers/cli_api/dify_cli/cli_api.py +++ b/api/controllers/cli_api/dify_cli/cli_api.py @@ -22,7 +22,7 @@ from core.session.cli_api import CliContext from core.skill.entities import ToolInvocationRequest from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager -from graphon.file.helpers import get_signed_file_url_for_plugin +from graphon.file.helpers import get_signed_file_url from libs.helper import length_prefixed_response from models.account import Account from models.model import EndUser, Tenant @@ -139,11 +139,9 @@ class CliUploadFileRequestApi(Resource): payload: RequestRequestUploadFile, cli_context: CliContext, ): - url = get_signed_file_url_for_plugin( - filename=payload.filename, - mimetype=payload.mimetype, + url = get_signed_file_url( + upload_file_id=f"{tenant_model.id}_{user_model.id}_{payload.filename}", tenant_id=tenant_model.id, - user_id=user_model.id, ) return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index c26631574d..925e4c2337 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -41,7 +41,7 @@ from . import ( init_validate, notification, ping, - # sandbox_files, # TODO: enable after full sandbox integration + sandbox_files, setup, spec, version, @@ -53,7 +53,7 @@ from .app import ( agent, annotation, app, - # app_asset, # TODO: enable after full sandbox integration + app_asset, audio, completion, conversation, @@ -64,7 +64,7 @@ from .app import ( model_config, ops_trace, site, - # skills, # TODO: enable after full sandbox integration + skills, statistic, workflow, workflow_app_log, @@ -133,7 +133,7 @@ from .workspace import ( model_providers, models, plugin, - # sandbox_providers, # TODO: enable after full sandbox integration + sandbox_providers, tool_providers, trigger_providers, workspace, diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index 3fa15d6d6d..51546fcd5e 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -121,3 +121,21 @@ class NeedAddIdsError(BaseHTTPException): error_code = "need_add_ids" description = "Need to add ids." code = 400 + + +class AppAssetNodeNotFoundError(BaseHTTPException): + error_code = "app_asset_node_not_found" + description = "App asset node not found." + code = 404 + + +class AppAssetFileRequiredError(BaseHTTPException): + error_code = "app_asset_file_required" + description = "File is required." + code = 400 + + +class AppAssetPathConflictError(BaseHTTPException): + error_code = "app_asset_path_conflict" + description = "Path already exists." + code = 409 diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py new file mode 100644 index 0000000000..191ba8fb34 --- /dev/null +++ b/api/controllers/console/app/workflow_comment.py @@ -0,0 +1,322 @@ +import logging + +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field, TypeAdapter + +from controllers.console import console_ns +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, setup_required +from fields.member_fields import AccountWithRole +from fields.workflow_comment_fields import ( + workflow_comment_basic_fields, + workflow_comment_create_fields, + workflow_comment_detail_fields, + workflow_comment_reply_create_fields, + workflow_comment_reply_update_fields, + workflow_comment_resolve_fields, + workflow_comment_update_fields, +) +from libs.login import current_user, login_required +from models import App +from services.account_service import TenantService +from services.workflow_comment_service import WorkflowCommentService + +logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class WorkflowCommentCreatePayload(BaseModel): + position_x: float = Field(..., description="Comment X position") + position_y: float = Field(..., description="Comment Y position") + content: str = Field(..., description="Comment content") + mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs") + + +class WorkflowCommentUpdatePayload(BaseModel): + content: str = Field(..., description="Comment content") + position_x: float | None = Field(default=None, description="Comment X position") + position_y: float | None = Field(default=None, description="Comment Y position") + mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs") + + +class WorkflowCommentReplyCreatePayload(BaseModel): + content: str = Field(..., description="Reply content") + mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs") + + +class WorkflowCommentReplyUpdatePayload(BaseModel): + content: str = Field(..., description="Reply content") + mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs") + + +class WorkflowCommentMentionUsersResponse(BaseModel): + users: list[AccountWithRole] = Field(description="Mentionable users") + + +for model in ( + WorkflowCommentCreatePayload, + WorkflowCommentUpdatePayload, + WorkflowCommentReplyCreatePayload, + WorkflowCommentReplyUpdatePayload, +): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + +for model in (AccountWithRole, WorkflowCommentMentionUsersResponse): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) + +workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields) +workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields) +workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields) +workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields) +workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields) +workflow_comment_reply_create_model = console_ns.model( + "WorkflowCommentReplyCreate", workflow_comment_reply_create_fields +) +workflow_comment_reply_update_model = console_ns.model( + "WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields +) +workflow_comment_mention_users_model = console_ns.models[WorkflowCommentMentionUsersResponse.__name__] + + +@console_ns.route("/apps//workflow/comments") +class WorkflowCommentListApi(Resource): + """API for listing and creating workflow comments.""" + + @console_ns.doc("list_workflow_comments") + @console_ns.doc(description="Get all comments for a workflow") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_basic_model, envelope="data") + def get(self, app_model: App): + """Get all comments for a workflow.""" + comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id) + + return comments + + @console_ns.doc("create_workflow_comment") + @console_ns.doc(description="Create a new workflow comment") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__]) + @console_ns.response(201, "Comment created successfully", workflow_comment_create_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_create_model) + def post(self, app_model: App): + """Create a new workflow comment.""" + payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.create_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + created_by=current_user.id, + content=payload.content, + position_x=payload.position_x, + position_y=payload.position_y, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result, 201 + + +@console_ns.route("/apps//workflow/comments/") +class WorkflowCommentDetailApi(Resource): + """API for managing individual workflow comments.""" + + @console_ns.doc("get_workflow_comment") + @console_ns.doc(description="Get a specific workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_detail_model) + def get(self, app_model: App, comment_id: str): + """Get a specific workflow comment.""" + comment = WorkflowCommentService.get_comment( + tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id + ) + + return comment + + @console_ns.doc("update_workflow_comment") + @console_ns.doc(description="Update a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__]) + @console_ns.response(200, "Comment updated successfully", workflow_comment_update_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_update_model) + def put(self, app_model: App, comment_id: str): + """Update a workflow comment.""" + payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.update_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + content=payload.content, + position_x=payload.position_x, + position_y=payload.position_y, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result + + @console_ns.doc("delete_workflow_comment") + @console_ns.doc(description="Delete a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(204, "Comment deleted successfully") + @login_required + @setup_required + @account_initialization_required + @get_app_model() + def delete(self, app_model: App, comment_id: str): + """Delete a workflow comment.""" + WorkflowCommentService.delete_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + ) + + return {"result": "success"}, 204 + + +@console_ns.route("/apps//workflow/comments//resolve") +class WorkflowCommentResolveApi(Resource): + """API for resolving and reopening workflow comments.""" + + @console_ns.doc("resolve_workflow_comment") + @console_ns.doc(description="Resolve a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_resolve_model) + def post(self, app_model: App, comment_id: str): + """Resolve a workflow comment.""" + comment = WorkflowCommentService.resolve_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + ) + + return comment + + +@console_ns.route("/apps//workflow/comments//replies") +class WorkflowCommentReplyApi(Resource): + """API for managing comment replies.""" + + @console_ns.doc("create_workflow_comment_reply") + @console_ns.doc(description="Add a reply to a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__]) + @console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_reply_create_model) + def post(self, app_model: App, comment_id: str): + """Add a reply to a workflow comment.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.create_reply( + comment_id=comment_id, + content=payload.content, + created_by=current_user.id, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result, 201 + + +@console_ns.route("/apps//workflow/comments//replies/") +class WorkflowCommentReplyDetailApi(Resource): + """API for managing individual comment replies.""" + + @console_ns.doc("update_workflow_comment_reply") + @console_ns.doc(description="Update a comment reply") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__]) + @console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_reply_update_model) + def put(self, app_model: App, comment_id: str, reply_id: str): + """Update a comment reply.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {}) + + reply = WorkflowCommentService.update_reply( + reply_id=reply_id, + user_id=current_user.id, + content=payload.content, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return reply + + @console_ns.doc("delete_workflow_comment_reply") + @console_ns.doc(description="Delete a comment reply") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"}) + @console_ns.response(204, "Reply deleted successfully") + @login_required + @setup_required + @account_initialization_required + @get_app_model() + def delete(self, app_model: App, comment_id: str, reply_id: str): + """Delete a comment reply.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id) + + return {"result": "success"}, 204 + + +@console_ns.route("/apps//workflow/comments/mention-users") +class WorkflowCommentMentionUsersApi(Resource): + """API for getting mentionable users for workflow comments.""" + + @console_ns.doc("workflow_comment_mention_users") + @console_ns.doc(description="Get all users in current tenant for mentions") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Mentionable users retrieved successfully", workflow_comment_mention_users_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + def get(self, app_model: App): + """Get all users in current tenant for mentions.""" + members = TenantService.get_tenant_members(current_user.current_tenant) + member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = WorkflowCommentMentionUsersResponse(users=member_models) + return response.model_dump(mode="json"), 200 diff --git a/api/controllers/console/socketio/__init__.py b/api/controllers/console/socketio/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/controllers/console/socketio/__init__.py @@ -0,0 +1 @@ + diff --git a/api/controllers/console/socketio/workflow.py b/api/controllers/console/socketio/workflow.py new file mode 100644 index 0000000000..ebc990b64e --- /dev/null +++ b/api/controllers/console/socketio/workflow.py @@ -0,0 +1,119 @@ +import logging +from collections.abc import Callable +from typing import cast + +from flask import Request as FlaskRequest + +from extensions.ext_socketio import sio +from libs.passport import PassportService +from libs.token import extract_access_token +from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository +from services.account_service import AccountService +from services.workflow_collaboration_service import WorkflowCollaborationService + +repository = WorkflowCollaborationRepository() +collaboration_service = WorkflowCollaborationService(repository, sio) + + +def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]: + return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event)) + + +@_sio_on("connect") +def socket_connect(sid, environ, auth): + """ + WebSocket connect event, do authentication here. + """ + try: + request_environ = FlaskRequest(environ) + token = extract_access_token(request_environ) + except Exception: + logging.exception("Failed to extract token") + token = None + + if not token: + logging.warning("Socket connect rejected: missing token (sid=%s)", sid) + return False + + try: + decoded = PassportService().verify(token) + user_id = decoded.get("user_id") + if not user_id: + logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid) + return False + + with sio.app.app_context(): + user = AccountService.load_logged_in_account(account_id=user_id) + if not user: + logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid) + return False + if not user.has_edit_permission: + logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid) + return False + + collaboration_service.save_session(sid, user) + return True + + except Exception: + logging.exception("Socket authentication failed") + return False + + +@_sio_on("user_connect") +def handle_user_connect(sid, data): + """ + Handle user connect event. Each session (tab) is treated as an independent collaborator. + """ + workflow_id = data.get("workflow_id") + if not workflow_id: + return {"msg": "workflow_id is required"}, 400 + + result = collaboration_service.register_session(workflow_id, sid) + if not result: + return {"msg": "unauthorized"}, 401 + + user_id, is_leader = result + return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader} + + +@_sio_on("disconnect") +def handle_disconnect(sid): + """ + Handle session disconnect event. Remove the specific session from online users. + """ + collaboration_service.disconnect_session(sid) + + +@_sio_on("collaboration_event") +def handle_collaboration_event(sid, data): + """ + Handle general collaboration events, include: + 1. mouse_move + 2. vars_and_features_update + 3. sync_request (ask leader to update graph) + 4. app_state_update + 5. mcp_server_update + 6. workflow_update + 7. comments_update + 8. node_panel_presence + 9. skill_file_active + 10. skill_sync_request + 11. skill_resync_request + """ + return collaboration_service.relay_collaboration_event(sid, data) + + +@_sio_on("graph_event") +def handle_graph_event(sid, data): + """ + Handle graph events - simple broadcast relay. + """ + return collaboration_service.relay_graph_event(sid, data) + + +@_sio_on("skill_event") +def handle_skill_event(sid, data): + """ + Handle skill events - simple broadcast relay. + """ + return collaboration_service.relay_skill_event(sid, data) diff --git a/api/controllers/console/workspace/dsl.py b/api/controllers/console/workspace/dsl.py new file mode 100644 index 0000000000..91873d026c --- /dev/null +++ b/api/controllers/console/workspace/dsl.py @@ -0,0 +1,67 @@ +import json + +import httpx +import yaml +from flask import request +from flask_restx import Resource +from pydantic import BaseModel +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.console import console_ns +from controllers.console.wraps import account_initialization_required, setup_required +from core.plugin.impl.exc import PluginPermissionDeniedError +from extensions.ext_database import db +from libs.login import current_account_with_tenant, login_required +from models.model import App +from models.workflow import Workflow +from services.app_dsl_service import AppDslService + + +class DSLPredictRequest(BaseModel): + app_id: str + current_node_id: str + + +@console_ns.route("/workspaces/current/dsl/predict") +class DSLPredictApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + user, _ = current_account_with_tenant() + if not user.is_admin_or_owner: + raise Forbidden() + + args = DSLPredictRequest.model_validate(request.get_json()) + + app_id: str = args.app_id + current_node_id: str = args.current_node_id + + with Session(db.engine) as session: + app = session.query(App).filter_by(id=app_id).first() + workflow = session.query(Workflow).filter_by(app_id=app_id, version=Workflow.VERSION_DRAFT).first() + + if not app: + raise ValueError("App not found") + if not workflow: + raise ValueError("Workflow not found") + + try: + i = 0 + for node_id, _ in workflow.walk_nodes(): + if node_id == current_node_id: + break + i += 1 + + dsl = yaml.safe_load(AppDslService.export_dsl(app_model=app)) + + response = httpx.post( + "http://spark-832c:8000/predict", + json={"graph_data": dsl, "source_node_index": i}, + ) + return { + "nodes": json.loads(response.json()), + } + except PluginPermissionDeniedError as e: + raise ValueError(e.description) from e diff --git a/api/core/agent/agent_app_runner.py b/api/core/agent/agent_app_runner.py new file mode 100644 index 0000000000..2e7884eaba --- /dev/null +++ b/api/core/agent/agent_app_runner.py @@ -0,0 +1,380 @@ +import logging +from collections.abc import Generator +from copy import deepcopy +from typing import Any + +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.entities import AgentEntity, AgentLog, AgentResult +from core.agent.patterns.strategy_factory import StrategyFactory +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine +from graphon.file import file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMUsage, + PromptMessage, + PromptMessageContentType, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from models.model import Message + +logger = logging.getLogger(__name__) + + +class AgentAppRunner(BaseAgentRunner): + def _create_tool_invoke_hook(self, message: Message): + """ + Create a tool invoke hook that uses ToolEngine.agent_invoke. + This hook handles file creation and returns proper meta information. + """ + # Get trace manager from app generate entity + trace_manager = self.application_generate_entity.trace_manager + + def tool_invoke_hook( + tool: Tool, tool_args: dict[str, Any], tool_name: str + ) -> tuple[str, list[str], ToolInvokeMeta]: + """Hook that uses agent_invoke for proper file and meta handling.""" + tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( + tool=tool, + tool_parameters=tool_args, + user_id=self.user_id, + tenant_id=self.tenant_id, + message=message, + invoke_from=self.application_generate_entity.invoke_from, + agent_tool_callback=self.agent_callback, + trace_manager=trace_manager, + app_id=self.application_generate_entity.app_config.app_id, + message_id=message.id, + conversation_id=self.conversation.id, + ) + + # Publish files and track IDs + for message_file_id in message_files: + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), + PublishFrom.APPLICATION_MANAGER, + ) + self._current_message_file_ids.append(message_file_id) + + return tool_invoke_response, message_files, tool_invoke_meta + + return tool_invoke_hook + + def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: + """ + Run Agent application + """ + self.query = query + app_generate_entity = self.application_generate_entity + + app_config = self.app_config + assert app_config is not None, "app_config is required" + assert app_config.agent is not None, "app_config.agent is required" + + # convert tools into ModelRuntime Tool format + tool_instances, _ = self._init_prompt_tools() + + assert app_config.agent + + # Create tool invoke hook for agent_invoke + tool_invoke_hook = self._create_tool_invoke_hook(message) + + # Get instruction for ReAct strategy + instruction = self.app_config.prompt_template.simple_prompt_template or "" + + # Use factory to create appropriate strategy + strategy = StrategyFactory.create_strategy( + model_features=self.model_features, + model_instance=self.model_instance, + tools=list(tool_instances.values()), + files=list(self.files), + max_iterations=app_config.agent.max_iteration, + context=self.build_execution_context(), + agent_strategy=self.config.strategy, + tool_invoke_hook=tool_invoke_hook, + instruction=instruction, + ) + + # Initialize state variables + current_agent_thought_id: str | None = None + has_published_thought = False + current_tool_name: str | None = None + self._current_message_file_ids: list[str] = [] + + # organize prompt messages + prompt_messages = self._organize_prompt_messages() + + # Run strategy + generator = strategy.run( + prompt_messages=prompt_messages, + model_parameters=app_generate_entity.model_conf.parameters, + stop=app_generate_entity.model_conf.stop, + stream=True, + ) + + # Consume generator and collect result + result: AgentResult | None = None + try: + while True: + try: + output = next(generator) + except StopIteration as e: + # Generator finished, get the return value + result = e.value + break + + if isinstance(output, LLMResultChunk): + # Handle LLM chunk + if current_agent_thought_id and not has_published_thought: + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + has_published_thought = True + + yield output + + elif isinstance(output, AgentLog): + # Handle Agent Log using log_type for type-safe dispatch + if output.status == AgentLog.LogStatus.START: + if output.log_type == AgentLog.LogType.ROUND: + # Start of a new round + message_file_ids: list[str] = [] + current_agent_thought_id = self.create_agent_thought( + message_id=message.id, + message="", + tool_name="", + tool_input="", + messages_ids=message_file_ids, + ) + has_published_thought = False + + elif output.log_type == AgentLog.LogType.TOOL_CALL: + if current_agent_thought_id is None: + continue + + # Tool call start - extract data from structured fields + current_tool_name = output.data.get("tool_name", "") + tool_input = output.data.get("tool_args", {}) + + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=current_tool_name, + tool_input=tool_input, + thought=None, + observation=None, + tool_invoke_meta=None, + answer=None, + messages_ids=[], + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + elif output.status == AgentLog.LogStatus.SUCCESS: + if output.log_type == AgentLog.LogType.THOUGHT: + if current_agent_thought_id is None: + continue + + thought_text = output.data.get("thought") + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=None, + tool_input=None, + thought=thought_text, + observation=None, + tool_invoke_meta=None, + answer=None, + messages_ids=[], + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + elif output.log_type == AgentLog.LogType.TOOL_CALL: + if current_agent_thought_id is None: + continue + + # Tool call finished + tool_output = output.data.get("output") + # Get meta from strategy output (now properly populated) + tool_meta = output.data.get("meta") + + # Wrap tool_meta with tool_name as key (required by agent_service) + if tool_meta and current_tool_name: + tool_meta = {current_tool_name: tool_meta} + + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=None, + tool_input=None, + thought=None, + observation=tool_output, + tool_invoke_meta=tool_meta, + answer=None, + messages_ids=self._current_message_file_ids, + ) + # Clear message file ids after saving + self._current_message_file_ids = [] + current_tool_name = None + + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + elif output.log_type == AgentLog.LogType.ROUND: + if current_agent_thought_id is None: + continue + + # Round finished - save LLM usage and answer + llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE) + llm_result = output.data.get("llm_result") + final_answer = output.data.get("final_answer") + + self.save_agent_thought( + agent_thought_id=current_agent_thought_id, + tool_name=None, + tool_input=None, + thought=llm_result, + observation=None, + tool_invoke_meta=None, + answer=final_answer, + messages_ids=[], + llm_usage=llm_usage, + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id), + PublishFrom.APPLICATION_MANAGER, + ) + + except Exception: + # Re-raise any other exceptions + raise + + # Process final result + if isinstance(result, AgentResult): + final_answer = result.text + usage = result.usage or LLMUsage.empty_usage() + + # Publish end event + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=self.model_instance.model_name, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=usage, + system_fingerprint="", + ) + ), + PublishFrom.APPLICATION_MANAGER, + ) + + def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Initialize system message + """ + if not prompt_template: + return prompt_messages or [] + + prompt_messages = prompt_messages or [] + + if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage): + prompt_messages[0] = SystemPromptMessage(content=prompt_template) + return prompt_messages + + if not prompt_messages: + return [SystemPromptMessage(content=prompt_template)] + + prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) + return prompt_messages + + def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Organize user query + """ + if self.files: + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) + prompt_message_contents.append(TextPromptMessageContent(data=query)) + + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: + prompt_messages.append(UserPromptMessage(content=query)) + + return prompt_messages + + def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + As for now, gpt supports both fc and vision at the first iteration. + We need to remove the image messages from the prompt messages at the first iteration. + """ + prompt_messages = deepcopy(prompt_messages) + + for prompt_message in prompt_messages: + if isinstance(prompt_message, UserPromptMessage): + if isinstance(prompt_message.content, list): + prompt_message.content = "\n".join( + [ + content.data + if content.type == PromptMessageContentType.TEXT + else "[image]" + if content.type == PromptMessageContentType.IMAGE + else "[file]" + for content in prompt_message.content + ] + ) + + return prompt_messages + + def _organize_prompt_messages(self): + # For ReAct strategy, use the agent prompt template + if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt: + prompt_template = self.config.prompt.first_prompt + else: + prompt_template = self.app_config.prompt_template.simple_prompt_template or "" + + self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) + query_prompt_messages = self._organize_user_query(self.query or "", []) + + self.history_prompt_messages = AgentHistoryPromptTransform( + model_config=self.model_config, + prompt_messages=[*query_prompt_messages, *self._current_thoughts], + history_messages=self.history_prompt_messages, + memory=self.memory, + ).get_prompt() + + prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] + if len(self._current_thoughts) != 0: + # clear messages after the first iteration + prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) + return prompt_messages diff --git a/api/core/app/entities/app_asset_entities.py b/api/core/app/entities/app_asset_entities.py new file mode 100644 index 0000000000..aaf6dffdce --- /dev/null +++ b/api/core/app/entities/app_asset_entities.py @@ -0,0 +1,352 @@ +from __future__ import annotations + +import os +from collections import defaultdict +from collections.abc import Generator +from enum import StrEnum + +from pydantic import BaseModel, Field + + +class AssetNodeType(StrEnum): + FILE = "file" + FOLDER = "folder" + + +class AppAssetNode(BaseModel): + id: str = Field(description="Unique identifier for the node") + node_type: AssetNodeType = Field(description="Type of node: file or folder") + name: str = Field(description="Name of the file or folder") + parent_id: str | None = Field(default=None, description="Parent folder ID, None for root level") + order: int = Field(default=0, description="Sort order within parent folder, lower values first") + extension: str = Field(default="", description="File extension without dot, empty for folders") + size: int = Field(default=0, description="File size in bytes, 0 for folders") + + @classmethod + def create_folder(cls, node_id: str, name: str, parent_id: str | None = None) -> AppAssetNode: + return cls(id=node_id, node_type=AssetNodeType.FOLDER, name=name, parent_id=parent_id) + + @classmethod + def create_file(cls, node_id: str, name: str, parent_id: str | None = None, size: int = 0) -> AppAssetNode: + return cls( + id=node_id, + node_type=AssetNodeType.FILE, + name=name, + parent_id=parent_id, + extension=name.rsplit(".", 1)[-1] if "." in name else "", + size=size, + ) + + +class AppAssetNodeView(BaseModel): + id: str = Field(description="Unique identifier for the node") + node_type: str = Field(description="Type of node: 'file' or 'folder'") + name: str = Field(description="Name of the file or folder") + path: str = Field(description="Full path from root, e.g. '/folder/file.txt'") + extension: str = Field(default="", description="File extension without dot") + size: int = Field(default=0, description="File size in bytes") + children: list[AppAssetNodeView] = Field(default_factory=list, description="Child nodes for folders") + + +class BatchUploadNode(BaseModel): + """Structure for batch upload_url tree nodes, used for both input and output.""" + + name: str + node_type: AssetNodeType + size: int = 0 + children: list[BatchUploadNode] = [] + id: str | None = None + upload_url: str | None = None + + def to_app_asset_nodes(self, parent_id: str | None = None) -> list[AppAssetNode]: + """ + Generate IDs when missing and convert to AppAssetNode list. + Mutates self to set id field when it is not set. + """ + from uuid import uuid4 + + self.id = self.id or str(uuid4()) + nodes: list[AppAssetNode] = [] + + if self.node_type == AssetNodeType.FOLDER: + nodes.append(AppAssetNode.create_folder(self.id, self.name, parent_id)) + for child in self.children: + nodes.extend(child.to_app_asset_nodes(self.id)) + else: + nodes.append(AppAssetNode.create_file(self.id, self.name, parent_id, self.size)) + + return nodes + + +class TreeNodeNotFoundError(Exception): + """Tree internal: node not found""" + + pass + + +class TreeParentNotFoundError(Exception): + """Tree internal: parent folder not found""" + + pass + + +class TreePathConflictError(Exception): + """Tree internal: path already exists""" + + pass + + +class AppAssetFileTree(BaseModel): + """ + File tree structure for app assets using adjacency list pattern. + + Design: + - Storage: Flat list with parent_id references (adjacency list) + - Path: Computed dynamically via get_path(), not stored + - Order: Integer field for user-defined sorting within each folder + - API response: transform() builds nested tree with computed paths + + Why adjacency list over nested tree or materialized path: + - Simpler CRUD: move/rename only updates one node's parent_id + - No path cascade: renaming parent doesn't require updating all descendants + - JSON-friendly: flat list serializes cleanly to database JSON column + - Trade-off: path lookup is O(depth), acceptable for typical file trees + """ + + nodes: list[AppAssetNode] = Field(default_factory=list, description="Flat list of all nodes in the tree") + + def ensure_unique_name( + self, + parent_id: str | None, + name: str, + *, + is_file: bool, + extra_taken: set[str] | None = None, + ) -> str: + """ + Return a sibling-unique name by appending numeric suffixes when needed. + + The suffix format is " " (e.g. "report 1", "report 2"). For files, + the suffix is inserted before the extension. + """ + taken = extra_taken or set() + if not self.has_child_named(parent_id, name) and name not in taken: + return name + suffix_index = 1 + while True: + candidate = self._apply_name_suffix(name, suffix_index, is_file=is_file) + if not self.has_child_named(parent_id, candidate) and candidate not in taken: + return candidate + suffix_index += 1 + + @staticmethod + def _apply_name_suffix(name: str, suffix_index: int, *, is_file: bool) -> str: + if not is_file: + return f"{name} {suffix_index}" + stem, extension = os.path.splitext(name) + return f"{stem} {suffix_index}{extension}" + + def get(self, node_id: str) -> AppAssetNode | None: + return next((n for n in self.nodes if n.id == node_id), None) + + def get_children(self, parent_id: str | None) -> list[AppAssetNode]: + return [n for n in self.nodes if n.parent_id == parent_id] + + def has_child_named(self, parent_id: str | None, name: str) -> bool: + return any(n.name == name and n.parent_id == parent_id for n in self.nodes) + + def get_path(self, node_id: str) -> str: + node = self.get(node_id) + if not node: + raise TreeNodeNotFoundError(node_id) + parts: list[str] = [] + current: AppAssetNode | None = node + while current: + parts.append(current.name) + current = self.get(current.parent_id) if current.parent_id else None + return "/".join(reversed(parts)) + + def relative_path(self, a: AppAssetNode, b: AppAssetNode) -> str: + """ + Calculate relative path from node a to node b for Markdown references. + Path is computed from a's parent directory (where the file resides). + + Examples: + /foo/a.md -> /foo/b.md => ./b.md + /foo/a.md -> /foo/sub/b.md => ./sub/b.md + /foo/sub/a.md -> /foo/b.md => ../b.md + /foo/sub/deep/a.md -> /foo/b.md => ../../b.md + """ + + def get_ancestor_ids(node_id: str | None) -> list[str]: + chain: list[str] = [] + current_id = node_id + while current_id: + chain.append(current_id) + node = self.get(current_id) + current_id = node.parent_id if node else None + return chain + + a_dir_ancestors = get_ancestor_ids(a.parent_id) + b_ancestors = [b.id] + get_ancestor_ids(b.parent_id) + a_dir_set = set(a_dir_ancestors) + + lca_id: str | None = None + lca_index_in_b = -1 + for idx, ancestor_id in enumerate(b_ancestors): + if ancestor_id in a_dir_set or (a.parent_id is None and b_ancestors[idx:] == []): + lca_id = ancestor_id + lca_index_in_b = idx + break + + if a.parent_id is None: + steps_up = 0 + lca_index_in_b = len(b_ancestors) + elif lca_id is None: + steps_up = len(a_dir_ancestors) + lca_index_in_b = len(b_ancestors) + else: + steps_up = 0 + for ancestor_id in a_dir_ancestors: + if ancestor_id == lca_id: + break + steps_up += 1 + + path_down: list[str] = [] + for i in range(lca_index_in_b - 1, -1, -1): + node = self.get(b_ancestors[i]) + if node: + path_down.append(node.name) + + if steps_up == 0: + return "./" + "/".join(path_down) + + parts: list[str] = [".."] * steps_up + path_down + return "/".join(parts) + + def get_descendant_ids(self, node_id: str) -> list[str]: + result: list[str] = [] + stack = [node_id] + while stack: + current_id = stack.pop() + for child in self.nodes: + if child.parent_id == current_id: + result.append(child.id) + stack.append(child.id) + return result + + def add(self, node: AppAssetNode) -> AppAssetNode: + if self.get(node.id): + raise TreePathConflictError(node.id) + if self.has_child_named(node.parent_id, node.name): + raise TreePathConflictError(node.name) + if node.parent_id: + parent = self.get(node.parent_id) + if not parent or parent.node_type != AssetNodeType.FOLDER: + raise TreeParentNotFoundError(node.parent_id) + siblings = self.get_children(node.parent_id) + node.order = max((s.order for s in siblings), default=-1) + 1 + self.nodes.append(node) + return node + + def update(self, node_id: str, size: int) -> AppAssetNode: + node = self.get(node_id) + if not node or node.node_type != AssetNodeType.FILE: + raise TreeNodeNotFoundError(node_id) + node.size = size + return node + + def rename(self, node_id: str, new_name: str) -> AppAssetNode: + node = self.get(node_id) + if not node: + raise TreeNodeNotFoundError(node_id) + if node.name != new_name and self.has_child_named(node.parent_id, new_name): + raise TreePathConflictError(new_name) + node.name = new_name + if node.node_type == AssetNodeType.FILE: + node.extension = new_name.rsplit(".", 1)[-1] if "." in new_name else "" + return node + + def move(self, node_id: str, new_parent_id: str | None) -> AppAssetNode: + node = self.get(node_id) + if not node: + raise TreeNodeNotFoundError(node_id) + if new_parent_id: + parent = self.get(new_parent_id) + if not parent or parent.node_type != AssetNodeType.FOLDER: + raise TreeParentNotFoundError(new_parent_id) + if self.has_child_named(new_parent_id, node.name): + raise TreePathConflictError(node.name) + node.parent_id = new_parent_id + siblings = self.get_children(new_parent_id) + node.order = max((s.order for s in siblings if s.id != node_id), default=-1) + 1 + return node + + def reorder(self, node_id: str, after_node_id: str | None) -> AppAssetNode: + node = self.get(node_id) + if not node: + raise TreeNodeNotFoundError(node_id) + + siblings = sorted(self.get_children(node.parent_id), key=lambda x: x.order) + siblings = [s for s in siblings if s.id != node_id] + + if after_node_id is None: + insert_idx = 0 + else: + after_node = self.get(after_node_id) + if not after_node or after_node.parent_id != node.parent_id: + raise TreeNodeNotFoundError(after_node_id) + insert_idx = next((i for i, s in enumerate(siblings) if s.id == after_node_id), -1) + 1 + + siblings.insert(insert_idx, node) + for idx, sibling in enumerate(siblings): + sibling.order = idx + + return node + + def remove(self, node_id: str) -> list[str]: + node = self.get(node_id) + if not node: + raise TreeNodeNotFoundError(node_id) + ids_to_remove = [node_id] + self.get_descendant_ids(node_id) + self.nodes = [n for n in self.nodes if n.id not in ids_to_remove] + return ids_to_remove + + def walk_files(self) -> Generator[AppAssetNode, None, None]: + return (n for n in self.nodes if n.node_type == AssetNodeType.FILE) + + def transform(self) -> list[AppAssetNodeView]: + by_parent: dict[str | None, list[AppAssetNode]] = defaultdict(list) + for n in self.nodes: + by_parent[n.parent_id].append(n) + + for children in by_parent.values(): + children.sort(key=lambda x: x.order) + + paths: dict[str, str] = {} + tree_views: dict[str, AppAssetNodeView] = {} + + def build_view(node: AppAssetNode, parent_path: str) -> None: + path = f"{parent_path}/{node.name}" + paths[node.id] = path + child_views: list[AppAssetNodeView] = [] + for child in by_parent.get(node.id, []): + build_view(child, path) + child_views.append(tree_views[child.id]) + tree_views[node.id] = AppAssetNodeView( + id=node.id, + node_type=node.node_type.value, + name=node.name, + path=path, + extension=node.extension, + size=node.size, + children=child_views, + ) + + for root_node in by_parent.get(None, []): + build_view(root_node, "") + + return [tree_views[n.id] for n in by_parent.get(None, [])] + + def empty(self) -> bool: + return len(self.nodes) == 0 diff --git a/api/core/app/entities/app_bundle_entities.py b/api/core/app/entities/app_bundle_entities.py new file mode 100644 index 0000000000..8566fd2bb1 --- /dev/null +++ b/api/core/app/entities/app_bundle_entities.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import re +from datetime import UTC, datetime + +from pydantic import BaseModel, ConfigDict, Field + +from core.app.entities.app_asset_entities import AppAssetFileTree + +# Constants +BUNDLE_DSL_FILENAME_PATTERN = re.compile(r"^[^/]+\.ya?ml$") +BUNDLE_MAX_SIZE = 50 * 1024 * 1024 # 50MB +MANIFEST_FILENAME = "manifest.json" +MANIFEST_SCHEMA_VERSION = "1.0" + + +# Exceptions +class BundleFormatError(Exception): + """Raised when bundle format is invalid.""" + + pass + + +class ZipSecurityError(Exception): + """Raised when zip file contains security violations.""" + + pass + + +# Manifest DTOs +class ManifestFileEntry(BaseModel): + """Maps node_id to file path in the bundle.""" + + model_config = ConfigDict(extra="forbid") + + node_id: str + path: str + + +class ManifestIntegrity(BaseModel): + """Basic integrity check fields.""" + + model_config = ConfigDict(extra="forbid") + + file_count: int + + +class ManifestAppAssets(BaseModel): + """App assets section containing the full tree.""" + + model_config = ConfigDict(extra="forbid") + + tree: AppAssetFileTree + + +class BundleManifest(BaseModel): + """ + Bundle manifest for app asset import/export. + + Schema version 1.0: + - dsl_filename: DSL file name in bundle root (e.g. "my_app.yml") + - tree: Full AppAssetFileTree (files + folders) for 100% restoration including node IDs + - files: Explicit node_id -> path mapping for file nodes only + - integrity: Basic file_count validation + """ + + model_config = ConfigDict(extra="forbid") + + schema_version: str = Field(default=MANIFEST_SCHEMA_VERSION) + generated_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC)) + dsl_filename: str = Field(description="DSL file name in bundle root") + app_assets: ManifestAppAssets + files: list[ManifestFileEntry] + integrity: ManifestIntegrity + + @property + def assets_prefix(self) -> str: + """Assets directory name (DSL filename without extension).""" + return self.dsl_filename.rsplit(".", 1)[0] + + @classmethod + def from_tree(cls, tree: AppAssetFileTree, dsl_filename: str) -> BundleManifest: + """Build manifest from an AppAssetFileTree.""" + files = [ManifestFileEntry(node_id=n.id, path=tree.get_path(n.id)) for n in tree.walk_files()] + return cls( + dsl_filename=dsl_filename, + app_assets=ManifestAppAssets(tree=tree), + files=files, + integrity=ManifestIntegrity(file_count=len(files)), + ) + + +# Export result +class BundleExportResult(BaseModel): + download_url: str = Field(description="Temporary download URL for the ZIP") + filename: str = Field(description="Suggested filename for the ZIP") diff --git a/api/core/app/entities/llm_generation_entities.py b/api/core/app/entities/llm_generation_entities.py new file mode 100644 index 0000000000..a65fe5de1f --- /dev/null +++ b/api/core/app/entities/llm_generation_entities.py @@ -0,0 +1,72 @@ +""" +LLM Generation Detail entities. + +Defines the structure for storing and transmitting LLM generation details +including reasoning content, tool calls, and their sequence. +""" + +from typing import Literal + +from pydantic import BaseModel, Field + + +class ContentSegment(BaseModel): + """Represents a content segment in the generation sequence.""" + + type: Literal["content"] = "content" + start: int = Field(..., description="Start position in the text") + end: int = Field(..., description="End position in the text") + + +class ReasoningSegment(BaseModel): + """Represents a reasoning segment in the generation sequence.""" + + type: Literal["reasoning"] = "reasoning" + index: int = Field(..., description="Index into reasoning_content array") + + +class ToolCallSegment(BaseModel): + """Represents a tool call segment in the generation sequence.""" + + type: Literal["tool_call"] = "tool_call" + index: int = Field(..., description="Index into tool_calls array") + + +SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment + + +class ToolCallDetail(BaseModel): + """Represents a tool call with its arguments and result.""" + + id: str = Field(default="", description="Unique identifier for the tool call") + name: str = Field(..., description="Name of the tool") + arguments: str = Field(default="", description="JSON string of tool arguments") + result: str = Field(default="", description="Result from the tool execution") + elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds") + icon: str | dict | None = Field(default=None, description="Icon of the tool") + icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool") + + +class LLMGenerationDetailData(BaseModel): + """ + Domain model for LLM generation detail. + + Contains the structured data for reasoning content, tool calls, + and their display sequence. + """ + + reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments") + tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details") + sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments") + + def is_empty(self) -> bool: + """Check if there's any meaningful generation detail.""" + return not self.reasoning_content and not self.tool_calls + + def to_response_dict(self) -> dict: + """Convert to dictionary for API response.""" + return { + "reasoning_content": self.reasoning_content, + "tool_calls": [tc.model_dump() for tc in self.tool_calls], + "sequence": [seg.model_dump() for seg in self.sequence], + } diff --git a/api/core/helper/creators.py b/api/core/helper/creators.py new file mode 100644 index 0000000000..530e5d4886 --- /dev/null +++ b/api/core/helper/creators.py @@ -0,0 +1,75 @@ +""" +Helper module for Creators Platform integration. + +Provides functionality to upload DSL files to the Creators Platform +and generate redirect URLs with OAuth authorization codes. +""" + +import logging +from urllib.parse import urlencode + +import httpx +from yarl import URL + +from configs import dify_config + +logger = logging.getLogger(__name__) + +creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL)) + + +def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str: + """Upload a DSL file to the Creators Platform anonymous upload endpoint. + + Args: + dsl_file_bytes: Raw bytes of the DSL file (YAML or ZIP). + filename: Original filename for the upload. + + Returns: + The claim_code string used to retrieve the DSL later. + + Raises: + httpx.HTTPStatusError: If the upload request fails. + ValueError: If the response does not contain a valid claim_code. + """ + url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload") + response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30) + response.raise_for_status() + + data = response.json() + claim_code = data.get("data", {}).get("claim_code") + if not claim_code: + raise ValueError("Creators Platform did not return a valid claim_code") + + return claim_code + + +def get_redirect_url(user_account_id: str, claim_code: str) -> str: + """Generate the redirect URL to the Creators Platform frontend. + + Redirects to the Creators Platform root page with the dsl_claim_code. + If CREATORS_PLATFORM_OAUTH_CLIENT_ID is configured (Dify Cloud), + also signs an OAuth authorization code so the frontend can + automatically authenticate the user via the OAuth callback. + + For self-hosted Dify without OAuth client_id configured, only the + dsl_claim_code is passed and the user must log in manually. + + Args: + user_account_id: The Dify user account ID. + claim_code: The claim_code obtained from upload_dsl(). + + Returns: + The full redirect URL string. + """ + base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/") + params: dict[str, str] = {"dsl_claim_code": claim_code} + + client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "") + if client_id: + from services.oauth_server import OAuthServerService + + oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id) + params["oauth_code"] = oauth_code + + return f"{base_url}?{urlencode(params)}" diff --git a/api/core/llm_generator/context_models.py b/api/core/llm_generator/context_models.py new file mode 100644 index 0000000000..66db0ac64b --- /dev/null +++ b/api/core/llm_generator/context_models.py @@ -0,0 +1,62 @@ +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class VariableSelectorPayload(BaseModel): + model_config = ConfigDict(extra="forbid") + + variable: str = Field(..., description="Variable name used in generated code") + value_selector: list[str] = Field(..., description="Path to upstream node output, format: [node_id, output_name]") + + +class CodeOutputPayload(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: str = Field(..., description="Output variable type") + + +class CodeContextPayload(BaseModel): + # From web/app/components/workflow/nodes/tool/components/context-generate-modal/index.tsx (code node snapshot). + model_config = ConfigDict(extra="forbid") + + code: str = Field(..., description="Existing code in the Code node") + outputs: dict[str, CodeOutputPayload] | None = Field( + default=None, description="Existing output definitions for the Code node" + ) + variables: list[VariableSelectorPayload] | None = Field( + default=None, description="Existing variable selectors used by the Code node" + ) + + +class AvailableVarPayload(BaseModel): + # From web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts (available variables). + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + value_selector: list[str] = Field(..., description="Path to upstream node output") + type: str = Field(..., description="Variable type, e.g. string, number, array[object]") + description: str | None = Field(default=None, description="Optional variable description") + node_id: str | None = Field(default=None, description="Source node ID") + node_title: str | None = Field(default=None, description="Source node title") + node_type: str | None = Field(default=None, description="Source node type") + json_schema: dict[str, Any] | None = Field( + default=None, + alias="schema", + description="Optional JSON schema for object variables", + ) + + +class ParameterInfoPayload(BaseModel): + # From web/app/components/workflow/nodes/tool/use-config.ts (ToolParameter metadata). + model_config = ConfigDict(extra="forbid") + + name: str = Field(..., description="Target parameter name") + type: str = Field(default="string", description="Target parameter type") + description: str = Field(default="", description="Parameter description") + required: bool | None = Field(default=None, description="Whether the parameter is required") + options: list[str] | None = Field(default=None, description="Allowed option values") + min: float | None = Field(default=None, description="Minimum numeric value") + max: float | None = Field(default=None, description="Maximum numeric value") + default: str | int | float | bool | None = Field(default=None, description="Default value") + multiple: bool | None = Field(default=None, description="Whether the parameter accepts multiple values") + label: str | None = Field(default=None, description="Optional display label") diff --git a/api/core/llm_generator/output_models.py b/api/core/llm_generator/output_models.py new file mode 100644 index 0000000000..a73158e25c --- /dev/null +++ b/api/core/llm_generator/output_models.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field + +from graphon.variables.types import SegmentType + + +class SuggestedQuestionsOutput(BaseModel): + """Output model for suggested questions generation.""" + + model_config = ConfigDict(extra="forbid") + + questions: list[str] = Field( + min_length=3, + max_length=3, + description="Exactly 3 suggested follow-up questions for the user", + ) + + +class VariableSelectorOutput(BaseModel): + """Variable selector mapping code variable to upstream node output. + + Note: Separate from VariableSelector to ensure 'additionalProperties: false' + in JSON schema for OpenAI/Azure strict mode. + """ + + model_config = ConfigDict(extra="forbid") + + variable: str = Field(description="Variable name used in the generated code") + value_selector: list[str] = Field(description="Path to upstream node output, format: [node_id, output_name]") + + +class CodeNodeOutputItem(BaseModel): + """Single output variable definition. + + Note: OpenAI/Azure strict mode requires 'additionalProperties: false' and + does not support dynamic object keys, so outputs use array format. + """ + + model_config = ConfigDict(extra="forbid") + + name: str = Field(description="Output variable name returned by the main function") + type: SegmentType = Field(description="Data type of the output variable") + + +class CodeNodeStructuredOutput(BaseModel): + """Structured output for code node generation.""" + + model_config = ConfigDict(extra="forbid") + + variables: list[VariableSelectorOutput] = Field( + description="Input variables mapping code variables to upstream node outputs" + ) + code: str = Field(description="Generated code with a main function that processes inputs and returns outputs") + outputs: list[CodeNodeOutputItem] = Field( + description="Output variable definitions specifying name and type for each return value" + ) + message: str = Field(description="Brief explanation of what the generated code does") + + +class InstructionModifyOutput(BaseModel): + """Output model for instruction-based prompt modification.""" + + model_config = ConfigDict(extra="forbid") + + modified: str = Field(description="The modified prompt content after applying the instruction") + message: str = Field(description="Brief explanation of what changes were made") diff --git a/api/core/llm_generator/output_parser/file_ref.py b/api/core/llm_generator/output_parser/file_ref.py new file mode 100644 index 0000000000..c43d213ca0 --- /dev/null +++ b/api/core/llm_generator/output_parser/file_ref.py @@ -0,0 +1,203 @@ +""" +File path detection and conversion for structured output. + +This module provides utilities to: +1. Detect sandbox file path fields in JSON Schema (format: "file-path") +2. Adapt schemas to add file-path descriptions before model invocation +3. Convert sandbox file path strings into File objects via a resolver +""" + +from collections.abc import Callable, Mapping, Sequence +from typing import Any, cast + +from graphon.file import File +from graphon.variables.segments import ArrayFileSegment, FileSegment + +FILE_PATH_FORMAT = "file-path" +FILE_PATH_DESCRIPTION_SUFFIX = "this field contains a file path from the Dify sandbox" + + +def is_file_path_property(schema: Mapping[str, Any]) -> bool: + """Check if a schema property represents a sandbox file path.""" + if schema.get("type") != "string": + return False + format_value = schema.get("format") + if not isinstance(format_value, str): + return False + normalized_format = format_value.lower().replace("_", "-") + return normalized_format == FILE_PATH_FORMAT + + +def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]: + """Recursively detect file path fields in a JSON schema.""" + file_path_fields: list[str] = [] + schema_type = schema.get("type") + + if schema_type == "object": + properties = schema.get("properties") + if isinstance(properties, Mapping): + properties_mapping = cast(Mapping[str, Any], properties) + for prop_name, prop_schema in properties_mapping.items(): + if not isinstance(prop_schema, Mapping): + continue + prop_schema_mapping = cast(Mapping[str, Any], prop_schema) + current_path = f"{path}.{prop_name}" if path else prop_name + + if is_file_path_property(prop_schema_mapping): + file_path_fields.append(current_path) + else: + file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path)) + + elif schema_type == "array": + items_schema = schema.get("items") + if not isinstance(items_schema, Mapping): + return file_path_fields + items_schema_mapping = cast(Mapping[str, Any], items_schema) + array_path = f"{path}[*]" if path else "[*]" + + if is_file_path_property(items_schema_mapping): + file_path_fields.append(array_path) + else: + file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path)) + + return file_path_fields + + +def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]: + """Normalize sandbox file path fields and collect their JSON paths.""" + result = _deep_copy_value(schema) + if not isinstance(result, dict): + raise ValueError("structured_output_schema must be a JSON object") + result_dict = cast(dict[str, Any], result) + + file_path_fields: list[str] = [] + _adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields) + return result_dict, file_path_fields + + +def convert_sandbox_file_paths_in_output( + output: Mapping[str, Any], + file_path_fields: Sequence[str], + file_resolver: Callable[[str], File], +) -> tuple[dict[str, Any], list[File]]: + """Convert sandbox file paths into File objects using the resolver.""" + if not file_path_fields: + return dict(output), [] + + result = _deep_copy_value(output) + if not isinstance(result, dict): + raise ValueError("Structured output must be a JSON object") + result_dict = cast(dict[str, Any], result) + + files: list[File] = [] + for path in file_path_fields: + _convert_path_in_place(result_dict, path.split("."), file_resolver, files) + + return result_dict, files + + +def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None: + schema_type = schema.get("type") + + if schema_type == "object": + properties = schema.get("properties") + if isinstance(properties, Mapping): + properties_mapping = cast(Mapping[str, Any], properties) + for prop_name, prop_schema in properties_mapping.items(): + if not isinstance(prop_schema, dict): + continue + prop_schema_dict = cast(dict[str, Any], prop_schema) + current_path = f"{path}.{prop_name}" if path else prop_name + + if is_file_path_property(prop_schema_dict): + _normalize_file_path_schema(prop_schema_dict) + file_path_fields.append(current_path) + else: + _adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields) + + elif schema_type == "array": + items_schema = schema.get("items") + if not isinstance(items_schema, dict): + return + items_schema_dict = cast(dict[str, Any], items_schema) + array_path = f"{path}[*]" if path else "[*]" + + if is_file_path_property(items_schema_dict): + _normalize_file_path_schema(items_schema_dict) + file_path_fields.append(array_path) + else: + _adapt_schema_in_place(items_schema_dict, array_path, file_path_fields) + + +def _normalize_file_path_schema(schema: dict[str, Any]) -> None: + schema["type"] = "string" + schema["format"] = FILE_PATH_FORMAT + description = schema.get("description", "") + if description: + if FILE_PATH_DESCRIPTION_SUFFIX not in description: + schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}" + else: + schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX + + +def _deep_copy_value(value: Any) -> Any: + if isinstance(value, Mapping): + mapping = cast(Mapping[str, Any], value) + return {key: _deep_copy_value(item) for key, item in mapping.items()} + if isinstance(value, list): + list_value = cast(list[Any], value) + return [_deep_copy_value(item) for item in list_value] + return value + + +def _convert_path_in_place( + obj: dict[str, Any], + path_parts: list[str], + file_resolver: Callable[[str], File], + files: list[File], +) -> None: + if not path_parts: + return + + current = path_parts[0] + remaining = path_parts[1:] + + if current.endswith("[*]"): + key = current[:-3] if current != "[*]" else "" + target_value = obj.get(key) if key else obj + + if isinstance(target_value, list): + target_list = cast(list[Any], target_value) + if remaining: + for item in target_list: + if isinstance(item, dict): + item_dict = cast(dict[str, Any], item) + _convert_path_in_place(item_dict, remaining, file_resolver, files) + else: + resolved_files: list[File] = [] + for item in target_list: + if not isinstance(item, str): + raise ValueError("File path must be a string") + file = file_resolver(item) + files.append(file) + resolved_files.append(file) + if key: + obj[key] = ArrayFileSegment(value=resolved_files) + return + + if not remaining: + if current not in obj: + return + value = obj[current] + if value is None: + obj[current] = None + return + if not isinstance(value, str): + raise ValueError("File path must be a string") + file = file_resolver(value) + files.append(file) + obj[current] = FileSegment(value=file) + return + + if current in obj and isinstance(obj[current], dict): + _convert_path_in_place(obj[current], remaining, file_resolver, files) diff --git a/api/core/llm_generator/utils.py b/api/core/llm_generator/utils.py new file mode 100644 index 0000000000..492b23ca92 --- /dev/null +++ b/api/core/llm_generator/utils.py @@ -0,0 +1,45 @@ +"""Utility functions for LLM generator.""" + +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) + + +def deserialize_prompt_messages(messages: list[dict]) -> list[PromptMessage]: + """ + Deserialize list of dicts to list[PromptMessage]. + + Expected format: + [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."}, + ] + """ + result: list[PromptMessage] = [] + for msg in messages: + role = PromptMessageRole.value_of(msg["role"]) + content = msg.get("content", "") + + match role: + case PromptMessageRole.USER: + result.append(UserPromptMessage(content=content)) + case PromptMessageRole.ASSISTANT: + result.append(AssistantPromptMessage(content=content)) + case PromptMessageRole.SYSTEM: + result.append(SystemPromptMessage(content=content)) + case PromptMessageRole.TOOL: + result.append(ToolPromptMessage(content=content, tool_call_id=msg.get("tool_call_id", ""))) + + return result + + +def serialize_prompt_messages(messages: list[PromptMessage]) -> list[dict]: + """ + Serialize list[PromptMessage] to list of dicts. + """ + return [{"role": msg.role.value, "content": msg.content} for msg in messages] diff --git a/api/core/memory/__init__.py b/api/core/memory/__init__.py new file mode 100644 index 0000000000..d0e2babde2 --- /dev/null +++ b/api/core/memory/__init__.py @@ -0,0 +1,11 @@ +from core.memory.base import BaseMemory +from core.memory.node_token_buffer_memory import ( + NodeTokenBufferMemory, +) +from core.memory.token_buffer_memory import TokenBufferMemory + +__all__ = [ + "BaseMemory", + "NodeTokenBufferMemory", + "TokenBufferMemory", +] diff --git a/api/core/memory/base.py b/api/core/memory/base.py new file mode 100644 index 0000000000..1cfa1873c5 --- /dev/null +++ b/api/core/memory/base.py @@ -0,0 +1,82 @@ +""" +Base memory interfaces and types. + +This module defines the common protocol for memory implementations. +""" + +from abc import ABC, abstractmethod +from collections.abc import Sequence + +from graphon.model_runtime.entities import ImagePromptMessageContent, PromptMessage + + +class BaseMemory(ABC): + """ + Abstract base class for memory implementations. + + Provides a common interface for both conversation-level and node-level memory. + """ + + @abstractmethod + def get_history_prompt_messages( + self, + max_token_limit: int = 2000, + message_limit: int | None = None, + ) -> Sequence[PromptMessage]: + """ + Get history prompt messages. + + :param max_token_limit: Maximum tokens for history + :param message_limit: Maximum number of messages + :return: Sequence of PromptMessage for LLM context + """ + pass + + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: int | None = None, + ) -> str: + """ + Get history prompt as formatted text. + + :param human_prefix: Prefix for human messages + :param ai_prefix: Prefix for assistant messages + :param max_token_limit: Maximum tokens for history + :param message_limit: Maximum number of messages + :return: Formatted history text + """ + from graphon.model_runtime.entities import ( + PromptMessageRole, + TextPromptMessageContent, + ) + + prompt_messages = self.get_history_prompt_messages( + max_token_limit=max_token_limit, + message_limit=message_limit, + ) + + string_messages = [] + for m in prompt_messages: + if m.role == PromptMessageRole.USER: + role = human_prefix + elif m.role == PromptMessageRole.ASSISTANT: + role = ai_prefix + else: + continue + + if isinstance(m.content, list): + inner_msg = "" + for content in m.content: + if isinstance(content, TextPromptMessageContent): + inner_msg += f"{content.data}\n" + elif isinstance(content, ImagePromptMessageContent): + inner_msg += "[image]\n" + string_messages.append(f"{role}: {inner_msg.strip()}") + else: + message = f"{role}: {m.content}" + string_messages.append(message) + + return "\n".join(string_messages) diff --git a/api/core/memory/node_token_buffer_memory.py b/api/core/memory/node_token_buffer_memory.py new file mode 100644 index 0000000000..d9ad333bd0 --- /dev/null +++ b/api/core/memory/node_token_buffer_memory.py @@ -0,0 +1,196 @@ +""" +Node-level Token Buffer Memory for Chatflow. + +This module provides node-scoped memory within a conversation. +Each LLM node in a workflow can maintain its own independent conversation history. + +Note: This is only available in Chatflow (advanced-chat mode) because it requires +both conversation_id and node_id. + +Design: +- History is read directly from WorkflowNodeExecutionModel.outputs["context"] +- No separate storage needed - the context is already saved during node execution +- Thread tracking leverages Message table's parent_message_id structure +""" + +import logging +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.memory.base import BaseMemory +from core.model_manager import ModelInstance +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from graphon.file import file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + MultiModalPromptMessageContent, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from extensions.ext_database import db +from models.model import Message +from models.workflow import WorkflowNodeExecutionModel + +logger = logging.getLogger(__name__) + + +class NodeTokenBufferMemory(BaseMemory): + """ + Node-level Token Buffer Memory. + + Provides node-scoped memory within a conversation. Each LLM node can maintain + its own independent conversation history. + + Key design: History is read directly from WorkflowNodeExecutionModel.outputs["context"], + which is already saved during node execution. No separate storage needed. + """ + + def __init__( + self, + app_id: str, + conversation_id: str, + node_id: str, + tenant_id: str, + model_instance: ModelInstance, + ): + self.app_id = app_id + self.conversation_id = conversation_id + self.node_id = node_id + self.tenant_id = tenant_id + self.model_instance = model_instance + + def _get_thread_workflow_run_ids(self) -> list[str]: + """ + Get workflow_run_ids for the current thread by querying Message table. + Returns workflow_run_ids in chronological order (oldest first). + """ + with Session(db.engine, expire_on_commit=False) as session: + stmt = ( + select(Message) + .where(Message.conversation_id == self.conversation_id) + .order_by(Message.created_at.desc()) + .limit(500) + ) + messages = list(session.scalars(stmt).all()) + + if not messages: + return [] + + # Extract thread messages using existing logic + thread_messages = extract_thread_messages(messages) + + # For newly created message, its answer is temporarily empty, skip it + if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0: + thread_messages.pop(0) + + # Reverse to get chronological order, extract workflow_run_ids + return [msg.workflow_run_id for msg in reversed(thread_messages) if msg.workflow_run_id] + + def _deserialize_prompt_message(self, msg_dict: dict) -> PromptMessage: + """Deserialize a dict to PromptMessage based on role.""" + role = msg_dict.get("role") + if role in (PromptMessageRole.USER, "user"): + return UserPromptMessage.model_validate(msg_dict) + elif role in (PromptMessageRole.ASSISTANT, "assistant"): + return AssistantPromptMessage.model_validate(msg_dict) + elif role in (PromptMessageRole.SYSTEM, "system"): + return SystemPromptMessage.model_validate(msg_dict) + elif role in (PromptMessageRole.TOOL, "tool"): + return ToolPromptMessage.model_validate(msg_dict) + else: + return PromptMessage.model_validate(msg_dict) + + def _deserialize_context(self, context_data: list[dict]) -> list[PromptMessage]: + """Deserialize context data from outputs to list of PromptMessage.""" + messages = [] + for msg_dict in context_data: + try: + msg = self._deserialize_prompt_message(msg_dict) + msg = self._restore_multimodal_content(msg) + messages.append(msg) + except Exception as e: + logger.warning("Failed to deserialize prompt message: %s", e) + return messages + + def _restore_multimodal_content(self, message: PromptMessage) -> PromptMessage: + """ + Restore multimodal content (base64 or url) from file_ref. + + When context is saved, base64_data is cleared to save storage space. + This method restores the content by parsing file_ref (format: "method:id_or_url"). + """ + content = message.content + if content is None or isinstance(content, str): + return message + + # Process list content, restoring multimodal data from file references + restored_content: list[PromptMessageContentUnionTypes] = [] + for item in content: + if isinstance(item, MultiModalPromptMessageContent): + # restore_multimodal_content preserves the concrete subclass type + restored_item = file_manager.restore_multimodal_content(item) + restored_content.append(cast(PromptMessageContentUnionTypes, restored_item)) + else: + restored_content.append(item) + + return message.model_copy(update={"content": restored_content}) + + def get_history_prompt_messages( + self, + max_token_limit: int = 2000, + message_limit: int | None = None, + ) -> Sequence[PromptMessage]: + """ + Retrieve history as PromptMessage sequence. + History is read directly from the last completed node execution's outputs["context"]. + """ + _ = message_limit # unused, kept for interface compatibility + + thread_workflow_run_ids = self._get_thread_workflow_run_ids() + if not thread_workflow_run_ids: + return [] + + # Get the last completed workflow_run_id (contains accumulated context) + last_run_id = thread_workflow_run_ids[-1] + + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.workflow_run_id == last_run_id, + WorkflowNodeExecutionModel.node_id == self.node_id, + WorkflowNodeExecutionModel.status == "succeeded", + ) + execution = session.scalars(stmt).first() + + if not execution: + return [] + + outputs = execution.outputs_dict + if not outputs: + return [] + + context_data = outputs.get("context") + + if not context_data or not isinstance(context_data, list): + return [] + + prompt_messages = self._deserialize_context(context_data) + if not prompt_messages: + return [] + + # Truncate by token limit + try: + current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) + while current_tokens > max_token_limit and len(prompt_messages) > 1: + prompt_messages.pop(0) + current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) + except Exception as e: + logger.warning("Failed to count tokens for truncation: %s", e) + + return prompt_messages diff --git a/api/core/session/__init__.py b/api/core/session/__init__.py new file mode 100644 index 0000000000..3c669a52ca --- /dev/null +++ b/api/core/session/__init__.py @@ -0,0 +1,11 @@ +from .cli_api import CliApiSession, CliApiSessionManager +from .session import BaseSession, RedisSessionStorage, SessionManager, SessionStorage + +__all__ = [ + "BaseSession", + "CliApiSession", + "CliApiSessionManager", + "RedisSessionStorage", + "SessionManager", + "SessionStorage", +] diff --git a/api/core/session/cli_api.py b/api/core/session/cli_api.py new file mode 100644 index 0000000000..a8c42a5ede --- /dev/null +++ b/api/core/session/cli_api.py @@ -0,0 +1,30 @@ +import secrets + +from pydantic import BaseModel, Field + +from configs import dify_config +from core.skill.entities import ToolAccessPolicy + +from .session import BaseSession, SessionManager + + +class CliApiSession(BaseSession): + secret: str = Field(default_factory=lambda: secrets.token_urlsafe(32)) + + +class CliContext(BaseModel): + tool_access: ToolAccessPolicy | None = Field(default=None, description="Tool access policy") + + +class CliApiSessionManager(SessionManager[CliApiSession]): + def __init__(self, ttl: int | None = None): + super().__init__( + key_prefix="cli_api_session", + session_class=CliApiSession, + ttl=ttl or dify_config.WORKFLOW_MAX_EXECUTION_TIME, + ) + + def create(self, tenant_id: str, user_id: str, context: CliContext) -> CliApiSession: + session = CliApiSession(tenant_id=tenant_id, user_id=user_id, context=context.model_dump(mode="json")) + self.save(session) + return session diff --git a/api/core/session/session.py b/api/core/session/session.py new file mode 100644 index 0000000000..620ea39b3a --- /dev/null +++ b/api/core/session/session.py @@ -0,0 +1,106 @@ +import json +import logging +import uuid +from datetime import UTC, datetime +from typing import Any, Generic, Protocol, TypeVar + +from pydantic import BaseModel, Field, ValidationError + +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + + +class SessionStorage(Protocol): + """Session storage interface.""" + + def get(self, key: str) -> str | None: ... + def set(self, key: str, value: str, ttl: int) -> None: ... + def delete(self, key: str) -> bool: ... + def exists(self, key: str) -> bool: ... + def refresh_ttl(self, key: str, ttl: int) -> bool: ... + + +class RedisSessionStorage: + """Redis storage implementation (default).""" + + def get(self, key: str) -> str | None: + result = redis_client.get(key) + if result is None: + return None + return result.decode() if isinstance(result, bytes) else result + + def set(self, key: str, value: str, ttl: int) -> None: + redis_client.setex(key, ttl, value) + + def delete(self, key: str) -> bool: + return redis_client.delete(key) > 0 + + def exists(self, key: str) -> bool: + return redis_client.exists(key) > 0 + + def refresh_ttl(self, key: str, ttl: int) -> bool: + return bool(redis_client.expire(key, ttl)) + + +class BaseSession(BaseModel): + """Base session model.""" + + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + tenant_id: str + user_id: str + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + context: dict[str, Any] = Field(default_factory=dict) + + def update_timestamp(self) -> None: + self.updated_at = datetime.now(UTC) + + +T = TypeVar("T", bound=BaseSession) + + +class SessionManager(Generic[T]): + """Generic session manager.""" + + DEFAULT_TTL = 7200 # 2 hours + + def __init__( + self, + key_prefix: str, + session_class: type[T], + storage: SessionStorage | None = None, + ttl: int | None = None, + ): + self._key_prefix = key_prefix + self._session_class = session_class + self._storage = storage or RedisSessionStorage() + self._ttl = ttl or self.DEFAULT_TTL + + def _get_key(self, session_id: str) -> str: + return f"{self._key_prefix}:{session_id}" + + def save(self, session: T) -> None: + session.update_timestamp() + key = self._get_key(session.id) + self._storage.set(key, session.model_dump_json(), self._ttl) + + def get(self, session_id: str) -> T | None: + key = self._get_key(session_id) + data = self._storage.get(key) + if data is None: + return None + try: + return self._session_class.model_validate(json.loads(data)) + except (json.JSONDecodeError, ValidationError) as e: + logger.warning("Failed to deserialize session %s: %s", session_id, e) + return None + + def delete(self, session_id: str) -> bool: + return self._storage.delete(self._get_key(session_id)) + + def exists(self, session_id: str) -> bool: + return self._storage.exists(self._get_key(session_id)) + + def refresh_ttl(self, session_id: str) -> bool: + return self._storage.refresh_ttl(self._get_key(session_id), self._ttl) diff --git a/api/core/tools/utils/system_encryption.py b/api/core/tools/utils/system_encryption.py new file mode 100644 index 0000000000..fa4625608b --- /dev/null +++ b/api/core/tools/utils/system_encryption.py @@ -0,0 +1,187 @@ +import base64 +import hashlib +import logging +from collections.abc import Mapping +from typing import Any + +from Crypto.Cipher import AES +from Crypto.Random import get_random_bytes +from Crypto.Util.Padding import pad, unpad +from pydantic import TypeAdapter + +from configs import dify_config + +logger = logging.getLogger(__name__) + + +class EncryptionError(Exception): + """Encryption/decryption specific error""" + + pass + + +class SystemEncrypter: + """ + A simple parameters encrypter using AES-CBC encryption. + + This class provides methods to encrypt and decrypt parameters + using AES-CBC mode with a key derived from the application's SECRET_KEY. + """ + + def __init__(self, secret_key: str | None = None): + """ + Initialize the encrypter. + + Args: + secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY + + Raises: + ValueError: If SECRET_KEY is not configured or empty + """ + secret_key = secret_key or dify_config.SECRET_KEY or "" + + # Generate a fixed 256-bit key using SHA-256 + self.key = hashlib.sha256(secret_key.encode()).digest() + + def encrypt_params(self, params: Mapping[str, Any]) -> str: + """ + Encrypt parameters. + + Args: + params: parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} + + Returns: + Base64-encoded encrypted string + + Raises: + EncryptionError: If encryption fails + ValueError: If params is invalid + """ + + try: + # Generate random IV (16 bytes) + iv = get_random_bytes(16) + + # Create AES cipher (CBC mode) + cipher = AES.new(self.key, AES.MODE_CBC, iv) + + # Encrypt data + padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size) + encrypted_data = cipher.encrypt(padded_data) + + # Combine IV and encrypted data + combined = iv + encrypted_data + + # Return base64 encoded string + return base64.b64encode(combined).decode() + + except Exception as e: + raise EncryptionError(f"Encryption failed: {str(e)}") from e + + def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]: + """ + Decrypt parameters. + + Args: + encrypted_data: Base64-encoded encrypted string + + Returns: + Decrypted parameters dictionary + + Raises: + EncryptionError: If decryption fails + ValueError: If encrypted_data is invalid + """ + if not isinstance(encrypted_data, str): + raise ValueError("encrypted_data must be a string") + + if not encrypted_data: + raise ValueError("encrypted_data cannot be empty") + + try: + # Base64 decode + combined = base64.b64decode(encrypted_data) + + # Check minimum length (IV + at least one AES block) + if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data + raise ValueError("Invalid encrypted data format") + + # Separate IV and encrypted data + iv = combined[:16] + encrypted_data_bytes = combined[16:] + + # Create AES cipher + cipher = AES.new(self.key, AES.MODE_CBC, iv) + + # Decrypt data + decrypted_data = cipher.decrypt(encrypted_data_bytes) + unpadded_data = unpad(decrypted_data, AES.block_size) + + # Parse JSON + params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) + + if not isinstance(params, dict): + raise ValueError("Decrypted data is not a valid dictionary") + + return params + + except Exception as e: + raise EncryptionError(f"Decryption failed: {str(e)}") from e + + +# Factory function for creating encrypter instances +def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter: + """ + Create an encrypter instance. + + Args: + secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY + + Returns: + SystemEncrypter instance + """ + return SystemEncrypter(secret_key=secret_key) + + +# Global encrypter instance (for backward compatibility) +_encrypter: SystemEncrypter | None = None + + +def get_system_encrypter() -> SystemEncrypter: + """ + Get the global encrypter instance. + + Returns: + SystemEncrypter instance + """ + global _encrypter + if _encrypter is None: + _encrypter = SystemEncrypter() + return _encrypter + + +# Convenience functions for backward compatibility +def encrypt_system_params(params: Mapping[str, Any]) -> str: + """ + Encrypt parameters using the global encrypter. + + Args: + params: parameters dictionary + + Returns: + Base64-encoded encrypted string + """ + return get_system_encrypter().encrypt_params(params) + + +def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]: + """ + Decrypt parameters using the global encrypter. + + Args: + encrypted_data: Base64-encoded encrypted string + + Returns: + Decrypted parameters dictionary + """ + return get_system_encrypter().decrypt_params(encrypted_data) diff --git a/api/core/workflow/nodes/command/__init__.py b/api/core/workflow/nodes/command/__init__.py new file mode 100644 index 0000000000..78ca8ca06d --- /dev/null +++ b/api/core/workflow/nodes/command/__init__.py @@ -0,0 +1,3 @@ +from .node import CommandNode + +__all__ = ["CommandNode"] diff --git a/api/core/workflow/nodes/command/entities.py b/api/core/workflow/nodes/command/entities.py new file mode 100644 index 0000000000..bdaf7b724f --- /dev/null +++ b/api/core/workflow/nodes/command/entities.py @@ -0,0 +1,10 @@ +from graphon.entities.base_node_data import BaseNodeData + + +class CommandNodeData(BaseNodeData): + """ + Command Node Data. + """ + + working_directory: str = "" # Working directory for command execution + command: str = "" # Command to execute diff --git a/api/core/workflow/nodes/command/exc.py b/api/core/workflow/nodes/command/exc.py new file mode 100644 index 0000000000..c6349d5630 --- /dev/null +++ b/api/core/workflow/nodes/command/exc.py @@ -0,0 +1,16 @@ +class CommandNodeError(ValueError): + """Base class for command node errors.""" + + pass + + +class CommandExecutionError(CommandNodeError): + """Raised when command execution fails.""" + + pass + + +class CommandTimeoutError(CommandNodeError): + """Raised when command execution times out.""" + + pass diff --git a/api/core/workflow/nodes/command/node.py b/api/core/workflow/nodes/command/node.py new file mode 100644 index 0000000000..a56c4870b5 --- /dev/null +++ b/api/core/workflow/nodes/command/node.py @@ -0,0 +1,152 @@ +import logging +from collections.abc import Mapping, Sequence +from typing import Any + +from core.sandbox import sandbox_debug +from core.sandbox.bash.session import SANDBOX_READY_TIMEOUT +from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError +from core.virtual_environment.__base.helpers import submit_command, with_connection +from core.workflow.nodes.command.entities import CommandNodeData +from core.workflow.nodes.command.exc import CommandExecutionError +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base import variable_template_parser +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + +logger = logging.getLogger(__name__) + +# FIXME(Mairuis): The timeout value is currently hardcoded and should be made configurable in the future. +COMMAND_NODE_TIMEOUT_SECONDS = 60 * 10 + + +class CommandNode(Node[CommandNodeData]): + node_type = BuiltinNodeTypes.COMMAND + + def _render_template(self, template: str) -> str: + parser = VariableTemplateParser(template=template) + selectors = parser.extract_variable_selectors() + if not selectors: + return template + + inputs: dict[str, Any] = {} + for selector in selectors: + value = self.graph_runtime_state.variable_pool.get(selector.value_selector) + inputs[selector.variable] = value.to_object() if value is not None else None + + return parser.format(inputs) + + @classmethod + def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: + return { + "type": "command", + "config": { + "working_directory": "", + "command": "", + }, + } + + @classmethod + def version(cls) -> str: + return "1" + + def _run(self) -> NodeRunResult: + sandbox = self.graph_runtime_state.sandbox + if sandbox is None: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="Sandbox not available for CommandNode.", + error_type="SandboxNotInitializedError", + ) + + working_directory = self._render_template((self.node_data.working_directory or "").strip()) + raw_command = self._render_template(self.node_data.command or "") + + working_directory = working_directory or None + + if not raw_command: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="Command is required.", + error_type="CommandNodeError", + ) + + try: + sandbox.wait_ready(timeout=SANDBOX_READY_TIMEOUT) + with with_connection(sandbox.vm) as conn: + command = ["bash", "-c", raw_command] + + sandbox_debug("command_node", "command", command) + + future = submit_command(sandbox.vm, conn, command, cwd=working_directory) + result = future.result(timeout=COMMAND_NODE_TIMEOUT_SECONDS) + + outputs: dict[str, Any] = { + "stdout": result.stdout.decode("utf-8", errors="replace"), + "stderr": result.stderr.decode("utf-8", errors="replace"), + "exit_code": result.exit_code, + "pid": result.pid, + } + process_data = {"command": command, "working_directory": working_directory} + + sandbox_debug("command_node", "outputs", result.debug_message) + + if result.exit_code not in (None, 0): + stderr_text = result.stderr.decode("utf-8", errors="replace") + error_message = f"{stderr_text}\n\nCommand exited with code {result.exit_code}" + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + outputs=outputs, + process_data=process_data, + error=error_message, + error_type=CommandExecutionError.__name__, + ) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs=outputs, + process_data=process_data, + ) + + except CommandTimeoutError: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=f"Command timed out after {COMMAND_NODE_TIMEOUT_SECONDS}s", + error_type=CommandTimeoutError.__name__, + ) + except CommandCancelledError: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="Command was cancelled", + error_type=CommandCancelledError.__name__, + ) + except Exception as e: + logger.exception("Command node %s failed", self.id) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + error_type=type(e).__name__, + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: CommandNodeData, + ) -> Mapping[str, Sequence[str]]: + _ = graph_config + + typed_node_data = node_data + + selectors: list[VariableSelector] = [] + selectors += list(variable_template_parser.extract_selectors_from_template(typed_node_data.command)) + selectors += list(variable_template_parser.extract_selectors_from_template(typed_node_data.working_directory)) + + mapping: dict[str, Sequence[str]] = {} + for selector in selectors: + mapping[node_id + "." + selector.variable] = selector.value_selector + + return mapping diff --git a/api/core/workflow/nodes/file_upload/__init__.py b/api/core/workflow/nodes/file_upload/__init__.py new file mode 100644 index 0000000000..89ddd56cea --- /dev/null +++ b/api/core/workflow/nodes/file_upload/__init__.py @@ -0,0 +1,4 @@ +from .entities import FileUploadNodeData +from .node import FileUploadNode + +__all__ = ["FileUploadNode", "FileUploadNodeData"] diff --git a/api/core/workflow/nodes/file_upload/entities.py b/api/core/workflow/nodes/file_upload/entities.py new file mode 100644 index 0000000000..3ecc3c0f01 --- /dev/null +++ b/api/core/workflow/nodes/file_upload/entities.py @@ -0,0 +1,7 @@ +from collections.abc import Sequence + +from graphon.entities.base_node_data import BaseNodeData + + +class FileUploadNodeData(BaseNodeData): + variable_selector: Sequence[str] diff --git a/api/core/workflow/nodes/file_upload/exc.py b/api/core/workflow/nodes/file_upload/exc.py new file mode 100644 index 0000000000..60bf5f33df --- /dev/null +++ b/api/core/workflow/nodes/file_upload/exc.py @@ -0,0 +1,6 @@ +class FileUploadNodeError(ValueError): + """Base exception for errors related to the FileUploadNode.""" + + +class FileUploadDownloadError(FileUploadNodeError): + """Exception raised when preparing file download in sandbox fails.""" diff --git a/api/core/workflow/nodes/file_upload/node.py b/api/core/workflow/nodes/file_upload/node.py new file mode 100644 index 0000000000..7742f1eb41 --- /dev/null +++ b/api/core/workflow/nodes/file_upload/node.py @@ -0,0 +1,244 @@ +import logging +import os +import posixpath +from collections.abc import Mapping, Sequence +from pathlib import PurePosixPath +from typing import Any, cast + +from core.sandbox.bash.session import SANDBOX_READY_TIMEOUT +from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError +from core.virtual_environment.__base.helpers import pipeline +from core.zip_sandbox import SandboxDownloadItem +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.variables import ArrayFileSegment +from graphon.variables.segments import ArrayStringSegment, FileSegment + +from .entities import FileUploadNodeData +from .exc import FileUploadDownloadError, FileUploadNodeError + +logger = logging.getLogger(__name__) + + +class FileUploadNode(Node[FileUploadNodeData]): + """Upload workflow file variables into sandbox via presigned URLs. + + The node intentionally avoids streaming file bytes through Dify workers. For local/tool + files, it generates storage-backed presigned URLs and lets sandbox download directly. + """ + + node_type = BuiltinNodeTypes.FILE_UPLOAD + + @classmethod + def version(cls) -> str: + return "1" + + @classmethod + def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: + _ = filters + return { + "type": "file-upload", + "config": { + "variable_selector": [], + }, + } + + def _run(self) -> NodeRunResult: + sandbox = self.graph_runtime_state.sandbox + variable_selector = self.node_data.variable_selector + inputs: dict[str, Any] = {"variable_selector": variable_selector} + if sandbox is None: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="Sandbox not available for FileUploadNode.", + error_type="SandboxNotInitializedError", + inputs=inputs, + ) + + variable = self.graph_runtime_state.variable_pool.get(variable_selector) + if variable is None: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=f"File variable not found for selector: {variable_selector}", + error_type=FileUploadNodeError.__name__, + inputs=inputs, + ) + + if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=f"Variable {variable_selector} is not a file or file array", + error_type=FileUploadNodeError.__name__, + inputs=inputs, + ) + + files = self._normalize_files(variable.value) + process_data: dict[str, Any] = { + "file_count": len(files), + "files": [file.to_dict() for file in files], + } + if not files: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + error="Selected file variable is empty.", + error_type=FileUploadNodeError.__name__, + inputs=inputs, + process_data=process_data, + ) + + try: + sandbox.wait_ready(timeout=SANDBOX_READY_TIMEOUT) + download_items: list[SandboxDownloadItem] = self._build_download_items(files) + sandbox_paths = self._upload(sandbox.vm, download_items) + file_names = [PurePosixPath(path).name for path in sandbox_paths] + process_data = { + **process_data, + "sandbox_paths": sandbox_paths, + "file_names": file_names, + } + + outputs: dict[str, Any] + if len(sandbox_paths) == 1: + outputs = { + "sandbox_path": sandbox_paths[0], + "file_name": file_names[0], + } + else: + outputs = { + "sandbox_path": ArrayStringSegment(value=sandbox_paths), + "file_name": ArrayStringSegment(value=file_names), + } + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + + except CommandTimeoutError: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="File upload timeout", + error_type=CommandTimeoutError.__name__, + inputs=inputs, + process_data=process_data, + ) + except CommandCancelledError: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="File upload command was cancelled", + error_type=CommandCancelledError.__name__, + inputs=inputs, + process_data=process_data, + ) + except FileUploadNodeError as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + error_type=type(e).__name__, + inputs=inputs, + process_data=process_data, + ) + except Exception as e: + logger.exception("File upload node %s failed", self.id) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + error_type=type(e).__name__, + inputs=inputs, + process_data=process_data, + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: FileUploadNodeData, + ) -> Mapping[str, Sequence[str]]: + _ = graph_config + typed_node_data = node_data + return {node_id + ".files": typed_node_data.variable_selector} + + @staticmethod + def _normalize_files(value: Any) -> list[File]: + if isinstance(value, File): + return [value] + if isinstance(value, list): + list_value = cast(list[object], value) + files: list[File] = [] + for idx in range(len(list_value)): + candidate = list_value[idx] + if not isinstance(candidate, File): + return [] + files.append(candidate) + return files + return [] + + def _build_download_items(self, files: Sequence[File]) -> list[SandboxDownloadItem]: + used_paths: set[str] = set() + items: list[SandboxDownloadItem] = [] + for index, file in enumerate(files): + file_url = self._get_download_url(file) + + filename = (file.filename or "").strip() + if not filename or filename in {".", ".."}: + filename = f"file-{index + 1}{file.extension or ''}" + filename = os.path.basename(filename) + + if filename in used_paths: + stem = PurePosixPath(filename).stem or f"file-{index + 1}" + suffix = PurePosixPath(filename).suffix + dedupe = 1 + while filename in used_paths: + filename = f"{stem}_{dedupe}{suffix}" + dedupe += 1 + + used_paths.add(filename) + items.append(SandboxDownloadItem(path=filename, url=file_url)) + return items + + @staticmethod + def _normalize_path(path: str) -> str: + normalized = posixpath.normpath(path.strip()) if path else "." + if normalized.startswith("/"): + normalized = normalized.lstrip("/") + return normalized or "." + + def _upload(self, vm: Any, items: list[SandboxDownloadItem]) -> list[str]: + p = pipeline(vm) + out_paths: list[str] = [] + for item in items: + out_path = self._normalize_path(item.path) + if out_path in ("", "."): + raise FileUploadDownloadError("Download item path must point to a file") + out_paths.append(out_path) + p.add(["curl", "-fsSL", item.url, "-o", out_path], error_message="Failed to download file") + + try: + p.execute(timeout=None, raise_on_error=True) + except Exception as exc: + raise FileUploadDownloadError(str(exc)) from exc + + return out_paths + + def _get_download_url(self, file: File) -> str: + if file.transfer_method == FileTransferMethod.REMOTE_URL: + if not file.remote_url: + raise FileUploadDownloadError("Remote file URL is missing") + return file.remote_url + + if file.transfer_method in ( + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.TOOL_FILE, + FileTransferMethod.DATASOURCE_FILE, + ): + download_url = file.generate_url(for_external=True) + if not download_url: + raise FileUploadDownloadError("Unable to generate download URL for file") + return download_url + + raise FileUploadDownloadError(f"Unsupported file transfer method: {file.transfer_method}") diff --git a/api/core/zip_sandbox/__init__.py b/api/core/zip_sandbox/__init__.py new file mode 100644 index 0000000000..266e6c7dc2 --- /dev/null +++ b/api/core/zip_sandbox/__init__.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .entities import SandboxDownloadItem, SandboxFile, SandboxUploadItem + +if TYPE_CHECKING: + from .zip_sandbox import ZipSandbox + +__all__ = [ + "SandboxDownloadItem", + "SandboxFile", + "SandboxUploadItem", + "ZipSandbox", +] + + +def __getattr__(name: str): + if name == "ZipSandbox": + from .zip_sandbox import ZipSandbox + + return ZipSandbox + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/api/core/zip_sandbox/cli_strategy.py b/api/core/zip_sandbox/cli_strategy.py new file mode 100644 index 0000000000..e41365dc6a --- /dev/null +++ b/api/core/zip_sandbox/cli_strategy.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import posixpath +from typing import TYPE_CHECKING + +from core.virtual_environment.__base.exec import CommandExecutionError +from core.virtual_environment.__base.helpers import execute, try_execute + +from .strategy import ZipStrategy + +if TYPE_CHECKING: + from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + + +class CliZipStrategy(ZipStrategy): + """Strategy using native zip/unzip CLI commands.""" + + def is_available(self, vm: VirtualEnvironment) -> bool: + result = try_execute(vm, ["which", "zip"], timeout=10) + has_zip = bool(result.stdout and result.stdout.strip()) + result = try_execute(vm, ["which", "unzip"], timeout=10) + has_unzip = bool(result.stdout and result.stdout.strip()) + return has_zip and has_unzip + + def zip( + self, + vm: VirtualEnvironment, + *, + src: str, + out_path: str, + cwd: str | None, + timeout: float, + ) -> None: + if src in (".", ""): + result = try_execute(vm, ["zip", "-qr", out_path, "."], timeout=timeout, cwd=cwd) + if not result.is_error: + return + # zip exits with 12 when there is nothing to do; create empty zip + if result.exit_code == 12: + self._write_empty_zip(vm, out_path) + return + raise CommandExecutionError("Failed to create zip archive", result) + + zip_cwd = posixpath.dirname(src) or "." + target = posixpath.basename(src) + result = try_execute(vm, ["zip", "-qr", out_path, target], timeout=timeout, cwd=zip_cwd) + if not result.is_error: + return + if result.exit_code == 12: + self._write_empty_zip(vm, out_path) + return + raise CommandExecutionError("Failed to create zip archive", result) + + def unzip( + self, + vm: VirtualEnvironment, + *, + archive_path: str, + dest_dir: str, + timeout: float, + ) -> None: + execute( + vm, + ["unzip", "-q", archive_path, "-d", dest_dir], + timeout=timeout, + error_message="Failed to unzip archive", + ) + + def _write_empty_zip(self, vm: VirtualEnvironment, out_path: str) -> None: + """Write an empty but valid zip file.""" + script = ( + 'printf "' + "\\x50\\x4b\\x05\\x06" + "\\x00\\x00\\x00\\x00" + "\\x00\\x00\\x00\\x00" + "\\x00\\x00\\x00\\x00" + "\\x00\\x00\\x00\\x00" + "\\x00\\x00\\x00\\x00" + '" > "$1"' + ) + execute(vm, ["sh", "-c", script, "sh", out_path], timeout=30, error_message="Failed to write empty zip") diff --git a/api/core/zip_sandbox/entities.py b/api/core/zip_sandbox/entities.py new file mode 100644 index 0000000000..fc350899cf --- /dev/null +++ b/api/core/zip_sandbox/entities.py @@ -0,0 +1,39 @@ +"""Data classes for ZipSandbox file operations. + +Separated from ``zip_sandbox.py`` so that lightweight consumers (tests, +shell-script builders) can import the types without pulling in the full +sandbox provider chain. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class SandboxDownloadItem: + """Unified download/inline item for sandbox file operations. + + For remote files, *url* is set and the item is fetched via ``curl``. + For inline content, *content* is set and the bytes are written directly + into the VM via ``upload_file`` — no network round-trip. + """ + + path: str + url: str = "" + content: bytes | None = field(default=None, repr=False) + + +@dataclass(frozen=True) +class SandboxUploadItem: + """Item for uploading: sandbox path -> URL.""" + + path: str + url: str + + +@dataclass(frozen=True) +class SandboxFile: + """A handle to a file in the sandbox.""" + + path: str diff --git a/api/core/zip_sandbox/node_strategy.py b/api/core/zip_sandbox/node_strategy.py new file mode 100644 index 0000000000..d77a236a6d --- /dev/null +++ b/api/core/zip_sandbox/node_strategy.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.virtual_environment.__base.helpers import execute, try_execute + +from .strategy import ZipStrategy + +if TYPE_CHECKING: + from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + + +ZIP_SCRIPT = r""" +const fs = require('fs'); +const path = require('path'); +const AdmZip = require('adm-zip'); + +const src = process.argv[2]; +const outPath = process.argv[3]; + +function walkAdd(zip, absPath, arcPrefix) { + const stat = fs.statSync(absPath); + if (stat.isDirectory()) { + const entries = fs.readdirSync(absPath); + if (entries.length === 0) { + zip.addFile(arcPrefix.replace(/\\/g, '/') + '/', Buffer.alloc(0)); + return; + } + for (const e of entries) { + walkAdd(zip, path.join(absPath, e), path.posix.join(arcPrefix, e)); + } + return; + } + if (stat.isFile()) { + const data = fs.readFileSync(absPath); + zip.addFile(arcPrefix.replace(/\\/g, '/'), data); + } +} + +const zip = new AdmZip(); +if (src === '.' || src === '') { + const entries = fs.readdirSync('.'); + for (const e of entries) { + walkAdd(zip, path.join('.', e), e); + } +} else { + const base = path.dirname(src) || '.'; + const prefix = path.basename(src.replace(/\/+$/, '')); + const root = path.join(base, prefix); + walkAdd(zip, root, prefix); +} + +zip.writeZip(outPath); +""" + +UNZIP_SCRIPT = r""" +const AdmZip = require('adm-zip'); +const archivePath = process.argv[2]; +const destDir = process.argv[3]; +const zip = new AdmZip(archivePath); +zip.extractAllTo(destDir, true); +""" + + +class NodeZipStrategy(ZipStrategy): + """Strategy using Node.js with adm-zip package.""" + + def is_available(self, vm: VirtualEnvironment) -> bool: + result = try_execute(vm, ["which", "node"], timeout=10) + if not (result.stdout and result.stdout.strip()): + return False + # Check if adm-zip module is available + result = try_execute(vm, ["node", "-e", "require('adm-zip')"], timeout=10) + return not result.is_error + + def zip( + self, + vm: VirtualEnvironment, + *, + src: str, + out_path: str, + cwd: str | None, + timeout: float, + ) -> None: + execute( + vm, + ["node", "-e", ZIP_SCRIPT, src, out_path], + timeout=timeout, + cwd=cwd, + error_message="Failed to create zip archive", + ) + + def unzip( + self, + vm: VirtualEnvironment, + *, + archive_path: str, + dest_dir: str, + timeout: float, + ) -> None: + execute( + vm, + ["node", "-e", UNZIP_SCRIPT, archive_path, dest_dir], + timeout=timeout, + error_message="Failed to unzip archive", + ) diff --git a/api/core/zip_sandbox/python_strategy.py b/api/core/zip_sandbox/python_strategy.py new file mode 100644 index 0000000000..5eae7efdac --- /dev/null +++ b/api/core/zip_sandbox/python_strategy.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.virtual_environment.__base.helpers import execute, try_execute + +from .strategy import ZipStrategy + +if TYPE_CHECKING: + from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + + +ZIP_SCRIPT = r""" +import os +import sys +import zipfile + +src = sys.argv[1] +out_path = sys.argv[2] + +def is_cwd(p: str) -> bool: + return p in (".", "") + +src = src.rstrip("/") + +if is_cwd(src): + base = "." + root = "." + prefix = "" +else: + base = os.path.dirname(src) or "." + prefix = os.path.basename(src) + root = os.path.join(base, prefix) + +def add_empty_dir(zf: zipfile.ZipFile, arc_dir: str) -> None: + name = arc_dir.rstrip("/") + "/" + if name != "/": + zf.writestr(name, b"") + +with zipfile.ZipFile(out_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: + if os.path.isfile(root): + zf.write(root, arcname=os.path.basename(root)) + else: + for dirpath, dirnames, filenames in os.walk(root): + rel_dir = os.path.relpath(dirpath, base) + rel_dir = "" if rel_dir == "." else rel_dir + if not dirnames and not filenames: + add_empty_dir(zf, rel_dir) + for fn in filenames: + fp = os.path.join(dirpath, fn) + arcname = os.path.join(rel_dir, fn) if rel_dir else fn + zf.write(fp, arcname=arcname) +""" + +UNZIP_SCRIPT = r""" +import sys +import zipfile + +archive_path = sys.argv[1] +dest_dir = sys.argv[2] + +with zipfile.ZipFile(archive_path, "r") as zf: + zf.extractall(dest_dir) +""" + + +class PythonZipStrategy(ZipStrategy): + """Strategy using Python's zipfile module.""" + + def __init__(self) -> None: + self._python_cmd: str | None = None + + def is_available(self, vm: VirtualEnvironment) -> bool: + for cmd in ("python3", "python"): + result = try_execute(vm, ["which", cmd], timeout=10) + if result.stdout and result.stdout.strip(): + self._python_cmd = cmd + return True + return False + + def zip( + self, + vm: VirtualEnvironment, + *, + src: str, + out_path: str, + cwd: str | None, + timeout: float, + ) -> None: + if self._python_cmd is None: + raise RuntimeError("Python not available") + + execute( + vm, + [self._python_cmd, "-c", ZIP_SCRIPT, src, out_path], + timeout=timeout, + cwd=cwd, + error_message="Failed to create zip archive", + ) + + def unzip( + self, + vm: VirtualEnvironment, + *, + archive_path: str, + dest_dir: str, + timeout: float, + ) -> None: + if self._python_cmd is None: + raise RuntimeError("Python not available") + + execute( + vm, + [self._python_cmd, "-c", UNZIP_SCRIPT, archive_path, dest_dir], + timeout=timeout, + error_message="Failed to unzip archive", + ) diff --git a/api/core/zip_sandbox/strategy.py b/api/core/zip_sandbox/strategy.py new file mode 100644 index 0000000000..f7356c4f79 --- /dev/null +++ b/api/core/zip_sandbox/strategy.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from core.virtual_environment.__base.virtual_environment import VirtualEnvironment + + +class ZipStrategy(ABC): + """Abstract base class for zip/unzip strategies.""" + + @abstractmethod + def is_available(self, vm: VirtualEnvironment) -> bool: + """Check if this strategy is available in the given VM.""" + ... + + @abstractmethod + def zip( + self, + vm: VirtualEnvironment, + *, + src: str, + out_path: str, + cwd: str | None, + timeout: float, + ) -> None: + """Create a zip archive.""" + ... + + @abstractmethod + def unzip( + self, + vm: VirtualEnvironment, + *, + archive_path: str, + dest_dir: str, + timeout: float, + ) -> None: + """Extract a zip archive.""" + ... diff --git a/api/core/zip_sandbox/zip_sandbox.py b/api/core/zip_sandbox/zip_sandbox.py new file mode 100644 index 0000000000..0b8f60e01b --- /dev/null +++ b/api/core/zip_sandbox/zip_sandbox.py @@ -0,0 +1,425 @@ +from __future__ import annotations + +import base64 +import posixpath +import shlex +from io import BytesIO +from pathlib import PurePosixPath +from types import TracebackType +from typing import Any +from urllib.parse import urlparse +from uuid import uuid4 + +from core.sandbox.builder import SandboxBuilder +from core.sandbox.entities.sandbox_type import SandboxType +from core.sandbox.sandbox import Sandbox +from core.sandbox.storage.noop_storage import NoopSandboxStorage +from core.virtual_environment.__base.exec import CommandExecutionError, PipelineExecutionError +from core.virtual_environment.__base.helpers import execute, pipeline +from core.virtual_environment.__base.virtual_environment import VirtualEnvironment +from services.sandbox.sandbox_provider_service import SandboxProviderService + +from .cli_strategy import CliZipStrategy +from .entities import SandboxDownloadItem, SandboxFile, SandboxUploadItem +from .node_strategy import NodeZipStrategy +from .python_strategy import PythonZipStrategy +from .strategy import ZipStrategy + + +class ZipSandbox: + """A sandbox for archive (zip) operations. + + Usage: + with ZipSandbox(tenant_id=..., user_id=...) as zs: + zs.download_items(items) + archive = zs.zip() + zs.upload(archive, upload_url) + # VM automatically released on exit + """ + + _DEFAULT_TIMEOUT_SECONDS = 60 * 5 + _STRATEGIES: list[ZipStrategy] = [CliZipStrategy(), PythonZipStrategy(), NodeZipStrategy()] + + def __init__( + self, + *, + tenant_id: str | None = None, + user_id: str | None = None, + app_id: str = "zip-sandbox", + sandbox_provider_type: str | None = None, + sandbox_provider_options: dict[str, Any] | None = None, + _vm: VirtualEnvironment | None = None, + ) -> None: + self._tenant_id = tenant_id + self._user_id = user_id + self._app_id = app_id + self._sandbox_provider_type = sandbox_provider_type + self._sandbox_provider_options = sandbox_provider_options + self._injected_vm = _vm + + self._sandbox: Sandbox | None = None + self._sandbox_id: str | None = None + self._vm: VirtualEnvironment | None = None + self._strategy: ZipStrategy | None = None + + def __enter__(self) -> ZipSandbox: + self._start() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self._stop() + + def _start(self) -> None: + if self._vm is not None: + raise RuntimeError("ZipSandbox already started") + + if self._injected_vm is not None: + self._vm = self._injected_vm + self._sandbox_id = uuid4().hex + return + + if not self._tenant_id: + raise ValueError("tenant_id is required") + if not self._user_id: + raise ValueError("user_id is required") + + if self._sandbox_provider_type is None or self._sandbox_provider_options is None: + provider = SandboxProviderService.get_sandbox_provider(self._tenant_id) + provider_type = provider.provider_type + provider_options = dict(provider.config) + else: + provider_type = self._sandbox_provider_type + provider_options = dict(self._sandbox_provider_options) + + self._sandbox_id = uuid4().hex + + storage = NoopSandboxStorage() + try: + self._sandbox = ( + SandboxBuilder(self._tenant_id, SandboxType(provider_type)) + .options(provider_options) + .user(self._user_id) + .app(self._app_id) + .storage(storage, assets_id="zip-sandbox") + .build() + ) + self._sandbox.wait_ready(timeout=60) + self._vm = self._sandbox.vm + except Exception: + if self._sandbox is not None: + self._sandbox.release() + self._vm = None + self._sandbox = None + self._sandbox_id = None + raise + + def _stop(self) -> None: + if self._vm is None: + return + + if self._sandbox is not None: + self._sandbox.release() + + self._vm = None + self._sandbox = None + self._sandbox_id = None + self._strategy = None + + @property + def vm(self) -> VirtualEnvironment: + if self._vm is None: + raise RuntimeError("ZipSandbox not started. Use 'with ZipSandbox(...) as zs:'") + return self._vm + + def _get_strategy(self) -> ZipStrategy: + if self._strategy is not None: + return self._strategy + + for strategy in self._STRATEGIES: + if strategy.is_available(self.vm): + self._strategy = strategy + return strategy + + raise RuntimeError("No available zip backend (zip/python/node+adm-zip)") + + # ========== Path utilities ========== + + @staticmethod + def _normalize_path(path: str | None) -> str: + raw = (path or ".").strip() + if raw == "": + raw = "." + + p = PurePosixPath(raw) + if p.is_absolute(): + raise ValueError("path must be relative") + if any(part == ".." for part in p.parts): + raise ValueError("path must not contain '..'") + + normalized = str(p) + return "." if normalized in (".", "") else normalized + + @staticmethod + def _dest_path_for_url(dest_dir: str, url: str) -> str: + parsed = urlparse(url) + path = parsed.path or "" + name = posixpath.basename(path) + if not name: + name = "download.bin" + return posixpath.join(dest_dir, name) + + # ========== File operations ========== + + def write_file(self, path: str, data: bytes) -> None: + path = self._normalize_path(path) + if path in ("", "."): + raise ValueError("path must point to a file") + + try: + self.vm.upload_file(path, BytesIO(data)) + except Exception as exc: + raise RuntimeError(f"Failed to write file to sandbox: {exc}") from exc + + def read_file(self, path: str, *, max_bytes: int = 10 * 1024 * 1024) -> bytes: + path = self._normalize_path(path) + if max_bytes <= 0: + raise ValueError("max_bytes must be positive") + + try: + data = self.vm.download_file(path).getvalue() + except Exception as exc: + raise RuntimeError(f"Failed to read file from sandbox: {exc}") from exc + + if len(data) > max_bytes: + raise ValueError(f"File too large: {len(data)} > {max_bytes}") + return data + + # ========== Download operations ========== + + def download_items(self, items: list[SandboxDownloadItem], *, dest_dir: str = ".") -> list[str]: + """Download or write items into the sandbox via a single pipeline. + + Remote items (with *url*) are fetched via ``curl``. Inline items + (with *content*) are written via ``base64 -d`` heredoc. Both go + through the same pipeline — no branching at the structural level. + """ + if not items: + return [] + + dest_dir = self._normalize_path(dest_dir) + p = pipeline(self.vm) + p.add(["mkdir", "-p", dest_dir], error_message="Failed to create download directory") + + out_paths: list[str] = [] + for item in items: + rel = self._normalize_path(item.path) + if rel in ("", "."): + raise ValueError("Download item path must point to a file") + out_path = posixpath.join(dest_dir, rel) + out_paths.append(out_path) + out_dir = posixpath.dirname(out_path) + if out_dir not in ("", "."): + p.add(["mkdir", "-p", out_dir], error_message="Failed to create download directory") + p.add( + self.to_download_command(item, out_path), + error_message=f"Failed to write {item.path}", + ) + + try: + p.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True) + except Exception as exc: + raise RuntimeError(str(exc)) from exc + + return out_paths + + @staticmethod + def to_download_command(item: SandboxDownloadItem, out_path: str) -> list[str]: + """Return the shell command to materialise *item* at *out_path*.""" + if item.content is not None: + encoded = base64.b64encode(item.content).decode("ascii") + return ["sh", "-c", f"base64 -d <<'_B64_' > {shlex.quote(out_path)}\n{encoded}\n_B64_"] + return ["curl", "-fsSL", item.url, "-o", out_path] + + def download_archive(self, archive_url: str, *, path: str = "input.tar.gz") -> str: + path = self._normalize_path(path) + + dir_path = posixpath.dirname(path) + p = pipeline(self.vm) + if dir_path not in ("", "."): + p.add(["mkdir", "-p", dir_path], error_message=f"Failed to create directory {dir_path}") + p.add(["curl", "-fsSL", archive_url, "-o", path], error_message=f"Failed to download archive to {path}") + + try: + p.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True) + except Exception as exc: + raise RuntimeError(str(exc)) from exc + + return path + + # ========== Upload operations ========== + + def upload(self, file: SandboxFile, target_url: str) -> None: + """Upload a sandbox file to the given URL.""" + try: + execute( + self.vm, + ["curl", "-fsSL", "-X", "PUT", "-T", file.path, target_url], + timeout=self._DEFAULT_TIMEOUT_SECONDS, + error_message="Failed to upload file from sandbox", + ) + except CommandExecutionError as exc: + raise RuntimeError(str(exc)) from exc + + def upload_items(self, items: list[SandboxUploadItem], *, src_dir: str = ".") -> None: + """Upload multiple files from sandbox to target URLs. + + Args: + items: List of SandboxUploadItem(path, url) + src_dir: Base directory containing the files + """ + if not items: + return + + src_dir = self._normalize_path(src_dir) + p = pipeline(self.vm) + + for item in items: + rel = self._normalize_path(item.path) + src_path = posixpath.join(src_dir, rel) if src_dir not in ("", ".") else rel + p.add( + ["curl", "-fsSL", "-X", "PUT", "-T", src_path, item.url], + error_message=f"Failed to upload {item.path}", + ) + + try: + p.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True) + except Exception as exc: + raise RuntimeError(str(exc)) from exc + + # ========== Archive operations ========== + + def zip(self, src: str = ".", *, include_base: bool = True) -> SandboxFile: + """Create a zip archive and return a handle to it.""" + src = self._normalize_path(src) + out_path = f"/tmp/{uuid4().hex}.zip" + + cwd = None + src_for_strategy = src + if src not in (".", "") and not include_base: + cwd = src + src_for_strategy = "." + + try: + self._get_strategy().zip( + self.vm, + src=src_for_strategy, + out_path=out_path, + cwd=cwd, + timeout=self._DEFAULT_TIMEOUT_SECONDS, + ) + except (PipelineExecutionError, CommandExecutionError) as exc: + raise RuntimeError(str(exc)) from exc + + return SandboxFile(path=out_path) + + def unzip(self, *, archive_path: str, dest_dir: str = "unpacked") -> str: + """Extract a zip archive to the destination directory.""" + archive_path = self._normalize_path(archive_path) + dest_dir = self._normalize_path(dest_dir) + + if not archive_path.lower().endswith(".zip"): + raise ValueError("archive_path must end with .zip") + + try: + pipeline(self.vm).add( + ["mkdir", "-p", dest_dir], error_message="Failed to create destination directory" + ).execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True) + + self._get_strategy().unzip( + self.vm, + archive_path=archive_path, + dest_dir=dest_dir, + timeout=self._DEFAULT_TIMEOUT_SECONDS, + ) + except (PipelineExecutionError, CommandExecutionError) as exc: + raise RuntimeError(str(exc)) from exc + + return dest_dir + + def untar(self, *, archive_path: str, dest_dir: str = "unpacked") -> str: + """Extract a tar archive to the destination directory.""" + archive_path = self._normalize_path(archive_path) + dest_dir = self._normalize_path(dest_dir) + + lower = archive_path.lower() + is_gz = lower.endswith(".tar.gz") or lower.endswith(".tgz") + extract_flag = "-xzf" if is_gz else "-xf" + + try: + ( + pipeline(self.vm) + .add(["mkdir", "-p", dest_dir], error_message="Failed to create destination directory") + .add( + ["sh", "-c", f'tar {extract_flag} "$1" -C "$2" 2>/dev/null; exit $?', "sh", archive_path, dest_dir], + error_message="Failed to extract tar archive", + ) + .execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True) + ) + except PipelineExecutionError as exc: + raise RuntimeError(str(exc)) from exc + + return dest_dir + + def tar(self, src: str = ".", *, include_base: bool = True, compress: bool = True) -> SandboxFile: + """Create a tar archive and return a handle to it. + + Args: + src: Source path to archive (file or directory) + include_base: If True, include the base directory name in the archive + compress: If True, create a gzipped tar archive (.tar.gz) + + Returns: + SandboxFile handle to the created archive + """ + src = self._normalize_path(src) + extension = ".tar.gz" if compress else ".tar" + out_path = f"/tmp/{uuid4().hex}{extension}" + + create_flag = "-czf" if compress else "-cf" + + try: + if src in (".", ""): + # Archive current directory contents + execute( + self.vm, + ["tar", create_flag, out_path, "-C", ".", "."], + timeout=self._DEFAULT_TIMEOUT_SECONDS, + error_message="Failed to create tar archive", + ) + elif include_base: + # Archive with base directory name included + parent_dir = posixpath.dirname(src) or "." + base_name = posixpath.basename(src) + execute( + self.vm, + ["tar", create_flag, out_path, "-C", parent_dir, base_name], + timeout=self._DEFAULT_TIMEOUT_SECONDS, + error_message="Failed to create tar archive", + ) + else: + # Archive contents without base directory name + execute( + self.vm, + ["tar", create_flag, out_path, "-C", src, "."], + timeout=self._DEFAULT_TIMEOUT_SECONDS, + error_message="Failed to create tar archive", + ) + except CommandExecutionError as exc: + raise RuntimeError(str(exc)) from exc + + return SandboxFile(path=out_path) diff --git a/api/dify_graph/entities/tool_entities.py b/api/dify_graph/entities/tool_entities.py new file mode 100644 index 0000000000..113751b362 --- /dev/null +++ b/api/dify_graph/entities/tool_entities.py @@ -0,0 +1,41 @@ +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, Field + +from graphon.file import File + + +class ToolResultStatus(StrEnum): + SUCCESS = "success" + ERROR = "error" + + +class ToolCall(BaseModel): + id: str | None = Field(default=None, description="Unique identifier for this tool call") + name: str | None = Field(default=None, description="Name of the tool being called") + arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON") + icon: str | dict | None = Field(default=None, description="Icon of the tool") + icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool") + + +class ToolResult(BaseModel): + id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to") + name: str | None = Field(default=None, description="Name of the tool") + output: str | None = Field(default=None, description="Tool output text, error or success message") + files: list[str] = Field(default_factory=list, description="File produced by tool") + status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status") + elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool") + icon: str | dict[str, Any] | None = Field(default=None, description="Icon of the tool") + icon_dark: str | dict[str, Any] | None = Field(default=None, description="Dark theme icon of the tool") + provider: str | None = Field(default=None, description="Tool provider identifier") + + +class ToolCallResult(BaseModel): + id: str | None = Field(default=None, description="Identifier for the tool call") + name: str | None = Field(default=None, description="Name of the tool") + arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON") + output: str | None = Field(default=None, description="Tool output text, error or success message") + files: list[File] = Field(default_factory=list, description="File produced by tool") + status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status") + elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool") diff --git a/api/dify_graph/nodes/agent/agent_node.py b/api/dify_graph/nodes/agent/agent_node.py new file mode 100644 index 0000000000..d9f439036e --- /dev/null +++ b/api/dify_graph/nodes/agent/agent_node.py @@ -0,0 +1,929 @@ +from __future__ import annotations + +import json +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, cast + +from packaging.version import Version +from pydantic import ValidationError +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.agent.entities import AgentToolEntity +from core.agent.plugin_entities import AgentStrategyParameter +from core.memory.base import BaseMemory +from core.memory.node_token_buffer_memory import NodeTokenBufferMemory +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance, ModelManager +from core.prompt.entities.advanced_prompt_entities import MemoryMode +from core.provider_manager import ProviderManager +from core.tools.entities.tool_entities import ( + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated +from core.workflow.nodes.agent.exceptions import ( + AgentInputTypeError, + AgentInvocationError, + AgentMessageTransformError, + AgentNodeError, + AgentVariableNotFoundError, + AgentVariableTypeError, + ToolFileNotFoundError, +) +from graphon.enums import ( + BuiltinNodeTypes, + NodeType, + SystemVariableKey, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import File, FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import ( + AgentLogEvent, + NodeEventBase, + NodeRunResult, + StreamChunkEvent, + StreamCompletedEvent, +) +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.runtime import VariablePool +from graphon.variables.segments import ArrayFileSegment, StringSegment +from extensions.ext_database import db +from factories import file_factory +from factories.agent_factory import get_plugin_agent_strategy +from models import ToolFile +from models.model import Conversation +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +if TYPE_CHECKING: + from core.agent.strategy.plugin import PluginAgentStrategy + from core.plugin.entities.request import InvokeCredentials + + +class AgentNode(Node[AgentNodeData]): + """ + Agent Node + """ + + node_type = BuiltinNodeTypes.AGENT + + @classmethod + def version(cls) -> str: + return "1" + + def _run(self) -> Generator[NodeEventBase, None, None]: + from core.plugin.impl.exc import PluginDaemonClientSideError + + dify_ctx = self.require_dify_context() + + try: + strategy = get_plugin_agent_strategy( + tenant_id=dify_ctx.tenant_id, + agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, + agent_strategy_name=self.node_data.agent_strategy_name, + ) + except Exception as e: + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error=f"Failed to get agent strategy: {str(e)}", + ), + ) + return + + agent_parameters = strategy.get_parameters() + + # get parameters + parameters = self._generate_agent_parameters( + agent_parameters=agent_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, + strategy=strategy, + ) + parameters_for_log = self._generate_agent_parameters( + agent_parameters=agent_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, + for_log=True, + strategy=strategy, + ) + credentials = self._generate_credentials(parameters=parameters) + + # get conversation id + conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + + try: + message_stream = strategy.invoke( + params=parameters, + user_id=dify_ctx.user_id, + app_id=dify_ctx.app_id, + conversation_id=conversation_id.text if conversation_id else None, + credentials=credentials, + ) + except Exception as e: + error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + error=str(error), + ) + ) + return + + # Fetch memory for node memory saving + memory = self._fetch_memory_for_save() + + try: + yield from self._transform_message( + messages=message_stream, + tool_info={ + "icon": self.agent_strategy_icon, + "agent_strategy": self.node_data.agent_strategy_name, + }, + parameters_for_log=parameters_for_log, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, + node_type=self.node_type, + node_id=self._node_id, + node_execution_id=self.id, + memory=memory, + ) + except PluginDaemonClientSideError as e: + transform_error = AgentMessageTransformError( + f"Failed to transform agent message: {str(e)}", original_error=e + ) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=parameters_for_log, + error=str(transform_error), + ) + ) + + def _generate_agent_parameters( + self, + *, + agent_parameters: Sequence[AgentStrategyParameter], + variable_pool: VariablePool, + node_data: AgentNodeData, + for_log: bool = False, + strategy: PluginAgentStrategy, + ) -> dict[str, Any]: + """ + Generate parameters based on the given tool parameters, variable pool, and node data. + + Args: + agent_parameters (Sequence[AgentParameter]): The list of agent parameters. + variable_pool (VariablePool): The variable pool containing the variables. + node_data (AgentNodeData): The data associated with the agent node. + + Returns: + Mapping[str, Any]: A dictionary containing the generated parameters. + + """ + agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters} + + result: dict[str, Any] = {} + for parameter_name in node_data.agent_parameters: + parameter = agent_parameters_dictionary.get(parameter_name) + if not parameter: + result[parameter_name] = None + continue + agent_input = node_data.agent_parameters[parameter_name] + match agent_input.type: + case "variable": + variable = variable_pool.get(agent_input.value) # type: ignore + if variable is None: + raise AgentVariableNotFoundError(str(agent_input.value)) + parameter_value = variable.value + case "mixed" | "constant": + # variable_pool.convert_template expects a string template, + # but if passing a dict, convert to JSON string first before rendering + try: + if not isinstance(agent_input.value, str): + parameter_value = json.dumps(agent_input.value, ensure_ascii=False) + else: + parameter_value = str(agent_input.value) + except TypeError: + parameter_value = str(agent_input.value) + segment_group = variable_pool.convert_template(parameter_value) + parameter_value = segment_group.log if for_log else segment_group.text + # variable_pool.convert_template returns a string, + # so we need to convert it back to a dictionary + try: + if not isinstance(agent_input.value, str): + parameter_value = json.loads(parameter_value) + except json.JSONDecodeError: + parameter_value = parameter_value + case _: + raise AgentInputTypeError(agent_input.type) + value = parameter_value + if parameter.type == "array[tools]": + value = cast(list[dict[str, Any]], value) + value = [tool for tool in value if tool.get("enabled", False)] + value = self._filter_mcp_type_tool(strategy, value) + for tool in value: + if "schemas" in tool: + tool.pop("schemas") + parameters = tool.get("parameters", {}) + if all(isinstance(v, dict) for _, v in parameters.items()): + params = {} + for key, param in parameters.items(): + if param.get("auto", ParamsAutoGenerated.OPEN) in ( + ParamsAutoGenerated.CLOSE, + 0, + ): + value_param = param.get("value", {}) + if value_param and value_param.get("type", "") == "variable": + variable_selector = value_param.get("value") + if not variable_selector: + raise ValueError("Variable selector is missing for a variable-type parameter.") + + variable = variable_pool.get(variable_selector) + if variable is None: + raise AgentVariableNotFoundError(str(variable_selector)) + + params[key] = variable.value + else: + params[key] = value_param.get("value", "") if value_param is not None else None + else: + params[key] = None + parameters = params + tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()} + tool["parameters"] = parameters + + if not for_log: + if parameter.type == "array[tools]": + value = cast(list[dict[str, Any]], value) + tool_value = [] + for tool in value: + provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) + setting_params = tool.get("settings", {}) + parameters = tool.get("parameters", {}) + manual_input_params = [key for key, value in parameters.items() if value is not None] + + parameters = {**parameters, **setting_params} + entity = AgentToolEntity( + provider_id=tool.get("provider_name", ""), + provider_type=provider_type, + tool_name=tool.get("tool_name", ""), + tool_parameters=parameters, + plugin_unique_identifier=tool.get("plugin_unique_identifier", None), + credential_id=tool.get("credential_id", None), + ) + + extra = tool.get("extra", {}) + + # This is an issue that caused problems before. + # Logically, we shouldn't use the node_data.version field for judgment + # But for backward compatibility with historical data + # this version field judgment is still preserved here. + runtime_variable_pool: VariablePool | None = None + if node_data.version != "1" or node_data.tool_node_version is not None: + runtime_variable_pool = variable_pool + dify_ctx = self.require_dify_context() + tool_runtime = ToolManager.get_agent_tool_runtime( + dify_ctx.tenant_id, + dify_ctx.app_id, + entity, + dify_ctx.invoke_from, + runtime_variable_pool, + ) + if tool_runtime.entity.description: + tool_runtime.entity.description.llm = ( + extra.get("description", "") or tool_runtime.entity.description.llm + ) + for tool_runtime_params in tool_runtime.entity.parameters: + tool_runtime_params.form = ( + ToolParameter.ToolParameterForm.FORM + if tool_runtime_params.name in manual_input_params + else tool_runtime_params.form + ) + manual_input_value = {} + if tool_runtime.entity.parameters: + manual_input_value = { + key: value for key, value in parameters.items() if key in manual_input_params + } + runtime_parameters = { + **tool_runtime.runtime.runtime_parameters, + **manual_input_value, + } + tool_value.append( + { + **tool_runtime.entity.model_dump(mode="json"), + "runtime_parameters": runtime_parameters, + "credential_id": tool.get("credential_id", None), + "provider_type": provider_type.value, + } + ) + value = tool_value + if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: + value = cast(dict[str, Any], value) + model_instance, model_schema = self._fetch_model(value) + # memory config + history_prompt_messages = [] + if node_data.memory: + memory = self._fetch_memory(model_instance) + if memory: + prompt_messages = memory.get_history_prompt_messages( + message_limit=node_data.memory.window.size or None + ) + history_prompt_messages = [ + prompt_message.model_dump(mode="json") for prompt_message in prompt_messages + ] + value["history_prompt_messages"] = history_prompt_messages + if model_schema: + # remove structured output feature to support old version agent plugin + model_schema = self._remove_unsupported_model_features_for_old_version(model_schema) + value["entity"] = model_schema.model_dump(mode="json") + else: + value["entity"] = None + result[parameter_name] = value + + return result + + def _generate_credentials( + self, + parameters: dict[str, Any], + ) -> InvokeCredentials: + """ + Generate credentials based on the given agent parameters. + """ + from core.plugin.entities.request import InvokeCredentials + + credentials = InvokeCredentials() + + # generate credentials for tools selector + credentials.tool_credentials = {} + for tool in parameters.get("tools", []): + if tool.get("credential_id"): + try: + identity = ToolIdentity.model_validate(tool.get("identity", {})) + credentials.tool_credentials[identity.provider] = tool.get("credential_id", None) + except ValidationError: + continue + return credentials + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: AgentNodeData, + ) -> Mapping[str, Sequence[str]]: + typed_node_data = node_data + + result: dict[str, Any] = {} + for parameter_name in typed_node_data.agent_parameters: + input = typed_node_data.agent_parameters[parameter_name] + match input.type: + case "mixed" | "constant": + selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value + + result = {node_id + "." + key: value for key, value in result.items()} + + return result + + @property + def agent_strategy_icon(self) -> str | None: + """ + Get agent strategy icon + :return: + """ + from core.plugin.impl.plugin import PluginInstaller + + manager = PluginInstaller() + dify_ctx = self.require_dify_context() + plugins = manager.list_plugins(dify_ctx.tenant_id) + try: + current_plugin = next( + plugin + for plugin in plugins + if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name + ) + icon = current_plugin.declaration.icon + except StopIteration: + icon = None + return icon + + def _fetch_memory(self, model_instance: ModelInstance) -> BaseMemory | None: + """ + Fetch memory based on configuration mode. + + Returns TokenBufferMemory for conversation mode (default), + or NodeTokenBufferMemory for node mode (Chatflow only). + """ + node_data = self.node_data + memory_config = node_data.memory + + if not memory_config: + return None + + # get conversation id (required for both modes in Chatflow) + conversation_id_variable = self.graph_runtime_state.variable_pool.get( + ["sys", SystemVariableKey.CONVERSATION_ID] + ) + if not isinstance(conversation_id_variable, StringSegment): + return None + conversation_id = conversation_id_variable.value + + dify_ctx = self.require_dify_context() + if memory_config.mode == MemoryMode.NODE: + return NodeTokenBufferMemory( + app_id=dify_ctx.app_id, + conversation_id=conversation_id, + node_id=self._node_id, + tenant_id=dify_ctx.tenant_id, + model_instance=model_instance, + ) + else: + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Conversation).where( + Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id + ) + conversation = session.scalar(stmt) + if not conversation: + return None + return TokenBufferMemory(conversation=conversation, model_instance=model_instance) + + def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: + dify_ctx = self.require_dify_context() + provider_manager = ProviderManager() + provider_model_bundle = provider_manager.get_provider_model_bundle( + tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM + ) + model_name = value.get("model", "") + model_credentials = provider_model_bundle.configuration.get_current_credentials( + model_type=ModelType.LLM, model=model_name + ) + provider_name = provider_model_bundle.configuration.provider.provider + model_type_instance = provider_model_bundle.model_type_instance + model_instance = ModelManager().get_model_instance( + tenant_id=dify_ctx.tenant_id, + provider=provider_name, + model_type=ModelType(value.get("model_type", "")), + model=model_name, + ) + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + return model_instance, model_schema + + def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity: + if model_schema.features: + for feature in model_schema.features[:]: # Create a copy to safely modify during iteration + try: + AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value + except ValueError: + model_schema.features.remove(feature) + return model_schema + + def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Filter MCP type tool + :param strategy: plugin agent strategy + :param tool: tool + :return: filtered tool dict + """ + meta_version = strategy.meta_version + if meta_version and Version(meta_version) > Version("0.0.1"): + return tools + else: + return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP] + + def _fetch_memory_for_save(self) -> BaseMemory | None: + """ + Fetch memory instance for saving node memory. + This is a simplified version that doesn't require model_instance. + """ + from core.model_manager import ModelManager + from graphon.model_runtime.entities.model_entities import ModelType + + node_data = self.node_data + if not node_data.memory: + return None + + # Get conversation_id + conversation_id_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + if not isinstance(conversation_id_var, StringSegment): + return None + conversation_id = conversation_id_var.value + + # Return appropriate memory type based on mode + if node_data.memory.mode == MemoryMode.NODE: + # For node memory, we need a model_instance for token counting + # Use a simple default model for this purpose + try: + model_instance = ModelManager().get_default_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + ) + except Exception: + return None + + return NodeTokenBufferMemory( + app_id=self.app_id, + conversation_id=conversation_id, + node_id=self._node_id, + tenant_id=self.tenant_id, + model_instance=model_instance, + ) + else: + # Conversation-level memory doesn't need saving here + return None + + def _build_context( + self, + parameters_for_log: dict[str, Any], + user_query: str, + assistant_response: str, + agent_logs: list[AgentLogEvent], + ) -> list[PromptMessage]: + """ + Build context from user query, tool calls, and assistant response. + Format: user -> assistant(with tool_calls) -> tool -> assistant + + The context includes: + - Current user query (always present, may be empty) + - Assistant message with tool_calls (if tools were called) + - Tool results + - Assistant's final response + """ + context_messages: list[PromptMessage] = [] + + # Always add user query (even if empty, to maintain conversation structure) + context_messages.append(UserPromptMessage(content=user_query or "")) + + # Extract actual tool calls from agent logs + # Only include logs with label starting with "CALL " - these are real tool invocations + tool_calls: list[AssistantPromptMessage.ToolCall] = [] + tool_results: list[tuple[str, str, str]] = [] # (tool_call_id, tool_name, result) + + for log in agent_logs: + if log.status == "success" and log.label and log.label.startswith("CALL "): + # Extract tool name from label (format: "CALL tool_name") + tool_name = log.label[5:] # Remove "CALL " prefix + tool_call_id = log.message_id + + # Parse tool response from data + data = log.data or {} + tool_response = "" + + # Try to extract the actual tool response + if "tool_response" in data: + tool_response = data["tool_response"] + elif "output" in data: + tool_response = data["output"] + elif "result" in data: + tool_response = data["result"] + + if isinstance(tool_response, dict): + tool_response = str(tool_response) + + # Get tool input for arguments + tool_input = data.get("tool_call_input", {}) or data.get("input", {}) + if isinstance(tool_input, dict): + import json + + tool_input_str = json.dumps(tool_input, ensure_ascii=False) + else: + tool_input_str = str(tool_input) if tool_input else "" + + if tool_response: + tool_calls.append( + AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool_name, + arguments=tool_input_str, + ), + ) + ) + tool_results.append((tool_call_id, tool_name, str(tool_response))) + + # Add assistant message with tool_calls if there were tool calls + if tool_calls: + context_messages.append(AssistantPromptMessage(content="", tool_calls=tool_calls)) + + # Add tool result messages + for tool_call_id, tool_name, result in tool_results: + context_messages.append( + ToolPromptMessage( + content=result, + tool_call_id=tool_call_id, + name=tool_name, + ) + ) + + # Add final assistant response + context_messages.append(AssistantPromptMessage(content=assistant_response)) + + return context_messages + + def _transform_message( + self, + messages: Generator[ToolInvokeMessage, None, None], + tool_info: Mapping[str, Any], + parameters_for_log: dict[str, Any], + user_id: str, + tenant_id: str, + node_type: NodeType, + node_id: str, + node_execution_id: str, + memory: BaseMemory | None = None, + ) -> Generator[NodeEventBase, None, None]: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + from core.plugin.impl.plugin import PluginInstaller + + message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json_list: list[dict | list] = [] + + agent_logs: list[AgentLogEvent] = [] + agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} + llm_usage = LLMUsage.empty_usage() + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + ToolInvokeMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + else: + transfer_method = FileTransferMethod.TOOL_FILE + + tool_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + files.append(file) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert message.meta + + tool_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + ) + elif message.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + text += message.message.text + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk=message.message.text, + is_final=False, + ) + elif message.type == ToolInvokeMessage.MessageType.JSON: + assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + if node_type == BuiltinNodeTypes.AGENT: + if isinstance(message.message.json_object, dict): + msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) + llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) + agent_execution_metadata = { + WorkflowNodeExecutionMetadataKey(key): value + for key, value in msg_metadata.items() + if key in WorkflowNodeExecutionMetadataKey.__members__.values() + } + else: + msg_metadata = {} + llm_usage = LLMUsage.empty_usage() + agent_execution_metadata = {} + if message.message.json_object: + json_list.append(message.message.json_object) + elif message.type == ToolInvokeMessage.MessageType.LINK: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk=stream_text, + is_final=False, + ) + elif message.type == ToolInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise AgentVariableTypeError( + "When 'stream' is True, 'variable_value' must be a string.", + variable_name=variable_name, + expected_type="str", + actual_type=type(variable_value).__name__, + ) + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield StreamChunkEvent( + selector=[node_id, variable_name], + chunk=variable_value, + is_final=False, + ) + else: + variables[variable_name] = variable_value + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + assert isinstance(message.meta, dict) + # Validate that meta contains a 'file' key + if "file" not in message.meta: + raise AgentNodeError("File message is missing 'file' key in meta") + + # Validate that the file is an instance of File + if not isinstance(message.meta["file"], File): + raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") + files.append(message.meta["file"]) + elif message.type == ToolInvokeMessage.MessageType.LOG: + assert isinstance(message.message, ToolInvokeMessage.LogMessage) + if message.message.metadata: + icon = tool_info.get("icon", "") + dict_metadata = dict(message.message.metadata) + if dict_metadata.get("provider"): + manager = PluginInstaller() + plugins = manager.list_plugins(tenant_id) + try: + current_plugin = next( + plugin + for plugin in plugins + if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] + ) + icon = current_plugin.declaration.icon + except StopIteration: + pass + icon_dark = None + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + user_id, + tenant_id, + ) + if provider.name == dict_metadata["provider"] + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + dict_metadata["icon"] = icon + dict_metadata["icon_dark"] = icon_dark + message.message.metadata = dict_metadata + agent_log = AgentLogEvent( + message_id=message.message.id, + node_execution_id=node_execution_id, + parent_id=message.message.parent_id, + error=message.message.error, + status=message.message.status.value, + data=message.message.data, + label=message.message.label, + metadata=message.message.metadata, + node_id=node_id, + ) + + # check if the agent log is already in the list + for log in agent_logs: + if log.message_id == agent_log.message_id: + # update the log + log.data = agent_log.data + log.status = agent_log.status + log.error = agent_log.error + log.label = agent_log.label + log.metadata = agent_log.metadata + break + else: + agent_logs.append(agent_log) + + yield agent_log + + # Add agent_logs to outputs['json'] to ensure frontend can access thinking process + json_output: list[dict[str, Any] | list[Any]] = [] + + # Step 1: append each agent log as its own dict. + if agent_logs: + for log in agent_logs: + json_output.append( + { + "id": log.message_id, + "parent_id": log.parent_id, + "error": log.error, + "status": log.status, + "data": log.data, + "label": log.label, + "metadata": log.metadata, + "node_id": log.node_id, + } + ) + # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] + if json_list: + json_output.extend(json_list) + else: + json_output.append({"data": []}) + + # Send final chunk events for all streamed outputs + # Final chunk for text stream + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk="", + is_final=True, + ) + + # Final chunks for any streamed variables + for var_name in variables: + yield StreamChunkEvent( + selector=[node_id, var_name], + chunk="", + is_final=True, + ) + + # Get user query from parameters for building context + user_query = parameters_for_log.get("query", "") + + # Build context from history, user query, tool calls and assistant response + context = self._build_context(parameters_for_log, user_query, text, agent_logs) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + "text": text, + "usage": jsonable_encoder(llm_usage), + "files": ArrayFileSegment(value=files), + "json": json_output, + "context": context, + **variables, + }, + metadata={ + **agent_execution_metadata, + WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, + WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, + }, + inputs=parameters_for_log, + llm_usage=llm_usage, + ) + ) diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 6a1bebffcd..097aea76be 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -90,9 +90,9 @@ def init_app(app: DifyApp): app.register_blueprint(inner_api_bp) app.register_blueprint(mcp_bp) - # TODO: enable after full sandbox integration - # from controllers.cli_api import bp as cli_api_bp - # app.register_blueprint(cli_api_bp) + from controllers.cli_api import bp as cli_api_bp + + app.register_blueprint(cli_api_bp) # Register trigger blueprint with CORS for webhook calls _apply_cors_once( diff --git a/api/extensions/ext_socketio.py b/api/extensions/ext_socketio.py new file mode 100644 index 0000000000..5ed82bac8d --- /dev/null +++ b/api/extensions/ext_socketio.py @@ -0,0 +1,5 @@ +import socketio # type: ignore[reportMissingTypeStubs] + +from configs import dify_config + +sio = socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS) diff --git a/api/fields/online_user_fields.py b/api/fields/online_user_fields.py new file mode 100644 index 0000000000..8fe0dc6a64 --- /dev/null +++ b/api/fields/online_user_fields.py @@ -0,0 +1,17 @@ +from flask_restx import fields + +online_user_partial_fields = { + "user_id": fields.String, + "username": fields.String, + "avatar": fields.String, + "sid": fields.String, +} + +workflow_online_users_fields = { + "workflow_id": fields.String, + "users": fields.List(fields.Nested(online_user_partial_fields)), +} + +online_user_list_fields = { + "data": fields.List(fields.Nested(workflow_online_users_fields)), +} diff --git a/api/fields/workflow_comment_fields.py b/api/fields/workflow_comment_fields.py new file mode 100644 index 0000000000..c708dd3460 --- /dev/null +++ b/api/fields/workflow_comment_fields.py @@ -0,0 +1,96 @@ +from flask_restx import fields + +from libs.helper import AvatarUrlField, TimestampField + +# basic account fields for comments +account_fields = { + "id": fields.String, + "name": fields.String, + "email": fields.String, + "avatar_url": AvatarUrlField, +} + +# Comment mention fields +workflow_comment_mention_fields = { + "mentioned_user_id": fields.String, + "mentioned_user_account": fields.Nested(account_fields, allow_null=True), + "reply_id": fields.String, +} + +# Comment reply fields +workflow_comment_reply_fields = { + "id": fields.String, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, +} + +# Basic comment fields (for list views) +workflow_comment_basic_fields = { + "id": fields.String, + "position_x": fields.Float, + "position_y": fields.Float, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, + "updated_at": TimestampField, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, + "resolved_by_account": fields.Nested(account_fields, allow_null=True), + "reply_count": fields.Integer, + "mention_count": fields.Integer, + "participants": fields.List(fields.Nested(account_fields)), +} + +# Detailed comment fields (for single comment view) +workflow_comment_detail_fields = { + "id": fields.String, + "position_x": fields.Float, + "position_y": fields.Float, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, + "updated_at": TimestampField, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, + "resolved_by_account": fields.Nested(account_fields, allow_null=True), + "replies": fields.List(fields.Nested(workflow_comment_reply_fields)), + "mentions": fields.List(fields.Nested(workflow_comment_mention_fields)), +} + +# Comment creation response fields (simplified) +workflow_comment_create_fields = { + "id": fields.String, + "created_at": TimestampField, +} + +# Comment update response fields (simplified) +workflow_comment_update_fields = { + "id": fields.String, + "updated_at": TimestampField, +} + +# Comment resolve response fields +workflow_comment_resolve_fields = { + "id": fields.String, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, +} + +# Reply creation response fields (simplified) +workflow_comment_reply_create_fields = { + "id": fields.String, + "created_at": TimestampField, +} + +# Reply update response fields +workflow_comment_reply_update_fields = { + "id": fields.String, + "updated_at": TimestampField, +} diff --git a/api/migrations/versions/2026_02_09_1031-aab323465866_agent_sandbox_support.py b/api/migrations/versions/2026_02_09_1031-aab323465866_agent_sandbox_support.py new file mode 100644 index 0000000000..fc91bdfba7 --- /dev/null +++ b/api/migrations/versions/2026_02_09_1031-aab323465866_agent_sandbox_support.py @@ -0,0 +1,143 @@ +"""Add sandbox providers, app assets, and LLM detail tables. + +Revision ID: aab323465866 +Revises: f55813ffe2c8 +Create Date: 2026-02-09 10:31:05.062722 + +""" + +import os +from uuid import uuid4 + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "aab323465866" +down_revision = "c3df22613c99" +branch_labels = None +depends_on = None + + +def _get_ssh_config_from_env() -> dict[str, str]: + """Build SSH sandbox config from environment variables. + + Defaults are chosen so that: + - All-in-one Docker Compose (api inside the network): agentbox:22 + - Middleware / local dev (api on the host): 127.0.0.1:2222 + + The env vars (SSH_SANDBOX_*) are documented in api/.env.example. + """ + return { + "ssh_host": os.environ.get("SSH_SANDBOX_HOST", "agentbox"), + "ssh_port": os.environ.get("SSH_SANDBOX_PORT", "22"), + "ssh_username": os.environ.get("SSH_SANDBOX_USERNAME", "agentbox"), + "ssh_password": os.environ.get("SSH_SANDBOX_PASSWORD", "agentbox"), + "base_working_path": os.environ.get("SSH_SANDBOX_BASE_WORKING_PATH", "/workspace/sandboxes"), + } + + +def upgrade(): + from core.tools.utils.system_encryption import encrypt_system_params + + op.create_table( + "sandbox_provider_system_config", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("provider_type", sa.String(length=50), nullable=False, comment="e2b, docker, local, ssh"), + sa.Column("encrypted_config", models.types.LongText(), nullable=False, comment="Encrypted config JSON"), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.PrimaryKeyConstraint("id", name="sandbox_provider_system_config_pkey"), + sa.UniqueConstraint("provider_type", name="unique_sandbox_provider_system_config_type"), + ) + op.create_table( + "sandbox_providers", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("tenant_id", models.types.StringUUID(), nullable=False), + sa.Column("provider_type", sa.String(length=50), nullable=False, comment="e2b, docker, local, ssh"), + sa.Column("configure_type", sa.String(length=20), server_default="user", nullable=False), + sa.Column("encrypted_config", models.types.LongText(), nullable=False, comment="Encrypted config JSON"), + sa.Column("is_active", sa.Boolean(), server_default=sa.text("false"), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.PrimaryKeyConstraint("id", name="sandbox_provider_pkey"), + sa.UniqueConstraint("tenant_id", "provider_type", "configure_type", name="unique_sandbox_provider_tenant_type"), + ) + with op.batch_alter_table("sandbox_providers", schema=None) as batch_op: + batch_op.create_index("idx_sandbox_providers_tenant_active", ["tenant_id", "is_active"], unique=False) + batch_op.create_index("idx_sandbox_providers_tenant_id", ["tenant_id"], unique=False) + + op.create_table( + "llm_generation_details", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("tenant_id", models.types.StringUUID(), nullable=False), + sa.Column("app_id", models.types.StringUUID(), nullable=False), + sa.Column("message_id", models.types.StringUUID(), nullable=True), + sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True), + sa.Column("node_id", sa.String(length=255), nullable=True), + sa.Column("reasoning_content", models.types.LongText(), nullable=True), + sa.Column("tool_calls", models.types.LongText(), nullable=True), + sa.Column("sequence", models.types.LongText(), nullable=True), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.CheckConstraint( + "(message_id IS NOT NULL AND workflow_run_id IS NULL AND node_id IS NULL) OR (message_id IS NULL AND workflow_run_id IS NOT NULL AND node_id IS NOT NULL)", + name=op.f("llm_generation_details_ck_llm_generation_detail_assoc_mode_check"), + ), + sa.PrimaryKeyConstraint("id", name="llm_generation_detail_pkey"), + sa.UniqueConstraint("message_id", name=op.f("llm_generation_details_message_id_key")), + ) + with op.batch_alter_table("llm_generation_details", schema=None) as batch_op: + batch_op.create_index("idx_llm_generation_detail_message", ["message_id"], unique=False) + batch_op.create_index("idx_llm_generation_detail_workflow", ["workflow_run_id", "node_id"], unique=False) + + op.create_table( + "app_assets", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("tenant_id", models.types.StringUUID(), nullable=False), + sa.Column("app_id", models.types.StringUUID(), nullable=False), + sa.Column("version", sa.String(length=255), nullable=False), + sa.Column("asset_tree", models.types.LongText(), nullable=False), + sa.Column("created_by", models.types.StringUUID(), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_by", models.types.StringUUID(), nullable=True), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.PrimaryKeyConstraint("id", name="app_assets_pkey"), + ) + with op.batch_alter_table("app_assets", schema=None) as batch_op: + batch_op.create_index("app_assets_version_idx", ["tenant_id", "app_id", "version"], unique=False) + + # Only seed a default SSH system provider for self-hosted deployments. + # CLOUD editions manage sandbox providers through admin tooling. + edition = os.environ.get("EDITION", "SELF_HOSTED") + if edition == "SELF_HOSTED": + ssh_config = _get_ssh_config_from_env() + encrypted_config = encrypt_system_params(ssh_config) + + op.execute( + sa.text( + """ + INSERT INTO sandbox_provider_system_config + (id, provider_type, encrypted_config, created_at, updated_at) + VALUES (:id, :provider_type, :encrypted_config, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + ON CONFLICT (provider_type) DO NOTHING + """ + ).bindparams( + id=str(uuid4()), + provider_type="ssh", + encrypted_config=encrypted_config, + ) + ) + + +def downgrade(): + op.drop_table("app_assets") + op.drop_table("llm_generation_details") + + with op.batch_alter_table("sandbox_providers", schema=None) as batch_op: + batch_op.drop_index("idx_sandbox_providers_tenant_id") + batch_op.drop_index("idx_sandbox_providers_tenant_active") + + op.drop_table("sandbox_providers") + op.drop_table("sandbox_provider_system_config") diff --git a/api/migrations/versions/2026_02_09_1726-227822d22895_add_workflow_comments_table.py b/api/migrations/versions/2026_02_09_1726-227822d22895_add_workflow_comments_table.py new file mode 100644 index 0000000000..af5e04a0e8 --- /dev/null +++ b/api/migrations/versions/2026_02_09_1726-227822d22895_add_workflow_comments_table.py @@ -0,0 +1,109 @@ +"""Add workflow comments table + +Revision ID: 227822d22895 +Revises: aab323465866 +Create Date: 2026-02-09 17:26:15.255980 + +""" + +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "227822d22895" +down_revision = "aab323465866" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "workflow_comments", + sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False), + sa.Column("tenant_id", models.types.StringUUID(), nullable=False), + sa.Column("app_id", models.types.StringUUID(), nullable=False), + sa.Column("position_x", sa.Float(), nullable=False), + sa.Column("position_y", sa.Float(), nullable=False), + sa.Column("content", sa.Text(), nullable=False), + sa.Column("created_by", models.types.StringUUID(), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("resolved", sa.Boolean(), server_default=sa.text("false"), nullable=False), + sa.Column("resolved_at", sa.DateTime(), nullable=True), + sa.Column("resolved_by", models.types.StringUUID(), nullable=True), + sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"), + ) + with op.batch_alter_table("workflow_comments", schema=None) as batch_op: + batch_op.create_index("workflow_comments_app_idx", ["tenant_id", "app_id"], unique=False) + batch_op.create_index("workflow_comments_created_at_idx", ["created_at"], unique=False) + + op.create_table( + "workflow_comment_replies", + sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False), + sa.Column("comment_id", models.types.StringUUID(), nullable=False), + sa.Column("content", sa.Text(), nullable=False), + sa.Column("created_by", models.types.StringUUID(), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.ForeignKeyConstraint( + ["comment_id"], + ["workflow_comments.id"], + name=op.f("workflow_comment_replies_comment_id_fkey"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"), + ) + with op.batch_alter_table("workflow_comment_replies", schema=None) as batch_op: + batch_op.create_index("comment_replies_comment_idx", ["comment_id"], unique=False) + batch_op.create_index("comment_replies_created_at_idx", ["created_at"], unique=False) + + op.create_table( + "workflow_comment_mentions", + sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False), + sa.Column("comment_id", models.types.StringUUID(), nullable=False), + sa.Column("reply_id", models.types.StringUUID(), nullable=True), + sa.Column("mentioned_user_id", models.types.StringUUID(), nullable=False), + sa.ForeignKeyConstraint( + ["comment_id"], + ["workflow_comments.id"], + name=op.f("workflow_comment_mentions_comment_id_fkey"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["reply_id"], + ["workflow_comment_replies.id"], + name=op.f("workflow_comment_mentions_reply_id_fkey"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"), + ) + with op.batch_alter_table("workflow_comment_mentions", schema=None) as batch_op: + batch_op.create_index("comment_mentions_comment_idx", ["comment_id"], unique=False) + batch_op.create_index("comment_mentions_reply_idx", ["reply_id"], unique=False) + batch_op.create_index("comment_mentions_user_idx", ["mentioned_user_id"], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("workflow_comment_mentions", schema=None) as batch_op: + batch_op.drop_index("comment_mentions_user_idx") + batch_op.drop_index("comment_mentions_reply_idx") + batch_op.drop_index("comment_mentions_comment_idx") + + op.drop_table("workflow_comment_mentions") + with op.batch_alter_table("workflow_comment_replies", schema=None) as batch_op: + batch_op.drop_index("comment_replies_created_at_idx") + batch_op.drop_index("comment_replies_comment_idx") + + op.drop_table("workflow_comment_replies") + with op.batch_alter_table("workflow_comments", schema=None) as batch_op: + batch_op.drop_index("workflow_comments_created_at_idx") + batch_op.drop_index("workflow_comments_app_idx") + + op.drop_table("workflow_comments") + # ### end Alembic commands ### diff --git a/api/migrations/versions/2026_03_09_1200-5ee0aa981887_add_app_asset_contents_table.py b/api/migrations/versions/2026_03_09_1200-5ee0aa981887_add_app_asset_contents_table.py new file mode 100644 index 0000000000..0a2dcae98c --- /dev/null +++ b/api/migrations/versions/2026_03_09_1200-5ee0aa981887_add_app_asset_contents_table.py @@ -0,0 +1,40 @@ +"""Add app_asset_contents table for inline content caching. + +Revision ID: 5ee0aa981887 +Revises: aab323465866 +Create Date: 2026-03-09 12:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "5ee0aa981887" +down_revision = "6b5f9f8b1a2c" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "app_asset_contents", + sa.Column("id", models.types.StringUUID(), nullable=False), + sa.Column("tenant_id", models.types.StringUUID(), nullable=False), + sa.Column("app_id", models.types.StringUUID(), nullable=False), + sa.Column("node_id", models.types.StringUUID(), nullable=False), + sa.Column("content", sa.Text(), nullable=False, server_default=""), + sa.Column("size", sa.Integer(), nullable=False, server_default="0"), + sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()), + sa.Column("updated_at", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()), + sa.PrimaryKeyConstraint("id", name="app_asset_contents_pkey"), + sa.UniqueConstraint("tenant_id", "app_id", "node_id", name="uq_asset_content_node"), + ) + op.create_index("idx_asset_content_app", "app_asset_contents", ["tenant_id", "app_id"]) + + +def downgrade() -> None: + op.drop_index("idx_asset_content_app", table_name="app_asset_contents") + op.drop_table("app_asset_contents") diff --git a/api/models/app_asset.py b/api/models/app_asset.py new file mode 100644 index 0000000000..d2c1efc10a --- /dev/null +++ b/api/models/app_asset.py @@ -0,0 +1,89 @@ +from datetime import datetime +from uuid import uuid4 + +import sqlalchemy as sa +from sqlalchemy import DateTime, Integer, String, func +from sqlalchemy.orm import Mapped, mapped_column + +from core.app.entities.app_asset_entities import AppAssetFileTree + +from .base import Base +from .types import LongText, StringUUID + + +class AppAssets(Base): + __tablename__ = "app_assets" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="app_assets_pkey"), + sa.Index("app_assets_version_idx", "tenant_id", "app_id", "version"), + ) + + VERSION_DRAFT = "draft" + VERSION_PUBLISHED = "published" + + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + version: Mapped[str] = mapped_column(String(255), nullable=False) + _asset_tree: Mapped[str] = mapped_column("asset_tree", LongText, nullable=False, default='{"nodes":[]}') + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_by: Mapped[str | None] = mapped_column(StringUUID) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + default=func.current_timestamp(), + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + ) + + @property + def asset_tree(self) -> AppAssetFileTree: + if not self._asset_tree: + return AppAssetFileTree() + return AppAssetFileTree.model_validate_json(self._asset_tree) + + @asset_tree.setter + def asset_tree(self, value: AppAssetFileTree) -> None: + self._asset_tree = value.model_dump_json() + + def __repr__(self) -> str: + return f"" + + +class AppAssetContent(Base): + """Inline content cache for app asset draft files. + + Acts as a read-through cache for S3: text-like asset content is dual-written + here on save and read from DB first (falling back to S3 on miss with sync backfill). + Keyed by (tenant_id, app_id, node_id) — stores only the current draft content, + not published snapshots. + + See core/app_assets/content_accessor.py for the accessor abstraction that + manages the DB/S3 read-through and dual-write logic. + """ + + __tablename__ = "app_asset_contents" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="app_asset_contents_pkey"), + sa.UniqueConstraint("tenant_id", "app_id", "node_id", name="uq_asset_content_node"), + sa.Index("idx_asset_content_app", "tenant_id", "app_id"), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + node_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + content: Mapped[str] = mapped_column(LongText, nullable=False, default="") + size: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + default=func.current_timestamp(), + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + ) + + def __repr__(self) -> str: + return f"" diff --git a/api/models/comment.py b/api/models/comment.py new file mode 100644 index 0000000000..21ccfa13db --- /dev/null +++ b/api/models/comment.py @@ -0,0 +1,210 @@ +"""Workflow comment models.""" + +from datetime import datetime +from typing import Optional + +from sqlalchemy import Index, func +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .account import Account +from .base import Base +from .engine import db +from .types import StringUUID + + +class WorkflowComment(Base): + """Workflow comment model for canvas commenting functionality. + + Comments are associated with apps rather than specific workflow versions, + since an app has only one draft workflow at a time and comments should persist + across workflow version changes. + + Attributes: + id: Comment ID + tenant_id: Workspace ID + app_id: App ID (primary association, comments belong to apps) + position_x: X coordinate on canvas + position_y: Y coordinate on canvas + content: Comment content + created_by: Creator account ID + created_at: Creation time + updated_at: Last update time + resolved: Whether comment is resolved + resolved_at: Resolution time + resolved_by: Resolver account ID + """ + + __tablename__ = "workflow_comments" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"), + Index("workflow_comments_app_idx", "tenant_id", "app_id"), + Index("workflow_comments_created_at_idx", "created_at"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + position_x: Mapped[float] = mapped_column(db.Float) + position_y: Mapped[float] = mapped_column(db.Float) + content: Mapped[str] = mapped_column(db.Text, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime) + resolved_by: Mapped[str | None] = mapped_column(StringUUID) + + # Relationships + replies: Mapped[list["WorkflowCommentReply"]] = relationship( + "WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan" + ) + mentions: Mapped[list["WorkflowCommentMention"]] = relationship( + "WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan" + ) + + @property + def created_by_account(self): + """Get creator account.""" + if hasattr(self, "_created_by_account_cache"): + return self._created_by_account_cache + return db.session.get(Account, self.created_by) + + def cache_created_by_account(self, account: Account | None) -> None: + """Cache creator account to avoid extra queries.""" + self._created_by_account_cache = account + + @property + def resolved_by_account(self): + """Get resolver account.""" + if hasattr(self, "_resolved_by_account_cache"): + return self._resolved_by_account_cache + if self.resolved_by: + return db.session.get(Account, self.resolved_by) + return None + + def cache_resolved_by_account(self, account: Account | None) -> None: + """Cache resolver account to avoid extra queries.""" + self._resolved_by_account_cache = account + + @property + def reply_count(self): + """Get reply count.""" + return len(self.replies) + + @property + def mention_count(self): + """Get mention count.""" + return len(self.mentions) + + @property + def participants(self): + """Get all participants (creator + repliers + mentioned users).""" + participant_ids = set() + + # Add comment creator + participant_ids.add(self.created_by) + + # Add reply creators + participant_ids.update(reply.created_by for reply in self.replies) + + # Add mentioned users + participant_ids.update(mention.mentioned_user_id for mention in self.mentions) + + # Get account objects + participants = [] + for user_id in participant_ids: + account = db.session.get(Account, user_id) + if account: + participants.append(account) + + return participants + + +class WorkflowCommentReply(Base): + """Workflow comment reply model. + + Attributes: + id: Reply ID + comment_id: Parent comment ID + content: Reply content + created_by: Creator account ID + created_at: Creation time + """ + + __tablename__ = "workflow_comment_replies" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"), + Index("comment_replies_comment_idx", "comment_id"), + Index("comment_replies_created_at_idx", "created_at"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + comment_id: Mapped[str] = mapped_column( + StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + ) + content: Mapped[str] = mapped_column(db.Text, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + # Relationships + comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies") + + @property + def created_by_account(self): + """Get creator account.""" + if hasattr(self, "_created_by_account_cache"): + return self._created_by_account_cache + return db.session.get(Account, self.created_by) + + def cache_created_by_account(self, account: Account | None) -> None: + """Cache creator account to avoid extra queries.""" + self._created_by_account_cache = account + + +class WorkflowCommentMention(Base): + """Workflow comment mention model. + + Mentions are only for internal accounts since end users + cannot access workflow canvas and commenting features. + + Attributes: + id: Mention ID + comment_id: Parent comment ID + mentioned_user_id: Mentioned account ID + """ + + __tablename__ = "workflow_comment_mentions" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"), + Index("comment_mentions_comment_idx", "comment_id"), + Index("comment_mentions_reply_idx", "reply_id"), + Index("comment_mentions_user_idx", "mentioned_user_id"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + comment_id: Mapped[str] = mapped_column( + StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + ) + reply_id: Mapped[str | None] = mapped_column( + StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True + ) + mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + # Relationships + comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions") + reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply") + + @property + def mentioned_user_account(self): + """Get mentioned account.""" + if hasattr(self, "_mentioned_user_account_cache"): + return self._mentioned_user_account_cache + return db.session.get(Account, self.mentioned_user_id) + + def cache_mentioned_user_account(self, account: Account | None) -> None: + """Cache mentioned account to avoid extra queries.""" + self._mentioned_user_account_cache = account diff --git a/api/models/sandbox.py b/api/models/sandbox.py new file mode 100644 index 0000000000..e384ab8853 --- /dev/null +++ b/api/models/sandbox.py @@ -0,0 +1,80 @@ +import json +from collections.abc import Mapping +from datetime import datetime +from typing import Any, cast +from uuid import uuid4 + +import sqlalchemy as sa +from sqlalchemy import DateTime, String, func +from sqlalchemy.orm import Mapped, mapped_column + +from .base import TypeBase +from .types import LongText, StringUUID + + +class SandboxProviderSystemConfig(TypeBase): + """ + System-level sandbox provider configuration. + Stores default configuration for each provider type. + """ + + __tablename__ = "sandbox_provider_system_config" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="sandbox_provider_system_config_pkey"), + sa.UniqueConstraint("provider_type", name="unique_sandbox_provider_system_config_type"), + ) + + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) + provider_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="e2b, docker, local, ssh") + encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False, comment="Encrypted config JSON") + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + server_default=func.current_timestamp(), + server_onupdate=func.current_timestamp(), + init=False, + ) + + +class SandboxProvider(TypeBase): + """ + Tenant-level sandbox provider configuration. + Each tenant can have one configuration per provider type. + Only one provider can be active at a time per tenant. + """ + + __tablename__ = "sandbox_providers" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="sandbox_provider_pkey"), + sa.UniqueConstraint("tenant_id", "provider_type", "configure_type", name="unique_sandbox_provider_tenant_type"), + sa.Index("idx_sandbox_providers_tenant_id", "tenant_id"), + sa.Index("idx_sandbox_providers_tenant_active", "tenant_id", "is_active"), + ) + + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="e2b, docker, local, ssh") + encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False, comment="Encrypted config JSON") + configure_type: Mapped[str] = mapped_column(String(20), nullable=False, server_default="user", default="user") + is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + server_default=func.current_timestamp(), + server_onupdate=func.current_timestamp(), + init=False, + ) + + @property + def config(self) -> Mapping[str, Any]: + return cast(Mapping[str, Any], json.loads(self.encrypted_config or "{}")) diff --git a/api/models/workflow_comment.py b/api/models/workflow_comment.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/models/workflow_features.py b/api/models/workflow_features.py new file mode 100644 index 0000000000..81fc0fa11a --- /dev/null +++ b/api/models/workflow_features.py @@ -0,0 +1,26 @@ +from collections.abc import Mapping +from dataclasses import dataclass +from enum import StrEnum +from typing import Any + + +class WorkflowFeatures(StrEnum): + SANDBOX = "sandbox" + SPEECH_TO_TEXT = "speech_to_text" + TEXT_TO_SPEECH = "text_to_speech" + RETRIEVER_RESOURCE = "retriever_resource" + SENSITIVE_WORD_AVOIDANCE = "sensitive_word_avoidance" + FILE_UPLOAD = "file_upload" + SUGGESTED_QUESTIONS_AFTER_ANSWER = "suggested_questions_after_answer" + + +@dataclass(frozen=True) +class WorkflowFeature: + enabled: bool + config: Mapping[str, Any] + + @classmethod + def from_dict(cls, data: Mapping[str, Any] | None) -> "WorkflowFeature": + if data is None or not isinstance(data, dict): + return cls(enabled=False, config={}) + return cls(enabled=bool(data.get("enabled", False)), config=data) diff --git a/api/repositories/workflow_collaboration_repository.py b/api/repositories/workflow_collaboration_repository.py new file mode 100644 index 0000000000..75a483e156 --- /dev/null +++ b/api/repositories/workflow_collaboration_repository.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +import json +from typing import TypedDict + +from extensions.ext_redis import redis_client + +SESSION_STATE_TTL_SECONDS = 3600 +WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:" +WORKFLOW_LEADER_PREFIX = "workflow_leader:" +WORKFLOW_SKILL_LEADER_PREFIX = "workflow_skill_leader:" +WS_SID_MAP_PREFIX = "ws_sid_map:" + + +class WorkflowSessionInfo(TypedDict): + user_id: str + username: str + avatar: str | None + sid: str + connected_at: int + graph_active: bool + active_skill_file_id: str | None + + +class SidMapping(TypedDict): + workflow_id: str + user_id: str + + +class WorkflowCollaborationRepository: + def __init__(self) -> None: + self._redis = redis_client + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(redis_client={self._redis})" + + @staticmethod + def workflow_key(workflow_id: str) -> str: + return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}" + + @staticmethod + def leader_key(workflow_id: str) -> str: + return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}" + + @staticmethod + def skill_leader_key(workflow_id: str, file_id: str) -> str: + return f"{WORKFLOW_SKILL_LEADER_PREFIX}{workflow_id}:{file_id}" + + @staticmethod + def sid_key(sid: str) -> str: + return f"{WS_SID_MAP_PREFIX}{sid}" + + @staticmethod + def _decode(value: str | bytes | None) -> str | None: + if value is None: + return None + if isinstance(value, bytes): + return value.decode("utf-8") + return value + + def refresh_session_state(self, workflow_id: str, sid: str) -> None: + workflow_key = self.workflow_key(workflow_id) + sid_key = self.sid_key(sid) + if self._redis.exists(workflow_key): + self._redis.expire(workflow_key, SESSION_STATE_TTL_SECONDS) + if self._redis.exists(sid_key): + self._redis.expire(sid_key, SESSION_STATE_TTL_SECONDS) + + def set_session_info(self, workflow_id: str, session_info: WorkflowSessionInfo) -> None: + workflow_key = self.workflow_key(workflow_id) + self._redis.hset(workflow_key, session_info["sid"], json.dumps(session_info)) + self._redis.set( + self.sid_key(session_info["sid"]), + json.dumps({"workflow_id": workflow_id, "user_id": session_info["user_id"]}), + ex=SESSION_STATE_TTL_SECONDS, + ) + self.refresh_session_state(workflow_id, session_info["sid"]) + + def get_session_info(self, workflow_id: str, sid: str) -> WorkflowSessionInfo | None: + raw = self._redis.hget(self.workflow_key(workflow_id), sid) + value = self._decode(raw) + if not value: + return None + try: + session_info = json.loads(value) + except (TypeError, json.JSONDecodeError): + return None + + if not isinstance(session_info, dict): + return None + if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info: + return None + + return { + "user_id": str(session_info["user_id"]), + "username": str(session_info["username"]), + "avatar": session_info.get("avatar"), + "sid": str(session_info["sid"]), + "connected_at": int(session_info.get("connected_at") or 0), + "graph_active": bool(session_info.get("graph_active")), + "active_skill_file_id": session_info.get("active_skill_file_id"), + } + + def set_graph_active(self, workflow_id: str, sid: str, active: bool) -> None: + session_info = self.get_session_info(workflow_id, sid) + if not session_info: + return + session_info["graph_active"] = bool(active) + self._redis.hset(self.workflow_key(workflow_id), sid, json.dumps(session_info)) + self.refresh_session_state(workflow_id, sid) + + def is_graph_active(self, workflow_id: str, sid: str) -> bool: + session_info = self.get_session_info(workflow_id, sid) + if not session_info: + return False + return bool(session_info.get("graph_active") or False) + + def set_active_skill_file(self, workflow_id: str, sid: str, file_id: str | None) -> None: + session_info = self.get_session_info(workflow_id, sid) + if not session_info: + return + session_info["active_skill_file_id"] = file_id + self._redis.hset(self.workflow_key(workflow_id), sid, json.dumps(session_info)) + self.refresh_session_state(workflow_id, sid) + + def get_active_skill_file_id(self, workflow_id: str, sid: str) -> str | None: + session_info = self.get_session_info(workflow_id, sid) + if not session_info: + return None + return session_info.get("active_skill_file_id") + + def get_sid_mapping(self, sid: str) -> SidMapping | None: + raw = self._redis.get(self.sid_key(sid)) + if not raw: + return None + value = self._decode(raw) + if not value: + return None + try: + return json.loads(value) + except (TypeError, json.JSONDecodeError): + return None + + def delete_session(self, workflow_id: str, sid: str) -> None: + self._redis.hdel(self.workflow_key(workflow_id), sid) + self._redis.delete(self.sid_key(sid)) + + def session_exists(self, workflow_id: str, sid: str) -> bool: + return bool(self._redis.hexists(self.workflow_key(workflow_id), sid)) + + def sid_mapping_exists(self, sid: str) -> bool: + return bool(self._redis.exists(self.sid_key(sid))) + + def get_session_sids(self, workflow_id: str) -> list[str]: + raw_sids = self._redis.hkeys(self.workflow_key(workflow_id)) + decoded_sids: list[str] = [] + for sid in raw_sids: + decoded = self._decode(sid) + if decoded: + decoded_sids.append(decoded) + return decoded_sids + + def list_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]: + sessions_json = self._redis.hgetall(self.workflow_key(workflow_id)) + users: list[WorkflowSessionInfo] = [] + + for session_info_json in sessions_json.values(): + value = self._decode(session_info_json) + if not value: + continue + try: + session_info = json.loads(value) + except (TypeError, json.JSONDecodeError): + continue + + if not isinstance(session_info, dict): + continue + if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info: + continue + + users.append( + { + "user_id": str(session_info["user_id"]), + "username": str(session_info["username"]), + "avatar": session_info.get("avatar"), + "sid": str(session_info["sid"]), + "connected_at": int(session_info.get("connected_at") or 0), + "graph_active": bool(session_info.get("graph_active")), + "active_skill_file_id": session_info.get("active_skill_file_id"), + } + ) + + return users + + def get_current_leader(self, workflow_id: str) -> str | None: + raw = self._redis.get(self.leader_key(workflow_id)) + return self._decode(raw) + + def get_skill_leader(self, workflow_id: str, file_id: str) -> str | None: + raw = self._redis.get(self.skill_leader_key(workflow_id, file_id)) + return self._decode(raw) + + def set_leader_if_absent(self, workflow_id: str, sid: str) -> bool: + return bool(self._redis.set(self.leader_key(workflow_id), sid, nx=True, ex=SESSION_STATE_TTL_SECONDS)) + + def set_leader(self, workflow_id: str, sid: str) -> None: + self._redis.set(self.leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS) + + def set_skill_leader(self, workflow_id: str, file_id: str, sid: str) -> None: + self._redis.set(self.skill_leader_key(workflow_id, file_id), sid, ex=SESSION_STATE_TTL_SECONDS) + + def delete_leader(self, workflow_id: str) -> None: + self._redis.delete(self.leader_key(workflow_id)) + + def delete_skill_leader(self, workflow_id: str, file_id: str) -> None: + self._redis.delete(self.skill_leader_key(workflow_id, file_id)) + + def expire_leader(self, workflow_id: str) -> None: + self._redis.expire(self.leader_key(workflow_id), SESSION_STATE_TTL_SECONDS) + + def expire_skill_leader(self, workflow_id: str, file_id: str) -> None: + self._redis.expire(self.skill_leader_key(workflow_id, file_id), SESSION_STATE_TTL_SECONDS) + + def get_active_skill_session_sids(self, workflow_id: str, file_id: str) -> list[str]: + sessions = self.list_sessions(workflow_id) + return [session["sid"] for session in sessions if session.get("active_skill_file_id") == file_id] diff --git a/api/services/app_asset_package_service.py b/api/services/app_asset_package_service.py new file mode 100644 index 0000000000..9df4874880 --- /dev/null +++ b/api/services/app_asset_package_service.py @@ -0,0 +1,195 @@ +"""Service for packaging and publishing app assets. + +This service handles operations that require core.zip_sandbox, +separated from AppAssetService to avoid circular imports. + +Dependency flow: + core/* -> AppAssetPackageService -> AppAssetService + (core modules can import this service without circular dependency) + +Inline content optimisation: + ``AssetItem`` objects returned by the build pipeline may carry an + in-process *content* field (e.g. resolved ``.md`` skill documents). + ``AppAssetService.to_download_items()`` converts these into unified + ``SandboxDownloadItem`` instances, and ``ZipSandbox.download_items()`` + handles both inline and remote items natively. +""" + +import logging +from uuid import uuid4 + +from sqlalchemy.orm import Session + +from core.app.entities.app_asset_entities import AppAssetFileTree +from core.app_assets.builder import AssetBuildPipeline, BuildContext +from core.app_assets.builder.file_builder import FileBuilder +from core.app_assets.builder.skill_builder import SkillBuilder +from core.app_assets.entities.assets import AssetItem +from core.app_assets.storage import AssetPaths +from core.zip_sandbox import ZipSandbox +from models.app_asset import AppAssets +from models.model import App + +logger = logging.getLogger(__name__) + + +class AppAssetPackageService: + """Service for packaging and publishing app assets. + + This service is designed to be imported by core/* modules without + causing circular imports. It depends on AppAssetService for basic + asset operations but provides the packaging/publishing functionality + that requires core.zip_sandbox. + """ + + @staticmethod + def get_tenant_app_assets(tenant_id: str, assets_id: str) -> AppAssets: + """Get app assets by tenant_id and assets_id. + + This is a read-only operation that doesn't require AppAssetService. + """ + from extensions.ext_database import db + + with Session(db.engine, expire_on_commit=False) as session: + app_assets = ( + session.query(AppAssets) + .filter( + AppAssets.tenant_id == tenant_id, + AppAssets.id == assets_id, + ) + .first() + ) + if not app_assets: + raise ValueError(f"App assets not found for tenant_id={tenant_id}, assets_id={assets_id}") + + return app_assets + + @staticmethod + def get_draft_asset_items(tenant_id: str, app_id: str, file_tree: AppAssetFileTree) -> list[AssetItem]: + """Convert file tree to asset items for packaging.""" + files = file_tree.walk_files() + return [ + AssetItem( + asset_id=f.id, + path=file_tree.get_path(f.id), + file_name=f.name, + extension=f.extension, + storage_key=AssetPaths.draft(tenant_id, app_id, f.id), + ) + for f in files + ] + + @staticmethod + def package_and_upload( + *, + assets: list[AssetItem], + upload_url: str, + tenant_id: str, + app_id: str, + user_id: str, + storage_key: str = "", + ) -> None: + """Package assets into a ZIP and upload directly to the given URL. + + Uses ``AppAssetService.to_download_items()`` to convert assets + into unified download items, then ``ZipSandbox.download_items()`` + handles both inline content and remote presigned URLs natively. + + When *assets* is empty an empty ZIP is written directly to storage + using *storage_key*, bypassing the HTTP ticket URL. + """ + from services.app_asset_service import AppAssetService + + if not assets: + import io + import zipfile + + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w"): + pass + buf.seek(0) + + # Write directly to storage instead of going through the HTTP + # ticket URL. The ticket URL (FILES_API_URL) is designed for + # sandbox containers (agentbox) and is not routable from the api + # container in standard Docker Compose deployments. + if storage_key: + from extensions.ext_storage import storage + + storage.save(storage_key, buf.getvalue()) + else: + import requests + + requests.put(upload_url, data=buf.getvalue(), timeout=30) + return + + download_items = AppAssetService.to_download_items(assets) + + with ZipSandbox(tenant_id=tenant_id, user_id=user_id, app_id="asset-packager") as zs: + zs.download_items(download_items) + archive = zs.zip() + zs.upload(archive, upload_url) + + @staticmethod + def publish(session: Session, app_model: App, account_id: str, workflow_id: str) -> AppAssets: + """Publish app assets for a workflow. + + Creates a versioned copy of draft assets and packages them for + runtime use. The build ZIP contains resolved ``.md`` content + (inline from ``SkillBuilder``) and raw draft content for all + other files. A separate source ZIP snapshots the raw drafts for + later export. + """ + from services.app_asset_service import AppAssetService + + tenant_id = app_model.tenant_id + app_id = app_model.id + + assets = AppAssetService.get_or_create_assets(session, app_model, account_id) + tree = assets.asset_tree + + publish_id = str(uuid4()) + + published = AppAssets( + id=publish_id, + tenant_id=tenant_id, + app_id=app_id, + version=workflow_id, + created_by=account_id, + ) + published.asset_tree = tree + session.add(published) + session.flush() + + asset_storage = AppAssetService.get_storage() + accessor = AppAssetService.get_accessor(tenant_id, app_id) + build_pipeline = AssetBuildPipeline([SkillBuilder(accessor=accessor), FileBuilder()]) + ctx = BuildContext(tenant_id=tenant_id, app_id=app_id, build_id=publish_id) + built_assets = build_pipeline.build_all(tree, ctx) + + # Runtime ZIP: resolved .md (inline) + raw draft (remote). + runtime_zip_key = AssetPaths.build_zip(tenant_id, app_id, publish_id) + runtime_upload_url = asset_storage.get_upload_url(runtime_zip_key) + AppAssetPackageService.package_and_upload( + assets=built_assets, + upload_url=runtime_upload_url, + tenant_id=tenant_id, + app_id=app_id, + user_id=account_id, + storage_key=runtime_zip_key, + ) + + # Source ZIP: all raw draft content (for export/restore). + source_items = AppAssetService.get_draft_assets(tenant_id, app_id) + source_key = AssetPaths.source_zip(tenant_id, app_id, workflow_id) + source_upload_url = asset_storage.get_upload_url(source_key) + AppAssetPackageService.package_and_upload( + assets=source_items, + upload_url=source_upload_url, + tenant_id=tenant_id, + app_id=app_id, + user_id=account_id, + storage_key=source_key, + ) + + return published diff --git a/api/services/app_runtime_upgrade_service.py b/api/services/app_runtime_upgrade_service.py new file mode 100644 index 0000000000..ff3a3814d2 --- /dev/null +++ b/api/services/app_runtime_upgrade_service.py @@ -0,0 +1,443 @@ +"""Service for upgrading Classic runtime apps to Sandboxed runtime via clone-and-convert. + +The upgrade flow: +1. Clone the source app via DSL export/import +2. On the cloned app's draft workflow, convert Agent nodes to LLM nodes +3. Rewrite variable references for all LLM nodes (old output names → new generation-based names) +4. Enable sandbox feature flag + +The original app is never modified; the user gets a new sandboxed copy. +""" + +import json +import logging +import re +import uuid +from typing import Any + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models import App, Workflow +from models.workflow_features import WorkflowFeatures +from services.app_dsl_service import AppDslService, ImportMode + +logger = logging.getLogger(__name__) + +_VAR_REWRITES: dict[str, list[str]] = { + "text": ["generation", "content"], + "reasoning_content": ["generation", "reasoning_content"], +} + +_PASSTHROUGH_KEYS = ( + "version", + "error_strategy", + "default_value", + "retry_config", + "parent_node_id", + "isInLoop", + "loop_id", + "isInIteration", + "iteration_id", +) + + +class AppRuntimeUpgradeService: + """Upgrades a Classic-runtime app to Sandboxed runtime by cloning and converting. + + Holds an active SQLAlchemy session; the caller is responsible for commit/rollback. + """ + + session: Session + + def __init__(self, session: Session) -> None: + self.session = session + + def upgrade(self, app_model: App, account: Any) -> dict[str, Any]: + """Clone *app_model* and upgrade the clone to sandboxed runtime. + + Returns: + dict with keys: result, new_app_id, converted_agents, skipped_agents. + """ + workflow = self._get_draft_workflow(app_model) + if not workflow: + return {"result": "no_draft"} + + if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled: + return {"result": "already_sandboxed"} + + new_app = self._clone_app(app_model, account) + new_workflow = self._get_draft_workflow(new_app) + if not new_workflow: + return {"result": "no_draft"} + + graph = json.loads(new_workflow.graph) if new_workflow.graph else {} + nodes = graph.get("nodes", []) + + converted, skipped = _convert_agent_nodes(nodes) + _enable_computer_use_for_existing_llm_nodes(nodes) + + llm_node_ids = {n["id"] for n in nodes if n.get("data", {}).get("type") == "llm"} + _rewrite_variable_references(nodes, llm_node_ids) + + new_workflow.graph = json.dumps(graph) + + features = json.loads(new_workflow.features) if new_workflow.features else {} + features.setdefault("sandbox", {})["enabled"] = True + new_workflow.features = json.dumps(features) + + return { + "result": "success", + "new_app_id": str(new_app.id), + "converted_agents": converted, + "skipped_agents": skipped, + } + + def _get_draft_workflow(self, app_model: App) -> Workflow | None: + stmt = select(Workflow).where( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.version == "draft", + ) + return self.session.scalar(stmt) + + def _clone_app(self, app_model: App, account: Any) -> App: + dsl_service = AppDslService(self.session) + yaml_content = dsl_service.export_dsl(app_model=app_model, include_secret=True) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=yaml_content, + name=f"{app_model.name} (Sandboxed)", + ) + stmt = select(App).where(App.id == result.app_id) + new_app = self.session.scalar(stmt) + if not new_app: + raise RuntimeError(f"Cloned app not found: {result.app_id}") + return new_app + + +# --------------------------------------------------------------------------- +# Pure conversion functions (no DB access) +# --------------------------------------------------------------------------- + + +def _convert_agent_nodes(nodes: list[dict[str, Any]]) -> tuple[int, int]: + """Convert Agent nodes to LLM nodes in-place. Returns (converted_count, skipped_count).""" + converted = 0 + + for node in nodes: + data = node.get("data", {}) + if data.get("type") != "agent": + continue + + node_id = node.get("id", "?") + node["data"] = _agent_data_to_llm_data(data) + logger.info("Converted agent node %s to LLM", node_id) + converted += 1 + + return converted, 0 + + +def _agent_data_to_llm_data(agent_data: dict[str, Any]) -> dict[str, Any]: + """Map an Agent node's data dict to an LLM node's data dict. + + Always returns a valid LLM data dict. If the agent has no model selected, + produces an empty LLM node with agent mode (computer_use) enabled. + """ + params = agent_data.get("agent_parameters") or {} + + model_param = params.get("model", {}) if isinstance(params, dict) else {} + model_value = model_param.get("value") if isinstance(model_param, dict) else None + + if isinstance(model_value, dict) and model_value.get("provider") and model_value.get("model"): + model_config = { + "provider": model_value["provider"], + "name": model_value["model"], + "mode": model_value.get("mode", "chat"), + "completion_params": model_value.get("completion_params", {}), + } + else: + model_config = {"provider": "", "name": "", "mode": "chat", "completion_params": {}} + + tools_param = params.get("tools", {}) + tools_value = tools_param.get("value", []) if isinstance(tools_param, dict) else [] + tools_meta, tool_settings = _convert_tools(tools_value if isinstance(tools_value, list) else []) + + instruction_param = params.get("instruction", {}) + instruction = instruction_param.get("value", "") if isinstance(instruction_param, dict) else "" + + query_param = params.get("query", {}) + query_value = query_param.get("value", "") if isinstance(query_param, dict) else "" + + has_tools = bool(tools_meta) + prompt_template = _build_prompt_template( + instruction, + query_value, + skill=has_tools, + tools=tools_value if has_tools else None, + ) + + max_iter_param = params.get("maximum_iterations", {}) + max_iterations = max_iter_param.get("value", 100) if isinstance(max_iter_param, dict) else 100 + + context_config = _extract_context(params) + vision_config = _extract_vision(params) + + llm_data: dict[str, Any] = { + "type": "llm", + "title": agent_data.get("title", "LLM"), + "desc": agent_data.get("desc", ""), + "model": model_config, + "prompt_template": prompt_template, + "prompt_config": {"jinja2_variables": []}, + "memory": agent_data.get("memory"), + "context": context_config, + "vision": vision_config, + "computer_use": True, + "structured_output_switch_on": False, + "reasoning_format": "separated", + "tools": tools_meta, + "tool_settings": tool_settings, + "max_iterations": max_iterations, + } + + for key in _PASSTHROUGH_KEYS: + if key in agent_data: + llm_data[key] = agent_data[key] + + return llm_data + + +def _extract_context(params: dict[str, Any]) -> dict[str, Any]: + """Extract context config from agent_parameters for LLM node format. + + Agent stores context as a variable selector in agent_parameters.context.value, + e.g. ["knowledge_retrieval_node_id", "result"]. Maps to LLM ContextConfig. + """ + if not isinstance(params, dict): + return {"enabled": False} + + ctx_param = params.get("context", {}) + ctx_value = ctx_param.get("value") if isinstance(ctx_param, dict) else None + + if isinstance(ctx_value, list) and len(ctx_value) >= 2 and all(isinstance(s, str) for s in ctx_value): + return {"enabled": True, "variable_selector": ctx_value} + + return {"enabled": False} + + +def _extract_vision(params: dict[str, Any]) -> dict[str, Any]: + """Extract vision config from agent_parameters for LLM node format.""" + if not isinstance(params, dict): + return {"enabled": False} + + vision_param = params.get("vision", {}) + vision_value = vision_param.get("value") if isinstance(vision_param, dict) else None + + if isinstance(vision_value, dict) and vision_value.get("enabled"): + return vision_value + + if isinstance(vision_value, bool) and vision_value: + return {"enabled": True} + + return {"enabled": False} + + +def _enable_computer_use_for_existing_llm_nodes(nodes: list[dict[str, Any]]) -> None: + """Enable computer_use for existing LLM nodes that have tools configured. + + After upgrade, the sandbox runtime requires computer_use=true for tool calling. + Existing LLM nodes from classic mode may have tools but computer_use=false. + """ + for node in nodes: + data = node.get("data", {}) + if data.get("type") != "llm": + continue + + tools = data.get("tools", []) + if tools and not data.get("computer_use"): + data["computer_use"] = True + logger.info("Enabled computer_use for LLM node %s with %d tools", node.get("id", "?"), len(tools)) + + +def _convert_tools( + tools_input: list[dict[str, Any]], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """Convert agent tool dicts to (ToolMetadata[], ToolSetting[]). + + Agent tools in graph JSON already use provider_name/settings/parameters — + the same field names as LLM ToolMetadata. We pass them through with defaults + for any missing fields. + """ + tools_meta: list[dict[str, Any]] = [] + tool_settings: list[dict[str, Any]] = [] + + for ts in tools_input: + if not isinstance(ts, dict): + continue + + provider_name = ts.get("provider_name", "") + tool_name = ts.get("tool_name", "") + tool_type = ts.get("type", "builtin") + + tools_meta.append( + { + "enabled": True, + "type": tool_type, + "provider_name": provider_name, + "tool_name": tool_name, + "plugin_unique_identifier": ts.get("plugin_unique_identifier"), + "credential_id": ts.get("credential_id"), + "parameters": ts.get("parameters", {}), + "settings": ts.get("settings", {}) or ts.get("tool_configuration", {}), + "extra": ts.get("extra", {}), + } + ) + + tool_settings.append( + { + "type": tool_type, + "provider": provider_name, + "tool_name": tool_name, + "enabled": True, + } + ) + + return tools_meta, tool_settings + + +def _build_prompt_template( + instruction: Any, + query: Any, + *, + skill: bool = False, + tools: list[dict[str, Any]] | None = None, +) -> list[dict[str, Any]]: + """Build LLM prompt_template from Agent instruction and query values. + + When *skill* is True each message gets ``"skill": True`` so the sandbox + engine treats the prompt as a skill document. + + When *tools* is provided, tool reference placeholders + (``§[tool].[provider].[name].[uuid]§``) are appended to the system + message and the corresponding ``ToolReference`` entries are placed in the + message's ``metadata.tools`` dict so the skill assembler can resolve them. + Tools from the same provider are grouped into a single token list. + """ + messages: list[dict[str, Any]] = [] + + system_text = instruction if isinstance(instruction, str) else (str(instruction) if instruction else "") + metadata: dict[str, Any] | None = None + + if tools: + tool_refs: dict[str, dict[str, Any]] = {} + provider_groups: dict[str, list[str]] = {} + for ts in tools: + if not isinstance(ts, dict): + continue + tool_uuid = str(uuid.uuid4()) + provider_id = ts.get("provider_name", "") + tool_name = ts.get("tool_name", "") + tool_type = ts.get("type", "builtin") + + token = f"§[tool].[{provider_id}].[{tool_name}].[{tool_uuid}]§" + provider_groups.setdefault(provider_id, []).append(token) + tool_refs[tool_uuid] = { + "type": tool_type, + "configuration": {"fields": []}, + "enabled": True, + **({"credential_id": ts.get("credential_id")} if ts.get("credential_id") else {}), + } + + if provider_groups: + group_texts: list[str] = [] + for tokens in provider_groups.values(): + if len(tokens) == 1: + group_texts.append(tokens[0]) + else: + group_texts.append("[" + ",".join(tokens) + "]") + all_tools_text = " ".join(group_texts) + system_text = f"{system_text}\n\n{all_tools_text}" if system_text else all_tools_text + metadata = {"tools": tool_refs, "files": []} + + if system_text: + msg: dict[str, Any] = {"role": "system", "text": system_text, "skill": skill} + if metadata: + msg["metadata"] = metadata + messages.append(msg) + + if isinstance(query, list) and len(query) >= 2: + template_ref = "{{#" + ".".join(str(s) for s in query) + "#}}" + messages.append({"role": "user", "text": template_ref, "skill": skill}) + elif query: + messages.append({"role": "user", "text": str(query), "skill": skill}) + + if not messages: + messages.append({"role": "user", "text": "", "skill": skill}) + + return messages + + +def _rewrite_variable_references(nodes: list[dict[str, Any]], llm_ids: set[str]) -> None: + """Recursively walk all node data and rewrite variable references for LLM nodes. + + Handles two forms: + - Structured selectors: [node_id, "text"] → [node_id, "generation", "content"] + - Template strings: {{#node_id.text#}} → {{#node_id.generation.content#}} + """ + if not llm_ids: + return + + escaped_ids = [re.escape(nid) for nid in llm_ids] + patterns: list[tuple[re.Pattern[str], str]] = [] + for old_name, new_path in _VAR_REWRITES.items(): + pattern = re.compile(r"\{\{#(" + "|".join(escaped_ids) + r")\." + re.escape(old_name) + r"#\}\}") + replacement = r"{{#\1." + ".".join(new_path) + r"#}}" + patterns.append((pattern, replacement)) + + for node in nodes: + data = node.get("data", {}) + _walk_and_rewrite(data, llm_ids, patterns) + + +def _walk_and_rewrite( + obj: Any, + llm_ids: set[str], + template_patterns: list[tuple[re.Pattern[str], str]], +) -> Any: + """Recursively rewrite variable references in a nested data structure.""" + if isinstance(obj, dict): + for key, value in obj.items(): + obj[key] = _walk_and_rewrite(value, llm_ids, template_patterns) + return obj + + if isinstance(obj, list): + if _is_variable_selector(obj, llm_ids): + return _rewrite_selector(obj) + for i, item in enumerate(obj): + obj[i] = _walk_and_rewrite(item, llm_ids, template_patterns) + return obj + + if isinstance(obj, str): + for pattern, replacement in template_patterns: + obj = pattern.sub(replacement, obj) + return obj + + return obj + + +def _is_variable_selector(lst: list, llm_ids: set[str]) -> bool: + """Check if a list is a structured variable selector pointing to an LLM node output.""" + if len(lst) < 2: + return False + if not all(isinstance(s, str) for s in lst): + return False + return lst[0] in llm_ids and lst[1] in _VAR_REWRITES + + +def _rewrite_selector(selector: list[str]) -> list[str]: + """Rewrite [node_id, "text"] → [node_id, "generation", "content"].""" + old_field = selector[1] + new_path = _VAR_REWRITES[old_field] + return [selector[0]] + new_path + selector[2:] diff --git a/api/services/asset_content_service.py b/api/services/asset_content_service.py new file mode 100644 index 0000000000..bf67e489fa --- /dev/null +++ b/api/services/asset_content_service.py @@ -0,0 +1,103 @@ +"""Service for the app_asset_contents table. + +Provides single-node and batch DB operations for the inline content cache. +All methods are static and open their own short-lived sessions. + +Collaborators: + - models.app_asset.AppAssetContent (SQLAlchemy model) + - core.app_assets.accessor (accessor abstraction that calls this service) +""" + +import logging + +from sqlalchemy import delete, select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from models.app_asset import AppAssetContent + +logger = logging.getLogger(__name__) + + +class AssetContentService: + """DB operations for the inline asset content cache. + + All methods are static. All queries are scoped by tenant_id + app_id. + """ + + @staticmethod + def get(tenant_id: str, app_id: str, node_id: str) -> str | None: + """Get cached content for a single node. Returns None on miss.""" + with Session(db.engine) as session: + return session.execute( + select(AppAssetContent.content).where( + AppAssetContent.tenant_id == tenant_id, + AppAssetContent.app_id == app_id, + AppAssetContent.node_id == node_id, + ) + ).scalar_one_or_none() + + @staticmethod + def get_many(tenant_id: str, app_id: str, node_ids: list[str]) -> dict[str, str]: + """Batch get. Returns {node_id: content} for hits only.""" + if not node_ids: + return {} + with Session(db.engine) as session: + rows = session.execute( + select(AppAssetContent.node_id, AppAssetContent.content).where( + AppAssetContent.tenant_id == tenant_id, + AppAssetContent.app_id == app_id, + AppAssetContent.node_id.in_(node_ids), + ) + ).all() + return {row.node_id: row.content for row in rows} + + @staticmethod + def upsert(tenant_id: str, app_id: str, node_id: str, content: str, size: int) -> None: + """Insert or update inline content for a single node.""" + with Session(db.engine) as session: + stmt = pg_insert(AppAssetContent).values( + tenant_id=tenant_id, + app_id=app_id, + node_id=node_id, + content=content, + size=size, + ) + stmt = stmt.on_conflict_do_update( + constraint="uq_asset_content_node", + set_={ + "content": stmt.excluded.content, + "size": stmt.excluded.size, + }, + ) + session.execute(stmt) + session.commit() + + @staticmethod + def delete(tenant_id: str, app_id: str, node_id: str) -> None: + """Delete cached content for a single node.""" + with Session(db.engine) as session: + session.execute( + delete(AppAssetContent).where( + AppAssetContent.tenant_id == tenant_id, + AppAssetContent.app_id == app_id, + AppAssetContent.node_id == node_id, + ) + ) + session.commit() + + @staticmethod + def delete_many(tenant_id: str, app_id: str, node_ids: list[str]) -> None: + """Delete cached content for multiple nodes.""" + if not node_ids: + return + with Session(db.engine) as session: + session.execute( + delete(AppAssetContent).where( + AppAssetContent.tenant_id == tenant_id, + AppAssetContent.app_id == app_id, + AppAssetContent.node_id.in_(node_ids), + ) + ) + session.commit() diff --git a/api/services/errors/app_asset.py b/api/services/errors/app_asset.py new file mode 100644 index 0000000000..7fe1c1a30f --- /dev/null +++ b/api/services/errors/app_asset.py @@ -0,0 +1,17 @@ +from .base import BaseServiceError + + +class AppAssetNodeNotFoundError(BaseServiceError): + pass + + +class AppAssetParentNotFoundError(BaseServiceError): + pass + + +class AppAssetPathConflictError(BaseServiceError): + pass + + +class AppAssetNodeTooLargeError(BaseServiceError): + pass diff --git a/api/services/llm_generation_service.py b/api/services/llm_generation_service.py new file mode 100644 index 0000000000..eb8327537e --- /dev/null +++ b/api/services/llm_generation_service.py @@ -0,0 +1,37 @@ +""" +LLM Generation Detail Service. + +Provides methods to query and attach generation details to workflow node executions +and messages, avoiding N+1 query problems. +""" + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.llm_generation_entities import LLMGenerationDetailData +from models import LLMGenerationDetail + + +class LLMGenerationService: + """Service for handling LLM generation details.""" + + def __init__(self, session: Session): + self._session = session + + def get_generation_detail_for_message(self, message_id: str) -> LLMGenerationDetailData | None: + """Query generation detail for a specific message.""" + stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id == message_id) + detail = self._session.scalars(stmt).first() + return detail.to_domain_model() if detail else None + + def get_generation_details_for_messages( + self, + message_ids: list[str], + ) -> dict[str, LLMGenerationDetailData]: + """Batch query generation details for multiple messages.""" + if not message_ids: + return {} + + stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id.in_(message_ids)) + details = self._session.scalars(stmt).all() + return {detail.message_id: detail.to_domain_model() for detail in details if detail.message_id} diff --git a/api/services/skill_service.py b/api/services/skill_service.py new file mode 100644 index 0000000000..77ae907ea3 --- /dev/null +++ b/api/services/skill_service.py @@ -0,0 +1,204 @@ +"""Service for extracting tool dependencies from LLM node skill prompts. + +Two public entry points: + +- ``extract_tool_dependencies`` — takes raw node data from the client, + real-time builds a ``SkillBundle`` from current draft ``.md`` assets, + and resolves transitive tool dependencies. Used by the per-node POST + endpoint. +- ``get_workflow_skills`` — scans all LLM nodes in a persisted draft + workflow and returns per-node skill info. Uses a cached bundle. +""" + +from __future__ import annotations + +import json +import logging +from collections.abc import Mapping +from functools import reduce +from typing import Any, cast + +from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode +from core.sandbox.entities.config import AppAssets +from core.skill.assembler import SkillBundleAssembler, SkillDocumentAssembler +from core.skill.entities.skill_bundle import SkillBundle +from core.skill.entities.skill_document import SkillDocument +from core.skill.entities.skill_metadata import SkillMetadata +from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency +from core.skill.skill_manager import SkillManager +from graphon.enums import BuiltinNodeTypes +from models.model import App +from services.app_asset_service import AppAssetService + +logger = logging.getLogger(__name__) + + +class SkillService: + """Service for managing and retrieving skill information from workflows.""" + + # ------------------------------------------------------------------ + # Per-node: client sends node data, server builds bundle in real-time + # ------------------------------------------------------------------ + + @staticmethod + def extract_tool_dependencies( + app: App, + node_data: Mapping[str, Any], + user_id: str, + ) -> list[ToolDependency]: + """Extract tool dependencies from an LLM node's skill prompts. + + Builds a fresh ``SkillBundle`` from current draft ``.md`` assets + every time — no cached bundle is used. The caller supplies the + full node ``data`` dict directly (not a ``node_id``). + + Returns an empty list when the node has no skill prompts or when + no draft assets exist. + """ + if node_data.get("type", "") != BuiltinNodeTypes.LLM: + return [] + + if not SkillService._has_skill(node_data): + return [] + + bundle = SkillService._build_bundle(app, user_id) + if bundle is None: + return [] + + return SkillService._resolve_prompt_dependencies(node_data, bundle) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _has_skill(node_data: Mapping[str, Any]) -> bool: + """Check if node has any skill prompts.""" + prompt_template_raw = node_data.get("prompt_template", []) + if isinstance(prompt_template_raw, list): + for prompt_item in cast(list[object], prompt_template_raw): + if isinstance(prompt_item, dict) and prompt_item.get("skill", False): + return True + return False + + @staticmethod + def _build_bundle(app: App, user_id: str) -> SkillBundle | None: + """Real-time build a SkillBundle from current draft .md assets. + + Reads all ``.md`` nodes from the draft file tree, bulk-loads + their content from the DB cache, parses into ``SkillDocument`` + objects, and assembles a full bundle with transitive dependency + resolution. + + The bundle is **not** persisted — it is built fresh for each + request so the response always reflects the latest draft state. + """ + assets = AppAssetService.get_assets( + tenant_id=app.tenant_id, + app_id=app.id, + user_id=user_id, + is_draft=True, + ) + if not assets: + return None + + file_tree: AppAssetFileTree = assets.asset_tree + if file_tree.empty(): + return SkillBundle(assets_id=assets.id, asset_tree=file_tree) + + # Collect all .md file nodes from the tree. + md_nodes: list[AppAssetNode] = [n for n in file_tree.walk_files() if n.extension == "md"] + if not md_nodes: + return SkillBundle(assets_id=assets.id, asset_tree=file_tree) + + # Bulk-load content from DB (with S3 fallback). + accessor = AppAssetService.get_accessor(app.tenant_id, app.id) + raw_contents = accessor.bulk_load(md_nodes) + + # Parse into SkillDocuments. + documents: dict[str, SkillDocument] = {} + for node in md_nodes: + raw = raw_contents.get(node.id) + if not raw: + continue + try: + data = {"skill_id": node.id, **json.loads(raw)} + documents[node.id] = SkillDocument.model_validate(data) + except (json.JSONDecodeError, TypeError, ValueError): + logger.warning("Skipping unparseable skill document node_id=%s", node.id) + continue + + return SkillBundleAssembler(file_tree).assemble_bundle(documents, assets.id) + + @staticmethod + def _resolve_prompt_dependencies( + node_data: Mapping[str, Any], + bundle: SkillBundle, + ) -> list[ToolDependency]: + """Resolve tool dependencies from skill prompts against a bundle.""" + assembler = SkillDocumentAssembler(bundle) + tool_deps_list: list[ToolDependencies] = [] + + prompt_template_raw = node_data.get("prompt_template", []) + if not isinstance(prompt_template_raw, list): + return [] + + for prompt_item in cast(list[object], prompt_template_raw): + if not isinstance(prompt_item, dict): + continue + prompt = cast(dict[str, Any], prompt_item) + if not prompt.get("skill", False): + continue + + text_raw = prompt.get("text", "") + text = text_raw if isinstance(text_raw, str) else str(text_raw) + + metadata_obj: object = prompt.get("metadata") + metadata = cast(dict[str, Any], metadata_obj) if isinstance(metadata_obj, dict) else {} + + skill_entry = assembler.assemble_document( + document=SkillDocument( + skill_id="anonymous", + content=text, + metadata=SkillMetadata.model_validate(metadata), + ), + base_path=AppAssets.PATH, + ) + tool_deps_list.append(skill_entry.dependance.tools) + + if not tool_deps_list: + return [] + + merged = reduce(lambda x, y: x.merge(y), tool_deps_list) + return merged.dependencies + + @staticmethod + def _extract_tool_dependencies_cached( + app: App, + node_data: Mapping[str, Any], + user_id: str, + ) -> list[ToolDependency]: + """Extract tool dependencies using a cached SkillBundle. + + Used by ``get_workflow_skills`` for the whole-workflow endpoint. + """ + assets = AppAssetService.get_assets( + tenant_id=app.tenant_id, + app_id=app.id, + user_id=user_id, + is_draft=True, + ) + if not assets: + return [] + + try: + bundle = SkillManager.load_bundle( + tenant_id=app.tenant_id, + app_id=app.id, + assets_id=assets.id, + ) + except Exception: + logger.debug("Failed to load cached skill bundle for app_id=%s", app.id, exc_info=True) + return [] + + return SkillService._resolve_prompt_dependencies(node_data, bundle) diff --git a/api/services/storage_ticket_service.py b/api/services/storage_ticket_service.py new file mode 100644 index 0000000000..793d387191 --- /dev/null +++ b/api/services/storage_ticket_service.py @@ -0,0 +1,153 @@ +"""Storage ticket service for generating opaque download/upload URLs. + +This service provides a ticket-based approach for file access. Instead of exposing +the real storage key in URLs, it generates a random UUID token and stores the mapping +in Redis with a TTL. + +Usage: + from services.storage_ticket_service import StorageTicketService + + # Generate a download ticket + url = StorageTicketService.create_download_url("path/to/file.txt", expires_in=300) + + # Generate an upload ticket + url = StorageTicketService.create_upload_url("path/to/file.txt", expires_in=300, max_bytes=10*1024*1024) + +URL format: + {FILES_API_URL}/files/storage-files/{token} + +The token is validated by looking up the Redis key, which contains: + - op: "download" or "upload" + - storage_key: the real storage path + - max_bytes: (upload only) maximum allowed upload size + - filename: suggested filename for Content-Disposition header +""" + +import logging +from typing import Literal +from uuid import uuid4 + +from pydantic import BaseModel + +from configs import dify_config +from extensions.ext_redis import redis_client + +logger = logging.getLogger(__name__) + +TICKET_KEY_PREFIX = "storage_files" +DEFAULT_DOWNLOAD_TTL = 300 # 5 minutes +DEFAULT_UPLOAD_TTL = 300 # 5 minutes +DEFAULT_MAX_UPLOAD_BYTES = 100 * 1024 * 1024 # 100MB + + +class StorageTicket(BaseModel): + """Represents a storage access ticket.""" + + op: Literal["download", "upload"] + storage_key: str + max_bytes: int | None = None # upload only + filename: str | None = None # suggested filename for download + + +class StorageTicketService: + """Service for creating and validating storage access tickets.""" + + @classmethod + def create_download_url( + cls, + storage_key: str, + *, + expires_in: int = DEFAULT_DOWNLOAD_TTL, + filename: str | None = None, + ) -> str: + """Create a download ticket and return the URL. + + Args: + storage_key: The real storage path + expires_in: TTL in seconds (default 300) + filename: Suggested filename for Content-Disposition header + + Returns: + Full URL with token + """ + if filename is None: + filename = storage_key.rsplit("/", 1)[-1] + + ticket = StorageTicket(op="download", storage_key=storage_key, filename=filename) + token = cls._store_ticket(ticket, expires_in) + return cls._build_url(token) + + @classmethod + def create_upload_url( + cls, + storage_key: str, + *, + expires_in: int = DEFAULT_UPLOAD_TTL, + max_bytes: int = DEFAULT_MAX_UPLOAD_BYTES, + ) -> str: + """Create an upload ticket and return the URL. + + Args: + storage_key: The real storage path + expires_in: TTL in seconds (default 300) + max_bytes: Maximum allowed upload size in bytes + + Returns: + Full URL with token + """ + ticket = StorageTicket(op="upload", storage_key=storage_key, max_bytes=max_bytes) + token = cls._store_ticket(ticket, expires_in) + return cls._build_url(token) + + @classmethod + def get_ticket(cls, token: str) -> StorageTicket | None: + """Retrieve a ticket by token. + + Args: + token: The UUID token from the URL + + Returns: + StorageTicket if found and valid, None otherwise + """ + key = cls._ticket_key(token) + try: + data = redis_client.get(key) + if data is None: + return None + if isinstance(data, bytes): + data = data.decode("utf-8") + return StorageTicket.model_validate_json(data) + except Exception: + logger.warning("Failed to retrieve storage ticket: %s", token, exc_info=True) + return None + + @classmethod + def _store_ticket(cls, ticket: StorageTicket, ttl: int) -> str: + """Store a ticket in Redis and return the token.""" + token = str(uuid4()) + key = cls._ticket_key(token) + value = ticket.model_dump_json() + redis_client.setex(key, ttl, value) + return token + + @classmethod + def _ticket_key(cls, token: str) -> str: + """Generate Redis key for a token.""" + return f"{TICKET_KEY_PREFIX}:{token}" + + @classmethod + def _build_url(cls, token: str) -> str: + """Build the full URL for a token. + + FILES_API_URL is dedicated to sandbox runtime file access (agentbox/e2b/etc.). + This endpoint must be routable from the runtime environment. + """ + base_url = dify_config.FILES_API_URL.strip() + if not base_url: + raise ValueError( + "FILES_API_URL is required for sandbox runtime file access. " + "Set FILES_API_URL to a URL reachable by your sandbox runtime. " + "For public sandbox environments (e.g. e2b), use a public domain or IP." + ) + base_url = base_url.rstrip("/") + return f"{base_url}/files/storage-files/{token}" diff --git a/api/services/workflow/nested_node_graph_service.py b/api/services/workflow/nested_node_graph_service.py new file mode 100644 index 0000000000..ccb44fe03b --- /dev/null +++ b/api/services/workflow/nested_node_graph_service.py @@ -0,0 +1,157 @@ +""" +Service for generating Nested Node LLM graph structures. + +This service creates graph structures containing LLM nodes configured for +extracting values from list[PromptMessage] variables. +""" + +from typing import Any + +from sqlalchemy.orm import Session + +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.entities import LLMMode +from services.model_provider_service import ModelProviderService +from services.workflow.entities import NestedNodeGraphRequest, NestedNodeGraphResponse, NestedNodeParameterSchema + + +class NestedNodeGraphService: + """Service for generating Nested Node LLM graph structures.""" + + def __init__(self, session: Session): + self._session = session + + def generate_nested_node_id(self, node_id: str, parameter_name: str) -> str: + """Generate nested node ID following the naming convention. + + Format: {node_id}_ext_{parameter_name} + """ + return f"{node_id}_ext_{parameter_name}" + + def generate_nested_node_graph(self, tenant_id: str, request: NestedNodeGraphRequest) -> NestedNodeGraphResponse: + """Generate a complete graph structure containing a Nested Node LLM node. + + Args: + tenant_id: The tenant ID for fetching default model config + request: The nested node graph generation request + + Returns: + Complete graph structure with nodes, edges, and viewport + """ + node_id = self.generate_nested_node_id(request.parent_node_id, request.parameter_key) + model_config = self._get_default_model_config(tenant_id) + node = self._build_nested_node_llm_node( + node_id=node_id, + parent_node_id=request.parent_node_id, + context_source=request.context_source, + parameter_schema=request.parameter_schema, + model_config=model_config, + ) + + graph = { + "nodes": [node], + "edges": [], + "viewport": {}, + } + + return NestedNodeGraphResponse(graph=graph) + + def _get_default_model_config(self, tenant_id: str) -> dict[str, Any]: + """Get the default LLM model configuration for the tenant.""" + model_provider_service = ModelProviderService() + default_model = model_provider_service.get_default_model_of_model_type( + tenant_id=tenant_id, + model_type="llm", + ) + + if default_model: + return { + "provider": default_model.provider.provider, + "name": default_model.model, + "mode": LLMMode.CHAT.value, + "completion_params": {}, + } + + # Fallback to empty config if no default model is configured + return { + "provider": "", + "name": "", + "mode": LLMMode.CHAT.value, + "completion_params": {}, + } + + def _build_nested_node_llm_node( + self, + *, + node_id: str, + parent_node_id: str, + context_source: list[str], + parameter_schema: NestedNodeParameterSchema, + model_config: dict[str, Any], + ) -> dict[str, Any]: + """Build the Nested Node LLM node structure. + + The node uses: + - $context in prompt_template to reference the PromptMessage list + - structured_output for extracting the specific parameter + - parent_node_id to associate with the parent node + """ + prompt_template = [ + { + "role": "system", + "text": "Extract the required parameter value from the conversation context above.", + "skill": False, + }, + {"$context": context_source}, + {"role": "user", "text": "", "skill": False}, + ] + + structured_output = { + "schema": { + "type": "object", + "properties": { + parameter_schema.name: { + "type": parameter_schema.type, + "description": parameter_schema.description, + } + }, + "required": [parameter_schema.name], + "additionalProperties": False, + } + } + + return { + "id": node_id, + "position": {"x": 0, "y": 0}, + "data": { + "type": BuiltinNodeTypes.LLM, + # BaseNodeData fields + "title": f"NestedNode: {parameter_schema.name}", + "desc": f"Extract {parameter_schema.name} from conversation context", + "version": "1", + "error_strategy": None, + "default_value": None, + "retry_config": {"max_retries": 0}, + "parent_node_id": parent_node_id, + # LLMNodeData fields + "model": model_config, + "prompt_template": prompt_template, + "prompt_config": {"jinja2_variables": []}, + "memory": None, + "context": { + "enabled": False, + "variable_selector": None, + }, + "vision": { + "enabled": False, + "configs": { + "variable_selector": ["sys", "files"], + "detail": "high", + }, + }, + "structured_output_enabled": True, + "structured_output": structured_output, + "computer_use": False, + "tool_settings": [], + }, + } diff --git a/api/services/workflow_collaboration_service.py b/api/services/workflow_collaboration_service.py new file mode 100644 index 0000000000..8968a624ff --- /dev/null +++ b/api/services/workflow_collaboration_service.py @@ -0,0 +1,391 @@ +from __future__ import annotations + +import logging +import time +from collections.abc import Mapping + +from models.account import Account +from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository, WorkflowSessionInfo + + +class WorkflowCollaborationService: + def __init__(self, repository: WorkflowCollaborationRepository, socketio) -> None: + self._repository = repository + self._socketio = socketio + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(repository={self._repository})" + + def save_session(self, sid: str, user: Account) -> None: + self._socketio.save_session( + sid, + { + "user_id": user.id, + "username": user.name, + "avatar": user.avatar, + }, + ) + + def register_session(self, workflow_id: str, sid: str) -> tuple[str, bool] | None: + session = self._socketio.get_session(sid) + user_id = session.get("user_id") + if not user_id: + return None + + session_info: WorkflowSessionInfo = { + "user_id": str(user_id), + "username": str(session.get("username", "Unknown")), + "avatar": session.get("avatar"), + "sid": sid, + "connected_at": int(time.time()), + "graph_active": True, + "active_skill_file_id": None, + } + + self._repository.set_session_info(workflow_id, session_info) + + leader_sid = self.get_or_set_leader(workflow_id, sid) + is_leader = leader_sid == sid if leader_sid else False + + self._socketio.enter_room(sid, workflow_id) + self.broadcast_online_users(workflow_id) + + self._socketio.emit("status", {"isLeader": is_leader}, room=sid) + + return str(user_id), is_leader + + def disconnect_session(self, sid: str) -> None: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return + + workflow_id = mapping["workflow_id"] + active_skill_file_id = self._repository.get_active_skill_file_id(workflow_id, sid) + self._repository.delete_session(workflow_id, sid) + + self.handle_leader_disconnect(workflow_id, sid) + if active_skill_file_id: + self.handle_skill_leader_disconnect(workflow_id, active_skill_file_id, sid) + self.broadcast_online_users(workflow_id) + + def relay_collaboration_event(self, sid: str, data: Mapping[str, object]) -> tuple[dict[str, str], int]: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return {"msg": "unauthorized"}, 401 + + workflow_id = mapping["workflow_id"] + user_id = mapping["user_id"] + self.refresh_session_state(workflow_id, sid) + + event_type = data.get("type") + event_data = data.get("data") + timestamp = data.get("timestamp", int(time.time())) + + if not event_type: + return {"msg": "invalid event type"}, 400 + + if event_type == "graph_view_active": + is_active = False + if isinstance(event_data, dict): + is_active = bool(event_data.get("active") or False) + self._repository.set_graph_active(workflow_id, sid, is_active) + self.refresh_session_state(workflow_id, sid) + self.broadcast_online_users(workflow_id) + return {"msg": "graph_view_active_updated"}, 200 + + if event_type == "skill_file_active": + file_id = None + is_active = False + if isinstance(event_data, dict): + file_id = event_data.get("file_id") + is_active = bool(event_data.get("active") or False) + + if not file_id or not isinstance(file_id, str): + return {"msg": "invalid skill_file_active payload"}, 400 + + previous_file_id = self._repository.get_active_skill_file_id(workflow_id, sid) + next_file_id = file_id if is_active else None + + if previous_file_id == next_file_id: + self.refresh_session_state(workflow_id, sid) + return {"msg": "skill_file_active_unchanged"}, 200 + + self._repository.set_active_skill_file(workflow_id, sid, next_file_id) + self.refresh_session_state(workflow_id, sid) + + if previous_file_id: + self._ensure_skill_leader(workflow_id, previous_file_id) + if next_file_id: + self._ensure_skill_leader(workflow_id, next_file_id, preferred_sid=sid) + + return {"msg": "skill_file_active_updated"}, 200 + + if event_type == "sync_request": + leader_sid = self._repository.get_current_leader(workflow_id) + if leader_sid and ( + self.is_session_active(workflow_id, leader_sid) + and self._repository.is_graph_active(workflow_id, leader_sid) + ): + target_sid = leader_sid + else: + if leader_sid: + self._repository.delete_leader(workflow_id) + target_sid = self._select_graph_leader(workflow_id, preferred_sid=sid) + if target_sid: + self._repository.set_leader(workflow_id, target_sid) + self.broadcast_leader_change(workflow_id, target_sid) + if not target_sid: + return {"msg": "no_active_leader"}, 200 + + self._socketio.emit( + "collaboration_update", + {"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp}, + room=target_sid, + ) + return {"msg": "sync_request_forwarded"}, 200 + + self._socketio.emit( + "collaboration_update", + {"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp}, + room=workflow_id, + skip_sid=sid, + ) + + return {"msg": "event_broadcasted"}, 200 + + def relay_graph_event(self, sid: str, data: object) -> tuple[dict[str, str], int]: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return {"msg": "unauthorized"}, 401 + + workflow_id = mapping["workflow_id"] + self.refresh_session_state(workflow_id, sid) + + self._socketio.emit("graph_update", data, room=workflow_id, skip_sid=sid) + + return {"msg": "graph_update_broadcasted"}, 200 + + def relay_skill_event(self, sid: str, data: object) -> tuple[dict[str, str], int]: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return {"msg": "unauthorized"}, 401 + + workflow_id = mapping["workflow_id"] + self.refresh_session_state(workflow_id, sid) + + self._socketio.emit("skill_update", data, room=workflow_id, skip_sid=sid) + + return {"msg": "skill_update_broadcasted"}, 200 + + def get_or_set_leader(self, workflow_id: str, sid: str) -> str | None: + current_leader = self._repository.get_current_leader(workflow_id) + + if current_leader: + if self.is_session_active(workflow_id, current_leader) and self._repository.is_graph_active( + workflow_id, current_leader + ): + return current_leader + self._repository.delete_session(workflow_id, current_leader) + self._repository.delete_leader(workflow_id) + + new_leader_sid = self._select_graph_leader(workflow_id, preferred_sid=sid) + if not new_leader_sid: + return None + + was_set = self._repository.set_leader_if_absent(workflow_id, new_leader_sid) + + if was_set: + if current_leader: + self.broadcast_leader_change(workflow_id, new_leader_sid) + return new_leader_sid + + current_leader = self._repository.get_current_leader(workflow_id) + if current_leader: + return current_leader + + return new_leader_sid + + def handle_leader_disconnect(self, workflow_id: str, disconnected_sid: str) -> None: + current_leader = self._repository.get_current_leader(workflow_id) + if not current_leader: + return + + if current_leader != disconnected_sid: + return + + new_leader_sid = self._select_graph_leader(workflow_id) + if new_leader_sid: + self._repository.set_leader(workflow_id, new_leader_sid) + self.broadcast_leader_change(workflow_id, new_leader_sid) + else: + self._repository.delete_leader(workflow_id) + self.broadcast_leader_change(workflow_id, None) + + def handle_skill_leader_disconnect(self, workflow_id: str, file_id: str, disconnected_sid: str) -> None: + current_leader = self._repository.get_skill_leader(workflow_id, file_id) + if not current_leader: + return + + if current_leader != disconnected_sid: + return + + new_leader_sid = self._select_skill_leader(workflow_id, file_id) + if new_leader_sid: + self._repository.set_skill_leader(workflow_id, file_id, new_leader_sid) + self.broadcast_skill_leader_change(workflow_id, file_id, new_leader_sid) + else: + self._repository.delete_skill_leader(workflow_id, file_id) + self.broadcast_skill_leader_change(workflow_id, file_id, None) + + def broadcast_leader_change(self, workflow_id: str, new_leader_sid: str | None) -> None: + for sid in self._repository.get_session_sids(workflow_id): + try: + is_leader = new_leader_sid is not None and sid == new_leader_sid + self._socketio.emit("status", {"isLeader": is_leader}, room=sid) + except Exception: + logging.exception("Failed to emit leader status to session %s", sid) + + def broadcast_skill_leader_change(self, workflow_id: str, file_id: str, new_leader_sid: str | None) -> None: + for sid in self._repository.get_session_sids(workflow_id): + try: + is_leader = new_leader_sid is not None and sid == new_leader_sid + self._socketio.emit("skill_status", {"file_id": file_id, "isLeader": is_leader}, room=sid) + except Exception: + logging.exception("Failed to emit skill leader status to session %s", sid) + + def get_current_leader(self, workflow_id: str) -> str | None: + return self._repository.get_current_leader(workflow_id) + + def _prune_inactive_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]: + """Remove inactive sessions from storage and return active sessions only.""" + sessions = self._repository.list_sessions(workflow_id) + if not sessions: + return [] + + active_sessions: list[WorkflowSessionInfo] = [] + stale_sids: list[str] = [] + for session in sessions: + sid = session["sid"] + if self.is_session_active(workflow_id, sid): + active_sessions.append(session) + else: + stale_sids.append(sid) + + for sid in stale_sids: + self._repository.delete_session(workflow_id, sid) + + return active_sessions + + def broadcast_online_users(self, workflow_id: str) -> None: + users = self._prune_inactive_sessions(workflow_id) + users.sort(key=lambda x: x.get("connected_at") or 0) + + leader_sid = self.get_current_leader(workflow_id) + previous_leader = leader_sid + active_sids = {user["sid"] for user in users} + if leader_sid and leader_sid not in active_sids: + self._repository.delete_leader(workflow_id) + leader_sid = None + + if not leader_sid and users: + leader_sid = self._select_graph_leader(workflow_id) + if leader_sid: + self._repository.set_leader(workflow_id, leader_sid) + + if leader_sid != previous_leader: + self.broadcast_leader_change(workflow_id, leader_sid) + + self._socketio.emit( + "online_users", + {"workflow_id": workflow_id, "users": users, "leader": leader_sid}, + room=workflow_id, + ) + + def refresh_session_state(self, workflow_id: str, sid: str) -> None: + self._repository.refresh_session_state(workflow_id, sid) + self._ensure_leader(workflow_id, sid) + active_skill_file_id = self._repository.get_active_skill_file_id(workflow_id, sid) + if active_skill_file_id: + self._ensure_skill_leader(workflow_id, active_skill_file_id, preferred_sid=sid) + + def _ensure_leader(self, workflow_id: str, sid: str) -> None: + current_leader = self._repository.get_current_leader(workflow_id) + if ( + current_leader + and self.is_session_active(workflow_id, current_leader) + and self._repository.is_graph_active(workflow_id, current_leader) + ): + self._repository.expire_leader(workflow_id) + return + + if current_leader: + self._repository.delete_leader(workflow_id) + + new_leader_sid = self._select_graph_leader(workflow_id, preferred_sid=sid) + if not new_leader_sid: + self.broadcast_leader_change(workflow_id, None) + return + + self._repository.set_leader(workflow_id, new_leader_sid) + self.broadcast_leader_change(workflow_id, new_leader_sid) + + def _ensure_skill_leader(self, workflow_id: str, file_id: str, preferred_sid: str | None = None) -> None: + current_leader = self._repository.get_skill_leader(workflow_id, file_id) + active_sids = self._repository.get_active_skill_session_sids(workflow_id, file_id) + if current_leader and self.is_session_active(workflow_id, current_leader): + if current_leader in active_sids or not active_sids: + self._repository.expire_skill_leader(workflow_id, file_id) + return + + if current_leader: + self._repository.delete_skill_leader(workflow_id, file_id) + + new_leader_sid = self._select_skill_leader(workflow_id, file_id, preferred_sid=preferred_sid) + if not new_leader_sid: + self.broadcast_skill_leader_change(workflow_id, file_id, None) + return + + self._repository.set_skill_leader(workflow_id, file_id, new_leader_sid) + self.broadcast_skill_leader_change(workflow_id, file_id, new_leader_sid) + + def _select_graph_leader(self, workflow_id: str, preferred_sid: str | None = None) -> str | None: + session_sids = [ + session["sid"] + for session in self._repository.list_sessions(workflow_id) + if session.get("graph_active") and self.is_session_active(workflow_id, session["sid"]) + ] + if not session_sids: + return None + if preferred_sid and preferred_sid in session_sids: + return preferred_sid + return session_sids[0] + + def _select_skill_leader(self, workflow_id: str, file_id: str, preferred_sid: str | None = None) -> str | None: + session_sids = [ + sid + for sid in self._repository.get_active_skill_session_sids(workflow_id, file_id) + if self.is_session_active(workflow_id, sid) + ] + if not session_sids: + return None + if preferred_sid and preferred_sid in session_sids: + return preferred_sid + return session_sids[0] + + def is_session_active(self, workflow_id: str, sid: str) -> bool: + if not sid: + return False + + try: + if not self._socketio.manager.is_connected(sid, "/"): + return False + except AttributeError: + return False + + if not self._repository.session_exists(workflow_id, sid): + return False + + if not self._repository.sid_mapping_exists(sid): + return False + + return True diff --git a/api/services/workflow_comment_service.py b/api/services/workflow_comment_service.py new file mode 100644 index 0000000000..93b2d2f01e --- /dev/null +++ b/api/services/workflow_comment_service.py @@ -0,0 +1,468 @@ +import logging +from collections.abc import Sequence + +from sqlalchemy import desc, select +from sqlalchemy.orm import Session, selectinload +from werkzeug.exceptions import Forbidden, NotFound + +from configs import dify_config +from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now +from libs.helper import uuid_value +from models import App, TenantAccountJoin, WorkflowComment, WorkflowCommentMention, WorkflowCommentReply +from models.account import Account +from tasks.mail_workflow_comment_task import send_workflow_comment_mention_email_task + +logger = logging.getLogger(__name__) + + +class WorkflowCommentService: + """Service for managing workflow comments.""" + + @staticmethod + def _validate_content(content: str) -> None: + if len(content.strip()) == 0: + raise ValueError("Comment content cannot be empty") + + if len(content) > 1000: + raise ValueError("Comment content cannot exceed 1000 characters") + + @staticmethod + def _filter_valid_mentioned_user_ids(mentioned_user_ids: Sequence[str]) -> list[str]: + """Return deduplicated UUID user IDs in the order provided.""" + unique_user_ids: list[str] = [] + seen: set[str] = set() + for user_id in mentioned_user_ids: + if not isinstance(user_id, str): + continue + if not uuid_value(user_id): + continue + if user_id in seen: + continue + seen.add(user_id) + unique_user_ids.append(user_id) + return unique_user_ids + + @staticmethod + def _format_comment_excerpt(content: str, max_length: int = 200) -> str: + """Trim comment content for email display.""" + trimmed = content.strip() + if len(trimmed) <= max_length: + return trimmed + if max_length <= 3: + return trimmed[:max_length] + return f"{trimmed[: max_length - 3].rstrip()}..." + + @staticmethod + def _build_mention_email_payloads( + session: Session, + tenant_id: str, + app_id: str, + mentioner_id: str, + mentioned_user_ids: Sequence[str], + content: str, + ) -> list[dict[str, str]]: + """Prepare email payloads for mentioned users, including the workflow app link.""" + if not mentioned_user_ids: + return [] + + candidate_user_ids = [user_id for user_id in mentioned_user_ids if user_id != mentioner_id] + if not candidate_user_ids: + return [] + + app_name = session.scalar(select(App.name).where(App.id == app_id, App.tenant_id == tenant_id)) or "Dify app" + commenter_name = session.scalar(select(Account.name).where(Account.id == mentioner_id)) or "Dify user" + comment_excerpt = WorkflowCommentService._format_comment_excerpt(content) + base_url = dify_config.CONSOLE_WEB_URL.rstrip("/") + app_url = f"{base_url}/app/{app_id}/workflow" + + accounts = session.scalars( + select(Account) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where(TenantAccountJoin.tenant_id == tenant_id, Account.id.in_(candidate_user_ids)) + ).all() + + payloads: list[dict[str, str]] = [] + for account in accounts: + payloads.append( + { + "language": account.interface_language or "en-US", + "to": account.email, + "mentioned_name": account.name or account.email, + "commenter_name": commenter_name, + "app_name": app_name, + "comment_content": comment_excerpt, + "app_url": app_url, + } + ) + return payloads + + @staticmethod + def _dispatch_mention_emails(payloads: Sequence[dict[str, str]]) -> None: + """Enqueue mention notification emails.""" + for payload in payloads: + send_workflow_comment_mention_email_task.delay(**payload) + + @staticmethod + def get_comments(tenant_id: str, app_id: str) -> Sequence[WorkflowComment]: + """Get all comments for a workflow.""" + with Session(db.engine) as session: + # Get all comments with eager loading + stmt = ( + select(WorkflowComment) + .options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions)) + .where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id) + .order_by(desc(WorkflowComment.created_at)) + ) + + comments = session.scalars(stmt).all() + + # Batch preload all Account objects to avoid N+1 queries + WorkflowCommentService._preload_accounts(session, comments) + + return comments + + @staticmethod + def _preload_accounts(session: Session, comments: Sequence[WorkflowComment]) -> None: + """Batch preload Account objects for comments, replies, and mentions.""" + # Collect all user IDs + user_ids: set[str] = set() + for comment in comments: + user_ids.add(comment.created_by) + if comment.resolved_by: + user_ids.add(comment.resolved_by) + user_ids.update(reply.created_by for reply in comment.replies) + user_ids.update(mention.mentioned_user_id for mention in comment.mentions) + + if not user_ids: + return + + # Batch query all accounts + accounts = session.scalars(select(Account).where(Account.id.in_(user_ids))).all() + account_map = {str(account.id): account for account in accounts} + + # Cache accounts on objects + for comment in comments: + comment.cache_created_by_account(account_map.get(comment.created_by)) + comment.cache_resolved_by_account(account_map.get(comment.resolved_by) if comment.resolved_by else None) + for reply in comment.replies: + reply.cache_created_by_account(account_map.get(reply.created_by)) + for mention in comment.mentions: + mention.cache_mentioned_user_account(account_map.get(mention.mentioned_user_id)) + + @staticmethod + def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session | None = None) -> WorkflowComment: + """Get a specific comment.""" + + def _get_comment(session: Session) -> WorkflowComment: + stmt = ( + select(WorkflowComment) + .options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions)) + .where( + WorkflowComment.id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + ) + comment = session.scalar(stmt) + + if not comment: + raise NotFound("Comment not found") + + # Preload accounts to avoid N+1 queries + WorkflowCommentService._preload_accounts(session, [comment]) + + return comment + + if session is not None: + return _get_comment(session) + else: + with Session(db.engine, expire_on_commit=False) as session: + return _get_comment(session) + + @staticmethod + def create_comment( + tenant_id: str, + app_id: str, + created_by: str, + content: str, + position_x: float, + position_y: float, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Create a new workflow comment and send mention notification emails.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine) as session: + comment = WorkflowComment( + tenant_id=tenant_id, + app_id=app_id, + position_x=position_x, + position_y=position_y, + content=content, + created_by=created_by, + ) + + session.add(comment) + session.flush() # Get the comment ID for mentions + + # Create mentions if specified + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or []) + for user_id in mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=comment.id, + reply_id=None, # This is a comment mention, not reply mention + mentioned_user_id=user_id, + ) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=tenant_id, + app_id=app_id, + mentioner_id=created_by, + mentioned_user_ids=mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + # Return only what we need - id and created_at + return {"id": comment.id, "created_at": comment.created_at} + + @staticmethod + def update_comment( + tenant_id: str, + app_id: str, + comment_id: str, + user_id: str, + content: str, + position_x: float | None = None, + position_y: float | None = None, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Update a workflow comment and notify newly mentioned users.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + # Get comment with validation + stmt = select(WorkflowComment).where( + WorkflowComment.id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + comment = session.scalar(stmt) + + if not comment: + raise NotFound("Comment not found") + + # Only the creator can update the comment + if comment.created_by != user_id: + raise Forbidden("Only the comment creator can update it") + + # Update comment fields + comment.content = content + if position_x is not None: + comment.position_x = position_x + if position_y is not None: + comment.position_y = position_y + + # Update mentions - first remove existing mentions for this comment only (not replies) + existing_mentions = session.scalars( + select(WorkflowCommentMention).where( + WorkflowCommentMention.comment_id == comment.id, + WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions + ) + ).all() + existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions} + for mention in existing_mentions: + session.delete(mention) + + # Add new mentions + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or []) + new_mentioned_user_ids = [ + user_id for user_id in mentioned_user_ids if user_id not in existing_mentioned_user_ids + ] + for user_id_str in mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=comment.id, + reply_id=None, # This is a comment mention + mentioned_user_id=user_id_str, + ) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=tenant_id, + app_id=app_id, + mentioner_id=user_id, + mentioned_user_ids=new_mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": comment.id, "updated_at": comment.updated_at} + + @staticmethod + def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None: + """Delete a workflow comment.""" + with Session(db.engine, expire_on_commit=False) as session: + comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session) + + # Only the creator can delete the comment + if comment.created_by != user_id: + raise Forbidden("Only the comment creator can delete it") + + # Delete associated mentions (both comment and reply mentions) + mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id) + ).all() + for mention in mentions: + session.delete(mention) + + # Delete associated replies + replies = session.scalars( + select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id) + ).all() + for reply in replies: + session.delete(reply) + + session.delete(comment) + session.commit() + + @staticmethod + def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment: + """Resolve a workflow comment.""" + with Session(db.engine, expire_on_commit=False) as session: + comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session) + if comment.resolved: + return comment + + comment.resolved = True + comment.resolved_at = naive_utc_now() + comment.resolved_by = user_id + session.commit() + + return comment + + @staticmethod + def create_reply( + comment_id: str, content: str, created_by: str, mentioned_user_ids: list[str] | None = None + ) -> dict: + """Add a reply to a workflow comment and notify mentioned users.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + # Check if comment exists + comment = session.get(WorkflowComment, comment_id) + if not comment: + raise NotFound("Comment not found") + + reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by) + + session.add(reply) + session.flush() # Get the reply ID for mentions + + # Create mentions if specified + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or []) + for user_id in mentioned_user_ids: + # Create mention linking to specific reply + mention = WorkflowCommentMention(comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=comment.tenant_id, + app_id=comment.app_id, + mentioner_id=created_by, + mentioned_user_ids=mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": reply.id, "created_at": reply.created_at} + + @staticmethod + def update_reply(reply_id: str, user_id: str, content: str, mentioned_user_ids: list[str] | None = None) -> dict: + """Update a comment reply and notify newly mentioned users.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + reply = session.get(WorkflowCommentReply, reply_id) + if not reply: + raise NotFound("Reply not found") + + # Only the creator can update the reply + if reply.created_by != user_id: + raise Forbidden("Only the reply creator can update it") + + reply.content = content + + # Update mentions - first remove existing mentions for this reply + existing_mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id) + ).all() + existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions} + for mention in existing_mentions: + session.delete(mention) + + # Add mentions + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or []) + new_mentioned_user_ids = [ + user_id for user_id in mentioned_user_ids if user_id not in existing_mentioned_user_ids + ] + for user_id_str in mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str + ) + session.add(mention) + + mention_email_payloads: list[dict[str, str]] = [] + comment = session.get(WorkflowComment, reply.comment_id) + if comment: + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=comment.tenant_id, + app_id=comment.app_id, + mentioner_id=user_id, + mentioned_user_ids=new_mentioned_user_ids, + content=content, + ) + + session.commit() + session.refresh(reply) # Refresh to get updated timestamp + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": reply.id, "updated_at": reply.updated_at} + + @staticmethod + def delete_reply(reply_id: str, user_id: str) -> None: + """Delete a comment reply.""" + with Session(db.engine, expire_on_commit=False) as session: + reply = session.get(WorkflowCommentReply, reply_id) + if not reply: + raise NotFound("Reply not found") + + # Only the creator can delete the reply + if reply.created_by != user_id: + raise Forbidden("Only the reply creator can delete it") + + # Delete associated mentions first + mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id) + ).all() + for mention in mentions: + session.delete(mention) + + session.delete(reply) + session.commit() + + @staticmethod + def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment: + """Validate that a comment belongs to the specified tenant and app.""" + return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id) diff --git a/api/tasks/mail_workflow_comment_task.py b/api/tasks/mail_workflow_comment_task.py new file mode 100644 index 0000000000..36d51f0514 --- /dev/null +++ b/api/tasks/mail_workflow_comment_task.py @@ -0,0 +1,65 @@ +import logging +import time + +import click +from celery import shared_task + +from extensions.ext_mail import mail +from libs.email_i18n import EmailType, get_email_i18n_service + +logger = logging.getLogger(__name__) + + +@shared_task(queue="mail") +def send_workflow_comment_mention_email_task( + language: str, + to: str, + mentioned_name: str, + commenter_name: str, + app_name: str, + comment_content: str, + app_url: str, +): + """ + Send workflow comment mention email with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + mentioned_name: Name of the mentioned user + commenter_name: Name of the comment author + app_name: Name of the app where the comment was made + comment_content: Comment content excerpt + app_url: Link to the app workflow page + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start workflow comment mention mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.WORKFLOW_COMMENT_MENTION, + language_code=language, + to=to, + template_context={ + "to": to, + "mentioned_name": mentioned_name, + "commenter_name": commenter_name, + "app_name": app_name, + "comment_content": comment_content, + "app_url": app_url, + }, + ) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Send workflow comment mention mail to {to} succeeded: latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception: + logger.exception("workflow comment mention email to %s failed", to)