feat(api): enable all sandbox/skill controller routes and resolve dependencies (P0)

Resolve the full dependency chain to enable all previously disabled controllers:

Enabled routes:
- sandbox_files: sandbox file browser API
- sandbox_providers: sandbox provider management API
- app_asset: app asset management API
- skills: skill extraction API
- CLI API blueprint: DifyCli callback endpoints (/cli/api/*)

Dependencies extracted (64 files, ~8000 lines):
- models/sandbox.py, models/app_asset.py: DB models
- core/zip_sandbox/: zip-based sandbox execution
- core/session/: CLI API session management
- core/memory/: base memory + node token buffer
- core/helper/creators.py: helper utilities
- core/llm_generator/: context models, output models, utils
- core/workflow/nodes/command/: command node type
- core/workflow/nodes/file_upload/: file upload node type
- core/app/entities/: app_asset_entities, app_bundle_entities, llm_generation_entities
- services/: asset_content, skill, workflow_collaboration, workflow_comment
- controllers/console/app/error.py: AppAsset error classes
- core/tools/utils/system_encryption.py

Import fixes:
- dify_graph.enums -> graphon.enums in skill_service.py
- get_signed_file_url_for_plugin -> get_signed_file_url in cli_api.py

All 5 controllers verified: import OK, Flask starts successfully.
46 existing tests still pass.

Made-with: Cursor
This commit is contained in:
Yansong Zhang 2026-04-09 09:36:16 +08:00
parent d3d9f21cdf
commit 44491e427c
64 changed files with 8030 additions and 12 deletions

View File

@ -22,7 +22,7 @@ from core.session.cli_api import CliContext
from core.skill.entities import ToolInvocationRequest
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from graphon.file.helpers import get_signed_file_url_for_plugin
from graphon.file.helpers import get_signed_file_url
from libs.helper import length_prefixed_response
from models.account import Account
from models.model import EndUser, Tenant
@ -139,11 +139,9 @@ class CliUploadFileRequestApi(Resource):
payload: RequestRequestUploadFile,
cli_context: CliContext,
):
url = get_signed_file_url_for_plugin(
filename=payload.filename,
mimetype=payload.mimetype,
url = get_signed_file_url(
upload_file_id=f"{tenant_model.id}_{user_model.id}_{payload.filename}",
tenant_id=tenant_model.id,
user_id=user_model.id,
)
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()

View File

@ -41,7 +41,7 @@ from . import (
init_validate,
notification,
ping,
# sandbox_files, # TODO: enable after full sandbox integration
sandbox_files,
setup,
spec,
version,
@ -53,7 +53,7 @@ from .app import (
agent,
annotation,
app,
# app_asset, # TODO: enable after full sandbox integration
app_asset,
audio,
completion,
conversation,
@ -64,7 +64,7 @@ from .app import (
model_config,
ops_trace,
site,
# skills, # TODO: enable after full sandbox integration
skills,
statistic,
workflow,
workflow_app_log,
@ -133,7 +133,7 @@ from .workspace import (
model_providers,
models,
plugin,
# sandbox_providers, # TODO: enable after full sandbox integration
sandbox_providers,
tool_providers,
trigger_providers,
workspace,

View File

@ -121,3 +121,21 @@ class NeedAddIdsError(BaseHTTPException):
error_code = "need_add_ids"
description = "Need to add ids."
code = 400
class AppAssetNodeNotFoundError(BaseHTTPException):
error_code = "app_asset_node_not_found"
description = "App asset node not found."
code = 404
class AppAssetFileRequiredError(BaseHTTPException):
error_code = "app_asset_file_required"
description = "File is required."
code = 400
class AppAssetPathConflictError(BaseHTTPException):
error_code = "app_asset_path_conflict"
description = "Path already exists."
code = 409

View File

@ -0,0 +1,322 @@
import logging
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, TypeAdapter
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from fields.member_fields import AccountWithRole
from fields.workflow_comment_fields import (
workflow_comment_basic_fields,
workflow_comment_create_fields,
workflow_comment_detail_fields,
workflow_comment_reply_create_fields,
workflow_comment_reply_update_fields,
workflow_comment_resolve_fields,
workflow_comment_update_fields,
)
from libs.login import current_user, login_required
from models import App
from services.account_service import TenantService
from services.workflow_comment_service import WorkflowCommentService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowCommentCreatePayload(BaseModel):
position_x: float = Field(..., description="Comment X position")
position_y: float = Field(..., description="Comment Y position")
content: str = Field(..., description="Comment content")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentUpdatePayload(BaseModel):
content: str = Field(..., description="Comment content")
position_x: float | None = Field(default=None, description="Comment X position")
position_y: float | None = Field(default=None, description="Comment Y position")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentReplyCreatePayload(BaseModel):
content: str = Field(..., description="Reply content")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentReplyUpdatePayload(BaseModel):
content: str = Field(..., description="Reply content")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentMentionUsersResponse(BaseModel):
users: list[AccountWithRole] = Field(description="Mentionable users")
for model in (
WorkflowCommentCreatePayload,
WorkflowCommentUpdatePayload,
WorkflowCommentReplyCreatePayload,
WorkflowCommentReplyUpdatePayload,
):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
for model in (AccountWithRole, WorkflowCommentMentionUsersResponse):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
workflow_comment_reply_create_model = console_ns.model(
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
)
workflow_comment_reply_update_model = console_ns.model(
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
)
workflow_comment_mention_users_model = console_ns.models[WorkflowCommentMentionUsersResponse.__name__]
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
class WorkflowCommentListApi(Resource):
"""API for listing and creating workflow comments."""
@console_ns.doc("list_workflow_comments")
@console_ns.doc(description="Get all comments for a workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_basic_model, envelope="data")
def get(self, app_model: App):
"""Get all comments for a workflow."""
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
return comments
@console_ns.doc("create_workflow_comment")
@console_ns.doc(description="Create a new workflow comment")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_create_model)
def post(self, app_model: App):
"""Create a new workflow comment."""
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
result = WorkflowCommentService.create_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
created_by=current_user.id,
content=payload.content,
position_x=payload.position_x,
position_y=payload.position_y,
mentioned_user_ids=payload.mentioned_user_ids,
)
return result, 201
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
class WorkflowCommentDetailApi(Resource):
"""API for managing individual workflow comments."""
@console_ns.doc("get_workflow_comment")
@console_ns.doc(description="Get a specific workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_detail_model)
def get(self, app_model: App, comment_id: str):
"""Get a specific workflow comment."""
comment = WorkflowCommentService.get_comment(
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
)
return comment
@console_ns.doc("update_workflow_comment")
@console_ns.doc(description="Update a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_update_model)
def put(self, app_model: App, comment_id: str):
"""Update a workflow comment."""
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
result = WorkflowCommentService.update_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
content=payload.content,
position_x=payload.position_x,
position_y=payload.position_y,
mentioned_user_ids=payload.mentioned_user_ids,
)
return result
@console_ns.doc("delete_workflow_comment")
@console_ns.doc(description="Delete a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(204, "Comment deleted successfully")
@login_required
@setup_required
@account_initialization_required
@get_app_model()
def delete(self, app_model: App, comment_id: str):
"""Delete a workflow comment."""
WorkflowCommentService.delete_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
class WorkflowCommentResolveApi(Resource):
"""API for resolving and reopening workflow comments."""
@console_ns.doc("resolve_workflow_comment")
@console_ns.doc(description="Resolve a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_resolve_model)
def post(self, app_model: App, comment_id: str):
"""Resolve a workflow comment."""
comment = WorkflowCommentService.resolve_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return comment
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
class WorkflowCommentReplyApi(Resource):
"""API for managing comment replies."""
@console_ns.doc("create_workflow_comment_reply")
@console_ns.doc(description="Add a reply to a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__])
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_reply_create_model)
def post(self, app_model: App, comment_id: str):
"""Add a reply to a workflow comment."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {})
result = WorkflowCommentService.create_reply(
comment_id=comment_id,
content=payload.content,
created_by=current_user.id,
mentioned_user_ids=payload.mentioned_user_ids,
)
return result, 201
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
class WorkflowCommentReplyDetailApi(Resource):
"""API for managing individual comment replies."""
@console_ns.doc("update_workflow_comment_reply")
@console_ns.doc(description="Update a comment reply")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
@console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__])
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_reply_update_model)
def put(self, app_model: App, comment_id: str, reply_id: str):
"""Update a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {})
reply = WorkflowCommentService.update_reply(
reply_id=reply_id,
user_id=current_user.id,
content=payload.content,
mentioned_user_ids=payload.mentioned_user_ids,
)
return reply
@console_ns.doc("delete_workflow_comment_reply")
@console_ns.doc(description="Delete a comment reply")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
@console_ns.response(204, "Reply deleted successfully")
@login_required
@setup_required
@account_initialization_required
@get_app_model()
def delete(self, app_model: App, comment_id: str, reply_id: str):
"""Delete a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
class WorkflowCommentMentionUsersApi(Resource):
"""API for getting mentionable users for workflow comments."""
@console_ns.doc("workflow_comment_mention_users")
@console_ns.doc(description="Get all users in current tenant for mentions")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Mentionable users retrieved successfully", workflow_comment_mention_users_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
def get(self, app_model: App):
"""Get all users in current tenant for mentions."""
members = TenantService.get_tenant_members(current_user.current_tenant)
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = WorkflowCommentMentionUsersResponse(users=member_models)
return response.model_dump(mode="json"), 200

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,119 @@
import logging
from collections.abc import Callable
from typing import cast
from flask import Request as FlaskRequest
from extensions.ext_socketio import sio
from libs.passport import PassportService
from libs.token import extract_access_token
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository
from services.account_service import AccountService
from services.workflow_collaboration_service import WorkflowCollaborationService
repository = WorkflowCollaborationRepository()
collaboration_service = WorkflowCollaborationService(repository, sio)
def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]:
return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event))
@_sio_on("connect")
def socket_connect(sid, environ, auth):
"""
WebSocket connect event, do authentication here.
"""
try:
request_environ = FlaskRequest(environ)
token = extract_access_token(request_environ)
except Exception:
logging.exception("Failed to extract token")
token = None
if not token:
logging.warning("Socket connect rejected: missing token (sid=%s)", sid)
return False
try:
decoded = PassportService().verify(token)
user_id = decoded.get("user_id")
if not user_id:
logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid)
return False
with sio.app.app_context():
user = AccountService.load_logged_in_account(account_id=user_id)
if not user:
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
return False
if not user.has_edit_permission:
logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid)
return False
collaboration_service.save_session(sid, user)
return True
except Exception:
logging.exception("Socket authentication failed")
return False
@_sio_on("user_connect")
def handle_user_connect(sid, data):
"""
Handle user connect event. Each session (tab) is treated as an independent collaborator.
"""
workflow_id = data.get("workflow_id")
if not workflow_id:
return {"msg": "workflow_id is required"}, 400
result = collaboration_service.register_session(workflow_id, sid)
if not result:
return {"msg": "unauthorized"}, 401
user_id, is_leader = result
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
@_sio_on("disconnect")
def handle_disconnect(sid):
"""
Handle session disconnect event. Remove the specific session from online users.
"""
collaboration_service.disconnect_session(sid)
@_sio_on("collaboration_event")
def handle_collaboration_event(sid, data):
"""
Handle general collaboration events, include:
1. mouse_move
2. vars_and_features_update
3. sync_request (ask leader to update graph)
4. app_state_update
5. mcp_server_update
6. workflow_update
7. comments_update
8. node_panel_presence
9. skill_file_active
10. skill_sync_request
11. skill_resync_request
"""
return collaboration_service.relay_collaboration_event(sid, data)
@_sio_on("graph_event")
def handle_graph_event(sid, data):
"""
Handle graph events - simple broadcast relay.
"""
return collaboration_service.relay_graph_event(sid, data)
@_sio_on("skill_event")
def handle_skill_event(sid, data):
"""
Handle skill events - simple broadcast relay.
"""
return collaboration_service.relay_skill_event(sid, data)

View File

@ -0,0 +1,67 @@
import json
import httpx
import yaml
from flask import request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.plugin.impl.exc import PluginPermissionDeniedError
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models.model import App
from models.workflow import Workflow
from services.app_dsl_service import AppDslService
class DSLPredictRequest(BaseModel):
app_id: str
current_node_id: str
@console_ns.route("/workspaces/current/dsl/predict")
class DSLPredictApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user, _ = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
args = DSLPredictRequest.model_validate(request.get_json())
app_id: str = args.app_id
current_node_id: str = args.current_node_id
with Session(db.engine) as session:
app = session.query(App).filter_by(id=app_id).first()
workflow = session.query(Workflow).filter_by(app_id=app_id, version=Workflow.VERSION_DRAFT).first()
if not app:
raise ValueError("App not found")
if not workflow:
raise ValueError("Workflow not found")
try:
i = 0
for node_id, _ in workflow.walk_nodes():
if node_id == current_node_id:
break
i += 1
dsl = yaml.safe_load(AppDslService.export_dsl(app_model=app))
response = httpx.post(
"http://spark-832c:8000/predict",
json={"graph_data": dsl, "source_node_index": i},
)
return {
"nodes": json.loads(response.json()),
}
except PluginPermissionDeniedError as e:
raise ValueError(e.description) from e

View File

@ -0,0 +1,380 @@
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentEntity, AgentLog, AgentResult
from core.agent.patterns.strategy_factory import StrategyFactory
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from graphon.file import file_manager
from graphon.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMUsage,
PromptMessage,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from models.model import Message
logger = logging.getLogger(__name__)
class AgentAppRunner(BaseAgentRunner):
def _create_tool_invoke_hook(self, message: Message):
"""
Create a tool invoke hook that uses ToolEngine.agent_invoke.
This hook handles file creation and returns proper meta information.
"""
# Get trace manager from app generate entity
trace_manager = self.application_generate_entity.trace_manager
def tool_invoke_hook(
tool: Tool, tool_args: dict[str, Any], tool_name: str
) -> tuple[str, list[str], ToolInvokeMeta]:
"""Hook that uses agent_invoke for proper file and meta handling."""
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool,
tool_parameters=tool_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=message.id,
conversation_id=self.conversation.id,
)
# Publish files and track IDs
for message_file_id in message_files:
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id),
PublishFrom.APPLICATION_MANAGER,
)
self._current_message_file_ids.append(message_file_id)
return tool_invoke_response, message_files, tool_invoke_meta
return tool_invoke_hook
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
"""
Run Agent application
"""
self.query = query
app_generate_entity = self.application_generate_entity
app_config = self.app_config
assert app_config is not None, "app_config is required"
assert app_config.agent is not None, "app_config.agent is required"
# convert tools into ModelRuntime Tool format
tool_instances, _ = self._init_prompt_tools()
assert app_config.agent
# Create tool invoke hook for agent_invoke
tool_invoke_hook = self._create_tool_invoke_hook(message)
# Get instruction for ReAct strategy
instruction = self.app_config.prompt_template.simple_prompt_template or ""
# Use factory to create appropriate strategy
strategy = StrategyFactory.create_strategy(
model_features=self.model_features,
model_instance=self.model_instance,
tools=list(tool_instances.values()),
files=list(self.files),
max_iterations=app_config.agent.max_iteration,
context=self.build_execution_context(),
agent_strategy=self.config.strategy,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)
# Initialize state variables
current_agent_thought_id: str | None = None
has_published_thought = False
current_tool_name: str | None = None
self._current_message_file_ids: list[str] = []
# organize prompt messages
prompt_messages = self._organize_prompt_messages()
# Run strategy
generator = strategy.run(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
stop=app_generate_entity.model_conf.stop,
stream=True,
)
# Consume generator and collect result
result: AgentResult | None = None
try:
while True:
try:
output = next(generator)
except StopIteration as e:
# Generator finished, get the return value
result = e.value
break
if isinstance(output, LLMResultChunk):
# Handle LLM chunk
if current_agent_thought_id and not has_published_thought:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
has_published_thought = True
yield output
elif isinstance(output, AgentLog):
# Handle Agent Log using log_type for type-safe dispatch
if output.status == AgentLog.LogStatus.START:
if output.log_type == AgentLog.LogType.ROUND:
# Start of a new round
message_file_ids: list[str] = []
current_agent_thought_id = self.create_agent_thought(
message_id=message.id,
message="",
tool_name="",
tool_input="",
messages_ids=message_file_ids,
)
has_published_thought = False
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
continue
# Tool call start - extract data from structured fields
current_tool_name = output.data.get("tool_name", "")
tool_input = output.data.get("tool_args", {})
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=current_tool_name,
tool_input=tool_input,
thought=None,
observation=None,
tool_invoke_meta=None,
answer=None,
messages_ids=[],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.status == AgentLog.LogStatus.SUCCESS:
if output.log_type == AgentLog.LogType.THOUGHT:
if current_agent_thought_id is None:
continue
thought_text = output.data.get("thought")
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=thought_text,
observation=None,
tool_invoke_meta=None,
answer=None,
messages_ids=[],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
continue
# Tool call finished
tool_output = output.data.get("output")
# Get meta from strategy output (now properly populated)
tool_meta = output.data.get("meta")
# Wrap tool_meta with tool_name as key (required by agent_service)
if tool_meta and current_tool_name:
tool_meta = {current_tool_name: tool_meta}
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=None,
observation=tool_output,
tool_invoke_meta=tool_meta,
answer=None,
messages_ids=self._current_message_file_ids,
)
# Clear message file ids after saving
self._current_message_file_ids = []
current_tool_name = None
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.log_type == AgentLog.LogType.ROUND:
if current_agent_thought_id is None:
continue
# Round finished - save LLM usage and answer
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
llm_result = output.data.get("llm_result")
final_answer = output.data.get("final_answer")
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=llm_result,
observation=None,
tool_invoke_meta=None,
answer=final_answer,
messages_ids=[],
llm_usage=llm_usage,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
except Exception:
# Re-raise any other exceptions
raise
# Process final result
if isinstance(result, AgentResult):
final_answer = result.text
usage = result.usage or LLMUsage.empty_usage()
# Publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=self.model_instance.model_name,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=usage,
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Initialize system message
"""
if not prompt_template:
return prompt_messages or []
prompt_messages = prompt_messages or []
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
return prompt_messages
if not prompt_messages:
return [SystemPromptMessage(content=prompt_template)]
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
We need to remove the image messages from the prompt messages at the first iteration.
"""
prompt_messages = deepcopy(prompt_messages)
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages
def _organize_prompt_messages(self):
# For ReAct strategy, use the agent prompt template
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
prompt_template = self.config.prompt.first_prompt
else:
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query or "", [])
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory,
).get_prompt()
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
return prompt_messages

View File

@ -0,0 +1,352 @@
from __future__ import annotations
import os
from collections import defaultdict
from collections.abc import Generator
from enum import StrEnum
from pydantic import BaseModel, Field
class AssetNodeType(StrEnum):
FILE = "file"
FOLDER = "folder"
class AppAssetNode(BaseModel):
id: str = Field(description="Unique identifier for the node")
node_type: AssetNodeType = Field(description="Type of node: file or folder")
name: str = Field(description="Name of the file or folder")
parent_id: str | None = Field(default=None, description="Parent folder ID, None for root level")
order: int = Field(default=0, description="Sort order within parent folder, lower values first")
extension: str = Field(default="", description="File extension without dot, empty for folders")
size: int = Field(default=0, description="File size in bytes, 0 for folders")
@classmethod
def create_folder(cls, node_id: str, name: str, parent_id: str | None = None) -> AppAssetNode:
return cls(id=node_id, node_type=AssetNodeType.FOLDER, name=name, parent_id=parent_id)
@classmethod
def create_file(cls, node_id: str, name: str, parent_id: str | None = None, size: int = 0) -> AppAssetNode:
return cls(
id=node_id,
node_type=AssetNodeType.FILE,
name=name,
parent_id=parent_id,
extension=name.rsplit(".", 1)[-1] if "." in name else "",
size=size,
)
class AppAssetNodeView(BaseModel):
id: str = Field(description="Unique identifier for the node")
node_type: str = Field(description="Type of node: 'file' or 'folder'")
name: str = Field(description="Name of the file or folder")
path: str = Field(description="Full path from root, e.g. '/folder/file.txt'")
extension: str = Field(default="", description="File extension without dot")
size: int = Field(default=0, description="File size in bytes")
children: list[AppAssetNodeView] = Field(default_factory=list, description="Child nodes for folders")
class BatchUploadNode(BaseModel):
"""Structure for batch upload_url tree nodes, used for both input and output."""
name: str
node_type: AssetNodeType
size: int = 0
children: list[BatchUploadNode] = []
id: str | None = None
upload_url: str | None = None
def to_app_asset_nodes(self, parent_id: str | None = None) -> list[AppAssetNode]:
"""
Generate IDs when missing and convert to AppAssetNode list.
Mutates self to set id field when it is not set.
"""
from uuid import uuid4
self.id = self.id or str(uuid4())
nodes: list[AppAssetNode] = []
if self.node_type == AssetNodeType.FOLDER:
nodes.append(AppAssetNode.create_folder(self.id, self.name, parent_id))
for child in self.children:
nodes.extend(child.to_app_asset_nodes(self.id))
else:
nodes.append(AppAssetNode.create_file(self.id, self.name, parent_id, self.size))
return nodes
class TreeNodeNotFoundError(Exception):
"""Tree internal: node not found"""
pass
class TreeParentNotFoundError(Exception):
"""Tree internal: parent folder not found"""
pass
class TreePathConflictError(Exception):
"""Tree internal: path already exists"""
pass
class AppAssetFileTree(BaseModel):
"""
File tree structure for app assets using adjacency list pattern.
Design:
- Storage: Flat list with parent_id references (adjacency list)
- Path: Computed dynamically via get_path(), not stored
- Order: Integer field for user-defined sorting within each folder
- API response: transform() builds nested tree with computed paths
Why adjacency list over nested tree or materialized path:
- Simpler CRUD: move/rename only updates one node's parent_id
- No path cascade: renaming parent doesn't require updating all descendants
- JSON-friendly: flat list serializes cleanly to database JSON column
- Trade-off: path lookup is O(depth), acceptable for typical file trees
"""
nodes: list[AppAssetNode] = Field(default_factory=list, description="Flat list of all nodes in the tree")
def ensure_unique_name(
self,
parent_id: str | None,
name: str,
*,
is_file: bool,
extra_taken: set[str] | None = None,
) -> str:
"""
Return a sibling-unique name by appending numeric suffixes when needed.
The suffix format is " <n>" (e.g. "report 1", "report 2"). For files,
the suffix is inserted before the extension.
"""
taken = extra_taken or set()
if not self.has_child_named(parent_id, name) and name not in taken:
return name
suffix_index = 1
while True:
candidate = self._apply_name_suffix(name, suffix_index, is_file=is_file)
if not self.has_child_named(parent_id, candidate) and candidate not in taken:
return candidate
suffix_index += 1
@staticmethod
def _apply_name_suffix(name: str, suffix_index: int, *, is_file: bool) -> str:
if not is_file:
return f"{name} {suffix_index}"
stem, extension = os.path.splitext(name)
return f"{stem} {suffix_index}{extension}"
def get(self, node_id: str) -> AppAssetNode | None:
return next((n for n in self.nodes if n.id == node_id), None)
def get_children(self, parent_id: str | None) -> list[AppAssetNode]:
return [n for n in self.nodes if n.parent_id == parent_id]
def has_child_named(self, parent_id: str | None, name: str) -> bool:
return any(n.name == name and n.parent_id == parent_id for n in self.nodes)
def get_path(self, node_id: str) -> str:
node = self.get(node_id)
if not node:
raise TreeNodeNotFoundError(node_id)
parts: list[str] = []
current: AppAssetNode | None = node
while current:
parts.append(current.name)
current = self.get(current.parent_id) if current.parent_id else None
return "/".join(reversed(parts))
def relative_path(self, a: AppAssetNode, b: AppAssetNode) -> str:
"""
Calculate relative path from node a to node b for Markdown references.
Path is computed from a's parent directory (where the file resides).
Examples:
/foo/a.md -> /foo/b.md => ./b.md
/foo/a.md -> /foo/sub/b.md => ./sub/b.md
/foo/sub/a.md -> /foo/b.md => ../b.md
/foo/sub/deep/a.md -> /foo/b.md => ../../b.md
"""
def get_ancestor_ids(node_id: str | None) -> list[str]:
chain: list[str] = []
current_id = node_id
while current_id:
chain.append(current_id)
node = self.get(current_id)
current_id = node.parent_id if node else None
return chain
a_dir_ancestors = get_ancestor_ids(a.parent_id)
b_ancestors = [b.id] + get_ancestor_ids(b.parent_id)
a_dir_set = set(a_dir_ancestors)
lca_id: str | None = None
lca_index_in_b = -1
for idx, ancestor_id in enumerate(b_ancestors):
if ancestor_id in a_dir_set or (a.parent_id is None and b_ancestors[idx:] == []):
lca_id = ancestor_id
lca_index_in_b = idx
break
if a.parent_id is None:
steps_up = 0
lca_index_in_b = len(b_ancestors)
elif lca_id is None:
steps_up = len(a_dir_ancestors)
lca_index_in_b = len(b_ancestors)
else:
steps_up = 0
for ancestor_id in a_dir_ancestors:
if ancestor_id == lca_id:
break
steps_up += 1
path_down: list[str] = []
for i in range(lca_index_in_b - 1, -1, -1):
node = self.get(b_ancestors[i])
if node:
path_down.append(node.name)
if steps_up == 0:
return "./" + "/".join(path_down)
parts: list[str] = [".."] * steps_up + path_down
return "/".join(parts)
def get_descendant_ids(self, node_id: str) -> list[str]:
result: list[str] = []
stack = [node_id]
while stack:
current_id = stack.pop()
for child in self.nodes:
if child.parent_id == current_id:
result.append(child.id)
stack.append(child.id)
return result
def add(self, node: AppAssetNode) -> AppAssetNode:
if self.get(node.id):
raise TreePathConflictError(node.id)
if self.has_child_named(node.parent_id, node.name):
raise TreePathConflictError(node.name)
if node.parent_id:
parent = self.get(node.parent_id)
if not parent or parent.node_type != AssetNodeType.FOLDER:
raise TreeParentNotFoundError(node.parent_id)
siblings = self.get_children(node.parent_id)
node.order = max((s.order for s in siblings), default=-1) + 1
self.nodes.append(node)
return node
def update(self, node_id: str, size: int) -> AppAssetNode:
node = self.get(node_id)
if not node or node.node_type != AssetNodeType.FILE:
raise TreeNodeNotFoundError(node_id)
node.size = size
return node
def rename(self, node_id: str, new_name: str) -> AppAssetNode:
node = self.get(node_id)
if not node:
raise TreeNodeNotFoundError(node_id)
if node.name != new_name and self.has_child_named(node.parent_id, new_name):
raise TreePathConflictError(new_name)
node.name = new_name
if node.node_type == AssetNodeType.FILE:
node.extension = new_name.rsplit(".", 1)[-1] if "." in new_name else ""
return node
def move(self, node_id: str, new_parent_id: str | None) -> AppAssetNode:
node = self.get(node_id)
if not node:
raise TreeNodeNotFoundError(node_id)
if new_parent_id:
parent = self.get(new_parent_id)
if not parent or parent.node_type != AssetNodeType.FOLDER:
raise TreeParentNotFoundError(new_parent_id)
if self.has_child_named(new_parent_id, node.name):
raise TreePathConflictError(node.name)
node.parent_id = new_parent_id
siblings = self.get_children(new_parent_id)
node.order = max((s.order for s in siblings if s.id != node_id), default=-1) + 1
return node
def reorder(self, node_id: str, after_node_id: str | None) -> AppAssetNode:
node = self.get(node_id)
if not node:
raise TreeNodeNotFoundError(node_id)
siblings = sorted(self.get_children(node.parent_id), key=lambda x: x.order)
siblings = [s for s in siblings if s.id != node_id]
if after_node_id is None:
insert_idx = 0
else:
after_node = self.get(after_node_id)
if not after_node or after_node.parent_id != node.parent_id:
raise TreeNodeNotFoundError(after_node_id)
insert_idx = next((i for i, s in enumerate(siblings) if s.id == after_node_id), -1) + 1
siblings.insert(insert_idx, node)
for idx, sibling in enumerate(siblings):
sibling.order = idx
return node
def remove(self, node_id: str) -> list[str]:
node = self.get(node_id)
if not node:
raise TreeNodeNotFoundError(node_id)
ids_to_remove = [node_id] + self.get_descendant_ids(node_id)
self.nodes = [n for n in self.nodes if n.id not in ids_to_remove]
return ids_to_remove
def walk_files(self) -> Generator[AppAssetNode, None, None]:
return (n for n in self.nodes if n.node_type == AssetNodeType.FILE)
def transform(self) -> list[AppAssetNodeView]:
by_parent: dict[str | None, list[AppAssetNode]] = defaultdict(list)
for n in self.nodes:
by_parent[n.parent_id].append(n)
for children in by_parent.values():
children.sort(key=lambda x: x.order)
paths: dict[str, str] = {}
tree_views: dict[str, AppAssetNodeView] = {}
def build_view(node: AppAssetNode, parent_path: str) -> None:
path = f"{parent_path}/{node.name}"
paths[node.id] = path
child_views: list[AppAssetNodeView] = []
for child in by_parent.get(node.id, []):
build_view(child, path)
child_views.append(tree_views[child.id])
tree_views[node.id] = AppAssetNodeView(
id=node.id,
node_type=node.node_type.value,
name=node.name,
path=path,
extension=node.extension,
size=node.size,
children=child_views,
)
for root_node in by_parent.get(None, []):
build_view(root_node, "")
return [tree_views[n.id] for n in by_parent.get(None, [])]
def empty(self) -> bool:
return len(self.nodes) == 0

View File

@ -0,0 +1,96 @@
from __future__ import annotations
import re
from datetime import UTC, datetime
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.app_asset_entities import AppAssetFileTree
# Constants
BUNDLE_DSL_FILENAME_PATTERN = re.compile(r"^[^/]+\.ya?ml$")
BUNDLE_MAX_SIZE = 50 * 1024 * 1024 # 50MB
MANIFEST_FILENAME = "manifest.json"
MANIFEST_SCHEMA_VERSION = "1.0"
# Exceptions
class BundleFormatError(Exception):
"""Raised when bundle format is invalid."""
pass
class ZipSecurityError(Exception):
"""Raised when zip file contains security violations."""
pass
# Manifest DTOs
class ManifestFileEntry(BaseModel):
"""Maps node_id to file path in the bundle."""
model_config = ConfigDict(extra="forbid")
node_id: str
path: str
class ManifestIntegrity(BaseModel):
"""Basic integrity check fields."""
model_config = ConfigDict(extra="forbid")
file_count: int
class ManifestAppAssets(BaseModel):
"""App assets section containing the full tree."""
model_config = ConfigDict(extra="forbid")
tree: AppAssetFileTree
class BundleManifest(BaseModel):
"""
Bundle manifest for app asset import/export.
Schema version 1.0:
- dsl_filename: DSL file name in bundle root (e.g. "my_app.yml")
- tree: Full AppAssetFileTree (files + folders) for 100% restoration including node IDs
- files: Explicit node_id -> path mapping for file nodes only
- integrity: Basic file_count validation
"""
model_config = ConfigDict(extra="forbid")
schema_version: str = Field(default=MANIFEST_SCHEMA_VERSION)
generated_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC))
dsl_filename: str = Field(description="DSL file name in bundle root")
app_assets: ManifestAppAssets
files: list[ManifestFileEntry]
integrity: ManifestIntegrity
@property
def assets_prefix(self) -> str:
"""Assets directory name (DSL filename without extension)."""
return self.dsl_filename.rsplit(".", 1)[0]
@classmethod
def from_tree(cls, tree: AppAssetFileTree, dsl_filename: str) -> BundleManifest:
"""Build manifest from an AppAssetFileTree."""
files = [ManifestFileEntry(node_id=n.id, path=tree.get_path(n.id)) for n in tree.walk_files()]
return cls(
dsl_filename=dsl_filename,
app_assets=ManifestAppAssets(tree=tree),
files=files,
integrity=ManifestIntegrity(file_count=len(files)),
)
# Export result
class BundleExportResult(BaseModel):
download_url: str = Field(description="Temporary download URL for the ZIP")
filename: str = Field(description="Suggested filename for the ZIP")

View File

@ -0,0 +1,72 @@
"""
LLM Generation Detail entities.
Defines the structure for storing and transmitting LLM generation details
including reasoning content, tool calls, and their sequence.
"""
from typing import Literal
from pydantic import BaseModel, Field
class ContentSegment(BaseModel):
"""Represents a content segment in the generation sequence."""
type: Literal["content"] = "content"
start: int = Field(..., description="Start position in the text")
end: int = Field(..., description="End position in the text")
class ReasoningSegment(BaseModel):
"""Represents a reasoning segment in the generation sequence."""
type: Literal["reasoning"] = "reasoning"
index: int = Field(..., description="Index into reasoning_content array")
class ToolCallSegment(BaseModel):
"""Represents a tool call segment in the generation sequence."""
type: Literal["tool_call"] = "tool_call"
index: int = Field(..., description="Index into tool_calls array")
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
class ToolCallDetail(BaseModel):
"""Represents a tool call with its arguments and result."""
id: str = Field(default="", description="Unique identifier for the tool call")
name: str = Field(..., description="Name of the tool")
arguments: str = Field(default="", description="JSON string of tool arguments")
result: str = Field(default="", description="Result from the tool execution")
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
icon: str | dict | None = Field(default=None, description="Icon of the tool")
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
class LLMGenerationDetailData(BaseModel):
"""
Domain model for LLM generation detail.
Contains the structured data for reasoning content, tool calls,
and their display sequence.
"""
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
def is_empty(self) -> bool:
"""Check if there's any meaningful generation detail."""
return not self.reasoning_content and not self.tool_calls
def to_response_dict(self) -> dict:
"""Convert to dictionary for API response."""
return {
"reasoning_content": self.reasoning_content,
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
"sequence": [seg.model_dump() for seg in self.sequence],
}

View File

@ -0,0 +1,75 @@
"""
Helper module for Creators Platform integration.
Provides functionality to upload DSL files to the Creators Platform
and generate redirect URLs with OAuth authorization codes.
"""
import logging
from urllib.parse import urlencode
import httpx
from yarl import URL
from configs import dify_config
logger = logging.getLogger(__name__)
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
"""Upload a DSL file to the Creators Platform anonymous upload endpoint.
Args:
dsl_file_bytes: Raw bytes of the DSL file (YAML or ZIP).
filename: Original filename for the upload.
Returns:
The claim_code string used to retrieve the DSL later.
Raises:
httpx.HTTPStatusError: If the upload request fails.
ValueError: If the response does not contain a valid claim_code.
"""
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
response.raise_for_status()
data = response.json()
claim_code = data.get("data", {}).get("claim_code")
if not claim_code:
raise ValueError("Creators Platform did not return a valid claim_code")
return claim_code
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
"""Generate the redirect URL to the Creators Platform frontend.
Redirects to the Creators Platform root page with the dsl_claim_code.
If CREATORS_PLATFORM_OAUTH_CLIENT_ID is configured (Dify Cloud),
also signs an OAuth authorization code so the frontend can
automatically authenticate the user via the OAuth callback.
For self-hosted Dify without OAuth client_id configured, only the
dsl_claim_code is passed and the user must log in manually.
Args:
user_account_id: The Dify user account ID.
claim_code: The claim_code obtained from upload_dsl().
Returns:
The full redirect URL string.
"""
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
params: dict[str, str] = {"dsl_claim_code": claim_code}
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
if client_id:
from services.oauth_server import OAuthServerService
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
params["oauth_code"] = oauth_code
return f"{base_url}?{urlencode(params)}"

View File

@ -0,0 +1,62 @@
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
class VariableSelectorPayload(BaseModel):
model_config = ConfigDict(extra="forbid")
variable: str = Field(..., description="Variable name used in generated code")
value_selector: list[str] = Field(..., description="Path to upstream node output, format: [node_id, output_name]")
class CodeOutputPayload(BaseModel):
model_config = ConfigDict(extra="forbid")
type: str = Field(..., description="Output variable type")
class CodeContextPayload(BaseModel):
# From web/app/components/workflow/nodes/tool/components/context-generate-modal/index.tsx (code node snapshot).
model_config = ConfigDict(extra="forbid")
code: str = Field(..., description="Existing code in the Code node")
outputs: dict[str, CodeOutputPayload] | None = Field(
default=None, description="Existing output definitions for the Code node"
)
variables: list[VariableSelectorPayload] | None = Field(
default=None, description="Existing variable selectors used by the Code node"
)
class AvailableVarPayload(BaseModel):
# From web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts (available variables).
model_config = ConfigDict(extra="forbid", populate_by_name=True)
value_selector: list[str] = Field(..., description="Path to upstream node output")
type: str = Field(..., description="Variable type, e.g. string, number, array[object]")
description: str | None = Field(default=None, description="Optional variable description")
node_id: str | None = Field(default=None, description="Source node ID")
node_title: str | None = Field(default=None, description="Source node title")
node_type: str | None = Field(default=None, description="Source node type")
json_schema: dict[str, Any] | None = Field(
default=None,
alias="schema",
description="Optional JSON schema for object variables",
)
class ParameterInfoPayload(BaseModel):
# From web/app/components/workflow/nodes/tool/use-config.ts (ToolParameter metadata).
model_config = ConfigDict(extra="forbid")
name: str = Field(..., description="Target parameter name")
type: str = Field(default="string", description="Target parameter type")
description: str = Field(default="", description="Parameter description")
required: bool | None = Field(default=None, description="Whether the parameter is required")
options: list[str] | None = Field(default=None, description="Allowed option values")
min: float | None = Field(default=None, description="Minimum numeric value")
max: float | None = Field(default=None, description="Maximum numeric value")
default: str | int | float | bool | None = Field(default=None, description="Default value")
multiple: bool | None = Field(default=None, description="Whether the parameter accepts multiple values")
label: str | None = Field(default=None, description="Optional display label")

View File

@ -0,0 +1,67 @@
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, Field
from graphon.variables.types import SegmentType
class SuggestedQuestionsOutput(BaseModel):
"""Output model for suggested questions generation."""
model_config = ConfigDict(extra="forbid")
questions: list[str] = Field(
min_length=3,
max_length=3,
description="Exactly 3 suggested follow-up questions for the user",
)
class VariableSelectorOutput(BaseModel):
"""Variable selector mapping code variable to upstream node output.
Note: Separate from VariableSelector to ensure 'additionalProperties: false'
in JSON schema for OpenAI/Azure strict mode.
"""
model_config = ConfigDict(extra="forbid")
variable: str = Field(description="Variable name used in the generated code")
value_selector: list[str] = Field(description="Path to upstream node output, format: [node_id, output_name]")
class CodeNodeOutputItem(BaseModel):
"""Single output variable definition.
Note: OpenAI/Azure strict mode requires 'additionalProperties: false' and
does not support dynamic object keys, so outputs use array format.
"""
model_config = ConfigDict(extra="forbid")
name: str = Field(description="Output variable name returned by the main function")
type: SegmentType = Field(description="Data type of the output variable")
class CodeNodeStructuredOutput(BaseModel):
"""Structured output for code node generation."""
model_config = ConfigDict(extra="forbid")
variables: list[VariableSelectorOutput] = Field(
description="Input variables mapping code variables to upstream node outputs"
)
code: str = Field(description="Generated code with a main function that processes inputs and returns outputs")
outputs: list[CodeNodeOutputItem] = Field(
description="Output variable definitions specifying name and type for each return value"
)
message: str = Field(description="Brief explanation of what the generated code does")
class InstructionModifyOutput(BaseModel):
"""Output model for instruction-based prompt modification."""
model_config = ConfigDict(extra="forbid")
modified: str = Field(description="The modified prompt content after applying the instruction")
message: str = Field(description="Brief explanation of what changes were made")

View File

@ -0,0 +1,203 @@
"""
File path detection and conversion for structured output.
This module provides utilities to:
1. Detect sandbox file path fields in JSON Schema (format: "file-path")
2. Adapt schemas to add file-path descriptions before model invocation
3. Convert sandbox file path strings into File objects via a resolver
"""
from collections.abc import Callable, Mapping, Sequence
from typing import Any, cast
from graphon.file import File
from graphon.variables.segments import ArrayFileSegment, FileSegment
FILE_PATH_FORMAT = "file-path"
FILE_PATH_DESCRIPTION_SUFFIX = "this field contains a file path from the Dify sandbox"
def is_file_path_property(schema: Mapping[str, Any]) -> bool:
"""Check if a schema property represents a sandbox file path."""
if schema.get("type") != "string":
return False
format_value = schema.get("format")
if not isinstance(format_value, str):
return False
normalized_format = format_value.lower().replace("_", "-")
return normalized_format == FILE_PATH_FORMAT
def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
"""Recursively detect file path fields in a JSON schema."""
file_path_fields: list[str] = []
schema_type = schema.get("type")
if schema_type == "object":
properties = schema.get("properties")
if isinstance(properties, Mapping):
properties_mapping = cast(Mapping[str, Any], properties)
for prop_name, prop_schema in properties_mapping.items():
if not isinstance(prop_schema, Mapping):
continue
prop_schema_mapping = cast(Mapping[str, Any], prop_schema)
current_path = f"{path}.{prop_name}" if path else prop_name
if is_file_path_property(prop_schema_mapping):
file_path_fields.append(current_path)
else:
file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path))
elif schema_type == "array":
items_schema = schema.get("items")
if not isinstance(items_schema, Mapping):
return file_path_fields
items_schema_mapping = cast(Mapping[str, Any], items_schema)
array_path = f"{path}[*]" if path else "[*]"
if is_file_path_property(items_schema_mapping):
file_path_fields.append(array_path)
else:
file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path))
return file_path_fields
def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""Normalize sandbox file path fields and collect their JSON paths."""
result = _deep_copy_value(schema)
if not isinstance(result, dict):
raise ValueError("structured_output_schema must be a JSON object")
result_dict = cast(dict[str, Any], result)
file_path_fields: list[str] = []
_adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields)
return result_dict, file_path_fields
def convert_sandbox_file_paths_in_output(
output: Mapping[str, Any],
file_path_fields: Sequence[str],
file_resolver: Callable[[str], File],
) -> tuple[dict[str, Any], list[File]]:
"""Convert sandbox file paths into File objects using the resolver."""
if not file_path_fields:
return dict(output), []
result = _deep_copy_value(output)
if not isinstance(result, dict):
raise ValueError("Structured output must be a JSON object")
result_dict = cast(dict[str, Any], result)
files: list[File] = []
for path in file_path_fields:
_convert_path_in_place(result_dict, path.split("."), file_resolver, files)
return result_dict, files
def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None:
schema_type = schema.get("type")
if schema_type == "object":
properties = schema.get("properties")
if isinstance(properties, Mapping):
properties_mapping = cast(Mapping[str, Any], properties)
for prop_name, prop_schema in properties_mapping.items():
if not isinstance(prop_schema, dict):
continue
prop_schema_dict = cast(dict[str, Any], prop_schema)
current_path = f"{path}.{prop_name}" if path else prop_name
if is_file_path_property(prop_schema_dict):
_normalize_file_path_schema(prop_schema_dict)
file_path_fields.append(current_path)
else:
_adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields)
elif schema_type == "array":
items_schema = schema.get("items")
if not isinstance(items_schema, dict):
return
items_schema_dict = cast(dict[str, Any], items_schema)
array_path = f"{path}[*]" if path else "[*]"
if is_file_path_property(items_schema_dict):
_normalize_file_path_schema(items_schema_dict)
file_path_fields.append(array_path)
else:
_adapt_schema_in_place(items_schema_dict, array_path, file_path_fields)
def _normalize_file_path_schema(schema: dict[str, Any]) -> None:
schema["type"] = "string"
schema["format"] = FILE_PATH_FORMAT
description = schema.get("description", "")
if description:
if FILE_PATH_DESCRIPTION_SUFFIX not in description:
schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}"
else:
schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX
def _deep_copy_value(value: Any) -> Any:
if isinstance(value, Mapping):
mapping = cast(Mapping[str, Any], value)
return {key: _deep_copy_value(item) for key, item in mapping.items()}
if isinstance(value, list):
list_value = cast(list[Any], value)
return [_deep_copy_value(item) for item in list_value]
return value
def _convert_path_in_place(
obj: dict[str, Any],
path_parts: list[str],
file_resolver: Callable[[str], File],
files: list[File],
) -> None:
if not path_parts:
return
current = path_parts[0]
remaining = path_parts[1:]
if current.endswith("[*]"):
key = current[:-3] if current != "[*]" else ""
target_value = obj.get(key) if key else obj
if isinstance(target_value, list):
target_list = cast(list[Any], target_value)
if remaining:
for item in target_list:
if isinstance(item, dict):
item_dict = cast(dict[str, Any], item)
_convert_path_in_place(item_dict, remaining, file_resolver, files)
else:
resolved_files: list[File] = []
for item in target_list:
if not isinstance(item, str):
raise ValueError("File path must be a string")
file = file_resolver(item)
files.append(file)
resolved_files.append(file)
if key:
obj[key] = ArrayFileSegment(value=resolved_files)
return
if not remaining:
if current not in obj:
return
value = obj[current]
if value is None:
obj[current] = None
return
if not isinstance(value, str):
raise ValueError("File path must be a string")
file = file_resolver(value)
files.append(file)
obj[current] = FileSegment(value=file)
return
if current in obj and isinstance(obj[current], dict):
_convert_path_in_place(obj[current], remaining, file_resolver, files)

View File

@ -0,0 +1,45 @@
"""Utility functions for LLM generator."""
from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
def deserialize_prompt_messages(messages: list[dict]) -> list[PromptMessage]:
"""
Deserialize list of dicts to list[PromptMessage].
Expected format:
[
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."},
]
"""
result: list[PromptMessage] = []
for msg in messages:
role = PromptMessageRole.value_of(msg["role"])
content = msg.get("content", "")
match role:
case PromptMessageRole.USER:
result.append(UserPromptMessage(content=content))
case PromptMessageRole.ASSISTANT:
result.append(AssistantPromptMessage(content=content))
case PromptMessageRole.SYSTEM:
result.append(SystemPromptMessage(content=content))
case PromptMessageRole.TOOL:
result.append(ToolPromptMessage(content=content, tool_call_id=msg.get("tool_call_id", "")))
return result
def serialize_prompt_messages(messages: list[PromptMessage]) -> list[dict]:
"""
Serialize list[PromptMessage] to list of dicts.
"""
return [{"role": msg.role.value, "content": msg.content} for msg in messages]

View File

@ -0,0 +1,11 @@
from core.memory.base import BaseMemory
from core.memory.node_token_buffer_memory import (
NodeTokenBufferMemory,
)
from core.memory.token_buffer_memory import TokenBufferMemory
__all__ = [
"BaseMemory",
"NodeTokenBufferMemory",
"TokenBufferMemory",
]

82
api/core/memory/base.py Normal file
View File

@ -0,0 +1,82 @@
"""
Base memory interfaces and types.
This module defines the common protocol for memory implementations.
"""
from abc import ABC, abstractmethod
from collections.abc import Sequence
from graphon.model_runtime.entities import ImagePromptMessageContent, PromptMessage
class BaseMemory(ABC):
"""
Abstract base class for memory implementations.
Provides a common interface for both conversation-level and node-level memory.
"""
@abstractmethod
def get_history_prompt_messages(
self,
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: Maximum tokens for history
:param message_limit: Maximum number of messages
:return: Sequence of PromptMessage for LLM context
"""
pass
def get_history_prompt_text(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> str:
"""
Get history prompt as formatted text.
:param human_prefix: Prefix for human messages
:param ai_prefix: Prefix for assistant messages
:param max_token_limit: Maximum tokens for history
:param message_limit: Maximum number of messages
:return: Formatted history text
"""
from graphon.model_runtime.entities import (
PromptMessageRole,
TextPromptMessageContent,
)
prompt_messages = self.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit,
)
string_messages = []
for m in prompt_messages:
if m.role == PromptMessageRole.USER:
role = human_prefix
elif m.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(m.content, list):
inner_msg = ""
for content in m.content:
if isinstance(content, TextPromptMessageContent):
inner_msg += f"{content.data}\n"
elif isinstance(content, ImagePromptMessageContent):
inner_msg += "[image]\n"
string_messages.append(f"{role}: {inner_msg.strip()}")
else:
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)

View File

@ -0,0 +1,196 @@
"""
Node-level Token Buffer Memory for Chatflow.
This module provides node-scoped memory within a conversation.
Each LLM node in a workflow can maintain its own independent conversation history.
Note: This is only available in Chatflow (advanced-chat mode) because it requires
both conversation_id and node_id.
Design:
- History is read directly from WorkflowNodeExecutionModel.outputs["context"]
- No separate storage needed - the context is already saved during node execution
- Thread tracking leverages Message table's parent_message_id structure
"""
import logging
from collections.abc import Sequence
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from graphon.file import file_manager
from graphon.model_runtime.entities import (
AssistantPromptMessage,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from extensions.ext_database import db
from models.model import Message
from models.workflow import WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
class NodeTokenBufferMemory(BaseMemory):
"""
Node-level Token Buffer Memory.
Provides node-scoped memory within a conversation. Each LLM node can maintain
its own independent conversation history.
Key design: History is read directly from WorkflowNodeExecutionModel.outputs["context"],
which is already saved during node execution. No separate storage needed.
"""
def __init__(
self,
app_id: str,
conversation_id: str,
node_id: str,
tenant_id: str,
model_instance: ModelInstance,
):
self.app_id = app_id
self.conversation_id = conversation_id
self.node_id = node_id
self.tenant_id = tenant_id
self.model_instance = model_instance
def _get_thread_workflow_run_ids(self) -> list[str]:
"""
Get workflow_run_ids for the current thread by querying Message table.
Returns workflow_run_ids in chronological order (oldest first).
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(Message)
.where(Message.conversation_id == self.conversation_id)
.order_by(Message.created_at.desc())
.limit(500)
)
messages = list(session.scalars(stmt).all())
if not messages:
return []
# Extract thread messages using existing logic
thread_messages = extract_thread_messages(messages)
# For newly created message, its answer is temporarily empty, skip it
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
thread_messages.pop(0)
# Reverse to get chronological order, extract workflow_run_ids
return [msg.workflow_run_id for msg in reversed(thread_messages) if msg.workflow_run_id]
def _deserialize_prompt_message(self, msg_dict: dict) -> PromptMessage:
"""Deserialize a dict to PromptMessage based on role."""
role = msg_dict.get("role")
if role in (PromptMessageRole.USER, "user"):
return UserPromptMessage.model_validate(msg_dict)
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
return AssistantPromptMessage.model_validate(msg_dict)
elif role in (PromptMessageRole.SYSTEM, "system"):
return SystemPromptMessage.model_validate(msg_dict)
elif role in (PromptMessageRole.TOOL, "tool"):
return ToolPromptMessage.model_validate(msg_dict)
else:
return PromptMessage.model_validate(msg_dict)
def _deserialize_context(self, context_data: list[dict]) -> list[PromptMessage]:
"""Deserialize context data from outputs to list of PromptMessage."""
messages = []
for msg_dict in context_data:
try:
msg = self._deserialize_prompt_message(msg_dict)
msg = self._restore_multimodal_content(msg)
messages.append(msg)
except Exception as e:
logger.warning("Failed to deserialize prompt message: %s", e)
return messages
def _restore_multimodal_content(self, message: PromptMessage) -> PromptMessage:
"""
Restore multimodal content (base64 or url) from file_ref.
When context is saved, base64_data is cleared to save storage space.
This method restores the content by parsing file_ref (format: "method:id_or_url").
"""
content = message.content
if content is None or isinstance(content, str):
return message
# Process list content, restoring multimodal data from file references
restored_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
# restore_multimodal_content preserves the concrete subclass type
restored_item = file_manager.restore_multimodal_content(item)
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
else:
restored_content.append(item)
return message.model_copy(update={"content": restored_content})
def get_history_prompt_messages(
self,
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> Sequence[PromptMessage]:
"""
Retrieve history as PromptMessage sequence.
History is read directly from the last completed node execution's outputs["context"].
"""
_ = message_limit # unused, kept for interface compatibility
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
if not thread_workflow_run_ids:
return []
# Get the last completed workflow_run_id (contains accumulated context)
last_run_id = thread_workflow_run_ids[-1]
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == last_run_id,
WorkflowNodeExecutionModel.node_id == self.node_id,
WorkflowNodeExecutionModel.status == "succeeded",
)
execution = session.scalars(stmt).first()
if not execution:
return []
outputs = execution.outputs_dict
if not outputs:
return []
context_data = outputs.get("context")
if not context_data or not isinstance(context_data, list):
return []
prompt_messages = self._deserialize_context(context_data)
if not prompt_messages:
return []
# Truncate by token limit
try:
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
while current_tokens > max_token_limit and len(prompt_messages) > 1:
prompt_messages.pop(0)
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
except Exception as e:
logger.warning("Failed to count tokens for truncation: %s", e)
return prompt_messages

View File

@ -0,0 +1,11 @@
from .cli_api import CliApiSession, CliApiSessionManager
from .session import BaseSession, RedisSessionStorage, SessionManager, SessionStorage
__all__ = [
"BaseSession",
"CliApiSession",
"CliApiSessionManager",
"RedisSessionStorage",
"SessionManager",
"SessionStorage",
]

View File

@ -0,0 +1,30 @@
import secrets
from pydantic import BaseModel, Field
from configs import dify_config
from core.skill.entities import ToolAccessPolicy
from .session import BaseSession, SessionManager
class CliApiSession(BaseSession):
secret: str = Field(default_factory=lambda: secrets.token_urlsafe(32))
class CliContext(BaseModel):
tool_access: ToolAccessPolicy | None = Field(default=None, description="Tool access policy")
class CliApiSessionManager(SessionManager[CliApiSession]):
def __init__(self, ttl: int | None = None):
super().__init__(
key_prefix="cli_api_session",
session_class=CliApiSession,
ttl=ttl or dify_config.WORKFLOW_MAX_EXECUTION_TIME,
)
def create(self, tenant_id: str, user_id: str, context: CliContext) -> CliApiSession:
session = CliApiSession(tenant_id=tenant_id, user_id=user_id, context=context.model_dump(mode="json"))
self.save(session)
return session

106
api/core/session/session.py Normal file
View File

@ -0,0 +1,106 @@
import json
import logging
import uuid
from datetime import UTC, datetime
from typing import Any, Generic, Protocol, TypeVar
from pydantic import BaseModel, Field, ValidationError
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class SessionStorage(Protocol):
"""Session storage interface."""
def get(self, key: str) -> str | None: ...
def set(self, key: str, value: str, ttl: int) -> None: ...
def delete(self, key: str) -> bool: ...
def exists(self, key: str) -> bool: ...
def refresh_ttl(self, key: str, ttl: int) -> bool: ...
class RedisSessionStorage:
"""Redis storage implementation (default)."""
def get(self, key: str) -> str | None:
result = redis_client.get(key)
if result is None:
return None
return result.decode() if isinstance(result, bytes) else result
def set(self, key: str, value: str, ttl: int) -> None:
redis_client.setex(key, ttl, value)
def delete(self, key: str) -> bool:
return redis_client.delete(key) > 0
def exists(self, key: str) -> bool:
return redis_client.exists(key) > 0
def refresh_ttl(self, key: str, ttl: int) -> bool:
return bool(redis_client.expire(key, ttl))
class BaseSession(BaseModel):
"""Base session model."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
tenant_id: str
user_id: str
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
context: dict[str, Any] = Field(default_factory=dict)
def update_timestamp(self) -> None:
self.updated_at = datetime.now(UTC)
T = TypeVar("T", bound=BaseSession)
class SessionManager(Generic[T]):
"""Generic session manager."""
DEFAULT_TTL = 7200 # 2 hours
def __init__(
self,
key_prefix: str,
session_class: type[T],
storage: SessionStorage | None = None,
ttl: int | None = None,
):
self._key_prefix = key_prefix
self._session_class = session_class
self._storage = storage or RedisSessionStorage()
self._ttl = ttl or self.DEFAULT_TTL
def _get_key(self, session_id: str) -> str:
return f"{self._key_prefix}:{session_id}"
def save(self, session: T) -> None:
session.update_timestamp()
key = self._get_key(session.id)
self._storage.set(key, session.model_dump_json(), self._ttl)
def get(self, session_id: str) -> T | None:
key = self._get_key(session_id)
data = self._storage.get(key)
if data is None:
return None
try:
return self._session_class.model_validate(json.loads(data))
except (json.JSONDecodeError, ValidationError) as e:
logger.warning("Failed to deserialize session %s: %s", session_id, e)
return None
def delete(self, session_id: str) -> bool:
return self._storage.delete(self._get_key(session_id))
def exists(self, session_id: str) -> bool:
return self._storage.exists(self._get_key(session_id))
def refresh_ttl(self, session_id: str) -> bool:
return self._storage.refresh_ttl(self._get_key(session_id), self._ttl)

View File

@ -0,0 +1,187 @@
import base64
import hashlib
import logging
from collections.abc import Mapping
from typing import Any
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from pydantic import TypeAdapter
from configs import dify_config
logger = logging.getLogger(__name__)
class EncryptionError(Exception):
"""Encryption/decryption specific error"""
pass
class SystemEncrypter:
"""
A simple parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: str | None = None):
"""
Initialize the encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Raises:
ValueError: If SECRET_KEY is not configured or empty
"""
secret_key = secret_key or dify_config.SECRET_KEY or ""
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_params(self, params: Mapping[str, Any]) -> str:
"""
Encrypt parameters.
Args:
params: parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
EncryptionError: If encryption fails
ValueError: If params is invalid
"""
try:
# Generate random IV (16 bytes)
iv = get_random_bytes(16)
# Create AES cipher (CBC mode)
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
combined = iv + encrypted_data
# Return base64 encoded string
return base64.b64encode(combined).decode()
except Exception as e:
raise EncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted parameters dictionary
Raises:
EncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
raise ValueError("encrypted_data must be a string")
if not encrypted_data:
raise ValueError("encrypted_data cannot be empty")
try:
# Base64 decode
combined = base64.b64decode(encrypted_data)
# Check minimum length (IV + at least one AES block)
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
raise ValueError("Invalid encrypted data format")
# Separate IV and encrypted data
iv = combined[:16]
encrypted_data_bytes = combined[16:]
# Create AES cipher
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Decrypt data
decrypted_data = cipher.decrypt(encrypted_data_bytes)
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return params
except Exception as e:
raise EncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
"""
Create an encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemEncrypter instance
"""
return SystemEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_encrypter: SystemEncrypter | None = None
def get_system_encrypter() -> SystemEncrypter:
"""
Get the global encrypter instance.
Returns:
SystemEncrypter instance
"""
global _encrypter
if _encrypter is None:
_encrypter = SystemEncrypter()
return _encrypter
# Convenience functions for backward compatibility
def encrypt_system_params(params: Mapping[str, Any]) -> str:
"""
Encrypt parameters using the global encrypter.
Args:
params: parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_encrypter().encrypt_params(params)
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted parameters dictionary
"""
return get_system_encrypter().decrypt_params(encrypted_data)

View File

@ -0,0 +1,3 @@
from .node import CommandNode
__all__ = ["CommandNode"]

View File

@ -0,0 +1,10 @@
from graphon.entities.base_node_data import BaseNodeData
class CommandNodeData(BaseNodeData):
"""
Command Node Data.
"""
working_directory: str = "" # Working directory for command execution
command: str = "" # Command to execute

View File

@ -0,0 +1,16 @@
class CommandNodeError(ValueError):
"""Base class for command node errors."""
pass
class CommandExecutionError(CommandNodeError):
"""Raised when command execution fails."""
pass
class CommandTimeoutError(CommandNodeError):
"""Raised when command execution times out."""
pass

View File

@ -0,0 +1,152 @@
import logging
from collections.abc import Mapping, Sequence
from typing import Any
from core.sandbox import sandbox_debug
from core.sandbox.bash.session import SANDBOX_READY_TIMEOUT
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
from core.virtual_environment.__base.helpers import submit_command, with_connection
from core.workflow.nodes.command.entities import CommandNodeData
from core.workflow.nodes.command.exc import CommandExecutionError
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.node_events import NodeRunResult
from graphon.nodes.base import variable_template_parser
from graphon.nodes.base.entities import VariableSelector
from graphon.nodes.base.node import Node
from graphon.nodes.base.variable_template_parser import VariableTemplateParser
logger = logging.getLogger(__name__)
# FIXME(Mairuis): The timeout value is currently hardcoded and should be made configurable in the future.
COMMAND_NODE_TIMEOUT_SECONDS = 60 * 10
class CommandNode(Node[CommandNodeData]):
node_type = BuiltinNodeTypes.COMMAND
def _render_template(self, template: str) -> str:
parser = VariableTemplateParser(template=template)
selectors = parser.extract_variable_selectors()
if not selectors:
return template
inputs: dict[str, Any] = {}
for selector in selectors:
value = self.graph_runtime_state.variable_pool.get(selector.value_selector)
inputs[selector.variable] = value.to_object() if value is not None else None
return parser.format(inputs)
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
"type": "command",
"config": {
"working_directory": "",
"command": "",
},
}
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
sandbox = self.graph_runtime_state.sandbox
if sandbox is None:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="Sandbox not available for CommandNode.",
error_type="SandboxNotInitializedError",
)
working_directory = self._render_template((self.node_data.working_directory or "").strip())
raw_command = self._render_template(self.node_data.command or "")
working_directory = working_directory or None
if not raw_command:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="Command is required.",
error_type="CommandNodeError",
)
try:
sandbox.wait_ready(timeout=SANDBOX_READY_TIMEOUT)
with with_connection(sandbox.vm) as conn:
command = ["bash", "-c", raw_command]
sandbox_debug("command_node", "command", command)
future = submit_command(sandbox.vm, conn, command, cwd=working_directory)
result = future.result(timeout=COMMAND_NODE_TIMEOUT_SECONDS)
outputs: dict[str, Any] = {
"stdout": result.stdout.decode("utf-8", errors="replace"),
"stderr": result.stderr.decode("utf-8", errors="replace"),
"exit_code": result.exit_code,
"pid": result.pid,
}
process_data = {"command": command, "working_directory": working_directory}
sandbox_debug("command_node", "outputs", result.debug_message)
if result.exit_code not in (None, 0):
stderr_text = result.stderr.decode("utf-8", errors="replace")
error_message = f"{stderr_text}\n\nCommand exited with code {result.exit_code}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs=outputs,
process_data=process_data,
error=error_message,
error_type=CommandExecutionError.__name__,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
process_data=process_data,
)
except CommandTimeoutError:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Command timed out after {COMMAND_NODE_TIMEOUT_SECONDS}s",
error_type=CommandTimeoutError.__name__,
)
except CommandCancelledError:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="Command was cancelled",
error_type=CommandCancelledError.__name__,
)
except Exception as e:
logger.exception("Command node %s failed", self.id)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
error_type=type(e).__name__,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: CommandNodeData,
) -> Mapping[str, Sequence[str]]:
_ = graph_config
typed_node_data = node_data
selectors: list[VariableSelector] = []
selectors += list(variable_template_parser.extract_selectors_from_template(typed_node_data.command))
selectors += list(variable_template_parser.extract_selectors_from_template(typed_node_data.working_directory))
mapping: dict[str, Sequence[str]] = {}
for selector in selectors:
mapping[node_id + "." + selector.variable] = selector.value_selector
return mapping

View File

@ -0,0 +1,4 @@
from .entities import FileUploadNodeData
from .node import FileUploadNode
__all__ = ["FileUploadNode", "FileUploadNodeData"]

View File

@ -0,0 +1,7 @@
from collections.abc import Sequence
from graphon.entities.base_node_data import BaseNodeData
class FileUploadNodeData(BaseNodeData):
variable_selector: Sequence[str]

View File

@ -0,0 +1,6 @@
class FileUploadNodeError(ValueError):
"""Base exception for errors related to the FileUploadNode."""
class FileUploadDownloadError(FileUploadNodeError):
"""Exception raised when preparing file download in sandbox fails."""

View File

@ -0,0 +1,244 @@
import logging
import os
import posixpath
from collections.abc import Mapping, Sequence
from pathlib import PurePosixPath
from typing import Any, cast
from core.sandbox.bash.session import SANDBOX_READY_TIMEOUT
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
from core.virtual_environment.__base.helpers import pipeline
from core.zip_sandbox import SandboxDownloadItem
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from graphon.file import File, FileTransferMethod
from graphon.node_events import NodeRunResult
from graphon.nodes.base.node import Node
from graphon.variables import ArrayFileSegment
from graphon.variables.segments import ArrayStringSegment, FileSegment
from .entities import FileUploadNodeData
from .exc import FileUploadDownloadError, FileUploadNodeError
logger = logging.getLogger(__name__)
class FileUploadNode(Node[FileUploadNodeData]):
"""Upload workflow file variables into sandbox via presigned URLs.
The node intentionally avoids streaming file bytes through Dify workers. For local/tool
files, it generates storage-backed presigned URLs and lets sandbox download directly.
"""
node_type = BuiltinNodeTypes.FILE_UPLOAD
@classmethod
def version(cls) -> str:
return "1"
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
_ = filters
return {
"type": "file-upload",
"config": {
"variable_selector": [],
},
}
def _run(self) -> NodeRunResult:
sandbox = self.graph_runtime_state.sandbox
variable_selector = self.node_data.variable_selector
inputs: dict[str, Any] = {"variable_selector": variable_selector}
if sandbox is None:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="Sandbox not available for FileUploadNode.",
error_type="SandboxNotInitializedError",
inputs=inputs,
)
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"File variable not found for selector: {variable_selector}",
error_type=FileUploadNodeError.__name__,
inputs=inputs,
)
if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Variable {variable_selector} is not a file or file array",
error_type=FileUploadNodeError.__name__,
inputs=inputs,
)
files = self._normalize_files(variable.value)
process_data: dict[str, Any] = {
"file_count": len(files),
"files": [file.to_dict() for file in files],
}
if not files:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error="Selected file variable is empty.",
error_type=FileUploadNodeError.__name__,
inputs=inputs,
process_data=process_data,
)
try:
sandbox.wait_ready(timeout=SANDBOX_READY_TIMEOUT)
download_items: list[SandboxDownloadItem] = self._build_download_items(files)
sandbox_paths = self._upload(sandbox.vm, download_items)
file_names = [PurePosixPath(path).name for path in sandbox_paths]
process_data = {
**process_data,
"sandbox_paths": sandbox_paths,
"file_names": file_names,
}
outputs: dict[str, Any]
if len(sandbox_paths) == 1:
outputs = {
"sandbox_path": sandbox_paths[0],
"file_name": file_names[0],
}
else:
outputs = {
"sandbox_path": ArrayStringSegment(value=sandbox_paths),
"file_name": ArrayStringSegment(value=file_names),
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
except CommandTimeoutError:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="File upload timeout",
error_type=CommandTimeoutError.__name__,
inputs=inputs,
process_data=process_data,
)
except CommandCancelledError:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error="File upload command was cancelled",
error_type=CommandCancelledError.__name__,
inputs=inputs,
process_data=process_data,
)
except FileUploadNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
error_type=type(e).__name__,
inputs=inputs,
process_data=process_data,
)
except Exception as e:
logger.exception("File upload node %s failed", self.id)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
error_type=type(e).__name__,
inputs=inputs,
process_data=process_data,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: FileUploadNodeData,
) -> Mapping[str, Sequence[str]]:
_ = graph_config
typed_node_data = node_data
return {node_id + ".files": typed_node_data.variable_selector}
@staticmethod
def _normalize_files(value: Any) -> list[File]:
if isinstance(value, File):
return [value]
if isinstance(value, list):
list_value = cast(list[object], value)
files: list[File] = []
for idx in range(len(list_value)):
candidate = list_value[idx]
if not isinstance(candidate, File):
return []
files.append(candidate)
return files
return []
def _build_download_items(self, files: Sequence[File]) -> list[SandboxDownloadItem]:
used_paths: set[str] = set()
items: list[SandboxDownloadItem] = []
for index, file in enumerate(files):
file_url = self._get_download_url(file)
filename = (file.filename or "").strip()
if not filename or filename in {".", ".."}:
filename = f"file-{index + 1}{file.extension or ''}"
filename = os.path.basename(filename)
if filename in used_paths:
stem = PurePosixPath(filename).stem or f"file-{index + 1}"
suffix = PurePosixPath(filename).suffix
dedupe = 1
while filename in used_paths:
filename = f"{stem}_{dedupe}{suffix}"
dedupe += 1
used_paths.add(filename)
items.append(SandboxDownloadItem(path=filename, url=file_url))
return items
@staticmethod
def _normalize_path(path: str) -> str:
normalized = posixpath.normpath(path.strip()) if path else "."
if normalized.startswith("/"):
normalized = normalized.lstrip("/")
return normalized or "."
def _upload(self, vm: Any, items: list[SandboxDownloadItem]) -> list[str]:
p = pipeline(vm)
out_paths: list[str] = []
for item in items:
out_path = self._normalize_path(item.path)
if out_path in ("", "."):
raise FileUploadDownloadError("Download item path must point to a file")
out_paths.append(out_path)
p.add(["curl", "-fsSL", item.url, "-o", out_path], error_message="Failed to download file")
try:
p.execute(timeout=None, raise_on_error=True)
except Exception as exc:
raise FileUploadDownloadError(str(exc)) from exc
return out_paths
def _get_download_url(self, file: File) -> str:
if file.transfer_method == FileTransferMethod.REMOTE_URL:
if not file.remote_url:
raise FileUploadDownloadError("Remote file URL is missing")
return file.remote_url
if file.transfer_method in (
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.TOOL_FILE,
FileTransferMethod.DATASOURCE_FILE,
):
download_url = file.generate_url(for_external=True)
if not download_url:
raise FileUploadDownloadError("Unable to generate download URL for file")
return download_url
raise FileUploadDownloadError(f"Unsupported file transfer method: {file.transfer_method}")

View File

@ -0,0 +1,23 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from .entities import SandboxDownloadItem, SandboxFile, SandboxUploadItem
if TYPE_CHECKING:
from .zip_sandbox import ZipSandbox
__all__ = [
"SandboxDownloadItem",
"SandboxFile",
"SandboxUploadItem",
"ZipSandbox",
]
def __getattr__(name: str):
if name == "ZipSandbox":
from .zip_sandbox import ZipSandbox
return ZipSandbox
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1,81 @@
from __future__ import annotations
import posixpath
from typing import TYPE_CHECKING
from core.virtual_environment.__base.exec import CommandExecutionError
from core.virtual_environment.__base.helpers import execute, try_execute
from .strategy import ZipStrategy
if TYPE_CHECKING:
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
class CliZipStrategy(ZipStrategy):
"""Strategy using native zip/unzip CLI commands."""
def is_available(self, vm: VirtualEnvironment) -> bool:
result = try_execute(vm, ["which", "zip"], timeout=10)
has_zip = bool(result.stdout and result.stdout.strip())
result = try_execute(vm, ["which", "unzip"], timeout=10)
has_unzip = bool(result.stdout and result.stdout.strip())
return has_zip and has_unzip
def zip(
self,
vm: VirtualEnvironment,
*,
src: str,
out_path: str,
cwd: str | None,
timeout: float,
) -> None:
if src in (".", ""):
result = try_execute(vm, ["zip", "-qr", out_path, "."], timeout=timeout, cwd=cwd)
if not result.is_error:
return
# zip exits with 12 when there is nothing to do; create empty zip
if result.exit_code == 12:
self._write_empty_zip(vm, out_path)
return
raise CommandExecutionError("Failed to create zip archive", result)
zip_cwd = posixpath.dirname(src) or "."
target = posixpath.basename(src)
result = try_execute(vm, ["zip", "-qr", out_path, target], timeout=timeout, cwd=zip_cwd)
if not result.is_error:
return
if result.exit_code == 12:
self._write_empty_zip(vm, out_path)
return
raise CommandExecutionError("Failed to create zip archive", result)
def unzip(
self,
vm: VirtualEnvironment,
*,
archive_path: str,
dest_dir: str,
timeout: float,
) -> None:
execute(
vm,
["unzip", "-q", archive_path, "-d", dest_dir],
timeout=timeout,
error_message="Failed to unzip archive",
)
def _write_empty_zip(self, vm: VirtualEnvironment, out_path: str) -> None:
"""Write an empty but valid zip file."""
script = (
'printf "'
"\\x50\\x4b\\x05\\x06"
"\\x00\\x00\\x00\\x00"
"\\x00\\x00\\x00\\x00"
"\\x00\\x00\\x00\\x00"
"\\x00\\x00\\x00\\x00"
"\\x00\\x00\\x00\\x00"
'" > "$1"'
)
execute(vm, ["sh", "-c", script, "sh", out_path], timeout=30, error_message="Failed to write empty zip")

View File

@ -0,0 +1,39 @@
"""Data classes for ZipSandbox file operations.
Separated from ``zip_sandbox.py`` so that lightweight consumers (tests,
shell-script builders) can import the types without pulling in the full
sandbox provider chain.
"""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass(frozen=True)
class SandboxDownloadItem:
"""Unified download/inline item for sandbox file operations.
For remote files, *url* is set and the item is fetched via ``curl``.
For inline content, *content* is set and the bytes are written directly
into the VM via ``upload_file`` no network round-trip.
"""
path: str
url: str = ""
content: bytes | None = field(default=None, repr=False)
@dataclass(frozen=True)
class SandboxUploadItem:
"""Item for uploading: sandbox path -> URL."""
path: str
url: str
@dataclass(frozen=True)
class SandboxFile:
"""A handle to a file in the sandbox."""
path: str

View File

@ -0,0 +1,106 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from core.virtual_environment.__base.helpers import execute, try_execute
from .strategy import ZipStrategy
if TYPE_CHECKING:
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
ZIP_SCRIPT = r"""
const fs = require('fs');
const path = require('path');
const AdmZip = require('adm-zip');
const src = process.argv[2];
const outPath = process.argv[3];
function walkAdd(zip, absPath, arcPrefix) {
const stat = fs.statSync(absPath);
if (stat.isDirectory()) {
const entries = fs.readdirSync(absPath);
if (entries.length === 0) {
zip.addFile(arcPrefix.replace(/\\/g, '/') + '/', Buffer.alloc(0));
return;
}
for (const e of entries) {
walkAdd(zip, path.join(absPath, e), path.posix.join(arcPrefix, e));
}
return;
}
if (stat.isFile()) {
const data = fs.readFileSync(absPath);
zip.addFile(arcPrefix.replace(/\\/g, '/'), data);
}
}
const zip = new AdmZip();
if (src === '.' || src === '') {
const entries = fs.readdirSync('.');
for (const e of entries) {
walkAdd(zip, path.join('.', e), e);
}
} else {
const base = path.dirname(src) || '.';
const prefix = path.basename(src.replace(/\/+$/, ''));
const root = path.join(base, prefix);
walkAdd(zip, root, prefix);
}
zip.writeZip(outPath);
"""
UNZIP_SCRIPT = r"""
const AdmZip = require('adm-zip');
const archivePath = process.argv[2];
const destDir = process.argv[3];
const zip = new AdmZip(archivePath);
zip.extractAllTo(destDir, true);
"""
class NodeZipStrategy(ZipStrategy):
"""Strategy using Node.js with adm-zip package."""
def is_available(self, vm: VirtualEnvironment) -> bool:
result = try_execute(vm, ["which", "node"], timeout=10)
if not (result.stdout and result.stdout.strip()):
return False
# Check if adm-zip module is available
result = try_execute(vm, ["node", "-e", "require('adm-zip')"], timeout=10)
return not result.is_error
def zip(
self,
vm: VirtualEnvironment,
*,
src: str,
out_path: str,
cwd: str | None,
timeout: float,
) -> None:
execute(
vm,
["node", "-e", ZIP_SCRIPT, src, out_path],
timeout=timeout,
cwd=cwd,
error_message="Failed to create zip archive",
)
def unzip(
self,
vm: VirtualEnvironment,
*,
archive_path: str,
dest_dir: str,
timeout: float,
) -> None:
execute(
vm,
["node", "-e", UNZIP_SCRIPT, archive_path, dest_dir],
timeout=timeout,
error_message="Failed to unzip archive",
)

View File

@ -0,0 +1,117 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from core.virtual_environment.__base.helpers import execute, try_execute
from .strategy import ZipStrategy
if TYPE_CHECKING:
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
ZIP_SCRIPT = r"""
import os
import sys
import zipfile
src = sys.argv[1]
out_path = sys.argv[2]
def is_cwd(p: str) -> bool:
return p in (".", "")
src = src.rstrip("/")
if is_cwd(src):
base = "."
root = "."
prefix = ""
else:
base = os.path.dirname(src) or "."
prefix = os.path.basename(src)
root = os.path.join(base, prefix)
def add_empty_dir(zf: zipfile.ZipFile, arc_dir: str) -> None:
name = arc_dir.rstrip("/") + "/"
if name != "/":
zf.writestr(name, b"")
with zipfile.ZipFile(out_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
if os.path.isfile(root):
zf.write(root, arcname=os.path.basename(root))
else:
for dirpath, dirnames, filenames in os.walk(root):
rel_dir = os.path.relpath(dirpath, base)
rel_dir = "" if rel_dir == "." else rel_dir
if not dirnames and not filenames:
add_empty_dir(zf, rel_dir)
for fn in filenames:
fp = os.path.join(dirpath, fn)
arcname = os.path.join(rel_dir, fn) if rel_dir else fn
zf.write(fp, arcname=arcname)
"""
UNZIP_SCRIPT = r"""
import sys
import zipfile
archive_path = sys.argv[1]
dest_dir = sys.argv[2]
with zipfile.ZipFile(archive_path, "r") as zf:
zf.extractall(dest_dir)
"""
class PythonZipStrategy(ZipStrategy):
"""Strategy using Python's zipfile module."""
def __init__(self) -> None:
self._python_cmd: str | None = None
def is_available(self, vm: VirtualEnvironment) -> bool:
for cmd in ("python3", "python"):
result = try_execute(vm, ["which", cmd], timeout=10)
if result.stdout and result.stdout.strip():
self._python_cmd = cmd
return True
return False
def zip(
self,
vm: VirtualEnvironment,
*,
src: str,
out_path: str,
cwd: str | None,
timeout: float,
) -> None:
if self._python_cmd is None:
raise RuntimeError("Python not available")
execute(
vm,
[self._python_cmd, "-c", ZIP_SCRIPT, src, out_path],
timeout=timeout,
cwd=cwd,
error_message="Failed to create zip archive",
)
def unzip(
self,
vm: VirtualEnvironment,
*,
archive_path: str,
dest_dir: str,
timeout: float,
) -> None:
if self._python_cmd is None:
raise RuntimeError("Python not available")
execute(
vm,
[self._python_cmd, "-c", UNZIP_SCRIPT, archive_path, dest_dir],
timeout=timeout,
error_message="Failed to unzip archive",
)

View File

@ -0,0 +1,41 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
class ZipStrategy(ABC):
"""Abstract base class for zip/unzip strategies."""
@abstractmethod
def is_available(self, vm: VirtualEnvironment) -> bool:
"""Check if this strategy is available in the given VM."""
...
@abstractmethod
def zip(
self,
vm: VirtualEnvironment,
*,
src: str,
out_path: str,
cwd: str | None,
timeout: float,
) -> None:
"""Create a zip archive."""
...
@abstractmethod
def unzip(
self,
vm: VirtualEnvironment,
*,
archive_path: str,
dest_dir: str,
timeout: float,
) -> None:
"""Extract a zip archive."""
...

View File

@ -0,0 +1,425 @@
from __future__ import annotations
import base64
import posixpath
import shlex
from io import BytesIO
from pathlib import PurePosixPath
from types import TracebackType
from typing import Any
from urllib.parse import urlparse
from uuid import uuid4
from core.sandbox.builder import SandboxBuilder
from core.sandbox.entities.sandbox_type import SandboxType
from core.sandbox.sandbox import Sandbox
from core.sandbox.storage.noop_storage import NoopSandboxStorage
from core.virtual_environment.__base.exec import CommandExecutionError, PipelineExecutionError
from core.virtual_environment.__base.helpers import execute, pipeline
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
from services.sandbox.sandbox_provider_service import SandboxProviderService
from .cli_strategy import CliZipStrategy
from .entities import SandboxDownloadItem, SandboxFile, SandboxUploadItem
from .node_strategy import NodeZipStrategy
from .python_strategy import PythonZipStrategy
from .strategy import ZipStrategy
class ZipSandbox:
"""A sandbox for archive (zip) operations.
Usage:
with ZipSandbox(tenant_id=..., user_id=...) as zs:
zs.download_items(items)
archive = zs.zip()
zs.upload(archive, upload_url)
# VM automatically released on exit
"""
_DEFAULT_TIMEOUT_SECONDS = 60 * 5
_STRATEGIES: list[ZipStrategy] = [CliZipStrategy(), PythonZipStrategy(), NodeZipStrategy()]
def __init__(
self,
*,
tenant_id: str | None = None,
user_id: str | None = None,
app_id: str = "zip-sandbox",
sandbox_provider_type: str | None = None,
sandbox_provider_options: dict[str, Any] | None = None,
_vm: VirtualEnvironment | None = None,
) -> None:
self._tenant_id = tenant_id
self._user_id = user_id
self._app_id = app_id
self._sandbox_provider_type = sandbox_provider_type
self._sandbox_provider_options = sandbox_provider_options
self._injected_vm = _vm
self._sandbox: Sandbox | None = None
self._sandbox_id: str | None = None
self._vm: VirtualEnvironment | None = None
self._strategy: ZipStrategy | None = None
def __enter__(self) -> ZipSandbox:
self._start()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None:
self._stop()
def _start(self) -> None:
if self._vm is not None:
raise RuntimeError("ZipSandbox already started")
if self._injected_vm is not None:
self._vm = self._injected_vm
self._sandbox_id = uuid4().hex
return
if not self._tenant_id:
raise ValueError("tenant_id is required")
if not self._user_id:
raise ValueError("user_id is required")
if self._sandbox_provider_type is None or self._sandbox_provider_options is None:
provider = SandboxProviderService.get_sandbox_provider(self._tenant_id)
provider_type = provider.provider_type
provider_options = dict(provider.config)
else:
provider_type = self._sandbox_provider_type
provider_options = dict(self._sandbox_provider_options)
self._sandbox_id = uuid4().hex
storage = NoopSandboxStorage()
try:
self._sandbox = (
SandboxBuilder(self._tenant_id, SandboxType(provider_type))
.options(provider_options)
.user(self._user_id)
.app(self._app_id)
.storage(storage, assets_id="zip-sandbox")
.build()
)
self._sandbox.wait_ready(timeout=60)
self._vm = self._sandbox.vm
except Exception:
if self._sandbox is not None:
self._sandbox.release()
self._vm = None
self._sandbox = None
self._sandbox_id = None
raise
def _stop(self) -> None:
if self._vm is None:
return
if self._sandbox is not None:
self._sandbox.release()
self._vm = None
self._sandbox = None
self._sandbox_id = None
self._strategy = None
@property
def vm(self) -> VirtualEnvironment:
if self._vm is None:
raise RuntimeError("ZipSandbox not started. Use 'with ZipSandbox(...) as zs:'")
return self._vm
def _get_strategy(self) -> ZipStrategy:
if self._strategy is not None:
return self._strategy
for strategy in self._STRATEGIES:
if strategy.is_available(self.vm):
self._strategy = strategy
return strategy
raise RuntimeError("No available zip backend (zip/python/node+adm-zip)")
# ========== Path utilities ==========
@staticmethod
def _normalize_path(path: str | None) -> str:
raw = (path or ".").strip()
if raw == "":
raw = "."
p = PurePosixPath(raw)
if p.is_absolute():
raise ValueError("path must be relative")
if any(part == ".." for part in p.parts):
raise ValueError("path must not contain '..'")
normalized = str(p)
return "." if normalized in (".", "") else normalized
@staticmethod
def _dest_path_for_url(dest_dir: str, url: str) -> str:
parsed = urlparse(url)
path = parsed.path or ""
name = posixpath.basename(path)
if not name:
name = "download.bin"
return posixpath.join(dest_dir, name)
# ========== File operations ==========
def write_file(self, path: str, data: bytes) -> None:
path = self._normalize_path(path)
if path in ("", "."):
raise ValueError("path must point to a file")
try:
self.vm.upload_file(path, BytesIO(data))
except Exception as exc:
raise RuntimeError(f"Failed to write file to sandbox: {exc}") from exc
def read_file(self, path: str, *, max_bytes: int = 10 * 1024 * 1024) -> bytes:
path = self._normalize_path(path)
if max_bytes <= 0:
raise ValueError("max_bytes must be positive")
try:
data = self.vm.download_file(path).getvalue()
except Exception as exc:
raise RuntimeError(f"Failed to read file from sandbox: {exc}") from exc
if len(data) > max_bytes:
raise ValueError(f"File too large: {len(data)} > {max_bytes}")
return data
# ========== Download operations ==========
def download_items(self, items: list[SandboxDownloadItem], *, dest_dir: str = ".") -> list[str]:
"""Download or write items into the sandbox via a single pipeline.
Remote items (with *url*) are fetched via ``curl``. Inline items
(with *content*) are written via ``base64 -d`` heredoc. Both go
through the same pipeline no branching at the structural level.
"""
if not items:
return []
dest_dir = self._normalize_path(dest_dir)
p = pipeline(self.vm)
p.add(["mkdir", "-p", dest_dir], error_message="Failed to create download directory")
out_paths: list[str] = []
for item in items:
rel = self._normalize_path(item.path)
if rel in ("", "."):
raise ValueError("Download item path must point to a file")
out_path = posixpath.join(dest_dir, rel)
out_paths.append(out_path)
out_dir = posixpath.dirname(out_path)
if out_dir not in ("", "."):
p.add(["mkdir", "-p", out_dir], error_message="Failed to create download directory")
p.add(
self.to_download_command(item, out_path),
error_message=f"Failed to write {item.path}",
)
try:
p.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
except Exception as exc:
raise RuntimeError(str(exc)) from exc
return out_paths
@staticmethod
def to_download_command(item: SandboxDownloadItem, out_path: str) -> list[str]:
"""Return the shell command to materialise *item* at *out_path*."""
if item.content is not None:
encoded = base64.b64encode(item.content).decode("ascii")
return ["sh", "-c", f"base64 -d <<'_B64_' > {shlex.quote(out_path)}\n{encoded}\n_B64_"]
return ["curl", "-fsSL", item.url, "-o", out_path]
def download_archive(self, archive_url: str, *, path: str = "input.tar.gz") -> str:
path = self._normalize_path(path)
dir_path = posixpath.dirname(path)
p = pipeline(self.vm)
if dir_path not in ("", "."):
p.add(["mkdir", "-p", dir_path], error_message=f"Failed to create directory {dir_path}")
p.add(["curl", "-fsSL", archive_url, "-o", path], error_message=f"Failed to download archive to {path}")
try:
p.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
except Exception as exc:
raise RuntimeError(str(exc)) from exc
return path
# ========== Upload operations ==========
def upload(self, file: SandboxFile, target_url: str) -> None:
"""Upload a sandbox file to the given URL."""
try:
execute(
self.vm,
["curl", "-fsSL", "-X", "PUT", "-T", file.path, target_url],
timeout=self._DEFAULT_TIMEOUT_SECONDS,
error_message="Failed to upload file from sandbox",
)
except CommandExecutionError as exc:
raise RuntimeError(str(exc)) from exc
def upload_items(self, items: list[SandboxUploadItem], *, src_dir: str = ".") -> None:
"""Upload multiple files from sandbox to target URLs.
Args:
items: List of SandboxUploadItem(path, url)
src_dir: Base directory containing the files
"""
if not items:
return
src_dir = self._normalize_path(src_dir)
p = pipeline(self.vm)
for item in items:
rel = self._normalize_path(item.path)
src_path = posixpath.join(src_dir, rel) if src_dir not in ("", ".") else rel
p.add(
["curl", "-fsSL", "-X", "PUT", "-T", src_path, item.url],
error_message=f"Failed to upload {item.path}",
)
try:
p.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
except Exception as exc:
raise RuntimeError(str(exc)) from exc
# ========== Archive operations ==========
def zip(self, src: str = ".", *, include_base: bool = True) -> SandboxFile:
"""Create a zip archive and return a handle to it."""
src = self._normalize_path(src)
out_path = f"/tmp/{uuid4().hex}.zip"
cwd = None
src_for_strategy = src
if src not in (".", "") and not include_base:
cwd = src
src_for_strategy = "."
try:
self._get_strategy().zip(
self.vm,
src=src_for_strategy,
out_path=out_path,
cwd=cwd,
timeout=self._DEFAULT_TIMEOUT_SECONDS,
)
except (PipelineExecutionError, CommandExecutionError) as exc:
raise RuntimeError(str(exc)) from exc
return SandboxFile(path=out_path)
def unzip(self, *, archive_path: str, dest_dir: str = "unpacked") -> str:
"""Extract a zip archive to the destination directory."""
archive_path = self._normalize_path(archive_path)
dest_dir = self._normalize_path(dest_dir)
if not archive_path.lower().endswith(".zip"):
raise ValueError("archive_path must end with .zip")
try:
pipeline(self.vm).add(
["mkdir", "-p", dest_dir], error_message="Failed to create destination directory"
).execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
self._get_strategy().unzip(
self.vm,
archive_path=archive_path,
dest_dir=dest_dir,
timeout=self._DEFAULT_TIMEOUT_SECONDS,
)
except (PipelineExecutionError, CommandExecutionError) as exc:
raise RuntimeError(str(exc)) from exc
return dest_dir
def untar(self, *, archive_path: str, dest_dir: str = "unpacked") -> str:
"""Extract a tar archive to the destination directory."""
archive_path = self._normalize_path(archive_path)
dest_dir = self._normalize_path(dest_dir)
lower = archive_path.lower()
is_gz = lower.endswith(".tar.gz") or lower.endswith(".tgz")
extract_flag = "-xzf" if is_gz else "-xf"
try:
(
pipeline(self.vm)
.add(["mkdir", "-p", dest_dir], error_message="Failed to create destination directory")
.add(
["sh", "-c", f'tar {extract_flag} "$1" -C "$2" 2>/dev/null; exit $?', "sh", archive_path, dest_dir],
error_message="Failed to extract tar archive",
)
.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
)
except PipelineExecutionError as exc:
raise RuntimeError(str(exc)) from exc
return dest_dir
def tar(self, src: str = ".", *, include_base: bool = True, compress: bool = True) -> SandboxFile:
"""Create a tar archive and return a handle to it.
Args:
src: Source path to archive (file or directory)
include_base: If True, include the base directory name in the archive
compress: If True, create a gzipped tar archive (.tar.gz)
Returns:
SandboxFile handle to the created archive
"""
src = self._normalize_path(src)
extension = ".tar.gz" if compress else ".tar"
out_path = f"/tmp/{uuid4().hex}{extension}"
create_flag = "-czf" if compress else "-cf"
try:
if src in (".", ""):
# Archive current directory contents
execute(
self.vm,
["tar", create_flag, out_path, "-C", ".", "."],
timeout=self._DEFAULT_TIMEOUT_SECONDS,
error_message="Failed to create tar archive",
)
elif include_base:
# Archive with base directory name included
parent_dir = posixpath.dirname(src) or "."
base_name = posixpath.basename(src)
execute(
self.vm,
["tar", create_flag, out_path, "-C", parent_dir, base_name],
timeout=self._DEFAULT_TIMEOUT_SECONDS,
error_message="Failed to create tar archive",
)
else:
# Archive contents without base directory name
execute(
self.vm,
["tar", create_flag, out_path, "-C", src, "."],
timeout=self._DEFAULT_TIMEOUT_SECONDS,
error_message="Failed to create tar archive",
)
except CommandExecutionError as exc:
raise RuntimeError(str(exc)) from exc
return SandboxFile(path=out_path)

View File

@ -0,0 +1,41 @@
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from graphon.file import File
class ToolResultStatus(StrEnum):
SUCCESS = "success"
ERROR = "error"
class ToolCall(BaseModel):
id: str | None = Field(default=None, description="Unique identifier for this tool call")
name: str | None = Field(default=None, description="Name of the tool being called")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
icon: str | dict | None = Field(default=None, description="Icon of the tool")
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
class ToolResult(BaseModel):
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
name: str | None = Field(default=None, description="Name of the tool")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[str] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
icon: str | dict[str, Any] | None = Field(default=None, description="Icon of the tool")
icon_dark: str | dict[str, Any] | None = Field(default=None, description="Dark theme icon of the tool")
provider: str | None = Field(default=None, description="Tool provider identifier")
class ToolCallResult(BaseModel):
id: str | None = Field(default=None, description="Identifier for the tool call")
name: str | None = Field(default=None, description="Name of the tool")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[File] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")

View File

@ -0,0 +1,929 @@
from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from packaging.version import Version
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.base import BaseMemory
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.prompt.entities.advanced_prompt_entities import MemoryMode
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import (
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.agent.exceptions import (
AgentInputTypeError,
AgentInvocationError,
AgentMessageTransformError,
AgentNodeError,
AgentVariableNotFoundError,
AgentVariableTypeError,
ToolFileNotFoundError,
)
from graphon.enums import (
BuiltinNodeTypes,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from graphon.file import File, FileTransferMethod
from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
from graphon.model_runtime.utils.encoders import jsonable_encoder
from graphon.node_events import (
AgentLogEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
StreamCompletedEvent,
)
from graphon.nodes.base.node import Node
from graphon.nodes.base.variable_template_parser import VariableTemplateParser
from graphon.runtime import VariablePool
from graphon.variables.segments import ArrayFileSegment, StringSegment
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
if TYPE_CHECKING:
from core.agent.strategy.plugin import PluginAgentStrategy
from core.plugin.entities.request import InvokeCredentials
class AgentNode(Node[AgentNodeData]):
"""
Agent Node
"""
node_type = BuiltinNodeTypes.AGENT
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
dify_ctx = self.require_dify_context()
try:
strategy = get_plugin_agent_strategy(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
agent_strategy_name=self.node_data.agent_strategy_name,
)
except Exception as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
),
)
return
agent_parameters = strategy.get_parameters()
# get parameters
parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
strategy=strategy,
)
parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
for_log=True,
strategy=strategy,
)
credentials = self._generate_credentials(parameters=parameters)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = strategy.invoke(
params=parameters,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(error),
)
)
return
# Fetch memory for node memory saving
memory = self._fetch_memory_for_save()
try:
yield from self._transform_message(
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": self.node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
memory=memory,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(transform_error),
)
)
def _generate_agent_parameters(
self,
*,
agent_parameters: Sequence[AgentStrategyParameter],
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
strategy: PluginAgentStrategy,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
agent_parameters (Sequence[AgentParameter]): The list of agent parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (AgentNodeData): The data associated with the agent node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
case "mixed" | "constant":
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
case _:
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_type=provider_type,
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
# This is an issue that caused problems before.
# Logically, we shouldn't use the node_data.version field for judgment
# But for backward compatibility with historical data
# this version field judgment is still preserved here.
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
runtime_variable_pool = variable_pool
dify_ctx = self.require_dify_context()
tool_runtime = ToolManager.get_agent_tool_runtime(
dify_ctx.tenant_id,
dify_ctx.app_id,
entity,
dify_ctx.invoke_from,
runtime_variable_pool,
)
if tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
ToolParameter.ToolParameterForm.FORM
if tool_runtime_params.name in manual_input_params
else tool_runtime_params.form
)
manual_input_value = {}
if tool_runtime.entity.parameters:
manual_input_value = {
key: value for key, value in parameters.items() if key in manual_input_params
}
runtime_parameters = {
**tool_runtime.runtime.runtime_parameters,
**manual_input_value,
}
tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self._fetch_model(value)
# memory config
history_prompt_messages = []
if node_data.memory:
memory = self._fetch_memory(model_instance)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size or None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
value["history_prompt_messages"] = history_prompt_messages
if model_schema:
# remove structured output feature to support old version agent plugin
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
value["entity"] = model_schema.model_dump(mode="json")
else:
value["entity"] = None
result[parameter_name] = value
return result
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
from core.plugin.entities.request import InvokeCredentials
credentials = InvokeCredentials()
# generate credentials for tools selector
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
if tool.get("credential_id"):
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
except ValidationError:
continue
return credentials
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AgentNodeData,
) -> Mapping[str, Sequence[str]]:
typed_node_data = node_data
result: dict[str, Any] = {}
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
match input.type:
case "mixed" | "constant":
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}
return result
@property
def agent_strategy_icon(self) -> str | None:
"""
Get agent strategy icon
:return:
"""
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
dify_ctx = self.require_dify_context()
plugins = manager.list_plugins(dify_ctx.tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> BaseMemory | None:
"""
Fetch memory based on configuration mode.
Returns TokenBufferMemory for conversation mode (default),
or NodeTokenBufferMemory for node mode (Chatflow only).
"""
node_data = self.node_data
memory_config = node_data.memory
if not memory_config:
return None
# get conversation id (required for both modes in Chatflow)
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
dify_ctx = self.require_dify_context()
if memory_config.mode == MemoryMode.NODE:
return NodeTokenBufferMemory(
app_id=dify_ctx.app_id,
conversation_id=conversation_id,
node_id=self._node_id,
tenant_id=dify_ctx.tenant_id,
model_instance=model_instance,
)
else:
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
dify_ctx = self.require_dify_context()
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model=model_name
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=dify_ctx.tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
if model_schema.features:
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
try:
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
except ValueError:
model_schema.features.remove(feature)
return model_schema
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy
:param tool: tool
:return: filtered tool dict
"""
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _fetch_memory_for_save(self) -> BaseMemory | None:
"""
Fetch memory instance for saving node memory.
This is a simplified version that doesn't require model_instance.
"""
from core.model_manager import ModelManager
from graphon.model_runtime.entities.model_entities import ModelType
node_data = self.node_data
if not node_data.memory:
return None
# Get conversation_id
conversation_id_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_var, StringSegment):
return None
conversation_id = conversation_id_var.value
# Return appropriate memory type based on mode
if node_data.memory.mode == MemoryMode.NODE:
# For node memory, we need a model_instance for token counting
# Use a simple default model for this purpose
try:
model_instance = ModelManager().get_default_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
)
except Exception:
return None
return NodeTokenBufferMemory(
app_id=self.app_id,
conversation_id=conversation_id,
node_id=self._node_id,
tenant_id=self.tenant_id,
model_instance=model_instance,
)
else:
# Conversation-level memory doesn't need saving here
return None
def _build_context(
self,
parameters_for_log: dict[str, Any],
user_query: str,
assistant_response: str,
agent_logs: list[AgentLogEvent],
) -> list[PromptMessage]:
"""
Build context from user query, tool calls, and assistant response.
Format: user -> assistant(with tool_calls) -> tool -> assistant
The context includes:
- Current user query (always present, may be empty)
- Assistant message with tool_calls (if tools were called)
- Tool results
- Assistant's final response
"""
context_messages: list[PromptMessage] = []
# Always add user query (even if empty, to maintain conversation structure)
context_messages.append(UserPromptMessage(content=user_query or ""))
# Extract actual tool calls from agent logs
# Only include logs with label starting with "CALL " - these are real tool invocations
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_results: list[tuple[str, str, str]] = [] # (tool_call_id, tool_name, result)
for log in agent_logs:
if log.status == "success" and log.label and log.label.startswith("CALL "):
# Extract tool name from label (format: "CALL tool_name")
tool_name = log.label[5:] # Remove "CALL " prefix
tool_call_id = log.message_id
# Parse tool response from data
data = log.data or {}
tool_response = ""
# Try to extract the actual tool response
if "tool_response" in data:
tool_response = data["tool_response"]
elif "output" in data:
tool_response = data["output"]
elif "result" in data:
tool_response = data["result"]
if isinstance(tool_response, dict):
tool_response = str(tool_response)
# Get tool input for arguments
tool_input = data.get("tool_call_input", {}) or data.get("input", {})
if isinstance(tool_input, dict):
import json
tool_input_str = json.dumps(tool_input, ensure_ascii=False)
else:
tool_input_str = str(tool_input) if tool_input else ""
if tool_response:
tool_calls.append(
AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_name,
arguments=tool_input_str,
),
)
)
tool_results.append((tool_call_id, tool_name, str(tool_response)))
# Add assistant message with tool_calls if there were tool calls
if tool_calls:
context_messages.append(AssistantPromptMessage(content="", tool_calls=tool_calls))
# Add tool result messages
for tool_call_id, tool_name, result in tool_results:
context_messages.append(
ToolPromptMessage(
content=result,
tool_call_id=tool_call_id,
name=tool_name,
)
)
# Add final assistant response
context_messages.append(AssistantPromptMessage(content=assistant_response))
return context_messages
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_type: NodeType,
node_id: str,
node_execution_id: str,
memory: BaseMemory | None = None,
) -> Generator[NodeEventBase, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json_list: list[dict | list] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage = LLMUsage.empty_usage()
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == BuiltinNodeTypes.AGENT:
if isinstance(message.message.json_object, dict):
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
else:
msg_metadata = {}
llm_usage = LLMUsage.empty_usage()
agent_execution_metadata = {}
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise AgentVariableTypeError(
"When 'stream' is True, 'variable_value' must be a string.",
variable_name=variable_name,
expected_type="str",
actual_type=type(variable_value).__name__,
)
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
# Validate that meta contains a 'file' key
if "file" not in message.meta:
raise AgentNodeError("File message is missing 'file' key in meta")
# Validate that the file is an instance of File
if not isinstance(message.meta["file"], File):
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
message_id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.message_id == agent_log.message_id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any] | list[Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json_list:
json_output.extend(json_list)
else:
json_output.append({"data": []})
# Send final chunk events for all streamed outputs
# Final chunk for text stream
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk="",
is_final=True,
)
# Final chunks for any streamed variables
for var_name in variables:
yield StreamChunkEvent(
selector=[node_id, var_name],
chunk="",
is_final=True,
)
# Get user query from parameters for building context
user_query = parameters_for_log.get("query", "")
# Build context from history, user query, tool calls and assistant response
context = self._build_context(parameters_for_log, user_query, text, agent_logs)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"text": text,
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": json_output,
"context": context,
**variables,
},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

View File

@ -90,9 +90,9 @@ def init_app(app: DifyApp):
app.register_blueprint(inner_api_bp)
app.register_blueprint(mcp_bp)
# TODO: enable after full sandbox integration
# from controllers.cli_api import bp as cli_api_bp
# app.register_blueprint(cli_api_bp)
from controllers.cli_api import bp as cli_api_bp
app.register_blueprint(cli_api_bp)
# Register trigger blueprint with CORS for webhook calls
_apply_cors_once(

View File

@ -0,0 +1,5 @@
import socketio # type: ignore[reportMissingTypeStubs]
from configs import dify_config
sio = socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS)

View File

@ -0,0 +1,17 @@
from flask_restx import fields
online_user_partial_fields = {
"user_id": fields.String,
"username": fields.String,
"avatar": fields.String,
"sid": fields.String,
}
workflow_online_users_fields = {
"workflow_id": fields.String,
"users": fields.List(fields.Nested(online_user_partial_fields)),
}
online_user_list_fields = {
"data": fields.List(fields.Nested(workflow_online_users_fields)),
}

View File

@ -0,0 +1,96 @@
from flask_restx import fields
from libs.helper import AvatarUrlField, TimestampField
# basic account fields for comments
account_fields = {
"id": fields.String,
"name": fields.String,
"email": fields.String,
"avatar_url": AvatarUrlField,
}
# Comment mention fields
workflow_comment_mention_fields = {
"mentioned_user_id": fields.String,
"mentioned_user_account": fields.Nested(account_fields, allow_null=True),
"reply_id": fields.String,
}
# Comment reply fields
workflow_comment_reply_fields = {
"id": fields.String,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
}
# Basic comment fields (for list views)
workflow_comment_basic_fields = {
"id": fields.String,
"position_x": fields.Float,
"position_y": fields.Float,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
"updated_at": TimestampField,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
"reply_count": fields.Integer,
"mention_count": fields.Integer,
"participants": fields.List(fields.Nested(account_fields)),
}
# Detailed comment fields (for single comment view)
workflow_comment_detail_fields = {
"id": fields.String,
"position_x": fields.Float,
"position_y": fields.Float,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
"updated_at": TimestampField,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
"replies": fields.List(fields.Nested(workflow_comment_reply_fields)),
"mentions": fields.List(fields.Nested(workflow_comment_mention_fields)),
}
# Comment creation response fields (simplified)
workflow_comment_create_fields = {
"id": fields.String,
"created_at": TimestampField,
}
# Comment update response fields (simplified)
workflow_comment_update_fields = {
"id": fields.String,
"updated_at": TimestampField,
}
# Comment resolve response fields
workflow_comment_resolve_fields = {
"id": fields.String,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
}
# Reply creation response fields (simplified)
workflow_comment_reply_create_fields = {
"id": fields.String,
"created_at": TimestampField,
}
# Reply update response fields
workflow_comment_reply_update_fields = {
"id": fields.String,
"updated_at": TimestampField,
}

View File

@ -0,0 +1,143 @@
"""Add sandbox providers, app assets, and LLM detail tables.
Revision ID: aab323465866
Revises: f55813ffe2c8
Create Date: 2026-02-09 10:31:05.062722
"""
import os
from uuid import uuid4
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = "aab323465866"
down_revision = "c3df22613c99"
branch_labels = None
depends_on = None
def _get_ssh_config_from_env() -> dict[str, str]:
"""Build SSH sandbox config from environment variables.
Defaults are chosen so that:
- All-in-one Docker Compose (api inside the network): agentbox:22
- Middleware / local dev (api on the host): 127.0.0.1:2222
The env vars (SSH_SANDBOX_*) are documented in api/.env.example.
"""
return {
"ssh_host": os.environ.get("SSH_SANDBOX_HOST", "agentbox"),
"ssh_port": os.environ.get("SSH_SANDBOX_PORT", "22"),
"ssh_username": os.environ.get("SSH_SANDBOX_USERNAME", "agentbox"),
"ssh_password": os.environ.get("SSH_SANDBOX_PASSWORD", "agentbox"),
"base_working_path": os.environ.get("SSH_SANDBOX_BASE_WORKING_PATH", "/workspace/sandboxes"),
}
def upgrade():
from core.tools.utils.system_encryption import encrypt_system_params
op.create_table(
"sandbox_provider_system_config",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("provider_type", sa.String(length=50), nullable=False, comment="e2b, docker, local, ssh"),
sa.Column("encrypted_config", models.types.LongText(), nullable=False, comment="Encrypted config JSON"),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.PrimaryKeyConstraint("id", name="sandbox_provider_system_config_pkey"),
sa.UniqueConstraint("provider_type", name="unique_sandbox_provider_system_config_type"),
)
op.create_table(
"sandbox_providers",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("provider_type", sa.String(length=50), nullable=False, comment="e2b, docker, local, ssh"),
sa.Column("configure_type", sa.String(length=20), server_default="user", nullable=False),
sa.Column("encrypted_config", models.types.LongText(), nullable=False, comment="Encrypted config JSON"),
sa.Column("is_active", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.PrimaryKeyConstraint("id", name="sandbox_provider_pkey"),
sa.UniqueConstraint("tenant_id", "provider_type", "configure_type", name="unique_sandbox_provider_tenant_type"),
)
with op.batch_alter_table("sandbox_providers", schema=None) as batch_op:
batch_op.create_index("idx_sandbox_providers_tenant_active", ["tenant_id", "is_active"], unique=False)
batch_op.create_index("idx_sandbox_providers_tenant_id", ["tenant_id"], unique=False)
op.create_table(
"llm_generation_details",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("app_id", models.types.StringUUID(), nullable=False),
sa.Column("message_id", models.types.StringUUID(), nullable=True),
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True),
sa.Column("node_id", sa.String(length=255), nullable=True),
sa.Column("reasoning_content", models.types.LongText(), nullable=True),
sa.Column("tool_calls", models.types.LongText(), nullable=True),
sa.Column("sequence", models.types.LongText(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.CheckConstraint(
"(message_id IS NOT NULL AND workflow_run_id IS NULL AND node_id IS NULL) OR (message_id IS NULL AND workflow_run_id IS NOT NULL AND node_id IS NOT NULL)",
name=op.f("llm_generation_details_ck_llm_generation_detail_assoc_mode_check"),
),
sa.PrimaryKeyConstraint("id", name="llm_generation_detail_pkey"),
sa.UniqueConstraint("message_id", name=op.f("llm_generation_details_message_id_key")),
)
with op.batch_alter_table("llm_generation_details", schema=None) as batch_op:
batch_op.create_index("idx_llm_generation_detail_message", ["message_id"], unique=False)
batch_op.create_index("idx_llm_generation_detail_workflow", ["workflow_run_id", "node_id"], unique=False)
op.create_table(
"app_assets",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("app_id", models.types.StringUUID(), nullable=False),
sa.Column("version", sa.String(length=255), nullable=False),
sa.Column("asset_tree", models.types.LongText(), nullable=False),
sa.Column("created_by", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.PrimaryKeyConstraint("id", name="app_assets_pkey"),
)
with op.batch_alter_table("app_assets", schema=None) as batch_op:
batch_op.create_index("app_assets_version_idx", ["tenant_id", "app_id", "version"], unique=False)
# Only seed a default SSH system provider for self-hosted deployments.
# CLOUD editions manage sandbox providers through admin tooling.
edition = os.environ.get("EDITION", "SELF_HOSTED")
if edition == "SELF_HOSTED":
ssh_config = _get_ssh_config_from_env()
encrypted_config = encrypt_system_params(ssh_config)
op.execute(
sa.text(
"""
INSERT INTO sandbox_provider_system_config
(id, provider_type, encrypted_config, created_at, updated_at)
VALUES (:id, :provider_type, :encrypted_config, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (provider_type) DO NOTHING
"""
).bindparams(
id=str(uuid4()),
provider_type="ssh",
encrypted_config=encrypted_config,
)
)
def downgrade():
op.drop_table("app_assets")
op.drop_table("llm_generation_details")
with op.batch_alter_table("sandbox_providers", schema=None) as batch_op:
batch_op.drop_index("idx_sandbox_providers_tenant_id")
batch_op.drop_index("idx_sandbox_providers_tenant_active")
op.drop_table("sandbox_providers")
op.drop_table("sandbox_provider_system_config")

View File

@ -0,0 +1,109 @@
"""Add workflow comments table
Revision ID: 227822d22895
Revises: aab323465866
Create Date: 2026-02-09 17:26:15.255980
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "227822d22895"
down_revision = "aab323465866"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"workflow_comments",
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("app_id", models.types.StringUUID(), nullable=False),
sa.Column("position_x", sa.Float(), nullable=False),
sa.Column("position_y", sa.Float(), nullable=False),
sa.Column("content", sa.Text(), nullable=False),
sa.Column("created_by", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("resolved", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("resolved_at", sa.DateTime(), nullable=True),
sa.Column("resolved_by", models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
)
with op.batch_alter_table("workflow_comments", schema=None) as batch_op:
batch_op.create_index("workflow_comments_app_idx", ["tenant_id", "app_id"], unique=False)
batch_op.create_index("workflow_comments_created_at_idx", ["created_at"], unique=False)
op.create_table(
"workflow_comment_replies",
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
sa.Column("comment_id", models.types.StringUUID(), nullable=False),
sa.Column("content", sa.Text(), nullable=False),
sa.Column("created_by", models.types.StringUUID(), nullable=False),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.ForeignKeyConstraint(
["comment_id"],
["workflow_comments.id"],
name=op.f("workflow_comment_replies_comment_id_fkey"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
)
with op.batch_alter_table("workflow_comment_replies", schema=None) as batch_op:
batch_op.create_index("comment_replies_comment_idx", ["comment_id"], unique=False)
batch_op.create_index("comment_replies_created_at_idx", ["created_at"], unique=False)
op.create_table(
"workflow_comment_mentions",
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
sa.Column("comment_id", models.types.StringUUID(), nullable=False),
sa.Column("reply_id", models.types.StringUUID(), nullable=True),
sa.Column("mentioned_user_id", models.types.StringUUID(), nullable=False),
sa.ForeignKeyConstraint(
["comment_id"],
["workflow_comments.id"],
name=op.f("workflow_comment_mentions_comment_id_fkey"),
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["reply_id"],
["workflow_comment_replies.id"],
name=op.f("workflow_comment_mentions_reply_id_fkey"),
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
)
with op.batch_alter_table("workflow_comment_mentions", schema=None) as batch_op:
batch_op.create_index("comment_mentions_comment_idx", ["comment_id"], unique=False)
batch_op.create_index("comment_mentions_reply_idx", ["reply_id"], unique=False)
batch_op.create_index("comment_mentions_user_idx", ["mentioned_user_id"], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("workflow_comment_mentions", schema=None) as batch_op:
batch_op.drop_index("comment_mentions_user_idx")
batch_op.drop_index("comment_mentions_reply_idx")
batch_op.drop_index("comment_mentions_comment_idx")
op.drop_table("workflow_comment_mentions")
with op.batch_alter_table("workflow_comment_replies", schema=None) as batch_op:
batch_op.drop_index("comment_replies_created_at_idx")
batch_op.drop_index("comment_replies_comment_idx")
op.drop_table("workflow_comment_replies")
with op.batch_alter_table("workflow_comments", schema=None) as batch_op:
batch_op.drop_index("workflow_comments_created_at_idx")
batch_op.drop_index("workflow_comments_app_idx")
op.drop_table("workflow_comments")
# ### end Alembic commands ###

View File

@ -0,0 +1,40 @@
"""Add app_asset_contents table for inline content caching.
Revision ID: 5ee0aa981887
Revises: aab323465866
Create Date: 2026-03-09 12:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = "5ee0aa981887"
down_revision = "6b5f9f8b1a2c"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"app_asset_contents",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("app_id", models.types.StringUUID(), nullable=False),
sa.Column("node_id", models.types.StringUUID(), nullable=False),
sa.Column("content", sa.Text(), nullable=False, server_default=""),
sa.Column("size", sa.Integer(), nullable=False, server_default="0"),
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
sa.Column("updated_at", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
sa.PrimaryKeyConstraint("id", name="app_asset_contents_pkey"),
sa.UniqueConstraint("tenant_id", "app_id", "node_id", name="uq_asset_content_node"),
)
op.create_index("idx_asset_content_app", "app_asset_contents", ["tenant_id", "app_id"])
def downgrade() -> None:
op.drop_index("idx_asset_content_app", table_name="app_asset_contents")
op.drop_table("app_asset_contents")

89
api/models/app_asset.py Normal file
View File

@ -0,0 +1,89 @@
from datetime import datetime
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, Integer, String, func
from sqlalchemy.orm import Mapped, mapped_column
from core.app.entities.app_asset_entities import AppAssetFileTree
from .base import Base
from .types import LongText, StringUUID
class AppAssets(Base):
__tablename__ = "app_assets"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="app_assets_pkey"),
sa.Index("app_assets_version_idx", "tenant_id", "app_id", "version"),
)
VERSION_DRAFT = "draft"
VERSION_PUBLISHED = "published"
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
version: Mapped[str] = mapped_column(String(255), nullable=False)
_asset_tree: Mapped[str] = mapped_column("asset_tree", LongText, nullable=False, default='{"nodes":[]}')
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by: Mapped[str | None] = mapped_column(StringUUID)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
default=func.current_timestamp(),
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
)
@property
def asset_tree(self) -> AppAssetFileTree:
if not self._asset_tree:
return AppAssetFileTree()
return AppAssetFileTree.model_validate_json(self._asset_tree)
@asset_tree.setter
def asset_tree(self, value: AppAssetFileTree) -> None:
self._asset_tree = value.model_dump_json()
def __repr__(self) -> str:
return f"<AppAssets(id={self.id}, app_id={self.app_id}, version={self.version})>"
class AppAssetContent(Base):
"""Inline content cache for app asset draft files.
Acts as a read-through cache for S3: text-like asset content is dual-written
here on save and read from DB first (falling back to S3 on miss with sync backfill).
Keyed by (tenant_id, app_id, node_id) stores only the current draft content,
not published snapshots.
See core/app_assets/content_accessor.py for the accessor abstraction that
manages the DB/S3 read-through and dual-write logic.
"""
__tablename__ = "app_asset_contents"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="app_asset_contents_pkey"),
sa.UniqueConstraint("tenant_id", "app_id", "node_id", name="uq_asset_content_node"),
sa.Index("idx_asset_content_app", "tenant_id", "app_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
node_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False, default="")
size: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
default=func.current_timestamp(),
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
)
def __repr__(self) -> str:
return f"<AppAssetContent(id={self.id}, node_id={self.node_id})>"

210
api/models/comment.py Normal file
View File

@ -0,0 +1,210 @@
"""Workflow comment models."""
from datetime import datetime
from typing import Optional
from sqlalchemy import Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .account import Account
from .base import Base
from .engine import db
from .types import StringUUID
class WorkflowComment(Base):
"""Workflow comment model for canvas commenting functionality.
Comments are associated with apps rather than specific workflow versions,
since an app has only one draft workflow at a time and comments should persist
across workflow version changes.
Attributes:
id: Comment ID
tenant_id: Workspace ID
app_id: App ID (primary association, comments belong to apps)
position_x: X coordinate on canvas
position_y: Y coordinate on canvas
content: Comment content
created_by: Creator account ID
created_at: Creation time
updated_at: Last update time
resolved: Whether comment is resolved
resolved_at: Resolution time
resolved_by: Resolver account ID
"""
__tablename__ = "workflow_comments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
Index("workflow_comments_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position_x: Mapped[float] = mapped_column(db.Float)
position_y: Mapped[float] = mapped_column(db.Float)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime)
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
# Relationships
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
)
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
)
@property
def created_by_account(self):
"""Get creator account."""
if hasattr(self, "_created_by_account_cache"):
return self._created_by_account_cache
return db.session.get(Account, self.created_by)
def cache_created_by_account(self, account: Account | None) -> None:
"""Cache creator account to avoid extra queries."""
self._created_by_account_cache = account
@property
def resolved_by_account(self):
"""Get resolver account."""
if hasattr(self, "_resolved_by_account_cache"):
return self._resolved_by_account_cache
if self.resolved_by:
return db.session.get(Account, self.resolved_by)
return None
def cache_resolved_by_account(self, account: Account | None) -> None:
"""Cache resolver account to avoid extra queries."""
self._resolved_by_account_cache = account
@property
def reply_count(self):
"""Get reply count."""
return len(self.replies)
@property
def mention_count(self):
"""Get mention count."""
return len(self.mentions)
@property
def participants(self):
"""Get all participants (creator + repliers + mentioned users)."""
participant_ids = set()
# Add comment creator
participant_ids.add(self.created_by)
# Add reply creators
participant_ids.update(reply.created_by for reply in self.replies)
# Add mentioned users
participant_ids.update(mention.mentioned_user_id for mention in self.mentions)
# Get account objects
participants = []
for user_id in participant_ids:
account = db.session.get(Account, user_id)
if account:
participants.append(account)
return participants
class WorkflowCommentReply(Base):
"""Workflow comment reply model.
Attributes:
id: Reply ID
comment_id: Parent comment ID
content: Reply content
created_by: Creator account ID
created_at: Creation time
"""
__tablename__ = "workflow_comment_replies"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
Index("comment_replies_comment_idx", "comment_id"),
Index("comment_replies_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
@property
def created_by_account(self):
"""Get creator account."""
if hasattr(self, "_created_by_account_cache"):
return self._created_by_account_cache
return db.session.get(Account, self.created_by)
def cache_created_by_account(self, account: Account | None) -> None:
"""Cache creator account to avoid extra queries."""
self._created_by_account_cache = account
class WorkflowCommentMention(Base):
"""Workflow comment mention model.
Mentions are only for internal accounts since end users
cannot access workflow canvas and commenting features.
Attributes:
id: Mention ID
comment_id: Parent comment ID
mentioned_user_id: Mentioned account ID
"""
__tablename__ = "workflow_comment_mentions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
Index("comment_mentions_comment_idx", "comment_id"),
Index("comment_mentions_reply_idx", "reply_id"),
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
reply_id: Mapped[str | None] = mapped_column(
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
)
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
@property
def mentioned_user_account(self):
"""Get mentioned account."""
if hasattr(self, "_mentioned_user_account_cache"):
return self._mentioned_user_account_cache
return db.session.get(Account, self.mentioned_user_id)
def cache_mentioned_user_account(self, account: Account | None) -> None:
"""Cache mentioned account to avoid extra queries."""
self._mentioned_user_account_cache = account

80
api/models/sandbox.py Normal file
View File

@ -0,0 +1,80 @@
import json
from collections.abc import Mapping
from datetime import datetime
from typing import Any, cast
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column
from .base import TypeBase
from .types import LongText, StringUUID
class SandboxProviderSystemConfig(TypeBase):
"""
System-level sandbox provider configuration.
Stores default configuration for each provider type.
"""
__tablename__ = "sandbox_provider_system_config"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="sandbox_provider_system_config_pkey"),
sa.UniqueConstraint("provider_type", name="unique_sandbox_provider_system_config_type"),
)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
provider_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="e2b, docker, local, ssh")
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False, comment="Encrypted config JSON")
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
server_default=func.current_timestamp(),
server_onupdate=func.current_timestamp(),
init=False,
)
class SandboxProvider(TypeBase):
"""
Tenant-level sandbox provider configuration.
Each tenant can have one configuration per provider type.
Only one provider can be active at a time per tenant.
"""
__tablename__ = "sandbox_providers"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="sandbox_provider_pkey"),
sa.UniqueConstraint("tenant_id", "provider_type", "configure_type", name="unique_sandbox_provider_tenant_type"),
sa.Index("idx_sandbox_providers_tenant_id", "tenant_id"),
sa.Index("idx_sandbox_providers_tenant_active", "tenant_id", "is_active"),
)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="e2b, docker, local, ssh")
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False, comment="Encrypted config JSON")
configure_type: Mapped[str] = mapped_column(String(20), nullable=False, server_default="user", default="user")
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
server_default=func.current_timestamp(),
server_onupdate=func.current_timestamp(),
init=False,
)
@property
def config(self) -> Mapping[str, Any]:
return cast(Mapping[str, Any], json.loads(self.encrypted_config or "{}"))

View File

View File

@ -0,0 +1,26 @@
from collections.abc import Mapping
from dataclasses import dataclass
from enum import StrEnum
from typing import Any
class WorkflowFeatures(StrEnum):
SANDBOX = "sandbox"
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
RETRIEVER_RESOURCE = "retriever_resource"
SENSITIVE_WORD_AVOIDANCE = "sensitive_word_avoidance"
FILE_UPLOAD = "file_upload"
SUGGESTED_QUESTIONS_AFTER_ANSWER = "suggested_questions_after_answer"
@dataclass(frozen=True)
class WorkflowFeature:
enabled: bool
config: Mapping[str, Any]
@classmethod
def from_dict(cls, data: Mapping[str, Any] | None) -> "WorkflowFeature":
if data is None or not isinstance(data, dict):
return cls(enabled=False, config={})
return cls(enabled=bool(data.get("enabled", False)), config=data)

View File

@ -0,0 +1,226 @@
from __future__ import annotations
import json
from typing import TypedDict
from extensions.ext_redis import redis_client
SESSION_STATE_TTL_SECONDS = 3600
WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:"
WORKFLOW_LEADER_PREFIX = "workflow_leader:"
WORKFLOW_SKILL_LEADER_PREFIX = "workflow_skill_leader:"
WS_SID_MAP_PREFIX = "ws_sid_map:"
class WorkflowSessionInfo(TypedDict):
user_id: str
username: str
avatar: str | None
sid: str
connected_at: int
graph_active: bool
active_skill_file_id: str | None
class SidMapping(TypedDict):
workflow_id: str
user_id: str
class WorkflowCollaborationRepository:
def __init__(self) -> None:
self._redis = redis_client
def __repr__(self) -> str:
return f"{self.__class__.__name__}(redis_client={self._redis})"
@staticmethod
def workflow_key(workflow_id: str) -> str:
return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}"
@staticmethod
def leader_key(workflow_id: str) -> str:
return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}"
@staticmethod
def skill_leader_key(workflow_id: str, file_id: str) -> str:
return f"{WORKFLOW_SKILL_LEADER_PREFIX}{workflow_id}:{file_id}"
@staticmethod
def sid_key(sid: str) -> str:
return f"{WS_SID_MAP_PREFIX}{sid}"
@staticmethod
def _decode(value: str | bytes | None) -> str | None:
if value is None:
return None
if isinstance(value, bytes):
return value.decode("utf-8")
return value
def refresh_session_state(self, workflow_id: str, sid: str) -> None:
workflow_key = self.workflow_key(workflow_id)
sid_key = self.sid_key(sid)
if self._redis.exists(workflow_key):
self._redis.expire(workflow_key, SESSION_STATE_TTL_SECONDS)
if self._redis.exists(sid_key):
self._redis.expire(sid_key, SESSION_STATE_TTL_SECONDS)
def set_session_info(self, workflow_id: str, session_info: WorkflowSessionInfo) -> None:
workflow_key = self.workflow_key(workflow_id)
self._redis.hset(workflow_key, session_info["sid"], json.dumps(session_info))
self._redis.set(
self.sid_key(session_info["sid"]),
json.dumps({"workflow_id": workflow_id, "user_id": session_info["user_id"]}),
ex=SESSION_STATE_TTL_SECONDS,
)
self.refresh_session_state(workflow_id, session_info["sid"])
def get_session_info(self, workflow_id: str, sid: str) -> WorkflowSessionInfo | None:
raw = self._redis.hget(self.workflow_key(workflow_id), sid)
value = self._decode(raw)
if not value:
return None
try:
session_info = json.loads(value)
except (TypeError, json.JSONDecodeError):
return None
if not isinstance(session_info, dict):
return None
if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info:
return None
return {
"user_id": str(session_info["user_id"]),
"username": str(session_info["username"]),
"avatar": session_info.get("avatar"),
"sid": str(session_info["sid"]),
"connected_at": int(session_info.get("connected_at") or 0),
"graph_active": bool(session_info.get("graph_active")),
"active_skill_file_id": session_info.get("active_skill_file_id"),
}
def set_graph_active(self, workflow_id: str, sid: str, active: bool) -> None:
session_info = self.get_session_info(workflow_id, sid)
if not session_info:
return
session_info["graph_active"] = bool(active)
self._redis.hset(self.workflow_key(workflow_id), sid, json.dumps(session_info))
self.refresh_session_state(workflow_id, sid)
def is_graph_active(self, workflow_id: str, sid: str) -> bool:
session_info = self.get_session_info(workflow_id, sid)
if not session_info:
return False
return bool(session_info.get("graph_active") or False)
def set_active_skill_file(self, workflow_id: str, sid: str, file_id: str | None) -> None:
session_info = self.get_session_info(workflow_id, sid)
if not session_info:
return
session_info["active_skill_file_id"] = file_id
self._redis.hset(self.workflow_key(workflow_id), sid, json.dumps(session_info))
self.refresh_session_state(workflow_id, sid)
def get_active_skill_file_id(self, workflow_id: str, sid: str) -> str | None:
session_info = self.get_session_info(workflow_id, sid)
if not session_info:
return None
return session_info.get("active_skill_file_id")
def get_sid_mapping(self, sid: str) -> SidMapping | None:
raw = self._redis.get(self.sid_key(sid))
if not raw:
return None
value = self._decode(raw)
if not value:
return None
try:
return json.loads(value)
except (TypeError, json.JSONDecodeError):
return None
def delete_session(self, workflow_id: str, sid: str) -> None:
self._redis.hdel(self.workflow_key(workflow_id), sid)
self._redis.delete(self.sid_key(sid))
def session_exists(self, workflow_id: str, sid: str) -> bool:
return bool(self._redis.hexists(self.workflow_key(workflow_id), sid))
def sid_mapping_exists(self, sid: str) -> bool:
return bool(self._redis.exists(self.sid_key(sid)))
def get_session_sids(self, workflow_id: str) -> list[str]:
raw_sids = self._redis.hkeys(self.workflow_key(workflow_id))
decoded_sids: list[str] = []
for sid in raw_sids:
decoded = self._decode(sid)
if decoded:
decoded_sids.append(decoded)
return decoded_sids
def list_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]:
sessions_json = self._redis.hgetall(self.workflow_key(workflow_id))
users: list[WorkflowSessionInfo] = []
for session_info_json in sessions_json.values():
value = self._decode(session_info_json)
if not value:
continue
try:
session_info = json.loads(value)
except (TypeError, json.JSONDecodeError):
continue
if not isinstance(session_info, dict):
continue
if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info:
continue
users.append(
{
"user_id": str(session_info["user_id"]),
"username": str(session_info["username"]),
"avatar": session_info.get("avatar"),
"sid": str(session_info["sid"]),
"connected_at": int(session_info.get("connected_at") or 0),
"graph_active": bool(session_info.get("graph_active")),
"active_skill_file_id": session_info.get("active_skill_file_id"),
}
)
return users
def get_current_leader(self, workflow_id: str) -> str | None:
raw = self._redis.get(self.leader_key(workflow_id))
return self._decode(raw)
def get_skill_leader(self, workflow_id: str, file_id: str) -> str | None:
raw = self._redis.get(self.skill_leader_key(workflow_id, file_id))
return self._decode(raw)
def set_leader_if_absent(self, workflow_id: str, sid: str) -> bool:
return bool(self._redis.set(self.leader_key(workflow_id), sid, nx=True, ex=SESSION_STATE_TTL_SECONDS))
def set_leader(self, workflow_id: str, sid: str) -> None:
self._redis.set(self.leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS)
def set_skill_leader(self, workflow_id: str, file_id: str, sid: str) -> None:
self._redis.set(self.skill_leader_key(workflow_id, file_id), sid, ex=SESSION_STATE_TTL_SECONDS)
def delete_leader(self, workflow_id: str) -> None:
self._redis.delete(self.leader_key(workflow_id))
def delete_skill_leader(self, workflow_id: str, file_id: str) -> None:
self._redis.delete(self.skill_leader_key(workflow_id, file_id))
def expire_leader(self, workflow_id: str) -> None:
self._redis.expire(self.leader_key(workflow_id), SESSION_STATE_TTL_SECONDS)
def expire_skill_leader(self, workflow_id: str, file_id: str) -> None:
self._redis.expire(self.skill_leader_key(workflow_id, file_id), SESSION_STATE_TTL_SECONDS)
def get_active_skill_session_sids(self, workflow_id: str, file_id: str) -> list[str]:
sessions = self.list_sessions(workflow_id)
return [session["sid"] for session in sessions if session.get("active_skill_file_id") == file_id]

View File

@ -0,0 +1,195 @@
"""Service for packaging and publishing app assets.
This service handles operations that require core.zip_sandbox,
separated from AppAssetService to avoid circular imports.
Dependency flow:
core/* -> AppAssetPackageService -> AppAssetService
(core modules can import this service without circular dependency)
Inline content optimisation:
``AssetItem`` objects returned by the build pipeline may carry an
in-process *content* field (e.g. resolved ``.md`` skill documents).
``AppAssetService.to_download_items()`` converts these into unified
``SandboxDownloadItem`` instances, and ``ZipSandbox.download_items()``
handles both inline and remote items natively.
"""
import logging
from uuid import uuid4
from sqlalchemy.orm import Session
from core.app.entities.app_asset_entities import AppAssetFileTree
from core.app_assets.builder import AssetBuildPipeline, BuildContext
from core.app_assets.builder.file_builder import FileBuilder
from core.app_assets.builder.skill_builder import SkillBuilder
from core.app_assets.entities.assets import AssetItem
from core.app_assets.storage import AssetPaths
from core.zip_sandbox import ZipSandbox
from models.app_asset import AppAssets
from models.model import App
logger = logging.getLogger(__name__)
class AppAssetPackageService:
"""Service for packaging and publishing app assets.
This service is designed to be imported by core/* modules without
causing circular imports. It depends on AppAssetService for basic
asset operations but provides the packaging/publishing functionality
that requires core.zip_sandbox.
"""
@staticmethod
def get_tenant_app_assets(tenant_id: str, assets_id: str) -> AppAssets:
"""Get app assets by tenant_id and assets_id.
This is a read-only operation that doesn't require AppAssetService.
"""
from extensions.ext_database import db
with Session(db.engine, expire_on_commit=False) as session:
app_assets = (
session.query(AppAssets)
.filter(
AppAssets.tenant_id == tenant_id,
AppAssets.id == assets_id,
)
.first()
)
if not app_assets:
raise ValueError(f"App assets not found for tenant_id={tenant_id}, assets_id={assets_id}")
return app_assets
@staticmethod
def get_draft_asset_items(tenant_id: str, app_id: str, file_tree: AppAssetFileTree) -> list[AssetItem]:
"""Convert file tree to asset items for packaging."""
files = file_tree.walk_files()
return [
AssetItem(
asset_id=f.id,
path=file_tree.get_path(f.id),
file_name=f.name,
extension=f.extension,
storage_key=AssetPaths.draft(tenant_id, app_id, f.id),
)
for f in files
]
@staticmethod
def package_and_upload(
*,
assets: list[AssetItem],
upload_url: str,
tenant_id: str,
app_id: str,
user_id: str,
storage_key: str = "",
) -> None:
"""Package assets into a ZIP and upload directly to the given URL.
Uses ``AppAssetService.to_download_items()`` to convert assets
into unified download items, then ``ZipSandbox.download_items()``
handles both inline content and remote presigned URLs natively.
When *assets* is empty an empty ZIP is written directly to storage
using *storage_key*, bypassing the HTTP ticket URL.
"""
from services.app_asset_service import AppAssetService
if not assets:
import io
import zipfile
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w"):
pass
buf.seek(0)
# Write directly to storage instead of going through the HTTP
# ticket URL. The ticket URL (FILES_API_URL) is designed for
# sandbox containers (agentbox) and is not routable from the api
# container in standard Docker Compose deployments.
if storage_key:
from extensions.ext_storage import storage
storage.save(storage_key, buf.getvalue())
else:
import requests
requests.put(upload_url, data=buf.getvalue(), timeout=30)
return
download_items = AppAssetService.to_download_items(assets)
with ZipSandbox(tenant_id=tenant_id, user_id=user_id, app_id="asset-packager") as zs:
zs.download_items(download_items)
archive = zs.zip()
zs.upload(archive, upload_url)
@staticmethod
def publish(session: Session, app_model: App, account_id: str, workflow_id: str) -> AppAssets:
"""Publish app assets for a workflow.
Creates a versioned copy of draft assets and packages them for
runtime use. The build ZIP contains resolved ``.md`` content
(inline from ``SkillBuilder``) and raw draft content for all
other files. A separate source ZIP snapshots the raw drafts for
later export.
"""
from services.app_asset_service import AppAssetService
tenant_id = app_model.tenant_id
app_id = app_model.id
assets = AppAssetService.get_or_create_assets(session, app_model, account_id)
tree = assets.asset_tree
publish_id = str(uuid4())
published = AppAssets(
id=publish_id,
tenant_id=tenant_id,
app_id=app_id,
version=workflow_id,
created_by=account_id,
)
published.asset_tree = tree
session.add(published)
session.flush()
asset_storage = AppAssetService.get_storage()
accessor = AppAssetService.get_accessor(tenant_id, app_id)
build_pipeline = AssetBuildPipeline([SkillBuilder(accessor=accessor), FileBuilder()])
ctx = BuildContext(tenant_id=tenant_id, app_id=app_id, build_id=publish_id)
built_assets = build_pipeline.build_all(tree, ctx)
# Runtime ZIP: resolved .md (inline) + raw draft (remote).
runtime_zip_key = AssetPaths.build_zip(tenant_id, app_id, publish_id)
runtime_upload_url = asset_storage.get_upload_url(runtime_zip_key)
AppAssetPackageService.package_and_upload(
assets=built_assets,
upload_url=runtime_upload_url,
tenant_id=tenant_id,
app_id=app_id,
user_id=account_id,
storage_key=runtime_zip_key,
)
# Source ZIP: all raw draft content (for export/restore).
source_items = AppAssetService.get_draft_assets(tenant_id, app_id)
source_key = AssetPaths.source_zip(tenant_id, app_id, workflow_id)
source_upload_url = asset_storage.get_upload_url(source_key)
AppAssetPackageService.package_and_upload(
assets=source_items,
upload_url=source_upload_url,
tenant_id=tenant_id,
app_id=app_id,
user_id=account_id,
storage_key=source_key,
)
return published

View File

@ -0,0 +1,443 @@
"""Service for upgrading Classic runtime apps to Sandboxed runtime via clone-and-convert.
The upgrade flow:
1. Clone the source app via DSL export/import
2. On the cloned app's draft workflow, convert Agent nodes to LLM nodes
3. Rewrite variable references for all LLM nodes (old output names new generation-based names)
4. Enable sandbox feature flag
The original app is never modified; the user gets a new sandboxed copy.
"""
import json
import logging
import re
import uuid
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from models import App, Workflow
from models.workflow_features import WorkflowFeatures
from services.app_dsl_service import AppDslService, ImportMode
logger = logging.getLogger(__name__)
_VAR_REWRITES: dict[str, list[str]] = {
"text": ["generation", "content"],
"reasoning_content": ["generation", "reasoning_content"],
}
_PASSTHROUGH_KEYS = (
"version",
"error_strategy",
"default_value",
"retry_config",
"parent_node_id",
"isInLoop",
"loop_id",
"isInIteration",
"iteration_id",
)
class AppRuntimeUpgradeService:
"""Upgrades a Classic-runtime app to Sandboxed runtime by cloning and converting.
Holds an active SQLAlchemy session; the caller is responsible for commit/rollback.
"""
session: Session
def __init__(self, session: Session) -> None:
self.session = session
def upgrade(self, app_model: App, account: Any) -> dict[str, Any]:
"""Clone *app_model* and upgrade the clone to sandboxed runtime.
Returns:
dict with keys: result, new_app_id, converted_agents, skipped_agents.
"""
workflow = self._get_draft_workflow(app_model)
if not workflow:
return {"result": "no_draft"}
if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled:
return {"result": "already_sandboxed"}
new_app = self._clone_app(app_model, account)
new_workflow = self._get_draft_workflow(new_app)
if not new_workflow:
return {"result": "no_draft"}
graph = json.loads(new_workflow.graph) if new_workflow.graph else {}
nodes = graph.get("nodes", [])
converted, skipped = _convert_agent_nodes(nodes)
_enable_computer_use_for_existing_llm_nodes(nodes)
llm_node_ids = {n["id"] for n in nodes if n.get("data", {}).get("type") == "llm"}
_rewrite_variable_references(nodes, llm_node_ids)
new_workflow.graph = json.dumps(graph)
features = json.loads(new_workflow.features) if new_workflow.features else {}
features.setdefault("sandbox", {})["enabled"] = True
new_workflow.features = json.dumps(features)
return {
"result": "success",
"new_app_id": str(new_app.id),
"converted_agents": converted,
"skipped_agents": skipped,
}
def _get_draft_workflow(self, app_model: App) -> Workflow | None:
stmt = select(Workflow).where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == "draft",
)
return self.session.scalar(stmt)
def _clone_app(self, app_model: App, account: Any) -> App:
dsl_service = AppDslService(self.session)
yaml_content = dsl_service.export_dsl(app_model=app_model, include_secret=True)
result = dsl_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
name=f"{app_model.name} (Sandboxed)",
)
stmt = select(App).where(App.id == result.app_id)
new_app = self.session.scalar(stmt)
if not new_app:
raise RuntimeError(f"Cloned app not found: {result.app_id}")
return new_app
# ---------------------------------------------------------------------------
# Pure conversion functions (no DB access)
# ---------------------------------------------------------------------------
def _convert_agent_nodes(nodes: list[dict[str, Any]]) -> tuple[int, int]:
"""Convert Agent nodes to LLM nodes in-place. Returns (converted_count, skipped_count)."""
converted = 0
for node in nodes:
data = node.get("data", {})
if data.get("type") != "agent":
continue
node_id = node.get("id", "?")
node["data"] = _agent_data_to_llm_data(data)
logger.info("Converted agent node %s to LLM", node_id)
converted += 1
return converted, 0
def _agent_data_to_llm_data(agent_data: dict[str, Any]) -> dict[str, Any]:
"""Map an Agent node's data dict to an LLM node's data dict.
Always returns a valid LLM data dict. If the agent has no model selected,
produces an empty LLM node with agent mode (computer_use) enabled.
"""
params = agent_data.get("agent_parameters") or {}
model_param = params.get("model", {}) if isinstance(params, dict) else {}
model_value = model_param.get("value") if isinstance(model_param, dict) else None
if isinstance(model_value, dict) and model_value.get("provider") and model_value.get("model"):
model_config = {
"provider": model_value["provider"],
"name": model_value["model"],
"mode": model_value.get("mode", "chat"),
"completion_params": model_value.get("completion_params", {}),
}
else:
model_config = {"provider": "", "name": "", "mode": "chat", "completion_params": {}}
tools_param = params.get("tools", {})
tools_value = tools_param.get("value", []) if isinstance(tools_param, dict) else []
tools_meta, tool_settings = _convert_tools(tools_value if isinstance(tools_value, list) else [])
instruction_param = params.get("instruction", {})
instruction = instruction_param.get("value", "") if isinstance(instruction_param, dict) else ""
query_param = params.get("query", {})
query_value = query_param.get("value", "") if isinstance(query_param, dict) else ""
has_tools = bool(tools_meta)
prompt_template = _build_prompt_template(
instruction,
query_value,
skill=has_tools,
tools=tools_value if has_tools else None,
)
max_iter_param = params.get("maximum_iterations", {})
max_iterations = max_iter_param.get("value", 100) if isinstance(max_iter_param, dict) else 100
context_config = _extract_context(params)
vision_config = _extract_vision(params)
llm_data: dict[str, Any] = {
"type": "llm",
"title": agent_data.get("title", "LLM"),
"desc": agent_data.get("desc", ""),
"model": model_config,
"prompt_template": prompt_template,
"prompt_config": {"jinja2_variables": []},
"memory": agent_data.get("memory"),
"context": context_config,
"vision": vision_config,
"computer_use": True,
"structured_output_switch_on": False,
"reasoning_format": "separated",
"tools": tools_meta,
"tool_settings": tool_settings,
"max_iterations": max_iterations,
}
for key in _PASSTHROUGH_KEYS:
if key in agent_data:
llm_data[key] = agent_data[key]
return llm_data
def _extract_context(params: dict[str, Any]) -> dict[str, Any]:
"""Extract context config from agent_parameters for LLM node format.
Agent stores context as a variable selector in agent_parameters.context.value,
e.g. ["knowledge_retrieval_node_id", "result"]. Maps to LLM ContextConfig.
"""
if not isinstance(params, dict):
return {"enabled": False}
ctx_param = params.get("context", {})
ctx_value = ctx_param.get("value") if isinstance(ctx_param, dict) else None
if isinstance(ctx_value, list) and len(ctx_value) >= 2 and all(isinstance(s, str) for s in ctx_value):
return {"enabled": True, "variable_selector": ctx_value}
return {"enabled": False}
def _extract_vision(params: dict[str, Any]) -> dict[str, Any]:
"""Extract vision config from agent_parameters for LLM node format."""
if not isinstance(params, dict):
return {"enabled": False}
vision_param = params.get("vision", {})
vision_value = vision_param.get("value") if isinstance(vision_param, dict) else None
if isinstance(vision_value, dict) and vision_value.get("enabled"):
return vision_value
if isinstance(vision_value, bool) and vision_value:
return {"enabled": True}
return {"enabled": False}
def _enable_computer_use_for_existing_llm_nodes(nodes: list[dict[str, Any]]) -> None:
"""Enable computer_use for existing LLM nodes that have tools configured.
After upgrade, the sandbox runtime requires computer_use=true for tool calling.
Existing LLM nodes from classic mode may have tools but computer_use=false.
"""
for node in nodes:
data = node.get("data", {})
if data.get("type") != "llm":
continue
tools = data.get("tools", [])
if tools and not data.get("computer_use"):
data["computer_use"] = True
logger.info("Enabled computer_use for LLM node %s with %d tools", node.get("id", "?"), len(tools))
def _convert_tools(
tools_input: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""Convert agent tool dicts to (ToolMetadata[], ToolSetting[]).
Agent tools in graph JSON already use provider_name/settings/parameters
the same field names as LLM ToolMetadata. We pass them through with defaults
for any missing fields.
"""
tools_meta: list[dict[str, Any]] = []
tool_settings: list[dict[str, Any]] = []
for ts in tools_input:
if not isinstance(ts, dict):
continue
provider_name = ts.get("provider_name", "")
tool_name = ts.get("tool_name", "")
tool_type = ts.get("type", "builtin")
tools_meta.append(
{
"enabled": True,
"type": tool_type,
"provider_name": provider_name,
"tool_name": tool_name,
"plugin_unique_identifier": ts.get("plugin_unique_identifier"),
"credential_id": ts.get("credential_id"),
"parameters": ts.get("parameters", {}),
"settings": ts.get("settings", {}) or ts.get("tool_configuration", {}),
"extra": ts.get("extra", {}),
}
)
tool_settings.append(
{
"type": tool_type,
"provider": provider_name,
"tool_name": tool_name,
"enabled": True,
}
)
return tools_meta, tool_settings
def _build_prompt_template(
instruction: Any,
query: Any,
*,
skill: bool = False,
tools: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
"""Build LLM prompt_template from Agent instruction and query values.
When *skill* is True each message gets ``"skill": True`` so the sandbox
engine treats the prompt as a skill document.
When *tools* is provided, tool reference placeholders
(``§[tool].[provider].[name].[uuid]§``) are appended to the system
message and the corresponding ``ToolReference`` entries are placed in the
message's ``metadata.tools`` dict so the skill assembler can resolve them.
Tools from the same provider are grouped into a single token list.
"""
messages: list[dict[str, Any]] = []
system_text = instruction if isinstance(instruction, str) else (str(instruction) if instruction else "")
metadata: dict[str, Any] | None = None
if tools:
tool_refs: dict[str, dict[str, Any]] = {}
provider_groups: dict[str, list[str]] = {}
for ts in tools:
if not isinstance(ts, dict):
continue
tool_uuid = str(uuid.uuid4())
provider_id = ts.get("provider_name", "")
tool_name = ts.get("tool_name", "")
tool_type = ts.get("type", "builtin")
token = f"§[tool].[{provider_id}].[{tool_name}].[{tool_uuid}"
provider_groups.setdefault(provider_id, []).append(token)
tool_refs[tool_uuid] = {
"type": tool_type,
"configuration": {"fields": []},
"enabled": True,
**({"credential_id": ts.get("credential_id")} if ts.get("credential_id") else {}),
}
if provider_groups:
group_texts: list[str] = []
for tokens in provider_groups.values():
if len(tokens) == 1:
group_texts.append(tokens[0])
else:
group_texts.append("[" + ",".join(tokens) + "]")
all_tools_text = " ".join(group_texts)
system_text = f"{system_text}\n\n{all_tools_text}" if system_text else all_tools_text
metadata = {"tools": tool_refs, "files": []}
if system_text:
msg: dict[str, Any] = {"role": "system", "text": system_text, "skill": skill}
if metadata:
msg["metadata"] = metadata
messages.append(msg)
if isinstance(query, list) and len(query) >= 2:
template_ref = "{{#" + ".".join(str(s) for s in query) + "#}}"
messages.append({"role": "user", "text": template_ref, "skill": skill})
elif query:
messages.append({"role": "user", "text": str(query), "skill": skill})
if not messages:
messages.append({"role": "user", "text": "", "skill": skill})
return messages
def _rewrite_variable_references(nodes: list[dict[str, Any]], llm_ids: set[str]) -> None:
"""Recursively walk all node data and rewrite variable references for LLM nodes.
Handles two forms:
- Structured selectors: [node_id, "text"] [node_id, "generation", "content"]
- Template strings: {{#node_id.text#}} → {{#node_id.generation.content#}}
"""
if not llm_ids:
return
escaped_ids = [re.escape(nid) for nid in llm_ids]
patterns: list[tuple[re.Pattern[str], str]] = []
for old_name, new_path in _VAR_REWRITES.items():
pattern = re.compile(r"\{\{#(" + "|".join(escaped_ids) + r")\." + re.escape(old_name) + r"#\}\}")
replacement = r"{{#\1." + ".".join(new_path) + r"#}}"
patterns.append((pattern, replacement))
for node in nodes:
data = node.get("data", {})
_walk_and_rewrite(data, llm_ids, patterns)
def _walk_and_rewrite(
obj: Any,
llm_ids: set[str],
template_patterns: list[tuple[re.Pattern[str], str]],
) -> Any:
"""Recursively rewrite variable references in a nested data structure."""
if isinstance(obj, dict):
for key, value in obj.items():
obj[key] = _walk_and_rewrite(value, llm_ids, template_patterns)
return obj
if isinstance(obj, list):
if _is_variable_selector(obj, llm_ids):
return _rewrite_selector(obj)
for i, item in enumerate(obj):
obj[i] = _walk_and_rewrite(item, llm_ids, template_patterns)
return obj
if isinstance(obj, str):
for pattern, replacement in template_patterns:
obj = pattern.sub(replacement, obj)
return obj
return obj
def _is_variable_selector(lst: list, llm_ids: set[str]) -> bool:
"""Check if a list is a structured variable selector pointing to an LLM node output."""
if len(lst) < 2:
return False
if not all(isinstance(s, str) for s in lst):
return False
return lst[0] in llm_ids and lst[1] in _VAR_REWRITES
def _rewrite_selector(selector: list[str]) -> list[str]:
"""Rewrite [node_id, "text"] → [node_id, "generation", "content"]."""
old_field = selector[1]
new_path = _VAR_REWRITES[old_field]
return [selector[0]] + new_path + selector[2:]

View File

@ -0,0 +1,103 @@
"""Service for the app_asset_contents table.
Provides single-node and batch DB operations for the inline content cache.
All methods are static and open their own short-lived sessions.
Collaborators:
- models.app_asset.AppAssetContent (SQLAlchemy model)
- core.app_assets.accessor (accessor abstraction that calls this service)
"""
import logging
from sqlalchemy import delete, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.app_asset import AppAssetContent
logger = logging.getLogger(__name__)
class AssetContentService:
"""DB operations for the inline asset content cache.
All methods are static. All queries are scoped by tenant_id + app_id.
"""
@staticmethod
def get(tenant_id: str, app_id: str, node_id: str) -> str | None:
"""Get cached content for a single node. Returns None on miss."""
with Session(db.engine) as session:
return session.execute(
select(AppAssetContent.content).where(
AppAssetContent.tenant_id == tenant_id,
AppAssetContent.app_id == app_id,
AppAssetContent.node_id == node_id,
)
).scalar_one_or_none()
@staticmethod
def get_many(tenant_id: str, app_id: str, node_ids: list[str]) -> dict[str, str]:
"""Batch get. Returns {node_id: content} for hits only."""
if not node_ids:
return {}
with Session(db.engine) as session:
rows = session.execute(
select(AppAssetContent.node_id, AppAssetContent.content).where(
AppAssetContent.tenant_id == tenant_id,
AppAssetContent.app_id == app_id,
AppAssetContent.node_id.in_(node_ids),
)
).all()
return {row.node_id: row.content for row in rows}
@staticmethod
def upsert(tenant_id: str, app_id: str, node_id: str, content: str, size: int) -> None:
"""Insert or update inline content for a single node."""
with Session(db.engine) as session:
stmt = pg_insert(AppAssetContent).values(
tenant_id=tenant_id,
app_id=app_id,
node_id=node_id,
content=content,
size=size,
)
stmt = stmt.on_conflict_do_update(
constraint="uq_asset_content_node",
set_={
"content": stmt.excluded.content,
"size": stmt.excluded.size,
},
)
session.execute(stmt)
session.commit()
@staticmethod
def delete(tenant_id: str, app_id: str, node_id: str) -> None:
"""Delete cached content for a single node."""
with Session(db.engine) as session:
session.execute(
delete(AppAssetContent).where(
AppAssetContent.tenant_id == tenant_id,
AppAssetContent.app_id == app_id,
AppAssetContent.node_id == node_id,
)
)
session.commit()
@staticmethod
def delete_many(tenant_id: str, app_id: str, node_ids: list[str]) -> None:
"""Delete cached content for multiple nodes."""
if not node_ids:
return
with Session(db.engine) as session:
session.execute(
delete(AppAssetContent).where(
AppAssetContent.tenant_id == tenant_id,
AppAssetContent.app_id == app_id,
AppAssetContent.node_id.in_(node_ids),
)
)
session.commit()

View File

@ -0,0 +1,17 @@
from .base import BaseServiceError
class AppAssetNodeNotFoundError(BaseServiceError):
pass
class AppAssetParentNotFoundError(BaseServiceError):
pass
class AppAssetPathConflictError(BaseServiceError):
pass
class AppAssetNodeTooLargeError(BaseServiceError):
pass

View File

@ -0,0 +1,37 @@
"""
LLM Generation Detail Service.
Provides methods to query and attach generation details to workflow node executions
and messages, avoiding N+1 query problems.
"""
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.llm_generation_entities import LLMGenerationDetailData
from models import LLMGenerationDetail
class LLMGenerationService:
"""Service for handling LLM generation details."""
def __init__(self, session: Session):
self._session = session
def get_generation_detail_for_message(self, message_id: str) -> LLMGenerationDetailData | None:
"""Query generation detail for a specific message."""
stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id == message_id)
detail = self._session.scalars(stmt).first()
return detail.to_domain_model() if detail else None
def get_generation_details_for_messages(
self,
message_ids: list[str],
) -> dict[str, LLMGenerationDetailData]:
"""Batch query generation details for multiple messages."""
if not message_ids:
return {}
stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id.in_(message_ids))
details = self._session.scalars(stmt).all()
return {detail.message_id: detail.to_domain_model() for detail in details if detail.message_id}

View File

@ -0,0 +1,204 @@
"""Service for extracting tool dependencies from LLM node skill prompts.
Two public entry points:
- ``extract_tool_dependencies`` takes raw node data from the client,
real-time builds a ``SkillBundle`` from current draft ``.md`` assets,
and resolves transitive tool dependencies. Used by the per-node POST
endpoint.
- ``get_workflow_skills`` scans all LLM nodes in a persisted draft
workflow and returns per-node skill info. Uses a cached bundle.
"""
from __future__ import annotations
import json
import logging
from collections.abc import Mapping
from functools import reduce
from typing import Any, cast
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
from core.sandbox.entities.config import AppAssets
from core.skill.assembler import SkillBundleAssembler, SkillDocumentAssembler
from core.skill.entities.skill_bundle import SkillBundle
from core.skill.entities.skill_document import SkillDocument
from core.skill.entities.skill_metadata import SkillMetadata
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
from core.skill.skill_manager import SkillManager
from graphon.enums import BuiltinNodeTypes
from models.model import App
from services.app_asset_service import AppAssetService
logger = logging.getLogger(__name__)
class SkillService:
"""Service for managing and retrieving skill information from workflows."""
# ------------------------------------------------------------------
# Per-node: client sends node data, server builds bundle in real-time
# ------------------------------------------------------------------
@staticmethod
def extract_tool_dependencies(
app: App,
node_data: Mapping[str, Any],
user_id: str,
) -> list[ToolDependency]:
"""Extract tool dependencies from an LLM node's skill prompts.
Builds a fresh ``SkillBundle`` from current draft ``.md`` assets
every time no cached bundle is used. The caller supplies the
full node ``data`` dict directly (not a ``node_id``).
Returns an empty list when the node has no skill prompts or when
no draft assets exist.
"""
if node_data.get("type", "") != BuiltinNodeTypes.LLM:
return []
if not SkillService._has_skill(node_data):
return []
bundle = SkillService._build_bundle(app, user_id)
if bundle is None:
return []
return SkillService._resolve_prompt_dependencies(node_data, bundle)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _has_skill(node_data: Mapping[str, Any]) -> bool:
"""Check if node has any skill prompts."""
prompt_template_raw = node_data.get("prompt_template", [])
if isinstance(prompt_template_raw, list):
for prompt_item in cast(list[object], prompt_template_raw):
if isinstance(prompt_item, dict) and prompt_item.get("skill", False):
return True
return False
@staticmethod
def _build_bundle(app: App, user_id: str) -> SkillBundle | None:
"""Real-time build a SkillBundle from current draft .md assets.
Reads all ``.md`` nodes from the draft file tree, bulk-loads
their content from the DB cache, parses into ``SkillDocument``
objects, and assembles a full bundle with transitive dependency
resolution.
The bundle is **not** persisted it is built fresh for each
request so the response always reflects the latest draft state.
"""
assets = AppAssetService.get_assets(
tenant_id=app.tenant_id,
app_id=app.id,
user_id=user_id,
is_draft=True,
)
if not assets:
return None
file_tree: AppAssetFileTree = assets.asset_tree
if file_tree.empty():
return SkillBundle(assets_id=assets.id, asset_tree=file_tree)
# Collect all .md file nodes from the tree.
md_nodes: list[AppAssetNode] = [n for n in file_tree.walk_files() if n.extension == "md"]
if not md_nodes:
return SkillBundle(assets_id=assets.id, asset_tree=file_tree)
# Bulk-load content from DB (with S3 fallback).
accessor = AppAssetService.get_accessor(app.tenant_id, app.id)
raw_contents = accessor.bulk_load(md_nodes)
# Parse into SkillDocuments.
documents: dict[str, SkillDocument] = {}
for node in md_nodes:
raw = raw_contents.get(node.id)
if not raw:
continue
try:
data = {"skill_id": node.id, **json.loads(raw)}
documents[node.id] = SkillDocument.model_validate(data)
except (json.JSONDecodeError, TypeError, ValueError):
logger.warning("Skipping unparseable skill document node_id=%s", node.id)
continue
return SkillBundleAssembler(file_tree).assemble_bundle(documents, assets.id)
@staticmethod
def _resolve_prompt_dependencies(
node_data: Mapping[str, Any],
bundle: SkillBundle,
) -> list[ToolDependency]:
"""Resolve tool dependencies from skill prompts against a bundle."""
assembler = SkillDocumentAssembler(bundle)
tool_deps_list: list[ToolDependencies] = []
prompt_template_raw = node_data.get("prompt_template", [])
if not isinstance(prompt_template_raw, list):
return []
for prompt_item in cast(list[object], prompt_template_raw):
if not isinstance(prompt_item, dict):
continue
prompt = cast(dict[str, Any], prompt_item)
if not prompt.get("skill", False):
continue
text_raw = prompt.get("text", "")
text = text_raw if isinstance(text_raw, str) else str(text_raw)
metadata_obj: object = prompt.get("metadata")
metadata = cast(dict[str, Any], metadata_obj) if isinstance(metadata_obj, dict) else {}
skill_entry = assembler.assemble_document(
document=SkillDocument(
skill_id="anonymous",
content=text,
metadata=SkillMetadata.model_validate(metadata),
),
base_path=AppAssets.PATH,
)
tool_deps_list.append(skill_entry.dependance.tools)
if not tool_deps_list:
return []
merged = reduce(lambda x, y: x.merge(y), tool_deps_list)
return merged.dependencies
@staticmethod
def _extract_tool_dependencies_cached(
app: App,
node_data: Mapping[str, Any],
user_id: str,
) -> list[ToolDependency]:
"""Extract tool dependencies using a cached SkillBundle.
Used by ``get_workflow_skills`` for the whole-workflow endpoint.
"""
assets = AppAssetService.get_assets(
tenant_id=app.tenant_id,
app_id=app.id,
user_id=user_id,
is_draft=True,
)
if not assets:
return []
try:
bundle = SkillManager.load_bundle(
tenant_id=app.tenant_id,
app_id=app.id,
assets_id=assets.id,
)
except Exception:
logger.debug("Failed to load cached skill bundle for app_id=%s", app.id, exc_info=True)
return []
return SkillService._resolve_prompt_dependencies(node_data, bundle)

View File

@ -0,0 +1,153 @@
"""Storage ticket service for generating opaque download/upload URLs.
This service provides a ticket-based approach for file access. Instead of exposing
the real storage key in URLs, it generates a random UUID token and stores the mapping
in Redis with a TTL.
Usage:
from services.storage_ticket_service import StorageTicketService
# Generate a download ticket
url = StorageTicketService.create_download_url("path/to/file.txt", expires_in=300)
# Generate an upload ticket
url = StorageTicketService.create_upload_url("path/to/file.txt", expires_in=300, max_bytes=10*1024*1024)
URL format:
{FILES_API_URL}/files/storage-files/{token}
The token is validated by looking up the Redis key, which contains:
- op: "download" or "upload"
- storage_key: the real storage path
- max_bytes: (upload only) maximum allowed upload size
- filename: suggested filename for Content-Disposition header
"""
import logging
from typing import Literal
from uuid import uuid4
from pydantic import BaseModel
from configs import dify_config
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
TICKET_KEY_PREFIX = "storage_files"
DEFAULT_DOWNLOAD_TTL = 300 # 5 minutes
DEFAULT_UPLOAD_TTL = 300 # 5 minutes
DEFAULT_MAX_UPLOAD_BYTES = 100 * 1024 * 1024 # 100MB
class StorageTicket(BaseModel):
"""Represents a storage access ticket."""
op: Literal["download", "upload"]
storage_key: str
max_bytes: int | None = None # upload only
filename: str | None = None # suggested filename for download
class StorageTicketService:
"""Service for creating and validating storage access tickets."""
@classmethod
def create_download_url(
cls,
storage_key: str,
*,
expires_in: int = DEFAULT_DOWNLOAD_TTL,
filename: str | None = None,
) -> str:
"""Create a download ticket and return the URL.
Args:
storage_key: The real storage path
expires_in: TTL in seconds (default 300)
filename: Suggested filename for Content-Disposition header
Returns:
Full URL with token
"""
if filename is None:
filename = storage_key.rsplit("/", 1)[-1]
ticket = StorageTicket(op="download", storage_key=storage_key, filename=filename)
token = cls._store_ticket(ticket, expires_in)
return cls._build_url(token)
@classmethod
def create_upload_url(
cls,
storage_key: str,
*,
expires_in: int = DEFAULT_UPLOAD_TTL,
max_bytes: int = DEFAULT_MAX_UPLOAD_BYTES,
) -> str:
"""Create an upload ticket and return the URL.
Args:
storage_key: The real storage path
expires_in: TTL in seconds (default 300)
max_bytes: Maximum allowed upload size in bytes
Returns:
Full URL with token
"""
ticket = StorageTicket(op="upload", storage_key=storage_key, max_bytes=max_bytes)
token = cls._store_ticket(ticket, expires_in)
return cls._build_url(token)
@classmethod
def get_ticket(cls, token: str) -> StorageTicket | None:
"""Retrieve a ticket by token.
Args:
token: The UUID token from the URL
Returns:
StorageTicket if found and valid, None otherwise
"""
key = cls._ticket_key(token)
try:
data = redis_client.get(key)
if data is None:
return None
if isinstance(data, bytes):
data = data.decode("utf-8")
return StorageTicket.model_validate_json(data)
except Exception:
logger.warning("Failed to retrieve storage ticket: %s", token, exc_info=True)
return None
@classmethod
def _store_ticket(cls, ticket: StorageTicket, ttl: int) -> str:
"""Store a ticket in Redis and return the token."""
token = str(uuid4())
key = cls._ticket_key(token)
value = ticket.model_dump_json()
redis_client.setex(key, ttl, value)
return token
@classmethod
def _ticket_key(cls, token: str) -> str:
"""Generate Redis key for a token."""
return f"{TICKET_KEY_PREFIX}:{token}"
@classmethod
def _build_url(cls, token: str) -> str:
"""Build the full URL for a token.
FILES_API_URL is dedicated to sandbox runtime file access (agentbox/e2b/etc.).
This endpoint must be routable from the runtime environment.
"""
base_url = dify_config.FILES_API_URL.strip()
if not base_url:
raise ValueError(
"FILES_API_URL is required for sandbox runtime file access. "
"Set FILES_API_URL to a URL reachable by your sandbox runtime. "
"For public sandbox environments (e.g. e2b), use a public domain or IP."
)
base_url = base_url.rstrip("/")
return f"{base_url}/files/storage-files/{token}"

View File

@ -0,0 +1,157 @@
"""
Service for generating Nested Node LLM graph structures.
This service creates graph structures containing LLM nodes configured for
extracting values from list[PromptMessage] variables.
"""
from typing import Any
from sqlalchemy.orm import Session
from graphon.enums import BuiltinNodeTypes
from graphon.model_runtime.entities import LLMMode
from services.model_provider_service import ModelProviderService
from services.workflow.entities import NestedNodeGraphRequest, NestedNodeGraphResponse, NestedNodeParameterSchema
class NestedNodeGraphService:
"""Service for generating Nested Node LLM graph structures."""
def __init__(self, session: Session):
self._session = session
def generate_nested_node_id(self, node_id: str, parameter_name: str) -> str:
"""Generate nested node ID following the naming convention.
Format: {node_id}_ext_{parameter_name}
"""
return f"{node_id}_ext_{parameter_name}"
def generate_nested_node_graph(self, tenant_id: str, request: NestedNodeGraphRequest) -> NestedNodeGraphResponse:
"""Generate a complete graph structure containing a Nested Node LLM node.
Args:
tenant_id: The tenant ID for fetching default model config
request: The nested node graph generation request
Returns:
Complete graph structure with nodes, edges, and viewport
"""
node_id = self.generate_nested_node_id(request.parent_node_id, request.parameter_key)
model_config = self._get_default_model_config(tenant_id)
node = self._build_nested_node_llm_node(
node_id=node_id,
parent_node_id=request.parent_node_id,
context_source=request.context_source,
parameter_schema=request.parameter_schema,
model_config=model_config,
)
graph = {
"nodes": [node],
"edges": [],
"viewport": {},
}
return NestedNodeGraphResponse(graph=graph)
def _get_default_model_config(self, tenant_id: str) -> dict[str, Any]:
"""Get the default LLM model configuration for the tenant."""
model_provider_service = ModelProviderService()
default_model = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id,
model_type="llm",
)
if default_model:
return {
"provider": default_model.provider.provider,
"name": default_model.model,
"mode": LLMMode.CHAT.value,
"completion_params": {},
}
# Fallback to empty config if no default model is configured
return {
"provider": "",
"name": "",
"mode": LLMMode.CHAT.value,
"completion_params": {},
}
def _build_nested_node_llm_node(
self,
*,
node_id: str,
parent_node_id: str,
context_source: list[str],
parameter_schema: NestedNodeParameterSchema,
model_config: dict[str, Any],
) -> dict[str, Any]:
"""Build the Nested Node LLM node structure.
The node uses:
- $context in prompt_template to reference the PromptMessage list
- structured_output for extracting the specific parameter
- parent_node_id to associate with the parent node
"""
prompt_template = [
{
"role": "system",
"text": "Extract the required parameter value from the conversation context above.",
"skill": False,
},
{"$context": context_source},
{"role": "user", "text": "", "skill": False},
]
structured_output = {
"schema": {
"type": "object",
"properties": {
parameter_schema.name: {
"type": parameter_schema.type,
"description": parameter_schema.description,
}
},
"required": [parameter_schema.name],
"additionalProperties": False,
}
}
return {
"id": node_id,
"position": {"x": 0, "y": 0},
"data": {
"type": BuiltinNodeTypes.LLM,
# BaseNodeData fields
"title": f"NestedNode: {parameter_schema.name}",
"desc": f"Extract {parameter_schema.name} from conversation context",
"version": "1",
"error_strategy": None,
"default_value": None,
"retry_config": {"max_retries": 0},
"parent_node_id": parent_node_id,
# LLMNodeData fields
"model": model_config,
"prompt_template": prompt_template,
"prompt_config": {"jinja2_variables": []},
"memory": None,
"context": {
"enabled": False,
"variable_selector": None,
},
"vision": {
"enabled": False,
"configs": {
"variable_selector": ["sys", "files"],
"detail": "high",
},
},
"structured_output_enabled": True,
"structured_output": structured_output,
"computer_use": False,
"tool_settings": [],
},
}

View File

@ -0,0 +1,391 @@
from __future__ import annotations
import logging
import time
from collections.abc import Mapping
from models.account import Account
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository, WorkflowSessionInfo
class WorkflowCollaborationService:
def __init__(self, repository: WorkflowCollaborationRepository, socketio) -> None:
self._repository = repository
self._socketio = socketio
def __repr__(self) -> str:
return f"{self.__class__.__name__}(repository={self._repository})"
def save_session(self, sid: str, user: Account) -> None:
self._socketio.save_session(
sid,
{
"user_id": user.id,
"username": user.name,
"avatar": user.avatar,
},
)
def register_session(self, workflow_id: str, sid: str) -> tuple[str, bool] | None:
session = self._socketio.get_session(sid)
user_id = session.get("user_id")
if not user_id:
return None
session_info: WorkflowSessionInfo = {
"user_id": str(user_id),
"username": str(session.get("username", "Unknown")),
"avatar": session.get("avatar"),
"sid": sid,
"connected_at": int(time.time()),
"graph_active": True,
"active_skill_file_id": None,
}
self._repository.set_session_info(workflow_id, session_info)
leader_sid = self.get_or_set_leader(workflow_id, sid)
is_leader = leader_sid == sid if leader_sid else False
self._socketio.enter_room(sid, workflow_id)
self.broadcast_online_users(workflow_id)
self._socketio.emit("status", {"isLeader": is_leader}, room=sid)
return str(user_id), is_leader
def disconnect_session(self, sid: str) -> None:
mapping = self._repository.get_sid_mapping(sid)
if not mapping:
return
workflow_id = mapping["workflow_id"]
active_skill_file_id = self._repository.get_active_skill_file_id(workflow_id, sid)
self._repository.delete_session(workflow_id, sid)
self.handle_leader_disconnect(workflow_id, sid)
if active_skill_file_id:
self.handle_skill_leader_disconnect(workflow_id, active_skill_file_id, sid)
self.broadcast_online_users(workflow_id)
def relay_collaboration_event(self, sid: str, data: Mapping[str, object]) -> tuple[dict[str, str], int]:
mapping = self._repository.get_sid_mapping(sid)
if not mapping:
return {"msg": "unauthorized"}, 401
workflow_id = mapping["workflow_id"]
user_id = mapping["user_id"]
self.refresh_session_state(workflow_id, sid)
event_type = data.get("type")
event_data = data.get("data")
timestamp = data.get("timestamp", int(time.time()))
if not event_type:
return {"msg": "invalid event type"}, 400
if event_type == "graph_view_active":
is_active = False
if isinstance(event_data, dict):
is_active = bool(event_data.get("active") or False)
self._repository.set_graph_active(workflow_id, sid, is_active)
self.refresh_session_state(workflow_id, sid)
self.broadcast_online_users(workflow_id)
return {"msg": "graph_view_active_updated"}, 200
if event_type == "skill_file_active":
file_id = None
is_active = False
if isinstance(event_data, dict):
file_id = event_data.get("file_id")
is_active = bool(event_data.get("active") or False)
if not file_id or not isinstance(file_id, str):
return {"msg": "invalid skill_file_active payload"}, 400
previous_file_id = self._repository.get_active_skill_file_id(workflow_id, sid)
next_file_id = file_id if is_active else None
if previous_file_id == next_file_id:
self.refresh_session_state(workflow_id, sid)
return {"msg": "skill_file_active_unchanged"}, 200
self._repository.set_active_skill_file(workflow_id, sid, next_file_id)
self.refresh_session_state(workflow_id, sid)
if previous_file_id:
self._ensure_skill_leader(workflow_id, previous_file_id)
if next_file_id:
self._ensure_skill_leader(workflow_id, next_file_id, preferred_sid=sid)
return {"msg": "skill_file_active_updated"}, 200
if event_type == "sync_request":
leader_sid = self._repository.get_current_leader(workflow_id)
if leader_sid and (
self.is_session_active(workflow_id, leader_sid)
and self._repository.is_graph_active(workflow_id, leader_sid)
):
target_sid = leader_sid
else:
if leader_sid:
self._repository.delete_leader(workflow_id)
target_sid = self._select_graph_leader(workflow_id, preferred_sid=sid)
if target_sid:
self._repository.set_leader(workflow_id, target_sid)
self.broadcast_leader_change(workflow_id, target_sid)
if not target_sid:
return {"msg": "no_active_leader"}, 200
self._socketio.emit(
"collaboration_update",
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
room=target_sid,
)
return {"msg": "sync_request_forwarded"}, 200
self._socketio.emit(
"collaboration_update",
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
room=workflow_id,
skip_sid=sid,
)
return {"msg": "event_broadcasted"}, 200
def relay_graph_event(self, sid: str, data: object) -> tuple[dict[str, str], int]:
mapping = self._repository.get_sid_mapping(sid)
if not mapping:
return {"msg": "unauthorized"}, 401
workflow_id = mapping["workflow_id"]
self.refresh_session_state(workflow_id, sid)
self._socketio.emit("graph_update", data, room=workflow_id, skip_sid=sid)
return {"msg": "graph_update_broadcasted"}, 200
def relay_skill_event(self, sid: str, data: object) -> tuple[dict[str, str], int]:
mapping = self._repository.get_sid_mapping(sid)
if not mapping:
return {"msg": "unauthorized"}, 401
workflow_id = mapping["workflow_id"]
self.refresh_session_state(workflow_id, sid)
self._socketio.emit("skill_update", data, room=workflow_id, skip_sid=sid)
return {"msg": "skill_update_broadcasted"}, 200
def get_or_set_leader(self, workflow_id: str, sid: str) -> str | None:
current_leader = self._repository.get_current_leader(workflow_id)
if current_leader:
if self.is_session_active(workflow_id, current_leader) and self._repository.is_graph_active(
workflow_id, current_leader
):
return current_leader
self._repository.delete_session(workflow_id, current_leader)
self._repository.delete_leader(workflow_id)
new_leader_sid = self._select_graph_leader(workflow_id, preferred_sid=sid)
if not new_leader_sid:
return None
was_set = self._repository.set_leader_if_absent(workflow_id, new_leader_sid)
if was_set:
if current_leader:
self.broadcast_leader_change(workflow_id, new_leader_sid)
return new_leader_sid
current_leader = self._repository.get_current_leader(workflow_id)
if current_leader:
return current_leader
return new_leader_sid
def handle_leader_disconnect(self, workflow_id: str, disconnected_sid: str) -> None:
current_leader = self._repository.get_current_leader(workflow_id)
if not current_leader:
return
if current_leader != disconnected_sid:
return
new_leader_sid = self._select_graph_leader(workflow_id)
if new_leader_sid:
self._repository.set_leader(workflow_id, new_leader_sid)
self.broadcast_leader_change(workflow_id, new_leader_sid)
else:
self._repository.delete_leader(workflow_id)
self.broadcast_leader_change(workflow_id, None)
def handle_skill_leader_disconnect(self, workflow_id: str, file_id: str, disconnected_sid: str) -> None:
current_leader = self._repository.get_skill_leader(workflow_id, file_id)
if not current_leader:
return
if current_leader != disconnected_sid:
return
new_leader_sid = self._select_skill_leader(workflow_id, file_id)
if new_leader_sid:
self._repository.set_skill_leader(workflow_id, file_id, new_leader_sid)
self.broadcast_skill_leader_change(workflow_id, file_id, new_leader_sid)
else:
self._repository.delete_skill_leader(workflow_id, file_id)
self.broadcast_skill_leader_change(workflow_id, file_id, None)
def broadcast_leader_change(self, workflow_id: str, new_leader_sid: str | None) -> None:
for sid in self._repository.get_session_sids(workflow_id):
try:
is_leader = new_leader_sid is not None and sid == new_leader_sid
self._socketio.emit("status", {"isLeader": is_leader}, room=sid)
except Exception:
logging.exception("Failed to emit leader status to session %s", sid)
def broadcast_skill_leader_change(self, workflow_id: str, file_id: str, new_leader_sid: str | None) -> None:
for sid in self._repository.get_session_sids(workflow_id):
try:
is_leader = new_leader_sid is not None and sid == new_leader_sid
self._socketio.emit("skill_status", {"file_id": file_id, "isLeader": is_leader}, room=sid)
except Exception:
logging.exception("Failed to emit skill leader status to session %s", sid)
def get_current_leader(self, workflow_id: str) -> str | None:
return self._repository.get_current_leader(workflow_id)
def _prune_inactive_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]:
"""Remove inactive sessions from storage and return active sessions only."""
sessions = self._repository.list_sessions(workflow_id)
if not sessions:
return []
active_sessions: list[WorkflowSessionInfo] = []
stale_sids: list[str] = []
for session in sessions:
sid = session["sid"]
if self.is_session_active(workflow_id, sid):
active_sessions.append(session)
else:
stale_sids.append(sid)
for sid in stale_sids:
self._repository.delete_session(workflow_id, sid)
return active_sessions
def broadcast_online_users(self, workflow_id: str) -> None:
users = self._prune_inactive_sessions(workflow_id)
users.sort(key=lambda x: x.get("connected_at") or 0)
leader_sid = self.get_current_leader(workflow_id)
previous_leader = leader_sid
active_sids = {user["sid"] for user in users}
if leader_sid and leader_sid not in active_sids:
self._repository.delete_leader(workflow_id)
leader_sid = None
if not leader_sid and users:
leader_sid = self._select_graph_leader(workflow_id)
if leader_sid:
self._repository.set_leader(workflow_id, leader_sid)
if leader_sid != previous_leader:
self.broadcast_leader_change(workflow_id, leader_sid)
self._socketio.emit(
"online_users",
{"workflow_id": workflow_id, "users": users, "leader": leader_sid},
room=workflow_id,
)
def refresh_session_state(self, workflow_id: str, sid: str) -> None:
self._repository.refresh_session_state(workflow_id, sid)
self._ensure_leader(workflow_id, sid)
active_skill_file_id = self._repository.get_active_skill_file_id(workflow_id, sid)
if active_skill_file_id:
self._ensure_skill_leader(workflow_id, active_skill_file_id, preferred_sid=sid)
def _ensure_leader(self, workflow_id: str, sid: str) -> None:
current_leader = self._repository.get_current_leader(workflow_id)
if (
current_leader
and self.is_session_active(workflow_id, current_leader)
and self._repository.is_graph_active(workflow_id, current_leader)
):
self._repository.expire_leader(workflow_id)
return
if current_leader:
self._repository.delete_leader(workflow_id)
new_leader_sid = self._select_graph_leader(workflow_id, preferred_sid=sid)
if not new_leader_sid:
self.broadcast_leader_change(workflow_id, None)
return
self._repository.set_leader(workflow_id, new_leader_sid)
self.broadcast_leader_change(workflow_id, new_leader_sid)
def _ensure_skill_leader(self, workflow_id: str, file_id: str, preferred_sid: str | None = None) -> None:
current_leader = self._repository.get_skill_leader(workflow_id, file_id)
active_sids = self._repository.get_active_skill_session_sids(workflow_id, file_id)
if current_leader and self.is_session_active(workflow_id, current_leader):
if current_leader in active_sids or not active_sids:
self._repository.expire_skill_leader(workflow_id, file_id)
return
if current_leader:
self._repository.delete_skill_leader(workflow_id, file_id)
new_leader_sid = self._select_skill_leader(workflow_id, file_id, preferred_sid=preferred_sid)
if not new_leader_sid:
self.broadcast_skill_leader_change(workflow_id, file_id, None)
return
self._repository.set_skill_leader(workflow_id, file_id, new_leader_sid)
self.broadcast_skill_leader_change(workflow_id, file_id, new_leader_sid)
def _select_graph_leader(self, workflow_id: str, preferred_sid: str | None = None) -> str | None:
session_sids = [
session["sid"]
for session in self._repository.list_sessions(workflow_id)
if session.get("graph_active") and self.is_session_active(workflow_id, session["sid"])
]
if not session_sids:
return None
if preferred_sid and preferred_sid in session_sids:
return preferred_sid
return session_sids[0]
def _select_skill_leader(self, workflow_id: str, file_id: str, preferred_sid: str | None = None) -> str | None:
session_sids = [
sid
for sid in self._repository.get_active_skill_session_sids(workflow_id, file_id)
if self.is_session_active(workflow_id, sid)
]
if not session_sids:
return None
if preferred_sid and preferred_sid in session_sids:
return preferred_sid
return session_sids[0]
def is_session_active(self, workflow_id: str, sid: str) -> bool:
if not sid:
return False
try:
if not self._socketio.manager.is_connected(sid, "/"):
return False
except AttributeError:
return False
if not self._repository.session_exists(workflow_id, sid):
return False
if not self._repository.sid_mapping_exists(sid):
return False
return True

View File

@ -0,0 +1,468 @@
import logging
from collections.abc import Sequence
from sqlalchemy import desc, select
from sqlalchemy.orm import Session, selectinload
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from models import App, TenantAccountJoin, WorkflowComment, WorkflowCommentMention, WorkflowCommentReply
from models.account import Account
from tasks.mail_workflow_comment_task import send_workflow_comment_mention_email_task
logger = logging.getLogger(__name__)
class WorkflowCommentService:
"""Service for managing workflow comments."""
@staticmethod
def _validate_content(content: str) -> None:
if len(content.strip()) == 0:
raise ValueError("Comment content cannot be empty")
if len(content) > 1000:
raise ValueError("Comment content cannot exceed 1000 characters")
@staticmethod
def _filter_valid_mentioned_user_ids(mentioned_user_ids: Sequence[str]) -> list[str]:
"""Return deduplicated UUID user IDs in the order provided."""
unique_user_ids: list[str] = []
seen: set[str] = set()
for user_id in mentioned_user_ids:
if not isinstance(user_id, str):
continue
if not uuid_value(user_id):
continue
if user_id in seen:
continue
seen.add(user_id)
unique_user_ids.append(user_id)
return unique_user_ids
@staticmethod
def _format_comment_excerpt(content: str, max_length: int = 200) -> str:
"""Trim comment content for email display."""
trimmed = content.strip()
if len(trimmed) <= max_length:
return trimmed
if max_length <= 3:
return trimmed[:max_length]
return f"{trimmed[: max_length - 3].rstrip()}..."
@staticmethod
def _build_mention_email_payloads(
session: Session,
tenant_id: str,
app_id: str,
mentioner_id: str,
mentioned_user_ids: Sequence[str],
content: str,
) -> list[dict[str, str]]:
"""Prepare email payloads for mentioned users, including the workflow app link."""
if not mentioned_user_ids:
return []
candidate_user_ids = [user_id for user_id in mentioned_user_ids if user_id != mentioner_id]
if not candidate_user_ids:
return []
app_name = session.scalar(select(App.name).where(App.id == app_id, App.tenant_id == tenant_id)) or "Dify app"
commenter_name = session.scalar(select(Account.name).where(Account.id == mentioner_id)) or "Dify user"
comment_excerpt = WorkflowCommentService._format_comment_excerpt(content)
base_url = dify_config.CONSOLE_WEB_URL.rstrip("/")
app_url = f"{base_url}/app/{app_id}/workflow"
accounts = session.scalars(
select(Account)
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
.where(TenantAccountJoin.tenant_id == tenant_id, Account.id.in_(candidate_user_ids))
).all()
payloads: list[dict[str, str]] = []
for account in accounts:
payloads.append(
{
"language": account.interface_language or "en-US",
"to": account.email,
"mentioned_name": account.name or account.email,
"commenter_name": commenter_name,
"app_name": app_name,
"comment_content": comment_excerpt,
"app_url": app_url,
}
)
return payloads
@staticmethod
def _dispatch_mention_emails(payloads: Sequence[dict[str, str]]) -> None:
"""Enqueue mention notification emails."""
for payload in payloads:
send_workflow_comment_mention_email_task.delay(**payload)
@staticmethod
def get_comments(tenant_id: str, app_id: str) -> Sequence[WorkflowComment]:
"""Get all comments for a workflow."""
with Session(db.engine) as session:
# Get all comments with eager loading
stmt = (
select(WorkflowComment)
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
.where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id)
.order_by(desc(WorkflowComment.created_at))
)
comments = session.scalars(stmt).all()
# Batch preload all Account objects to avoid N+1 queries
WorkflowCommentService._preload_accounts(session, comments)
return comments
@staticmethod
def _preload_accounts(session: Session, comments: Sequence[WorkflowComment]) -> None:
"""Batch preload Account objects for comments, replies, and mentions."""
# Collect all user IDs
user_ids: set[str] = set()
for comment in comments:
user_ids.add(comment.created_by)
if comment.resolved_by:
user_ids.add(comment.resolved_by)
user_ids.update(reply.created_by for reply in comment.replies)
user_ids.update(mention.mentioned_user_id for mention in comment.mentions)
if not user_ids:
return
# Batch query all accounts
accounts = session.scalars(select(Account).where(Account.id.in_(user_ids))).all()
account_map = {str(account.id): account for account in accounts}
# Cache accounts on objects
for comment in comments:
comment.cache_created_by_account(account_map.get(comment.created_by))
comment.cache_resolved_by_account(account_map.get(comment.resolved_by) if comment.resolved_by else None)
for reply in comment.replies:
reply.cache_created_by_account(account_map.get(reply.created_by))
for mention in comment.mentions:
mention.cache_mentioned_user_account(account_map.get(mention.mentioned_user_id))
@staticmethod
def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session | None = None) -> WorkflowComment:
"""Get a specific comment."""
def _get_comment(session: Session) -> WorkflowComment:
stmt = (
select(WorkflowComment)
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
.where(
WorkflowComment.id == comment_id,
WorkflowComment.tenant_id == tenant_id,
WorkflowComment.app_id == app_id,
)
)
comment = session.scalar(stmt)
if not comment:
raise NotFound("Comment not found")
# Preload accounts to avoid N+1 queries
WorkflowCommentService._preload_accounts(session, [comment])
return comment
if session is not None:
return _get_comment(session)
else:
with Session(db.engine, expire_on_commit=False) as session:
return _get_comment(session)
@staticmethod
def create_comment(
tenant_id: str,
app_id: str,
created_by: str,
content: str,
position_x: float,
position_y: float,
mentioned_user_ids: list[str] | None = None,
) -> dict:
"""Create a new workflow comment and send mention notification emails."""
WorkflowCommentService._validate_content(content)
with Session(db.engine) as session:
comment = WorkflowComment(
tenant_id=tenant_id,
app_id=app_id,
position_x=position_x,
position_y=position_y,
content=content,
created_by=created_by,
)
session.add(comment)
session.flush() # Get the comment ID for mentions
# Create mentions if specified
mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or [])
for user_id in mentioned_user_ids:
mention = WorkflowCommentMention(
comment_id=comment.id,
reply_id=None, # This is a comment mention, not reply mention
mentioned_user_id=user_id,
)
session.add(mention)
mention_email_payloads = WorkflowCommentService._build_mention_email_payloads(
session=session,
tenant_id=tenant_id,
app_id=app_id,
mentioner_id=created_by,
mentioned_user_ids=mentioned_user_ids,
content=content,
)
session.commit()
WorkflowCommentService._dispatch_mention_emails(mention_email_payloads)
# Return only what we need - id and created_at
return {"id": comment.id, "created_at": comment.created_at}
@staticmethod
def update_comment(
tenant_id: str,
app_id: str,
comment_id: str,
user_id: str,
content: str,
position_x: float | None = None,
position_y: float | None = None,
mentioned_user_ids: list[str] | None = None,
) -> dict:
"""Update a workflow comment and notify newly mentioned users."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
# Get comment with validation
stmt = select(WorkflowComment).where(
WorkflowComment.id == comment_id,
WorkflowComment.tenant_id == tenant_id,
WorkflowComment.app_id == app_id,
)
comment = session.scalar(stmt)
if not comment:
raise NotFound("Comment not found")
# Only the creator can update the comment
if comment.created_by != user_id:
raise Forbidden("Only the comment creator can update it")
# Update comment fields
comment.content = content
if position_x is not None:
comment.position_x = position_x
if position_y is not None:
comment.position_y = position_y
# Update mentions - first remove existing mentions for this comment only (not replies)
existing_mentions = session.scalars(
select(WorkflowCommentMention).where(
WorkflowCommentMention.comment_id == comment.id,
WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions
)
).all()
existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions}
for mention in existing_mentions:
session.delete(mention)
# Add new mentions
mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or [])
new_mentioned_user_ids = [
user_id for user_id in mentioned_user_ids if user_id not in existing_mentioned_user_ids
]
for user_id_str in mentioned_user_ids:
mention = WorkflowCommentMention(
comment_id=comment.id,
reply_id=None, # This is a comment mention
mentioned_user_id=user_id_str,
)
session.add(mention)
mention_email_payloads = WorkflowCommentService._build_mention_email_payloads(
session=session,
tenant_id=tenant_id,
app_id=app_id,
mentioner_id=user_id,
mentioned_user_ids=new_mentioned_user_ids,
content=content,
)
session.commit()
WorkflowCommentService._dispatch_mention_emails(mention_email_payloads)
return {"id": comment.id, "updated_at": comment.updated_at}
@staticmethod
def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None:
"""Delete a workflow comment."""
with Session(db.engine, expire_on_commit=False) as session:
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
# Only the creator can delete the comment
if comment.created_by != user_id:
raise Forbidden("Only the comment creator can delete it")
# Delete associated mentions (both comment and reply mentions)
mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id)
).all()
for mention in mentions:
session.delete(mention)
# Delete associated replies
replies = session.scalars(
select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id)
).all()
for reply in replies:
session.delete(reply)
session.delete(comment)
session.commit()
@staticmethod
def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment:
"""Resolve a workflow comment."""
with Session(db.engine, expire_on_commit=False) as session:
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
if comment.resolved:
return comment
comment.resolved = True
comment.resolved_at = naive_utc_now()
comment.resolved_by = user_id
session.commit()
return comment
@staticmethod
def create_reply(
comment_id: str, content: str, created_by: str, mentioned_user_ids: list[str] | None = None
) -> dict:
"""Add a reply to a workflow comment and notify mentioned users."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
# Check if comment exists
comment = session.get(WorkflowComment, comment_id)
if not comment:
raise NotFound("Comment not found")
reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by)
session.add(reply)
session.flush() # Get the reply ID for mentions
# Create mentions if specified
mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or [])
for user_id in mentioned_user_ids:
# Create mention linking to specific reply
mention = WorkflowCommentMention(comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id)
session.add(mention)
mention_email_payloads = WorkflowCommentService._build_mention_email_payloads(
session=session,
tenant_id=comment.tenant_id,
app_id=comment.app_id,
mentioner_id=created_by,
mentioned_user_ids=mentioned_user_ids,
content=content,
)
session.commit()
WorkflowCommentService._dispatch_mention_emails(mention_email_payloads)
return {"id": reply.id, "created_at": reply.created_at}
@staticmethod
def update_reply(reply_id: str, user_id: str, content: str, mentioned_user_ids: list[str] | None = None) -> dict:
"""Update a comment reply and notify newly mentioned users."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
reply = session.get(WorkflowCommentReply, reply_id)
if not reply:
raise NotFound("Reply not found")
# Only the creator can update the reply
if reply.created_by != user_id:
raise Forbidden("Only the reply creator can update it")
reply.content = content
# Update mentions - first remove existing mentions for this reply
existing_mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id)
).all()
existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions}
for mention in existing_mentions:
session.delete(mention)
# Add mentions
mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or [])
new_mentioned_user_ids = [
user_id for user_id in mentioned_user_ids if user_id not in existing_mentioned_user_ids
]
for user_id_str in mentioned_user_ids:
mention = WorkflowCommentMention(
comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str
)
session.add(mention)
mention_email_payloads: list[dict[str, str]] = []
comment = session.get(WorkflowComment, reply.comment_id)
if comment:
mention_email_payloads = WorkflowCommentService._build_mention_email_payloads(
session=session,
tenant_id=comment.tenant_id,
app_id=comment.app_id,
mentioner_id=user_id,
mentioned_user_ids=new_mentioned_user_ids,
content=content,
)
session.commit()
session.refresh(reply) # Refresh to get updated timestamp
WorkflowCommentService._dispatch_mention_emails(mention_email_payloads)
return {"id": reply.id, "updated_at": reply.updated_at}
@staticmethod
def delete_reply(reply_id: str, user_id: str) -> None:
"""Delete a comment reply."""
with Session(db.engine, expire_on_commit=False) as session:
reply = session.get(WorkflowCommentReply, reply_id)
if not reply:
raise NotFound("Reply not found")
# Only the creator can delete the reply
if reply.created_by != user_id:
raise Forbidden("Only the reply creator can delete it")
# Delete associated mentions first
mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id)
).all()
for mention in mentions:
session.delete(mention)
session.delete(reply)
session.commit()
@staticmethod
def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment:
"""Validate that a comment belongs to the specified tenant and app."""
return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id)

View File

@ -0,0 +1,65 @@
import logging
import time
import click
from celery import shared_task
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_workflow_comment_mention_email_task(
language: str,
to: str,
mentioned_name: str,
commenter_name: str,
app_name: str,
comment_content: str,
app_url: str,
):
"""
Send workflow comment mention email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
mentioned_name: Name of the mentioned user
commenter_name: Name of the comment author
app_name: Name of the app where the comment was made
comment_content: Comment content excerpt
app_url: Link to the app workflow page
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start workflow comment mention mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.WORKFLOW_COMMENT_MENTION,
language_code=language,
to=to,
template_context={
"to": to,
"mentioned_name": mentioned_name,
"commenter_name": commenter_name,
"app_name": app_name,
"comment_content": comment_content,
"app_url": app_url,
},
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Send workflow comment mention mail to {to} succeeded: latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("workflow comment mention email to %s failed", to)