mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 05:56:31 +08:00
feat(api): enable all sandbox/skill controller routes and resolve dependencies (P0)
Resolve the full dependency chain to enable all previously disabled controllers: Enabled routes: - sandbox_files: sandbox file browser API - sandbox_providers: sandbox provider management API - app_asset: app asset management API - skills: skill extraction API - CLI API blueprint: DifyCli callback endpoints (/cli/api/*) Dependencies extracted (64 files, ~8000 lines): - models/sandbox.py, models/app_asset.py: DB models - core/zip_sandbox/: zip-based sandbox execution - core/session/: CLI API session management - core/memory/: base memory + node token buffer - core/helper/creators.py: helper utilities - core/llm_generator/: context models, output models, utils - core/workflow/nodes/command/: command node type - core/workflow/nodes/file_upload/: file upload node type - core/app/entities/: app_asset_entities, app_bundle_entities, llm_generation_entities - services/: asset_content, skill, workflow_collaboration, workflow_comment - controllers/console/app/error.py: AppAsset error classes - core/tools/utils/system_encryption.py Import fixes: - dify_graph.enums -> graphon.enums in skill_service.py - get_signed_file_url_for_plugin -> get_signed_file_url in cli_api.py All 5 controllers verified: import OK, Flask starts successfully. 46 existing tests still pass. Made-with: Cursor
This commit is contained in:
parent
d3d9f21cdf
commit
44491e427c
@ -22,7 +22,7 @@ from core.session.cli_api import CliContext
|
||||
from core.skill.entities import ToolInvocationRequest
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from graphon.file.helpers import get_signed_file_url_for_plugin
|
||||
from graphon.file.helpers import get_signed_file_url
|
||||
from libs.helper import length_prefixed_response
|
||||
from models.account import Account
|
||||
from models.model import EndUser, Tenant
|
||||
@ -139,11 +139,9 @@ class CliUploadFileRequestApi(Resource):
|
||||
payload: RequestRequestUploadFile,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
url = get_signed_file_url_for_plugin(
|
||||
filename=payload.filename,
|
||||
mimetype=payload.mimetype,
|
||||
url = get_signed_file_url(
|
||||
upload_file_id=f"{tenant_model.id}_{user_model.id}_{payload.filename}",
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
)
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ from . import (
|
||||
init_validate,
|
||||
notification,
|
||||
ping,
|
||||
# sandbox_files, # TODO: enable after full sandbox integration
|
||||
sandbox_files,
|
||||
setup,
|
||||
spec,
|
||||
version,
|
||||
@ -53,7 +53,7 @@ from .app import (
|
||||
agent,
|
||||
annotation,
|
||||
app,
|
||||
# app_asset, # TODO: enable after full sandbox integration
|
||||
app_asset,
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
@ -64,7 +64,7 @@ from .app import (
|
||||
model_config,
|
||||
ops_trace,
|
||||
site,
|
||||
# skills, # TODO: enable after full sandbox integration
|
||||
skills,
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
@ -133,7 +133,7 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
# sandbox_providers, # TODO: enable after full sandbox integration
|
||||
sandbox_providers,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
|
||||
@ -121,3 +121,21 @@ class NeedAddIdsError(BaseHTTPException):
|
||||
error_code = "need_add_ids"
|
||||
description = "Need to add ids."
|
||||
code = 400
|
||||
|
||||
|
||||
class AppAssetNodeNotFoundError(BaseHTTPException):
|
||||
error_code = "app_asset_node_not_found"
|
||||
description = "App asset node not found."
|
||||
code = 404
|
||||
|
||||
|
||||
class AppAssetFileRequiredError(BaseHTTPException):
|
||||
error_code = "app_asset_file_required"
|
||||
description = "File is required."
|
||||
code = 400
|
||||
|
||||
|
||||
class AppAssetPathConflictError(BaseHTTPException):
|
||||
error_code = "app_asset_path_conflict"
|
||||
description = "Path already exists."
|
||||
code = 409
|
||||
|
||||
322
api/controllers/console/app/workflow_comment.py
Normal file
322
api/controllers/console/app/workflow_comment.py
Normal file
@ -0,0 +1,322 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.member_fields import AccountWithRole
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowCommentCreatePayload(BaseModel):
|
||||
position_x: float = Field(..., description="Comment X position")
|
||||
position_y: float = Field(..., description="Comment Y position")
|
||||
content: str = Field(..., description="Comment content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Comment content")
|
||||
position_x: float | None = Field(default=None, description="Comment X position")
|
||||
position_y: float | None = Field(default=None, description="Comment Y position")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyCreatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentMentionUsersResponse(BaseModel):
|
||||
users: list[AccountWithRole] = Field(description="Mentionable users")
|
||||
|
||||
|
||||
for model in (
|
||||
WorkflowCommentCreatePayload,
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyCreatePayload,
|
||||
WorkflowCommentReplyUpdatePayload,
|
||||
):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
for model in (AccountWithRole, WorkflowCommentMentionUsersResponse):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
)
|
||||
workflow_comment_mention_users_model = console_ns.models[WorkflowCommentMentionUsersResponse.__name__]
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(204, "Comment deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@console_ns.doc("create_workflow_comment_reply")
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=payload.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@console_ns.doc("update_workflow_comment_reply")
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
reply_id=reply_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.response(204, "Reply deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@console_ns.doc("workflow_comment_mention_users")
|
||||
@console_ns.doc(description="Get all users in current tenant for mentions")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Mentionable users retrieved successfully", workflow_comment_mention_users_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersResponse(users=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
1
api/controllers/console/socketio/__init__.py
Normal file
1
api/controllers/console/socketio/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
119
api/controllers/console/socketio/workflow.py
Normal file
119
api/controllers/console/socketio/workflow.py
Normal file
@ -0,0 +1,119 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from flask import Request as FlaskRequest
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository
|
||||
from services.account_service import AccountService
|
||||
from services.workflow_collaboration_service import WorkflowCollaborationService
|
||||
|
||||
repository = WorkflowCollaborationRepository()
|
||||
collaboration_service = WorkflowCollaborationService(repository, sio)
|
||||
|
||||
|
||||
def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]:
|
||||
return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event))
|
||||
|
||||
|
||||
@_sio_on("connect")
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
try:
|
||||
request_environ = FlaskRequest(environ)
|
||||
token = extract_access_token(request_environ)
|
||||
except Exception:
|
||||
logging.exception("Failed to extract token")
|
||||
token = None
|
||||
|
||||
if not token:
|
||||
logging.warning("Socket connect rejected: missing token (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
if not user.has_edit_permission:
|
||||
logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
|
||||
collaboration_service.save_session(sid, user)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logging.exception("Socket authentication failed")
|
||||
return False
|
||||
|
||||
|
||||
@_sio_on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event. Each session (tab) is treated as an independent collaborator.
|
||||
"""
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
result = collaboration_service.register_session(workflow_id, sid)
|
||||
if not result:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
user_id, is_leader = result
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
|
||||
|
||||
|
||||
@_sio_on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle session disconnect event. Remove the specific session from online users.
|
||||
"""
|
||||
collaboration_service.disconnect_session(sid)
|
||||
|
||||
|
||||
@_sio_on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouse_move
|
||||
2. vars_and_features_update
|
||||
3. sync_request (ask leader to update graph)
|
||||
4. app_state_update
|
||||
5. mcp_server_update
|
||||
6. workflow_update
|
||||
7. comments_update
|
||||
8. node_panel_presence
|
||||
9. skill_file_active
|
||||
10. skill_sync_request
|
||||
11. skill_resync_request
|
||||
"""
|
||||
return collaboration_service.relay_collaboration_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("graph_event")
|
||||
def handle_graph_event(sid, data):
|
||||
"""
|
||||
Handle graph events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_graph_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("skill_event")
|
||||
def handle_skill_event(sid, data):
|
||||
"""
|
||||
Handle skill events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_skill_event(sid, data)
|
||||
67
api/controllers/console/workspace/dsl.py
Normal file
67
api/controllers/console/workspace/dsl.py
Normal file
@ -0,0 +1,67 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.app_dsl_service import AppDslService
|
||||
|
||||
|
||||
class DSLPredictRequest(BaseModel):
|
||||
app_id: str
|
||||
current_node_id: str
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/dsl/predict")
|
||||
class DSLPredictApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, _ = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = DSLPredictRequest.model_validate(request.get_json())
|
||||
|
||||
app_id: str = args.app_id
|
||||
current_node_id: str = args.current_node_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
app = session.query(App).filter_by(id=app_id).first()
|
||||
workflow = session.query(Workflow).filter_by(app_id=app_id, version=Workflow.VERSION_DRAFT).first()
|
||||
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
try:
|
||||
i = 0
|
||||
for node_id, _ in workflow.walk_nodes():
|
||||
if node_id == current_node_id:
|
||||
break
|
||||
i += 1
|
||||
|
||||
dsl = yaml.safe_load(AppDslService.export_dsl(app_model=app))
|
||||
|
||||
response = httpx.post(
|
||||
"http://spark-832c:8000/predict",
|
||||
json={"graph_data": dsl, "source_node_index": i},
|
||||
)
|
||||
return {
|
||||
"nodes": json.loads(response.json()),
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
380
api/core/agent/agent_app_runner.py
Normal file
380
api/core/agent/agent_app_runner.py
Normal file
@ -0,0 +1,380 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentAppRunner(BaseAgentRunner):
|
||||
def _create_tool_invoke_hook(self, message: Message):
|
||||
"""
|
||||
Create a tool invoke hook that uses ToolEngine.agent_invoke.
|
||||
This hook handles file creation and returns proper meta information.
|
||||
"""
|
||||
# Get trace manager from app generate entity
|
||||
trace_manager = self.application_generate_entity.trace_manager
|
||||
|
||||
def tool_invoke_hook(
|
||||
tool: Tool, tool_args: dict[str, Any], tool_name: str
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""Hook that uses agent_invoke for proper file and meta handling."""
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
|
||||
# Publish files and track IDs
|
||||
for message_file_id in message_files:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
self._current_message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, message_files, tool_invoke_meta
|
||||
|
||||
return tool_invoke_hook
|
||||
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run Agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, _ = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
# Create tool invoke hook for agent_invoke
|
||||
tool_invoke_hook = self._create_tool_invoke_hook(message)
|
||||
|
||||
# Get instruction for ReAct strategy
|
||||
instruction = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=self.model_features,
|
||||
model_instance=self.model_instance,
|
||||
tools=list(tool_instances.values()),
|
||||
files=list(self.files),
|
||||
max_iterations=app_config.agent.max_iteration,
|
||||
context=self.build_execution_context(),
|
||||
agent_strategy=self.config.strategy,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Initialize state variables
|
||||
current_agent_thought_id: str | None = None
|
||||
has_published_thought = False
|
||||
current_tool_name: str | None = None
|
||||
self._current_message_file_ids: list[str] = []
|
||||
|
||||
# organize prompt messages
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
|
||||
# Run strategy
|
||||
generator = strategy.run(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
result: AgentResult | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
output = next(generator)
|
||||
except StopIteration as e:
|
||||
# Generator finished, get the return value
|
||||
result = e.value
|
||||
break
|
||||
|
||||
if isinstance(output, LLMResultChunk):
|
||||
# Handle LLM chunk
|
||||
if current_agent_thought_id and not has_published_thought:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
has_published_thought = True
|
||||
|
||||
yield output
|
||||
|
||||
elif isinstance(output, AgentLog):
|
||||
# Handle Agent Log using log_type for type-safe dispatch
|
||||
if output.status == AgentLog.LogStatus.START:
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
# Start of a new round
|
||||
message_file_ids: list[str] = []
|
||||
current_agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message="",
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
has_published_thought = False
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call start - extract data from structured fields
|
||||
current_tool_name = output.data.get("tool_name", "")
|
||||
tool_input = output.data.get("tool_args", {})
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_input=tool_input,
|
||||
thought=None,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.status == AgentLog.LogStatus.SUCCESS:
|
||||
if output.log_type == AgentLog.LogType.THOUGHT:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
thought_text = output.data.get("thought")
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=thought_text,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call finished
|
||||
tool_output = output.data.get("output")
|
||||
# Get meta from strategy output (now properly populated)
|
||||
tool_meta = output.data.get("meta")
|
||||
|
||||
# Wrap tool_meta with tool_name as key (required by agent_service)
|
||||
if tool_meta and current_tool_name:
|
||||
tool_meta = {current_tool_name: tool_meta}
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_output,
|
||||
tool_invoke_meta=tool_meta,
|
||||
answer=None,
|
||||
messages_ids=self._current_message_file_ids,
|
||||
)
|
||||
# Clear message file ids after saving
|
||||
self._current_message_file_ids = []
|
||||
current_tool_name = None
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.ROUND:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Round finished - save LLM usage and answer
|
||||
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
|
||||
llm_result = output.data.get("llm_result")
|
||||
final_answer = output.data.get("final_answer")
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=llm_result,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Re-raise any other exceptions
|
||||
raise
|
||||
|
||||
# Process final result
|
||||
if isinstance(result, AgentResult):
|
||||
final_answer = result.text
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_template:
|
||||
return prompt_messages or []
|
||||
|
||||
prompt_messages = prompt_messages or []
|
||||
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
|
||||
return prompt_messages
|
||||
|
||||
if not prompt_messages:
|
||||
return [SystemPromptMessage(content=prompt_template)]
|
||||
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
# For ReAct strategy, use the agent prompt template
|
||||
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
|
||||
prompt_template = self.config.prompt.first_prompt
|
||||
else:
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
352
api/core/app/entities/app_asset_entities.py
Normal file
352
api/core/app/entities/app_asset_entities.py
Normal file
@ -0,0 +1,352 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AssetNodeType(StrEnum):
|
||||
FILE = "file"
|
||||
FOLDER = "folder"
|
||||
|
||||
|
||||
class AppAssetNode(BaseModel):
|
||||
id: str = Field(description="Unique identifier for the node")
|
||||
node_type: AssetNodeType = Field(description="Type of node: file or folder")
|
||||
name: str = Field(description="Name of the file or folder")
|
||||
parent_id: str | None = Field(default=None, description="Parent folder ID, None for root level")
|
||||
order: int = Field(default=0, description="Sort order within parent folder, lower values first")
|
||||
extension: str = Field(default="", description="File extension without dot, empty for folders")
|
||||
size: int = Field(default=0, description="File size in bytes, 0 for folders")
|
||||
|
||||
@classmethod
|
||||
def create_folder(cls, node_id: str, name: str, parent_id: str | None = None) -> AppAssetNode:
|
||||
return cls(id=node_id, node_type=AssetNodeType.FOLDER, name=name, parent_id=parent_id)
|
||||
|
||||
@classmethod
|
||||
def create_file(cls, node_id: str, name: str, parent_id: str | None = None, size: int = 0) -> AppAssetNode:
|
||||
return cls(
|
||||
id=node_id,
|
||||
node_type=AssetNodeType.FILE,
|
||||
name=name,
|
||||
parent_id=parent_id,
|
||||
extension=name.rsplit(".", 1)[-1] if "." in name else "",
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
class AppAssetNodeView(BaseModel):
|
||||
id: str = Field(description="Unique identifier for the node")
|
||||
node_type: str = Field(description="Type of node: 'file' or 'folder'")
|
||||
name: str = Field(description="Name of the file or folder")
|
||||
path: str = Field(description="Full path from root, e.g. '/folder/file.txt'")
|
||||
extension: str = Field(default="", description="File extension without dot")
|
||||
size: int = Field(default=0, description="File size in bytes")
|
||||
children: list[AppAssetNodeView] = Field(default_factory=list, description="Child nodes for folders")
|
||||
|
||||
|
||||
class BatchUploadNode(BaseModel):
|
||||
"""Structure for batch upload_url tree nodes, used for both input and output."""
|
||||
|
||||
name: str
|
||||
node_type: AssetNodeType
|
||||
size: int = 0
|
||||
children: list[BatchUploadNode] = []
|
||||
id: str | None = None
|
||||
upload_url: str | None = None
|
||||
|
||||
def to_app_asset_nodes(self, parent_id: str | None = None) -> list[AppAssetNode]:
|
||||
"""
|
||||
Generate IDs when missing and convert to AppAssetNode list.
|
||||
Mutates self to set id field when it is not set.
|
||||
"""
|
||||
from uuid import uuid4
|
||||
|
||||
self.id = self.id or str(uuid4())
|
||||
nodes: list[AppAssetNode] = []
|
||||
|
||||
if self.node_type == AssetNodeType.FOLDER:
|
||||
nodes.append(AppAssetNode.create_folder(self.id, self.name, parent_id))
|
||||
for child in self.children:
|
||||
nodes.extend(child.to_app_asset_nodes(self.id))
|
||||
else:
|
||||
nodes.append(AppAssetNode.create_file(self.id, self.name, parent_id, self.size))
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
class TreeNodeNotFoundError(Exception):
|
||||
"""Tree internal: node not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TreeParentNotFoundError(Exception):
|
||||
"""Tree internal: parent folder not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TreePathConflictError(Exception):
|
||||
"""Tree internal: path already exists"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AppAssetFileTree(BaseModel):
|
||||
"""
|
||||
File tree structure for app assets using adjacency list pattern.
|
||||
|
||||
Design:
|
||||
- Storage: Flat list with parent_id references (adjacency list)
|
||||
- Path: Computed dynamically via get_path(), not stored
|
||||
- Order: Integer field for user-defined sorting within each folder
|
||||
- API response: transform() builds nested tree with computed paths
|
||||
|
||||
Why adjacency list over nested tree or materialized path:
|
||||
- Simpler CRUD: move/rename only updates one node's parent_id
|
||||
- No path cascade: renaming parent doesn't require updating all descendants
|
||||
- JSON-friendly: flat list serializes cleanly to database JSON column
|
||||
- Trade-off: path lookup is O(depth), acceptable for typical file trees
|
||||
"""
|
||||
|
||||
nodes: list[AppAssetNode] = Field(default_factory=list, description="Flat list of all nodes in the tree")
|
||||
|
||||
def ensure_unique_name(
|
||||
self,
|
||||
parent_id: str | None,
|
||||
name: str,
|
||||
*,
|
||||
is_file: bool,
|
||||
extra_taken: set[str] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Return a sibling-unique name by appending numeric suffixes when needed.
|
||||
|
||||
The suffix format is " <n>" (e.g. "report 1", "report 2"). For files,
|
||||
the suffix is inserted before the extension.
|
||||
"""
|
||||
taken = extra_taken or set()
|
||||
if not self.has_child_named(parent_id, name) and name not in taken:
|
||||
return name
|
||||
suffix_index = 1
|
||||
while True:
|
||||
candidate = self._apply_name_suffix(name, suffix_index, is_file=is_file)
|
||||
if not self.has_child_named(parent_id, candidate) and candidate not in taken:
|
||||
return candidate
|
||||
suffix_index += 1
|
||||
|
||||
@staticmethod
|
||||
def _apply_name_suffix(name: str, suffix_index: int, *, is_file: bool) -> str:
|
||||
if not is_file:
|
||||
return f"{name} {suffix_index}"
|
||||
stem, extension = os.path.splitext(name)
|
||||
return f"{stem} {suffix_index}{extension}"
|
||||
|
||||
def get(self, node_id: str) -> AppAssetNode | None:
|
||||
return next((n for n in self.nodes if n.id == node_id), None)
|
||||
|
||||
def get_children(self, parent_id: str | None) -> list[AppAssetNode]:
|
||||
return [n for n in self.nodes if n.parent_id == parent_id]
|
||||
|
||||
def has_child_named(self, parent_id: str | None, name: str) -> bool:
|
||||
return any(n.name == name and n.parent_id == parent_id for n in self.nodes)
|
||||
|
||||
def get_path(self, node_id: str) -> str:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
parts: list[str] = []
|
||||
current: AppAssetNode | None = node
|
||||
while current:
|
||||
parts.append(current.name)
|
||||
current = self.get(current.parent_id) if current.parent_id else None
|
||||
return "/".join(reversed(parts))
|
||||
|
||||
def relative_path(self, a: AppAssetNode, b: AppAssetNode) -> str:
|
||||
"""
|
||||
Calculate relative path from node a to node b for Markdown references.
|
||||
Path is computed from a's parent directory (where the file resides).
|
||||
|
||||
Examples:
|
||||
/foo/a.md -> /foo/b.md => ./b.md
|
||||
/foo/a.md -> /foo/sub/b.md => ./sub/b.md
|
||||
/foo/sub/a.md -> /foo/b.md => ../b.md
|
||||
/foo/sub/deep/a.md -> /foo/b.md => ../../b.md
|
||||
"""
|
||||
|
||||
def get_ancestor_ids(node_id: str | None) -> list[str]:
|
||||
chain: list[str] = []
|
||||
current_id = node_id
|
||||
while current_id:
|
||||
chain.append(current_id)
|
||||
node = self.get(current_id)
|
||||
current_id = node.parent_id if node else None
|
||||
return chain
|
||||
|
||||
a_dir_ancestors = get_ancestor_ids(a.parent_id)
|
||||
b_ancestors = [b.id] + get_ancestor_ids(b.parent_id)
|
||||
a_dir_set = set(a_dir_ancestors)
|
||||
|
||||
lca_id: str | None = None
|
||||
lca_index_in_b = -1
|
||||
for idx, ancestor_id in enumerate(b_ancestors):
|
||||
if ancestor_id in a_dir_set or (a.parent_id is None and b_ancestors[idx:] == []):
|
||||
lca_id = ancestor_id
|
||||
lca_index_in_b = idx
|
||||
break
|
||||
|
||||
if a.parent_id is None:
|
||||
steps_up = 0
|
||||
lca_index_in_b = len(b_ancestors)
|
||||
elif lca_id is None:
|
||||
steps_up = len(a_dir_ancestors)
|
||||
lca_index_in_b = len(b_ancestors)
|
||||
else:
|
||||
steps_up = 0
|
||||
for ancestor_id in a_dir_ancestors:
|
||||
if ancestor_id == lca_id:
|
||||
break
|
||||
steps_up += 1
|
||||
|
||||
path_down: list[str] = []
|
||||
for i in range(lca_index_in_b - 1, -1, -1):
|
||||
node = self.get(b_ancestors[i])
|
||||
if node:
|
||||
path_down.append(node.name)
|
||||
|
||||
if steps_up == 0:
|
||||
return "./" + "/".join(path_down)
|
||||
|
||||
parts: list[str] = [".."] * steps_up + path_down
|
||||
return "/".join(parts)
|
||||
|
||||
def get_descendant_ids(self, node_id: str) -> list[str]:
|
||||
result: list[str] = []
|
||||
stack = [node_id]
|
||||
while stack:
|
||||
current_id = stack.pop()
|
||||
for child in self.nodes:
|
||||
if child.parent_id == current_id:
|
||||
result.append(child.id)
|
||||
stack.append(child.id)
|
||||
return result
|
||||
|
||||
def add(self, node: AppAssetNode) -> AppAssetNode:
|
||||
if self.get(node.id):
|
||||
raise TreePathConflictError(node.id)
|
||||
if self.has_child_named(node.parent_id, node.name):
|
||||
raise TreePathConflictError(node.name)
|
||||
if node.parent_id:
|
||||
parent = self.get(node.parent_id)
|
||||
if not parent or parent.node_type != AssetNodeType.FOLDER:
|
||||
raise TreeParentNotFoundError(node.parent_id)
|
||||
siblings = self.get_children(node.parent_id)
|
||||
node.order = max((s.order for s in siblings), default=-1) + 1
|
||||
self.nodes.append(node)
|
||||
return node
|
||||
|
||||
def update(self, node_id: str, size: int) -> AppAssetNode:
|
||||
node = self.get(node_id)
|
||||
if not node or node.node_type != AssetNodeType.FILE:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
node.size = size
|
||||
return node
|
||||
|
||||
def rename(self, node_id: str, new_name: str) -> AppAssetNode:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
if node.name != new_name and self.has_child_named(node.parent_id, new_name):
|
||||
raise TreePathConflictError(new_name)
|
||||
node.name = new_name
|
||||
if node.node_type == AssetNodeType.FILE:
|
||||
node.extension = new_name.rsplit(".", 1)[-1] if "." in new_name else ""
|
||||
return node
|
||||
|
||||
def move(self, node_id: str, new_parent_id: str | None) -> AppAssetNode:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
if new_parent_id:
|
||||
parent = self.get(new_parent_id)
|
||||
if not parent or parent.node_type != AssetNodeType.FOLDER:
|
||||
raise TreeParentNotFoundError(new_parent_id)
|
||||
if self.has_child_named(new_parent_id, node.name):
|
||||
raise TreePathConflictError(node.name)
|
||||
node.parent_id = new_parent_id
|
||||
siblings = self.get_children(new_parent_id)
|
||||
node.order = max((s.order for s in siblings if s.id != node_id), default=-1) + 1
|
||||
return node
|
||||
|
||||
def reorder(self, node_id: str, after_node_id: str | None) -> AppAssetNode:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
|
||||
siblings = sorted(self.get_children(node.parent_id), key=lambda x: x.order)
|
||||
siblings = [s for s in siblings if s.id != node_id]
|
||||
|
||||
if after_node_id is None:
|
||||
insert_idx = 0
|
||||
else:
|
||||
after_node = self.get(after_node_id)
|
||||
if not after_node or after_node.parent_id != node.parent_id:
|
||||
raise TreeNodeNotFoundError(after_node_id)
|
||||
insert_idx = next((i for i, s in enumerate(siblings) if s.id == after_node_id), -1) + 1
|
||||
|
||||
siblings.insert(insert_idx, node)
|
||||
for idx, sibling in enumerate(siblings):
|
||||
sibling.order = idx
|
||||
|
||||
return node
|
||||
|
||||
def remove(self, node_id: str) -> list[str]:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
ids_to_remove = [node_id] + self.get_descendant_ids(node_id)
|
||||
self.nodes = [n for n in self.nodes if n.id not in ids_to_remove]
|
||||
return ids_to_remove
|
||||
|
||||
def walk_files(self) -> Generator[AppAssetNode, None, None]:
|
||||
return (n for n in self.nodes if n.node_type == AssetNodeType.FILE)
|
||||
|
||||
def transform(self) -> list[AppAssetNodeView]:
|
||||
by_parent: dict[str | None, list[AppAssetNode]] = defaultdict(list)
|
||||
for n in self.nodes:
|
||||
by_parent[n.parent_id].append(n)
|
||||
|
||||
for children in by_parent.values():
|
||||
children.sort(key=lambda x: x.order)
|
||||
|
||||
paths: dict[str, str] = {}
|
||||
tree_views: dict[str, AppAssetNodeView] = {}
|
||||
|
||||
def build_view(node: AppAssetNode, parent_path: str) -> None:
|
||||
path = f"{parent_path}/{node.name}"
|
||||
paths[node.id] = path
|
||||
child_views: list[AppAssetNodeView] = []
|
||||
for child in by_parent.get(node.id, []):
|
||||
build_view(child, path)
|
||||
child_views.append(tree_views[child.id])
|
||||
tree_views[node.id] = AppAssetNodeView(
|
||||
id=node.id,
|
||||
node_type=node.node_type.value,
|
||||
name=node.name,
|
||||
path=path,
|
||||
extension=node.extension,
|
||||
size=node.size,
|
||||
children=child_views,
|
||||
)
|
||||
|
||||
for root_node in by_parent.get(None, []):
|
||||
build_view(root_node, "")
|
||||
|
||||
return [tree_views[n.id] for n in by_parent.get(None, [])]
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.nodes) == 0
|
||||
96
api/core/app/entities/app_bundle_entities.py
Normal file
96
api/core/app/entities/app_bundle_entities.py
Normal file
@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
|
||||
# Constants
|
||||
BUNDLE_DSL_FILENAME_PATTERN = re.compile(r"^[^/]+\.ya?ml$")
|
||||
BUNDLE_MAX_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
MANIFEST_FILENAME = "manifest.json"
|
||||
MANIFEST_SCHEMA_VERSION = "1.0"
|
||||
|
||||
|
||||
# Exceptions
|
||||
class BundleFormatError(Exception):
|
||||
"""Raised when bundle format is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ZipSecurityError(Exception):
|
||||
"""Raised when zip file contains security violations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Manifest DTOs
|
||||
class ManifestFileEntry(BaseModel):
|
||||
"""Maps node_id to file path in the bundle."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
node_id: str
|
||||
path: str
|
||||
|
||||
|
||||
class ManifestIntegrity(BaseModel):
|
||||
"""Basic integrity check fields."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
file_count: int
|
||||
|
||||
|
||||
class ManifestAppAssets(BaseModel):
|
||||
"""App assets section containing the full tree."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
tree: AppAssetFileTree
|
||||
|
||||
|
||||
class BundleManifest(BaseModel):
|
||||
"""
|
||||
Bundle manifest for app asset import/export.
|
||||
|
||||
Schema version 1.0:
|
||||
- dsl_filename: DSL file name in bundle root (e.g. "my_app.yml")
|
||||
- tree: Full AppAssetFileTree (files + folders) for 100% restoration including node IDs
|
||||
- files: Explicit node_id -> path mapping for file nodes only
|
||||
- integrity: Basic file_count validation
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
schema_version: str = Field(default=MANIFEST_SCHEMA_VERSION)
|
||||
generated_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC))
|
||||
dsl_filename: str = Field(description="DSL file name in bundle root")
|
||||
app_assets: ManifestAppAssets
|
||||
files: list[ManifestFileEntry]
|
||||
integrity: ManifestIntegrity
|
||||
|
||||
@property
|
||||
def assets_prefix(self) -> str:
|
||||
"""Assets directory name (DSL filename without extension)."""
|
||||
return self.dsl_filename.rsplit(".", 1)[0]
|
||||
|
||||
@classmethod
|
||||
def from_tree(cls, tree: AppAssetFileTree, dsl_filename: str) -> BundleManifest:
|
||||
"""Build manifest from an AppAssetFileTree."""
|
||||
files = [ManifestFileEntry(node_id=n.id, path=tree.get_path(n.id)) for n in tree.walk_files()]
|
||||
return cls(
|
||||
dsl_filename=dsl_filename,
|
||||
app_assets=ManifestAppAssets(tree=tree),
|
||||
files=files,
|
||||
integrity=ManifestIntegrity(file_count=len(files)),
|
||||
)
|
||||
|
||||
|
||||
# Export result
|
||||
class BundleExportResult(BaseModel):
|
||||
download_url: str = Field(description="Temporary download URL for the ZIP")
|
||||
filename: str = Field(description="Suggested filename for the ZIP")
|
||||
72
api/core/app/entities/llm_generation_entities.py
Normal file
72
api/core/app/entities/llm_generation_entities.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""
|
||||
LLM Generation Detail entities.
|
||||
|
||||
Defines the structure for storing and transmitting LLM generation details
|
||||
including reasoning content, tool calls, and their sequence.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContentSegment(BaseModel):
|
||||
"""Represents a content segment in the generation sequence."""
|
||||
|
||||
type: Literal["content"] = "content"
|
||||
start: int = Field(..., description="Start position in the text")
|
||||
end: int = Field(..., description="End position in the text")
|
||||
|
||||
|
||||
class ReasoningSegment(BaseModel):
|
||||
"""Represents a reasoning segment in the generation sequence."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
index: int = Field(..., description="Index into reasoning_content array")
|
||||
|
||||
|
||||
class ToolCallSegment(BaseModel):
|
||||
"""Represents a tool call segment in the generation sequence."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
index: int = Field(..., description="Index into tool_calls array")
|
||||
|
||||
|
||||
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
|
||||
|
||||
|
||||
class ToolCallDetail(BaseModel):
|
||||
"""Represents a tool call with its arguments and result."""
|
||||
|
||||
id: str = Field(default="", description="Unique identifier for the tool call")
|
||||
name: str = Field(..., description="Name of the tool")
|
||||
arguments: str = Field(default="", description="JSON string of tool arguments")
|
||||
result: str = Field(default="", description="Result from the tool execution")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class LLMGenerationDetailData(BaseModel):
|
||||
"""
|
||||
Domain model for LLM generation detail.
|
||||
|
||||
Contains the structured data for reasoning content, tool calls,
|
||||
and their display sequence.
|
||||
"""
|
||||
|
||||
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
|
||||
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
|
||||
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if there's any meaningful generation detail."""
|
||||
return not self.reasoning_content and not self.tool_calls
|
||||
|
||||
def to_response_dict(self) -> dict:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"reasoning_content": self.reasoning_content,
|
||||
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
|
||||
"sequence": [seg.model_dump() for seg in self.sequence],
|
||||
}
|
||||
75
api/core/helper/creators.py
Normal file
75
api/core/helper/creators.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""
|
||||
Helper module for Creators Platform integration.
|
||||
|
||||
Provides functionality to upload DSL files to the Creators Platform
|
||||
and generate redirect URLs with OAuth authorization codes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
|
||||
|
||||
|
||||
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
|
||||
"""Upload a DSL file to the Creators Platform anonymous upload endpoint.
|
||||
|
||||
Args:
|
||||
dsl_file_bytes: Raw bytes of the DSL file (YAML or ZIP).
|
||||
filename: Original filename for the upload.
|
||||
|
||||
Returns:
|
||||
The claim_code string used to retrieve the DSL later.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the upload request fails.
|
||||
ValueError: If the response does not contain a valid claim_code.
|
||||
"""
|
||||
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
|
||||
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
claim_code = data.get("data", {}).get("claim_code")
|
||||
if not claim_code:
|
||||
raise ValueError("Creators Platform did not return a valid claim_code")
|
||||
|
||||
return claim_code
|
||||
|
||||
|
||||
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
|
||||
"""Generate the redirect URL to the Creators Platform frontend.
|
||||
|
||||
Redirects to the Creators Platform root page with the dsl_claim_code.
|
||||
If CREATORS_PLATFORM_OAUTH_CLIENT_ID is configured (Dify Cloud),
|
||||
also signs an OAuth authorization code so the frontend can
|
||||
automatically authenticate the user via the OAuth callback.
|
||||
|
||||
For self-hosted Dify without OAuth client_id configured, only the
|
||||
dsl_claim_code is passed and the user must log in manually.
|
||||
|
||||
Args:
|
||||
user_account_id: The Dify user account ID.
|
||||
claim_code: The claim_code obtained from upload_dsl().
|
||||
|
||||
Returns:
|
||||
The full redirect URL string.
|
||||
"""
|
||||
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
|
||||
params: dict[str, str] = {"dsl_claim_code": claim_code}
|
||||
|
||||
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
|
||||
if client_id:
|
||||
from services.oauth_server import OAuthServerService
|
||||
|
||||
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
|
||||
params["oauth_code"] = oauth_code
|
||||
|
||||
return f"{base_url}?{urlencode(params)}"
|
||||
62
api/core/llm_generator/context_models.py
Normal file
62
api/core/llm_generator/context_models.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class VariableSelectorPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variable: str = Field(..., description="Variable name used in generated code")
|
||||
value_selector: list[str] = Field(..., description="Path to upstream node output, format: [node_id, output_name]")
|
||||
|
||||
|
||||
class CodeOutputPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: str = Field(..., description="Output variable type")
|
||||
|
||||
|
||||
class CodeContextPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/tool/components/context-generate-modal/index.tsx (code node snapshot).
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
code: str = Field(..., description="Existing code in the Code node")
|
||||
outputs: dict[str, CodeOutputPayload] | None = Field(
|
||||
default=None, description="Existing output definitions for the Code node"
|
||||
)
|
||||
variables: list[VariableSelectorPayload] | None = Field(
|
||||
default=None, description="Existing variable selectors used by the Code node"
|
||||
)
|
||||
|
||||
|
||||
class AvailableVarPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts (available variables).
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
value_selector: list[str] = Field(..., description="Path to upstream node output")
|
||||
type: str = Field(..., description="Variable type, e.g. string, number, array[object]")
|
||||
description: str | None = Field(default=None, description="Optional variable description")
|
||||
node_id: str | None = Field(default=None, description="Source node ID")
|
||||
node_title: str | None = Field(default=None, description="Source node title")
|
||||
node_type: str | None = Field(default=None, description="Source node type")
|
||||
json_schema: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
alias="schema",
|
||||
description="Optional JSON schema for object variables",
|
||||
)
|
||||
|
||||
|
||||
class ParameterInfoPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/tool/use-config.ts (ToolParameter metadata).
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str = Field(..., description="Target parameter name")
|
||||
type: str = Field(default="string", description="Target parameter type")
|
||||
description: str = Field(default="", description="Parameter description")
|
||||
required: bool | None = Field(default=None, description="Whether the parameter is required")
|
||||
options: list[str] | None = Field(default=None, description="Allowed option values")
|
||||
min: float | None = Field(default=None, description="Minimum numeric value")
|
||||
max: float | None = Field(default=None, description="Maximum numeric value")
|
||||
default: str | int | float | bool | None = Field(default=None, description="Default value")
|
||||
multiple: bool | None = Field(default=None, description="Whether the parameter accepts multiple values")
|
||||
label: str | None = Field(default=None, description="Optional display label")
|
||||
67
api/core/llm_generator/output_models.py
Normal file
67
api/core/llm_generator/output_models.py
Normal file
@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from graphon.variables.types import SegmentType
|
||||
|
||||
|
||||
class SuggestedQuestionsOutput(BaseModel):
|
||||
"""Output model for suggested questions generation."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
questions: list[str] = Field(
|
||||
min_length=3,
|
||||
max_length=3,
|
||||
description="Exactly 3 suggested follow-up questions for the user",
|
||||
)
|
||||
|
||||
|
||||
class VariableSelectorOutput(BaseModel):
|
||||
"""Variable selector mapping code variable to upstream node output.
|
||||
|
||||
Note: Separate from VariableSelector to ensure 'additionalProperties: false'
|
||||
in JSON schema for OpenAI/Azure strict mode.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variable: str = Field(description="Variable name used in the generated code")
|
||||
value_selector: list[str] = Field(description="Path to upstream node output, format: [node_id, output_name]")
|
||||
|
||||
|
||||
class CodeNodeOutputItem(BaseModel):
|
||||
"""Single output variable definition.
|
||||
|
||||
Note: OpenAI/Azure strict mode requires 'additionalProperties: false' and
|
||||
does not support dynamic object keys, so outputs use array format.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str = Field(description="Output variable name returned by the main function")
|
||||
type: SegmentType = Field(description="Data type of the output variable")
|
||||
|
||||
|
||||
class CodeNodeStructuredOutput(BaseModel):
|
||||
"""Structured output for code node generation."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variables: list[VariableSelectorOutput] = Field(
|
||||
description="Input variables mapping code variables to upstream node outputs"
|
||||
)
|
||||
code: str = Field(description="Generated code with a main function that processes inputs and returns outputs")
|
||||
outputs: list[CodeNodeOutputItem] = Field(
|
||||
description="Output variable definitions specifying name and type for each return value"
|
||||
)
|
||||
message: str = Field(description="Brief explanation of what the generated code does")
|
||||
|
||||
|
||||
class InstructionModifyOutput(BaseModel):
|
||||
"""Output model for instruction-based prompt modification."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
modified: str = Field(description="The modified prompt content after applying the instruction")
|
||||
message: str = Field(description="Brief explanation of what changes were made")
|
||||
203
api/core/llm_generator/output_parser/file_ref.py
Normal file
203
api/core/llm_generator/output_parser/file_ref.py
Normal file
@ -0,0 +1,203 @@
|
||||
"""
|
||||
File path detection and conversion for structured output.
|
||||
|
||||
This module provides utilities to:
|
||||
1. Detect sandbox file path fields in JSON Schema (format: "file-path")
|
||||
2. Adapt schemas to add file-path descriptions before model invocation
|
||||
3. Convert sandbox file path strings into File objects via a resolver
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.file import File
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
FILE_PATH_FORMAT = "file-path"
|
||||
FILE_PATH_DESCRIPTION_SUFFIX = "this field contains a file path from the Dify sandbox"
|
||||
|
||||
|
||||
def is_file_path_property(schema: Mapping[str, Any]) -> bool:
|
||||
"""Check if a schema property represents a sandbox file path."""
|
||||
if schema.get("type") != "string":
|
||||
return False
|
||||
format_value = schema.get("format")
|
||||
if not isinstance(format_value, str):
|
||||
return False
|
||||
normalized_format = format_value.lower().replace("_", "-")
|
||||
return normalized_format == FILE_PATH_FORMAT
|
||||
|
||||
|
||||
def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||
"""Recursively detect file path fields in a JSON schema."""
|
||||
file_path_fields: list[str] = []
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties")
|
||||
if isinstance(properties, Mapping):
|
||||
properties_mapping = cast(Mapping[str, Any], properties)
|
||||
for prop_name, prop_schema in properties_mapping.items():
|
||||
if not isinstance(prop_schema, Mapping):
|
||||
continue
|
||||
prop_schema_mapping = cast(Mapping[str, Any], prop_schema)
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_path_property(prop_schema_mapping):
|
||||
file_path_fields.append(current_path)
|
||||
else:
|
||||
file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path))
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items")
|
||||
if not isinstance(items_schema, Mapping):
|
||||
return file_path_fields
|
||||
items_schema_mapping = cast(Mapping[str, Any], items_schema)
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_path_property(items_schema_mapping):
|
||||
file_path_fields.append(array_path)
|
||||
else:
|
||||
file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path))
|
||||
|
||||
return file_path_fields
|
||||
|
||||
|
||||
def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Normalize sandbox file path fields and collect their JSON paths."""
|
||||
result = _deep_copy_value(schema)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("structured_output_schema must be a JSON object")
|
||||
result_dict = cast(dict[str, Any], result)
|
||||
|
||||
file_path_fields: list[str] = []
|
||||
_adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields)
|
||||
return result_dict, file_path_fields
|
||||
|
||||
|
||||
def convert_sandbox_file_paths_in_output(
|
||||
output: Mapping[str, Any],
|
||||
file_path_fields: Sequence[str],
|
||||
file_resolver: Callable[[str], File],
|
||||
) -> tuple[dict[str, Any], list[File]]:
|
||||
"""Convert sandbox file paths into File objects using the resolver."""
|
||||
if not file_path_fields:
|
||||
return dict(output), []
|
||||
|
||||
result = _deep_copy_value(output)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("Structured output must be a JSON object")
|
||||
result_dict = cast(dict[str, Any], result)
|
||||
|
||||
files: list[File] = []
|
||||
for path in file_path_fields:
|
||||
_convert_path_in_place(result_dict, path.split("."), file_resolver, files)
|
||||
|
||||
return result_dict, files
|
||||
|
||||
|
||||
def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None:
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties")
|
||||
if isinstance(properties, Mapping):
|
||||
properties_mapping = cast(Mapping[str, Any], properties)
|
||||
for prop_name, prop_schema in properties_mapping.items():
|
||||
if not isinstance(prop_schema, dict):
|
||||
continue
|
||||
prop_schema_dict = cast(dict[str, Any], prop_schema)
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_path_property(prop_schema_dict):
|
||||
_normalize_file_path_schema(prop_schema_dict)
|
||||
file_path_fields.append(current_path)
|
||||
else:
|
||||
_adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields)
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items")
|
||||
if not isinstance(items_schema, dict):
|
||||
return
|
||||
items_schema_dict = cast(dict[str, Any], items_schema)
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_path_property(items_schema_dict):
|
||||
_normalize_file_path_schema(items_schema_dict)
|
||||
file_path_fields.append(array_path)
|
||||
else:
|
||||
_adapt_schema_in_place(items_schema_dict, array_path, file_path_fields)
|
||||
|
||||
|
||||
def _normalize_file_path_schema(schema: dict[str, Any]) -> None:
|
||||
schema["type"] = "string"
|
||||
schema["format"] = FILE_PATH_FORMAT
|
||||
description = schema.get("description", "")
|
||||
if description:
|
||||
if FILE_PATH_DESCRIPTION_SUFFIX not in description:
|
||||
schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}"
|
||||
else:
|
||||
schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX
|
||||
|
||||
|
||||
def _deep_copy_value(value: Any) -> Any:
|
||||
if isinstance(value, Mapping):
|
||||
mapping = cast(Mapping[str, Any], value)
|
||||
return {key: _deep_copy_value(item) for key, item in mapping.items()}
|
||||
if isinstance(value, list):
|
||||
list_value = cast(list[Any], value)
|
||||
return [_deep_copy_value(item) for item in list_value]
|
||||
return value
|
||||
|
||||
|
||||
def _convert_path_in_place(
|
||||
obj: dict[str, Any],
|
||||
path_parts: list[str],
|
||||
file_resolver: Callable[[str], File],
|
||||
files: list[File],
|
||||
) -> None:
|
||||
if not path_parts:
|
||||
return
|
||||
|
||||
current = path_parts[0]
|
||||
remaining = path_parts[1:]
|
||||
|
||||
if current.endswith("[*]"):
|
||||
key = current[:-3] if current != "[*]" else ""
|
||||
target_value = obj.get(key) if key else obj
|
||||
|
||||
if isinstance(target_value, list):
|
||||
target_list = cast(list[Any], target_value)
|
||||
if remaining:
|
||||
for item in target_list:
|
||||
if isinstance(item, dict):
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
_convert_path_in_place(item_dict, remaining, file_resolver, files)
|
||||
else:
|
||||
resolved_files: list[File] = []
|
||||
for item in target_list:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError("File path must be a string")
|
||||
file = file_resolver(item)
|
||||
files.append(file)
|
||||
resolved_files.append(file)
|
||||
if key:
|
||||
obj[key] = ArrayFileSegment(value=resolved_files)
|
||||
return
|
||||
|
||||
if not remaining:
|
||||
if current not in obj:
|
||||
return
|
||||
value = obj[current]
|
||||
if value is None:
|
||||
obj[current] = None
|
||||
return
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("File path must be a string")
|
||||
file = file_resolver(value)
|
||||
files.append(file)
|
||||
obj[current] = FileSegment(value=file)
|
||||
return
|
||||
|
||||
if current in obj and isinstance(obj[current], dict):
|
||||
_convert_path_in_place(obj[current], remaining, file_resolver, files)
|
||||
45
api/core/llm_generator/utils.py
Normal file
45
api/core/llm_generator/utils.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Utility functions for LLM generator."""
|
||||
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_prompt_messages(messages: list[dict]) -> list[PromptMessage]:
|
||||
"""
|
||||
Deserialize list of dicts to list[PromptMessage].
|
||||
|
||||
Expected format:
|
||||
[
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."},
|
||||
]
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
for msg in messages:
|
||||
role = PromptMessageRole.value_of(msg["role"])
|
||||
content = msg.get("content", "")
|
||||
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
result.append(UserPromptMessage(content=content))
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
result.append(AssistantPromptMessage(content=content))
|
||||
case PromptMessageRole.SYSTEM:
|
||||
result.append(SystemPromptMessage(content=content))
|
||||
case PromptMessageRole.TOOL:
|
||||
result.append(ToolPromptMessage(content=content, tool_call_id=msg.get("tool_call_id", "")))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def serialize_prompt_messages(messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Serialize list[PromptMessage] to list of dicts.
|
||||
"""
|
||||
return [{"role": msg.role.value, "content": msg.content} for msg in messages]
|
||||
11
api/core/memory/__init__.py
Normal file
11
api/core/memory/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import (
|
||||
NodeTokenBufferMemory,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
__all__ = [
|
||||
"BaseMemory",
|
||||
"NodeTokenBufferMemory",
|
||||
"TokenBufferMemory",
|
||||
]
|
||||
82
api/core/memory/base.py
Normal file
82
api/core/memory/base.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""
|
||||
Base memory interfaces and types.
|
||||
|
||||
This module defines the common protocol for memory implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
from graphon.model_runtime.entities import ImagePromptMessageContent, PromptMessage
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
"""
|
||||
Abstract base class for memory implementations.
|
||||
|
||||
Provides a common interface for both conversation-level and node-level memory.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt as formatted text.
|
||||
|
||||
:param human_prefix: Prefix for human messages
|
||||
:param ai_prefix: Prefix for assistant messages
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Formatted history text
|
||||
"""
|
||||
from graphon.model_runtime.entities import (
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
prompt_messages = self.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=message_limit,
|
||||
)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
196
api/core/memory/node_token_buffer_memory.py
Normal file
196
api/core/memory/node_token_buffer_memory.py
Normal file
@ -0,0 +1,196 @@
|
||||
"""
|
||||
Node-level Token Buffer Memory for Chatflow.
|
||||
|
||||
This module provides node-scoped memory within a conversation.
|
||||
Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
Note: This is only available in Chatflow (advanced-chat mode) because it requires
|
||||
both conversation_id and node_id.
|
||||
|
||||
Design:
|
||||
- History is read directly from WorkflowNodeExecutionModel.outputs["context"]
|
||||
- No separate storage needed - the context is already saved during node execution
|
||||
- Thread tracking leverages Message table's parent_message_id structure
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeTokenBufferMemory(BaseMemory):
|
||||
"""
|
||||
Node-level Token Buffer Memory.
|
||||
|
||||
Provides node-scoped memory within a conversation. Each LLM node can maintain
|
||||
its own independent conversation history.
|
||||
|
||||
Key design: History is read directly from WorkflowNodeExecutionModel.outputs["context"],
|
||||
which is already saved during node execution. No separate storage needed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.node_id = node_id
|
||||
self.tenant_id = tenant_id
|
||||
self.model_instance = model_instance
|
||||
|
||||
def _get_thread_workflow_run_ids(self) -> list[str]:
|
||||
"""
|
||||
Get workflow_run_ids for the current thread by querying Message table.
|
||||
Returns workflow_run_ids in chronological order (oldest first).
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == self.conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(500)
|
||||
)
|
||||
messages = list(session.scalars(stmt).all())
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# Extract thread messages using existing logic
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# For newly created message, its answer is temporarily empty, skip it
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
|
||||
# Reverse to get chronological order, extract workflow_run_ids
|
||||
return [msg.workflow_run_id for msg in reversed(thread_messages) if msg.workflow_run_id]
|
||||
|
||||
def _deserialize_prompt_message(self, msg_dict: dict) -> PromptMessage:
|
||||
"""Deserialize a dict to PromptMessage based on role."""
|
||||
role = msg_dict.get("role")
|
||||
if role in (PromptMessageRole.USER, "user"):
|
||||
return UserPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
|
||||
return AssistantPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.SYSTEM, "system"):
|
||||
return SystemPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.TOOL, "tool"):
|
||||
return ToolPromptMessage.model_validate(msg_dict)
|
||||
else:
|
||||
return PromptMessage.model_validate(msg_dict)
|
||||
|
||||
def _deserialize_context(self, context_data: list[dict]) -> list[PromptMessage]:
|
||||
"""Deserialize context data from outputs to list of PromptMessage."""
|
||||
messages = []
|
||||
for msg_dict in context_data:
|
||||
try:
|
||||
msg = self._deserialize_prompt_message(msg_dict)
|
||||
msg = self._restore_multimodal_content(msg)
|
||||
messages.append(msg)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to deserialize prompt message: %s", e)
|
||||
return messages
|
||||
|
||||
def _restore_multimodal_content(self, message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Restore multimodal content (base64 or url) from file_ref.
|
||||
|
||||
When context is saved, base64_data is cleared to save storage space.
|
||||
This method restores the content by parsing file_ref (format: "method:id_or_url").
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, restoring multimodal data from file references
|
||||
restored_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
# restore_multimodal_content preserves the concrete subclass type
|
||||
restored_item = file_manager.restore_multimodal_content(item)
|
||||
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
|
||||
else:
|
||||
restored_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": restored_content})
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
History is read directly from the last completed node execution's outputs["context"].
|
||||
"""
|
||||
_ = message_limit # unused, kept for interface compatibility
|
||||
|
||||
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
|
||||
if not thread_workflow_run_ids:
|
||||
return []
|
||||
|
||||
# Get the last completed workflow_run_id (contains accumulated context)
|
||||
last_run_id = thread_workflow_run_ids[-1]
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == last_run_id,
|
||||
WorkflowNodeExecutionModel.node_id == self.node_id,
|
||||
WorkflowNodeExecutionModel.status == "succeeded",
|
||||
)
|
||||
execution = session.scalars(stmt).first()
|
||||
|
||||
if not execution:
|
||||
return []
|
||||
|
||||
outputs = execution.outputs_dict
|
||||
if not outputs:
|
||||
return []
|
||||
|
||||
context_data = outputs.get("context")
|
||||
|
||||
if not context_data or not isinstance(context_data, list):
|
||||
return []
|
||||
|
||||
prompt_messages = self._deserialize_context(context_data)
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# Truncate by token limit
|
||||
try:
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
while current_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
prompt_messages.pop(0)
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to count tokens for truncation: %s", e)
|
||||
|
||||
return prompt_messages
|
||||
11
api/core/session/__init__.py
Normal file
11
api/core/session/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from .cli_api import CliApiSession, CliApiSessionManager
|
||||
from .session import BaseSession, RedisSessionStorage, SessionManager, SessionStorage
|
||||
|
||||
__all__ = [
|
||||
"BaseSession",
|
||||
"CliApiSession",
|
||||
"CliApiSessionManager",
|
||||
"RedisSessionStorage",
|
||||
"SessionManager",
|
||||
"SessionStorage",
|
||||
]
|
||||
30
api/core/session/cli_api.py
Normal file
30
api/core/session/cli_api.py
Normal file
@ -0,0 +1,30 @@
|
||||
import secrets
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from configs import dify_config
|
||||
from core.skill.entities import ToolAccessPolicy
|
||||
|
||||
from .session import BaseSession, SessionManager
|
||||
|
||||
|
||||
class CliApiSession(BaseSession):
|
||||
secret: str = Field(default_factory=lambda: secrets.token_urlsafe(32))
|
||||
|
||||
|
||||
class CliContext(BaseModel):
|
||||
tool_access: ToolAccessPolicy | None = Field(default=None, description="Tool access policy")
|
||||
|
||||
|
||||
class CliApiSessionManager(SessionManager[CliApiSession]):
|
||||
def __init__(self, ttl: int | None = None):
|
||||
super().__init__(
|
||||
key_prefix="cli_api_session",
|
||||
session_class=CliApiSession,
|
||||
ttl=ttl or dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
)
|
||||
|
||||
def create(self, tenant_id: str, user_id: str, context: CliContext) -> CliApiSession:
|
||||
session = CliApiSession(tenant_id=tenant_id, user_id=user_id, context=context.model_dump(mode="json"))
|
||||
self.save(session)
|
||||
return session
|
||||
106
api/core/session/session.py
Normal file
106
api/core/session/session.py
Normal file
@ -0,0 +1,106 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Generic, Protocol, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionStorage(Protocol):
|
||||
"""Session storage interface."""
|
||||
|
||||
def get(self, key: str) -> str | None: ...
|
||||
def set(self, key: str, value: str, ttl: int) -> None: ...
|
||||
def delete(self, key: str) -> bool: ...
|
||||
def exists(self, key: str) -> bool: ...
|
||||
def refresh_ttl(self, key: str, ttl: int) -> bool: ...
|
||||
|
||||
|
||||
class RedisSessionStorage:
|
||||
"""Redis storage implementation (default)."""
|
||||
|
||||
def get(self, key: str) -> str | None:
|
||||
result = redis_client.get(key)
|
||||
if result is None:
|
||||
return None
|
||||
return result.decode() if isinstance(result, bytes) else result
|
||||
|
||||
def set(self, key: str, value: str, ttl: int) -> None:
|
||||
redis_client.setex(key, ttl, value)
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
return redis_client.delete(key) > 0
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return redis_client.exists(key) > 0
|
||||
|
||||
def refresh_ttl(self, key: str, ttl: int) -> bool:
|
||||
return bool(redis_client.expire(key, ttl))
|
||||
|
||||
|
||||
class BaseSession(BaseModel):
|
||||
"""Base session model."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
context: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def update_timestamp(self) -> None:
|
||||
self.updated_at = datetime.now(UTC)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseSession)
|
||||
|
||||
|
||||
class SessionManager(Generic[T]):
|
||||
"""Generic session manager."""
|
||||
|
||||
DEFAULT_TTL = 7200 # 2 hours
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key_prefix: str,
|
||||
session_class: type[T],
|
||||
storage: SessionStorage | None = None,
|
||||
ttl: int | None = None,
|
||||
):
|
||||
self._key_prefix = key_prefix
|
||||
self._session_class = session_class
|
||||
self._storage = storage or RedisSessionStorage()
|
||||
self._ttl = ttl or self.DEFAULT_TTL
|
||||
|
||||
def _get_key(self, session_id: str) -> str:
|
||||
return f"{self._key_prefix}:{session_id}"
|
||||
|
||||
def save(self, session: T) -> None:
|
||||
session.update_timestamp()
|
||||
key = self._get_key(session.id)
|
||||
self._storage.set(key, session.model_dump_json(), self._ttl)
|
||||
|
||||
def get(self, session_id: str) -> T | None:
|
||||
key = self._get_key(session_id)
|
||||
data = self._storage.get(key)
|
||||
if data is None:
|
||||
return None
|
||||
try:
|
||||
return self._session_class.model_validate(json.loads(data))
|
||||
except (json.JSONDecodeError, ValidationError) as e:
|
||||
logger.warning("Failed to deserialize session %s: %s", session_id, e)
|
||||
return None
|
||||
|
||||
def delete(self, session_id: str) -> bool:
|
||||
return self._storage.delete(self._get_key(session_id))
|
||||
|
||||
def exists(self, session_id: str) -> bool:
|
||||
return self._storage.exists(self._get_key(session_id))
|
||||
|
||||
def refresh_ttl(self, session_id: str) -> bool:
|
||||
return self._storage.refresh_ttl(self._get_key(session_id), self._ttl)
|
||||
187
api/core/tools/utils/system_encryption.py
Normal file
187
api/core/tools/utils/system_encryption.py
Normal file
@ -0,0 +1,187 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EncryptionError(Exception):
|
||||
"""Encryption/decryption specific error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemEncrypter:
|
||||
"""
|
||||
A simple parameters encrypter using AES-CBC encryption.
|
||||
|
||||
This class provides methods to encrypt and decrypt parameters
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: str | None = None):
|
||||
"""
|
||||
Initialize the encrypter.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Raises:
|
||||
ValueError: If SECRET_KEY is not configured or empty
|
||||
"""
|
||||
secret_key = secret_key or dify_config.SECRET_KEY or ""
|
||||
|
||||
# Generate a fixed 256-bit key using SHA-256
|
||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||
|
||||
def encrypt_params(self, params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt parameters.
|
||||
|
||||
Args:
|
||||
params: parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
|
||||
Raises:
|
||||
EncryptionError: If encryption fails
|
||||
ValueError: If params is invalid
|
||||
"""
|
||||
|
||||
try:
|
||||
# Generate random IV (16 bytes)
|
||||
iv = get_random_bytes(16)
|
||||
|
||||
# Create AES cipher (CBC mode)
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Encrypt data
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
|
||||
# Combine IV and encrypted data
|
||||
combined = iv + encrypted_data
|
||||
|
||||
# Return base64 encoded string
|
||||
return base64.b64encode(combined).decode()
|
||||
|
||||
except Exception as e:
|
||||
raise EncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
|
||||
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt parameters.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted parameters dictionary
|
||||
|
||||
Raises:
|
||||
EncryptionError: If decryption fails
|
||||
ValueError: If encrypted_data is invalid
|
||||
"""
|
||||
if not isinstance(encrypted_data, str):
|
||||
raise ValueError("encrypted_data must be a string")
|
||||
|
||||
if not encrypted_data:
|
||||
raise ValueError("encrypted_data cannot be empty")
|
||||
|
||||
try:
|
||||
# Base64 decode
|
||||
combined = base64.b64decode(encrypted_data)
|
||||
|
||||
# Check minimum length (IV + at least one AES block)
|
||||
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
|
||||
raise ValueError("Invalid encrypted data format")
|
||||
|
||||
# Separate IV and encrypted data
|
||||
iv = combined[:16]
|
||||
encrypted_data_bytes = combined[16:]
|
||||
|
||||
# Create AES cipher
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Decrypt data
|
||||
decrypted_data = cipher.decrypt(encrypted_data_bytes)
|
||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||
|
||||
# Parse JSON
|
||||
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
|
||||
if not isinstance(params, dict):
|
||||
raise ValueError("Decrypted data is not a valid dictionary")
|
||||
|
||||
return params
|
||||
|
||||
except Exception as e:
|
||||
raise EncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
|
||||
"""
|
||||
Create an encrypter instance.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Returns:
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
return SystemEncrypter(secret_key=secret_key)
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_encrypter: SystemEncrypter | None = None
|
||||
|
||||
|
||||
def get_system_encrypter() -> SystemEncrypter:
|
||||
"""
|
||||
Get the global encrypter instance.
|
||||
|
||||
Returns:
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
global _encrypter
|
||||
if _encrypter is None:
|
||||
_encrypter = SystemEncrypter()
|
||||
return _encrypter
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def encrypt_system_params(params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
params: parameters dictionary
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
"""
|
||||
return get_system_encrypter().encrypt_params(params)
|
||||
|
||||
|
||||
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted parameters dictionary
|
||||
"""
|
||||
return get_system_encrypter().decrypt_params(encrypted_data)
|
||||
3
api/core/workflow/nodes/command/__init__.py
Normal file
3
api/core/workflow/nodes/command/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .node import CommandNode
|
||||
|
||||
__all__ = ["CommandNode"]
|
||||
10
api/core/workflow/nodes/command/entities.py
Normal file
10
api/core/workflow/nodes/command/entities.py
Normal file
@ -0,0 +1,10 @@
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
|
||||
|
||||
class CommandNodeData(BaseNodeData):
|
||||
"""
|
||||
Command Node Data.
|
||||
"""
|
||||
|
||||
working_directory: str = "" # Working directory for command execution
|
||||
command: str = "" # Command to execute
|
||||
16
api/core/workflow/nodes/command/exc.py
Normal file
16
api/core/workflow/nodes/command/exc.py
Normal file
@ -0,0 +1,16 @@
|
||||
class CommandNodeError(ValueError):
|
||||
"""Base class for command node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CommandExecutionError(CommandNodeError):
|
||||
"""Raised when command execution fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CommandTimeoutError(CommandNodeError):
|
||||
"""Raised when command execution times out."""
|
||||
|
||||
pass
|
||||
152
api/core/workflow/nodes/command/node.py
Normal file
152
api/core/workflow/nodes/command/node.py
Normal file
@ -0,0 +1,152 @@
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.sandbox import sandbox_debug
|
||||
from core.sandbox.bash.session import SANDBOX_READY_TIMEOUT
|
||||
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
|
||||
from core.virtual_environment.__base.helpers import submit_command, with_connection
|
||||
from core.workflow.nodes.command.entities import CommandNodeData
|
||||
from core.workflow.nodes.command.exc import CommandExecutionError
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.base import variable_template_parser
|
||||
from graphon.nodes.base.entities import VariableSelector
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# FIXME(Mairuis): The timeout value is currently hardcoded and should be made configurable in the future.
|
||||
COMMAND_NODE_TIMEOUT_SECONDS = 60 * 10
|
||||
|
||||
|
||||
class CommandNode(Node[CommandNodeData]):
|
||||
node_type = BuiltinNodeTypes.COMMAND
|
||||
|
||||
def _render_template(self, template: str) -> str:
|
||||
parser = VariableTemplateParser(template=template)
|
||||
selectors = parser.extract_variable_selectors()
|
||||
if not selectors:
|
||||
return template
|
||||
|
||||
inputs: dict[str, Any] = {}
|
||||
for selector in selectors:
|
||||
value = self.graph_runtime_state.variable_pool.get(selector.value_selector)
|
||||
inputs[selector.variable] = value.to_object() if value is not None else None
|
||||
|
||||
return parser.format(inputs)
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": "command",
|
||||
"config": {
|
||||
"working_directory": "",
|
||||
"command": "",
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
sandbox = self.graph_runtime_state.sandbox
|
||||
if sandbox is None:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error="Sandbox not available for CommandNode.",
|
||||
error_type="SandboxNotInitializedError",
|
||||
)
|
||||
|
||||
working_directory = self._render_template((self.node_data.working_directory or "").strip())
|
||||
raw_command = self._render_template(self.node_data.command or "")
|
||||
|
||||
working_directory = working_directory or None
|
||||
|
||||
if not raw_command:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error="Command is required.",
|
||||
error_type="CommandNodeError",
|
||||
)
|
||||
|
||||
try:
|
||||
sandbox.wait_ready(timeout=SANDBOX_READY_TIMEOUT)
|
||||
with with_connection(sandbox.vm) as conn:
|
||||
command = ["bash", "-c", raw_command]
|
||||
|
||||
sandbox_debug("command_node", "command", command)
|
||||
|
||||
future = submit_command(sandbox.vm, conn, command, cwd=working_directory)
|
||||
result = future.result(timeout=COMMAND_NODE_TIMEOUT_SECONDS)
|
||||
|
||||
outputs: dict[str, Any] = {
|
||||
"stdout": result.stdout.decode("utf-8", errors="replace"),
|
||||
"stderr": result.stderr.decode("utf-8", errors="replace"),
|
||||
"exit_code": result.exit_code,
|
||||
"pid": result.pid,
|
||||
}
|
||||
process_data = {"command": command, "working_directory": working_directory}
|
||||
|
||||
sandbox_debug("command_node", "outputs", result.debug_message)
|
||||
|
||||
if result.exit_code not in (None, 0):
|
||||
stderr_text = result.stderr.decode("utf-8", errors="replace")
|
||||
error_message = f"{stderr_text}\n\nCommand exited with code {result.exit_code}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
error=error_message,
|
||||
error_type=CommandExecutionError.__name__,
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
except CommandTimeoutError:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Command timed out after {COMMAND_NODE_TIMEOUT_SECONDS}s",
|
||||
error_type=CommandTimeoutError.__name__,
|
||||
)
|
||||
except CommandCancelledError:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error="Command was cancelled",
|
||||
error_type=CommandCancelledError.__name__,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Command node %s failed", self.id)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: CommandNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config
|
||||
|
||||
typed_node_data = node_data
|
||||
|
||||
selectors: list[VariableSelector] = []
|
||||
selectors += list(variable_template_parser.extract_selectors_from_template(typed_node_data.command))
|
||||
selectors += list(variable_template_parser.extract_selectors_from_template(typed_node_data.working_directory))
|
||||
|
||||
mapping: dict[str, Sequence[str]] = {}
|
||||
for selector in selectors:
|
||||
mapping[node_id + "." + selector.variable] = selector.value_selector
|
||||
|
||||
return mapping
|
||||
4
api/core/workflow/nodes/file_upload/__init__.py
Normal file
4
api/core/workflow/nodes/file_upload/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .entities import FileUploadNodeData
|
||||
from .node import FileUploadNode
|
||||
|
||||
__all__ = ["FileUploadNode", "FileUploadNodeData"]
|
||||
7
api/core/workflow/nodes/file_upload/entities.py
Normal file
7
api/core/workflow/nodes/file_upload/entities.py
Normal file
@ -0,0 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
|
||||
|
||||
class FileUploadNodeData(BaseNodeData):
|
||||
variable_selector: Sequence[str]
|
||||
6
api/core/workflow/nodes/file_upload/exc.py
Normal file
6
api/core/workflow/nodes/file_upload/exc.py
Normal file
@ -0,0 +1,6 @@
|
||||
class FileUploadNodeError(ValueError):
|
||||
"""Base exception for errors related to the FileUploadNode."""
|
||||
|
||||
|
||||
class FileUploadDownloadError(FileUploadNodeError):
|
||||
"""Exception raised when preparing file download in sandbox fails."""
|
||||
244
api/core/workflow/nodes/file_upload/node.py
Normal file
244
api/core/workflow/nodes/file_upload/node.py
Normal file
@ -0,0 +1,244 @@
|
||||
import logging
|
||||
import os
|
||||
import posixpath
|
||||
from collections.abc import Mapping, Sequence
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Any, cast
|
||||
|
||||
from core.sandbox.bash.session import SANDBOX_READY_TIMEOUT
|
||||
from core.virtual_environment.__base.command_future import CommandCancelledError, CommandTimeoutError
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
from core.zip_sandbox import SandboxDownloadItem
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from graphon.file import File, FileTransferMethod
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.variables import ArrayFileSegment
|
||||
from graphon.variables.segments import ArrayStringSegment, FileSegment
|
||||
|
||||
from .entities import FileUploadNodeData
|
||||
from .exc import FileUploadDownloadError, FileUploadNodeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileUploadNode(Node[FileUploadNodeData]):
|
||||
"""Upload workflow file variables into sandbox via presigned URLs.
|
||||
|
||||
The node intentionally avoids streaming file bytes through Dify workers. For local/tool
|
||||
files, it generates storage-backed presigned URLs and lets sandbox download directly.
|
||||
"""
|
||||
|
||||
node_type = BuiltinNodeTypes.FILE_UPLOAD
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
_ = filters
|
||||
return {
|
||||
"type": "file-upload",
|
||||
"config": {
|
||||
"variable_selector": [],
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
sandbox = self.graph_runtime_state.sandbox
|
||||
variable_selector = self.node_data.variable_selector
|
||||
inputs: dict[str, Any] = {"variable_selector": variable_selector}
|
||||
if sandbox is None:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error="Sandbox not available for FileUploadNode.",
|
||||
error_type="SandboxNotInitializedError",
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"File variable not found for selector: {variable_selector}",
|
||||
error_type=FileUploadNodeError.__name__,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Variable {variable_selector} is not a file or file array",
|
||||
error_type=FileUploadNodeError.__name__,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
files = self._normalize_files(variable.value)
|
||||
process_data: dict[str, Any] = {
|
||||
"file_count": len(files),
|
||||
"files": [file.to_dict() for file in files],
|
||||
}
|
||||
if not files:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error="Selected file variable is empty.",
|
||||
error_type=FileUploadNodeError.__name__,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
try:
|
||||
sandbox.wait_ready(timeout=SANDBOX_READY_TIMEOUT)
|
||||
download_items: list[SandboxDownloadItem] = self._build_download_items(files)
|
||||
sandbox_paths = self._upload(sandbox.vm, download_items)
|
||||
file_names = [PurePosixPath(path).name for path in sandbox_paths]
|
||||
process_data = {
|
||||
**process_data,
|
||||
"sandbox_paths": sandbox_paths,
|
||||
"file_names": file_names,
|
||||
}
|
||||
|
||||
outputs: dict[str, Any]
|
||||
if len(sandbox_paths) == 1:
|
||||
outputs = {
|
||||
"sandbox_path": sandbox_paths[0],
|
||||
"file_name": file_names[0],
|
||||
}
|
||||
else:
|
||||
outputs = {
|
||||
"sandbox_path": ArrayStringSegment(value=sandbox_paths),
|
||||
"file_name": ArrayStringSegment(value=file_names),
|
||||
}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
except CommandTimeoutError:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error="File upload timeout",
|
||||
error_type=CommandTimeoutError.__name__,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
except CommandCancelledError:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error="File upload command was cancelled",
|
||||
error_type=CommandCancelledError.__name__,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
except FileUploadNodeError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("File upload node %s failed", self.id)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: FileUploadNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
_ = graph_config
|
||||
typed_node_data = node_data
|
||||
return {node_id + ".files": typed_node_data.variable_selector}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_files(value: Any) -> list[File]:
|
||||
if isinstance(value, File):
|
||||
return [value]
|
||||
if isinstance(value, list):
|
||||
list_value = cast(list[object], value)
|
||||
files: list[File] = []
|
||||
for idx in range(len(list_value)):
|
||||
candidate = list_value[idx]
|
||||
if not isinstance(candidate, File):
|
||||
return []
|
||||
files.append(candidate)
|
||||
return files
|
||||
return []
|
||||
|
||||
def _build_download_items(self, files: Sequence[File]) -> list[SandboxDownloadItem]:
|
||||
used_paths: set[str] = set()
|
||||
items: list[SandboxDownloadItem] = []
|
||||
for index, file in enumerate(files):
|
||||
file_url = self._get_download_url(file)
|
||||
|
||||
filename = (file.filename or "").strip()
|
||||
if not filename or filename in {".", ".."}:
|
||||
filename = f"file-{index + 1}{file.extension or ''}"
|
||||
filename = os.path.basename(filename)
|
||||
|
||||
if filename in used_paths:
|
||||
stem = PurePosixPath(filename).stem or f"file-{index + 1}"
|
||||
suffix = PurePosixPath(filename).suffix
|
||||
dedupe = 1
|
||||
while filename in used_paths:
|
||||
filename = f"{stem}_{dedupe}{suffix}"
|
||||
dedupe += 1
|
||||
|
||||
used_paths.add(filename)
|
||||
items.append(SandboxDownloadItem(path=filename, url=file_url))
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def _normalize_path(path: str) -> str:
|
||||
normalized = posixpath.normpath(path.strip()) if path else "."
|
||||
if normalized.startswith("/"):
|
||||
normalized = normalized.lstrip("/")
|
||||
return normalized or "."
|
||||
|
||||
def _upload(self, vm: Any, items: list[SandboxDownloadItem]) -> list[str]:
|
||||
p = pipeline(vm)
|
||||
out_paths: list[str] = []
|
||||
for item in items:
|
||||
out_path = self._normalize_path(item.path)
|
||||
if out_path in ("", "."):
|
||||
raise FileUploadDownloadError("Download item path must point to a file")
|
||||
out_paths.append(out_path)
|
||||
p.add(["curl", "-fsSL", item.url, "-o", out_path], error_message="Failed to download file")
|
||||
|
||||
try:
|
||||
p.execute(timeout=None, raise_on_error=True)
|
||||
except Exception as exc:
|
||||
raise FileUploadDownloadError(str(exc)) from exc
|
||||
|
||||
return out_paths
|
||||
|
||||
def _get_download_url(self, file: File) -> str:
|
||||
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if not file.remote_url:
|
||||
raise FileUploadDownloadError("Remote file URL is missing")
|
||||
return file.remote_url
|
||||
|
||||
if file.transfer_method in (
|
||||
FileTransferMethod.LOCAL_FILE,
|
||||
FileTransferMethod.TOOL_FILE,
|
||||
FileTransferMethod.DATASOURCE_FILE,
|
||||
):
|
||||
download_url = file.generate_url(for_external=True)
|
||||
if not download_url:
|
||||
raise FileUploadDownloadError("Unable to generate download URL for file")
|
||||
return download_url
|
||||
|
||||
raise FileUploadDownloadError(f"Unsupported file transfer method: {file.transfer_method}")
|
||||
23
api/core/zip_sandbox/__init__.py
Normal file
23
api/core/zip_sandbox/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .entities import SandboxDownloadItem, SandboxFile, SandboxUploadItem
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .zip_sandbox import ZipSandbox
|
||||
|
||||
__all__ = [
|
||||
"SandboxDownloadItem",
|
||||
"SandboxFile",
|
||||
"SandboxUploadItem",
|
||||
"ZipSandbox",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "ZipSandbox":
|
||||
from .zip_sandbox import ZipSandbox
|
||||
|
||||
return ZipSandbox
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
81
api/core/zip_sandbox/cli_strategy.py
Normal file
81
api/core/zip_sandbox/cli_strategy.py
Normal file
@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import posixpath
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.virtual_environment.__base.exec import CommandExecutionError
|
||||
from core.virtual_environment.__base.helpers import execute, try_execute
|
||||
|
||||
from .strategy import ZipStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
|
||||
class CliZipStrategy(ZipStrategy):
|
||||
"""Strategy using native zip/unzip CLI commands."""
|
||||
|
||||
def is_available(self, vm: VirtualEnvironment) -> bool:
|
||||
result = try_execute(vm, ["which", "zip"], timeout=10)
|
||||
has_zip = bool(result.stdout and result.stdout.strip())
|
||||
result = try_execute(vm, ["which", "unzip"], timeout=10)
|
||||
has_unzip = bool(result.stdout and result.stdout.strip())
|
||||
return has_zip and has_unzip
|
||||
|
||||
def zip(
|
||||
self,
|
||||
vm: VirtualEnvironment,
|
||||
*,
|
||||
src: str,
|
||||
out_path: str,
|
||||
cwd: str | None,
|
||||
timeout: float,
|
||||
) -> None:
|
||||
if src in (".", ""):
|
||||
result = try_execute(vm, ["zip", "-qr", out_path, "."], timeout=timeout, cwd=cwd)
|
||||
if not result.is_error:
|
||||
return
|
||||
# zip exits with 12 when there is nothing to do; create empty zip
|
||||
if result.exit_code == 12:
|
||||
self._write_empty_zip(vm, out_path)
|
||||
return
|
||||
raise CommandExecutionError("Failed to create zip archive", result)
|
||||
|
||||
zip_cwd = posixpath.dirname(src) or "."
|
||||
target = posixpath.basename(src)
|
||||
result = try_execute(vm, ["zip", "-qr", out_path, target], timeout=timeout, cwd=zip_cwd)
|
||||
if not result.is_error:
|
||||
return
|
||||
if result.exit_code == 12:
|
||||
self._write_empty_zip(vm, out_path)
|
||||
return
|
||||
raise CommandExecutionError("Failed to create zip archive", result)
|
||||
|
||||
def unzip(
|
||||
self,
|
||||
vm: VirtualEnvironment,
|
||||
*,
|
||||
archive_path: str,
|
||||
dest_dir: str,
|
||||
timeout: float,
|
||||
) -> None:
|
||||
execute(
|
||||
vm,
|
||||
["unzip", "-q", archive_path, "-d", dest_dir],
|
||||
timeout=timeout,
|
||||
error_message="Failed to unzip archive",
|
||||
)
|
||||
|
||||
def _write_empty_zip(self, vm: VirtualEnvironment, out_path: str) -> None:
|
||||
"""Write an empty but valid zip file."""
|
||||
script = (
|
||||
'printf "'
|
||||
"\\x50\\x4b\\x05\\x06"
|
||||
"\\x00\\x00\\x00\\x00"
|
||||
"\\x00\\x00\\x00\\x00"
|
||||
"\\x00\\x00\\x00\\x00"
|
||||
"\\x00\\x00\\x00\\x00"
|
||||
"\\x00\\x00\\x00\\x00"
|
||||
'" > "$1"'
|
||||
)
|
||||
execute(vm, ["sh", "-c", script, "sh", out_path], timeout=30, error_message="Failed to write empty zip")
|
||||
39
api/core/zip_sandbox/entities.py
Normal file
39
api/core/zip_sandbox/entities.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""Data classes for ZipSandbox file operations.
|
||||
|
||||
Separated from ``zip_sandbox.py`` so that lightweight consumers (tests,
|
||||
shell-script builders) can import the types without pulling in the full
|
||||
sandbox provider chain.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SandboxDownloadItem:
|
||||
"""Unified download/inline item for sandbox file operations.
|
||||
|
||||
For remote files, *url* is set and the item is fetched via ``curl``.
|
||||
For inline content, *content* is set and the bytes are written directly
|
||||
into the VM via ``upload_file`` — no network round-trip.
|
||||
"""
|
||||
|
||||
path: str
|
||||
url: str = ""
|
||||
content: bytes | None = field(default=None, repr=False)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SandboxUploadItem:
|
||||
"""Item for uploading: sandbox path -> URL."""
|
||||
|
||||
path: str
|
||||
url: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SandboxFile:
|
||||
"""A handle to a file in the sandbox."""
|
||||
|
||||
path: str
|
||||
106
api/core/zip_sandbox/node_strategy.py
Normal file
106
api/core/zip_sandbox/node_strategy.py
Normal file
@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.virtual_environment.__base.helpers import execute, try_execute
|
||||
|
||||
from .strategy import ZipStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
|
||||
ZIP_SCRIPT = r"""
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const AdmZip = require('adm-zip');
|
||||
|
||||
const src = process.argv[2];
|
||||
const outPath = process.argv[3];
|
||||
|
||||
function walkAdd(zip, absPath, arcPrefix) {
|
||||
const stat = fs.statSync(absPath);
|
||||
if (stat.isDirectory()) {
|
||||
const entries = fs.readdirSync(absPath);
|
||||
if (entries.length === 0) {
|
||||
zip.addFile(arcPrefix.replace(/\\/g, '/') + '/', Buffer.alloc(0));
|
||||
return;
|
||||
}
|
||||
for (const e of entries) {
|
||||
walkAdd(zip, path.join(absPath, e), path.posix.join(arcPrefix, e));
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (stat.isFile()) {
|
||||
const data = fs.readFileSync(absPath);
|
||||
zip.addFile(arcPrefix.replace(/\\/g, '/'), data);
|
||||
}
|
||||
}
|
||||
|
||||
const zip = new AdmZip();
|
||||
if (src === '.' || src === '') {
|
||||
const entries = fs.readdirSync('.');
|
||||
for (const e of entries) {
|
||||
walkAdd(zip, path.join('.', e), e);
|
||||
}
|
||||
} else {
|
||||
const base = path.dirname(src) || '.';
|
||||
const prefix = path.basename(src.replace(/\/+$/, ''));
|
||||
const root = path.join(base, prefix);
|
||||
walkAdd(zip, root, prefix);
|
||||
}
|
||||
|
||||
zip.writeZip(outPath);
|
||||
"""
|
||||
|
||||
UNZIP_SCRIPT = r"""
|
||||
const AdmZip = require('adm-zip');
|
||||
const archivePath = process.argv[2];
|
||||
const destDir = process.argv[3];
|
||||
const zip = new AdmZip(archivePath);
|
||||
zip.extractAllTo(destDir, true);
|
||||
"""
|
||||
|
||||
|
||||
class NodeZipStrategy(ZipStrategy):
|
||||
"""Strategy using Node.js with adm-zip package."""
|
||||
|
||||
def is_available(self, vm: VirtualEnvironment) -> bool:
|
||||
result = try_execute(vm, ["which", "node"], timeout=10)
|
||||
if not (result.stdout and result.stdout.strip()):
|
||||
return False
|
||||
# Check if adm-zip module is available
|
||||
result = try_execute(vm, ["node", "-e", "require('adm-zip')"], timeout=10)
|
||||
return not result.is_error
|
||||
|
||||
def zip(
|
||||
self,
|
||||
vm: VirtualEnvironment,
|
||||
*,
|
||||
src: str,
|
||||
out_path: str,
|
||||
cwd: str | None,
|
||||
timeout: float,
|
||||
) -> None:
|
||||
execute(
|
||||
vm,
|
||||
["node", "-e", ZIP_SCRIPT, src, out_path],
|
||||
timeout=timeout,
|
||||
cwd=cwd,
|
||||
error_message="Failed to create zip archive",
|
||||
)
|
||||
|
||||
def unzip(
|
||||
self,
|
||||
vm: VirtualEnvironment,
|
||||
*,
|
||||
archive_path: str,
|
||||
dest_dir: str,
|
||||
timeout: float,
|
||||
) -> None:
|
||||
execute(
|
||||
vm,
|
||||
["node", "-e", UNZIP_SCRIPT, archive_path, dest_dir],
|
||||
timeout=timeout,
|
||||
error_message="Failed to unzip archive",
|
||||
)
|
||||
117
api/core/zip_sandbox/python_strategy.py
Normal file
117
api/core/zip_sandbox/python_strategy.py
Normal file
@ -0,0 +1,117 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.virtual_environment.__base.helpers import execute, try_execute
|
||||
|
||||
from .strategy import ZipStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
|
||||
ZIP_SCRIPT = r"""
|
||||
import os
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
src = sys.argv[1]
|
||||
out_path = sys.argv[2]
|
||||
|
||||
def is_cwd(p: str) -> bool:
|
||||
return p in (".", "")
|
||||
|
||||
src = src.rstrip("/")
|
||||
|
||||
if is_cwd(src):
|
||||
base = "."
|
||||
root = "."
|
||||
prefix = ""
|
||||
else:
|
||||
base = os.path.dirname(src) or "."
|
||||
prefix = os.path.basename(src)
|
||||
root = os.path.join(base, prefix)
|
||||
|
||||
def add_empty_dir(zf: zipfile.ZipFile, arc_dir: str) -> None:
|
||||
name = arc_dir.rstrip("/") + "/"
|
||||
if name != "/":
|
||||
zf.writestr(name, b"")
|
||||
|
||||
with zipfile.ZipFile(out_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
||||
if os.path.isfile(root):
|
||||
zf.write(root, arcname=os.path.basename(root))
|
||||
else:
|
||||
for dirpath, dirnames, filenames in os.walk(root):
|
||||
rel_dir = os.path.relpath(dirpath, base)
|
||||
rel_dir = "" if rel_dir == "." else rel_dir
|
||||
if not dirnames and not filenames:
|
||||
add_empty_dir(zf, rel_dir)
|
||||
for fn in filenames:
|
||||
fp = os.path.join(dirpath, fn)
|
||||
arcname = os.path.join(rel_dir, fn) if rel_dir else fn
|
||||
zf.write(fp, arcname=arcname)
|
||||
"""
|
||||
|
||||
UNZIP_SCRIPT = r"""
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
archive_path = sys.argv[1]
|
||||
dest_dir = sys.argv[2]
|
||||
|
||||
with zipfile.ZipFile(archive_path, "r") as zf:
|
||||
zf.extractall(dest_dir)
|
||||
"""
|
||||
|
||||
|
||||
class PythonZipStrategy(ZipStrategy):
|
||||
"""Strategy using Python's zipfile module."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._python_cmd: str | None = None
|
||||
|
||||
def is_available(self, vm: VirtualEnvironment) -> bool:
|
||||
for cmd in ("python3", "python"):
|
||||
result = try_execute(vm, ["which", cmd], timeout=10)
|
||||
if result.stdout and result.stdout.strip():
|
||||
self._python_cmd = cmd
|
||||
return True
|
||||
return False
|
||||
|
||||
def zip(
|
||||
self,
|
||||
vm: VirtualEnvironment,
|
||||
*,
|
||||
src: str,
|
||||
out_path: str,
|
||||
cwd: str | None,
|
||||
timeout: float,
|
||||
) -> None:
|
||||
if self._python_cmd is None:
|
||||
raise RuntimeError("Python not available")
|
||||
|
||||
execute(
|
||||
vm,
|
||||
[self._python_cmd, "-c", ZIP_SCRIPT, src, out_path],
|
||||
timeout=timeout,
|
||||
cwd=cwd,
|
||||
error_message="Failed to create zip archive",
|
||||
)
|
||||
|
||||
def unzip(
|
||||
self,
|
||||
vm: VirtualEnvironment,
|
||||
*,
|
||||
archive_path: str,
|
||||
dest_dir: str,
|
||||
timeout: float,
|
||||
) -> None:
|
||||
if self._python_cmd is None:
|
||||
raise RuntimeError("Python not available")
|
||||
|
||||
execute(
|
||||
vm,
|
||||
[self._python_cmd, "-c", UNZIP_SCRIPT, archive_path, dest_dir],
|
||||
timeout=timeout,
|
||||
error_message="Failed to unzip archive",
|
||||
)
|
||||
41
api/core/zip_sandbox/strategy.py
Normal file
41
api/core/zip_sandbox/strategy.py
Normal file
@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
|
||||
class ZipStrategy(ABC):
|
||||
"""Abstract base class for zip/unzip strategies."""
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self, vm: VirtualEnvironment) -> bool:
|
||||
"""Check if this strategy is available in the given VM."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def zip(
|
||||
self,
|
||||
vm: VirtualEnvironment,
|
||||
*,
|
||||
src: str,
|
||||
out_path: str,
|
||||
cwd: str | None,
|
||||
timeout: float,
|
||||
) -> None:
|
||||
"""Create a zip archive."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def unzip(
|
||||
self,
|
||||
vm: VirtualEnvironment,
|
||||
*,
|
||||
archive_path: str,
|
||||
dest_dir: str,
|
||||
timeout: float,
|
||||
) -> None:
|
||||
"""Extract a zip archive."""
|
||||
...
|
||||
425
api/core/zip_sandbox/zip_sandbox.py
Normal file
425
api/core/zip_sandbox/zip_sandbox.py
Normal file
@ -0,0 +1,425 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import posixpath
|
||||
import shlex
|
||||
from io import BytesIO
|
||||
from pathlib import PurePosixPath
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from core.sandbox.builder import SandboxBuilder
|
||||
from core.sandbox.entities.sandbox_type import SandboxType
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.sandbox.storage.noop_storage import NoopSandboxStorage
|
||||
from core.virtual_environment.__base.exec import CommandExecutionError, PipelineExecutionError
|
||||
from core.virtual_environment.__base.helpers import execute, pipeline
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
|
||||
from .cli_strategy import CliZipStrategy
|
||||
from .entities import SandboxDownloadItem, SandboxFile, SandboxUploadItem
|
||||
from .node_strategy import NodeZipStrategy
|
||||
from .python_strategy import PythonZipStrategy
|
||||
from .strategy import ZipStrategy
|
||||
|
||||
|
||||
class ZipSandbox:
|
||||
"""A sandbox for archive (zip) operations.
|
||||
|
||||
Usage:
|
||||
with ZipSandbox(tenant_id=..., user_id=...) as zs:
|
||||
zs.download_items(items)
|
||||
archive = zs.zip()
|
||||
zs.upload(archive, upload_url)
|
||||
# VM automatically released on exit
|
||||
"""
|
||||
|
||||
_DEFAULT_TIMEOUT_SECONDS = 60 * 5
|
||||
_STRATEGIES: list[ZipStrategy] = [CliZipStrategy(), PythonZipStrategy(), NodeZipStrategy()]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
app_id: str = "zip-sandbox",
|
||||
sandbox_provider_type: str | None = None,
|
||||
sandbox_provider_options: dict[str, Any] | None = None,
|
||||
_vm: VirtualEnvironment | None = None,
|
||||
) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._app_id = app_id
|
||||
self._sandbox_provider_type = sandbox_provider_type
|
||||
self._sandbox_provider_options = sandbox_provider_options
|
||||
self._injected_vm = _vm
|
||||
|
||||
self._sandbox: Sandbox | None = None
|
||||
self._sandbox_id: str | None = None
|
||||
self._vm: VirtualEnvironment | None = None
|
||||
self._strategy: ZipStrategy | None = None
|
||||
|
||||
def __enter__(self) -> ZipSandbox:
|
||||
self._start()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
tb: TracebackType | None,
|
||||
) -> None:
|
||||
self._stop()
|
||||
|
||||
def _start(self) -> None:
|
||||
if self._vm is not None:
|
||||
raise RuntimeError("ZipSandbox already started")
|
||||
|
||||
if self._injected_vm is not None:
|
||||
self._vm = self._injected_vm
|
||||
self._sandbox_id = uuid4().hex
|
||||
return
|
||||
|
||||
if not self._tenant_id:
|
||||
raise ValueError("tenant_id is required")
|
||||
if not self._user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if self._sandbox_provider_type is None or self._sandbox_provider_options is None:
|
||||
provider = SandboxProviderService.get_sandbox_provider(self._tenant_id)
|
||||
provider_type = provider.provider_type
|
||||
provider_options = dict(provider.config)
|
||||
else:
|
||||
provider_type = self._sandbox_provider_type
|
||||
provider_options = dict(self._sandbox_provider_options)
|
||||
|
||||
self._sandbox_id = uuid4().hex
|
||||
|
||||
storage = NoopSandboxStorage()
|
||||
try:
|
||||
self._sandbox = (
|
||||
SandboxBuilder(self._tenant_id, SandboxType(provider_type))
|
||||
.options(provider_options)
|
||||
.user(self._user_id)
|
||||
.app(self._app_id)
|
||||
.storage(storage, assets_id="zip-sandbox")
|
||||
.build()
|
||||
)
|
||||
self._sandbox.wait_ready(timeout=60)
|
||||
self._vm = self._sandbox.vm
|
||||
except Exception:
|
||||
if self._sandbox is not None:
|
||||
self._sandbox.release()
|
||||
self._vm = None
|
||||
self._sandbox = None
|
||||
self._sandbox_id = None
|
||||
raise
|
||||
|
||||
def _stop(self) -> None:
|
||||
if self._vm is None:
|
||||
return
|
||||
|
||||
if self._sandbox is not None:
|
||||
self._sandbox.release()
|
||||
|
||||
self._vm = None
|
||||
self._sandbox = None
|
||||
self._sandbox_id = None
|
||||
self._strategy = None
|
||||
|
||||
@property
|
||||
def vm(self) -> VirtualEnvironment:
|
||||
if self._vm is None:
|
||||
raise RuntimeError("ZipSandbox not started. Use 'with ZipSandbox(...) as zs:'")
|
||||
return self._vm
|
||||
|
||||
def _get_strategy(self) -> ZipStrategy:
|
||||
if self._strategy is not None:
|
||||
return self._strategy
|
||||
|
||||
for strategy in self._STRATEGIES:
|
||||
if strategy.is_available(self.vm):
|
||||
self._strategy = strategy
|
||||
return strategy
|
||||
|
||||
raise RuntimeError("No available zip backend (zip/python/node+adm-zip)")
|
||||
|
||||
# ========== Path utilities ==========
|
||||
|
||||
@staticmethod
|
||||
def _normalize_path(path: str | None) -> str:
|
||||
raw = (path or ".").strip()
|
||||
if raw == "":
|
||||
raw = "."
|
||||
|
||||
p = PurePosixPath(raw)
|
||||
if p.is_absolute():
|
||||
raise ValueError("path must be relative")
|
||||
if any(part == ".." for part in p.parts):
|
||||
raise ValueError("path must not contain '..'")
|
||||
|
||||
normalized = str(p)
|
||||
return "." if normalized in (".", "") else normalized
|
||||
|
||||
@staticmethod
|
||||
def _dest_path_for_url(dest_dir: str, url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
path = parsed.path or ""
|
||||
name = posixpath.basename(path)
|
||||
if not name:
|
||||
name = "download.bin"
|
||||
return posixpath.join(dest_dir, name)
|
||||
|
||||
# ========== File operations ==========
|
||||
|
||||
def write_file(self, path: str, data: bytes) -> None:
|
||||
path = self._normalize_path(path)
|
||||
if path in ("", "."):
|
||||
raise ValueError("path must point to a file")
|
||||
|
||||
try:
|
||||
self.vm.upload_file(path, BytesIO(data))
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Failed to write file to sandbox: {exc}") from exc
|
||||
|
||||
def read_file(self, path: str, *, max_bytes: int = 10 * 1024 * 1024) -> bytes:
|
||||
path = self._normalize_path(path)
|
||||
if max_bytes <= 0:
|
||||
raise ValueError("max_bytes must be positive")
|
||||
|
||||
try:
|
||||
data = self.vm.download_file(path).getvalue()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Failed to read file from sandbox: {exc}") from exc
|
||||
|
||||
if len(data) > max_bytes:
|
||||
raise ValueError(f"File too large: {len(data)} > {max_bytes}")
|
||||
return data
|
||||
|
||||
# ========== Download operations ==========
|
||||
|
||||
def download_items(self, items: list[SandboxDownloadItem], *, dest_dir: str = ".") -> list[str]:
|
||||
"""Download or write items into the sandbox via a single pipeline.
|
||||
|
||||
Remote items (with *url*) are fetched via ``curl``. Inline items
|
||||
(with *content*) are written via ``base64 -d`` heredoc. Both go
|
||||
through the same pipeline — no branching at the structural level.
|
||||
"""
|
||||
if not items:
|
||||
return []
|
||||
|
||||
dest_dir = self._normalize_path(dest_dir)
|
||||
p = pipeline(self.vm)
|
||||
p.add(["mkdir", "-p", dest_dir], error_message="Failed to create download directory")
|
||||
|
||||
out_paths: list[str] = []
|
||||
for item in items:
|
||||
rel = self._normalize_path(item.path)
|
||||
if rel in ("", "."):
|
||||
raise ValueError("Download item path must point to a file")
|
||||
out_path = posixpath.join(dest_dir, rel)
|
||||
out_paths.append(out_path)
|
||||
out_dir = posixpath.dirname(out_path)
|
||||
if out_dir not in ("", "."):
|
||||
p.add(["mkdir", "-p", out_dir], error_message="Failed to create download directory")
|
||||
p.add(
|
||||
self.to_download_command(item, out_path),
|
||||
error_message=f"Failed to write {item.path}",
|
||||
)
|
||||
|
||||
try:
|
||||
p.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
return out_paths
|
||||
|
||||
@staticmethod
|
||||
def to_download_command(item: SandboxDownloadItem, out_path: str) -> list[str]:
|
||||
"""Return the shell command to materialise *item* at *out_path*."""
|
||||
if item.content is not None:
|
||||
encoded = base64.b64encode(item.content).decode("ascii")
|
||||
return ["sh", "-c", f"base64 -d <<'_B64_' > {shlex.quote(out_path)}\n{encoded}\n_B64_"]
|
||||
return ["curl", "-fsSL", item.url, "-o", out_path]
|
||||
|
||||
def download_archive(self, archive_url: str, *, path: str = "input.tar.gz") -> str:
|
||||
path = self._normalize_path(path)
|
||||
|
||||
dir_path = posixpath.dirname(path)
|
||||
p = pipeline(self.vm)
|
||||
if dir_path not in ("", "."):
|
||||
p.add(["mkdir", "-p", dir_path], error_message=f"Failed to create directory {dir_path}")
|
||||
p.add(["curl", "-fsSL", archive_url, "-o", path], error_message=f"Failed to download archive to {path}")
|
||||
|
||||
try:
|
||||
p.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
return path
|
||||
|
||||
# ========== Upload operations ==========
|
||||
|
||||
def upload(self, file: SandboxFile, target_url: str) -> None:
|
||||
"""Upload a sandbox file to the given URL."""
|
||||
try:
|
||||
execute(
|
||||
self.vm,
|
||||
["curl", "-fsSL", "-X", "PUT", "-T", file.path, target_url],
|
||||
timeout=self._DEFAULT_TIMEOUT_SECONDS,
|
||||
error_message="Failed to upload file from sandbox",
|
||||
)
|
||||
except CommandExecutionError as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
def upload_items(self, items: list[SandboxUploadItem], *, src_dir: str = ".") -> None:
|
||||
"""Upload multiple files from sandbox to target URLs.
|
||||
|
||||
Args:
|
||||
items: List of SandboxUploadItem(path, url)
|
||||
src_dir: Base directory containing the files
|
||||
"""
|
||||
if not items:
|
||||
return
|
||||
|
||||
src_dir = self._normalize_path(src_dir)
|
||||
p = pipeline(self.vm)
|
||||
|
||||
for item in items:
|
||||
rel = self._normalize_path(item.path)
|
||||
src_path = posixpath.join(src_dir, rel) if src_dir not in ("", ".") else rel
|
||||
p.add(
|
||||
["curl", "-fsSL", "-X", "PUT", "-T", src_path, item.url],
|
||||
error_message=f"Failed to upload {item.path}",
|
||||
)
|
||||
|
||||
try:
|
||||
p.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
# ========== Archive operations ==========
|
||||
|
||||
def zip(self, src: str = ".", *, include_base: bool = True) -> SandboxFile:
|
||||
"""Create a zip archive and return a handle to it."""
|
||||
src = self._normalize_path(src)
|
||||
out_path = f"/tmp/{uuid4().hex}.zip"
|
||||
|
||||
cwd = None
|
||||
src_for_strategy = src
|
||||
if src not in (".", "") and not include_base:
|
||||
cwd = src
|
||||
src_for_strategy = "."
|
||||
|
||||
try:
|
||||
self._get_strategy().zip(
|
||||
self.vm,
|
||||
src=src_for_strategy,
|
||||
out_path=out_path,
|
||||
cwd=cwd,
|
||||
timeout=self._DEFAULT_TIMEOUT_SECONDS,
|
||||
)
|
||||
except (PipelineExecutionError, CommandExecutionError) as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
return SandboxFile(path=out_path)
|
||||
|
||||
def unzip(self, *, archive_path: str, dest_dir: str = "unpacked") -> str:
|
||||
"""Extract a zip archive to the destination directory."""
|
||||
archive_path = self._normalize_path(archive_path)
|
||||
dest_dir = self._normalize_path(dest_dir)
|
||||
|
||||
if not archive_path.lower().endswith(".zip"):
|
||||
raise ValueError("archive_path must end with .zip")
|
||||
|
||||
try:
|
||||
pipeline(self.vm).add(
|
||||
["mkdir", "-p", dest_dir], error_message="Failed to create destination directory"
|
||||
).execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
|
||||
self._get_strategy().unzip(
|
||||
self.vm,
|
||||
archive_path=archive_path,
|
||||
dest_dir=dest_dir,
|
||||
timeout=self._DEFAULT_TIMEOUT_SECONDS,
|
||||
)
|
||||
except (PipelineExecutionError, CommandExecutionError) as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
return dest_dir
|
||||
|
||||
def untar(self, *, archive_path: str, dest_dir: str = "unpacked") -> str:
|
||||
"""Extract a tar archive to the destination directory."""
|
||||
archive_path = self._normalize_path(archive_path)
|
||||
dest_dir = self._normalize_path(dest_dir)
|
||||
|
||||
lower = archive_path.lower()
|
||||
is_gz = lower.endswith(".tar.gz") or lower.endswith(".tgz")
|
||||
extract_flag = "-xzf" if is_gz else "-xf"
|
||||
|
||||
try:
|
||||
(
|
||||
pipeline(self.vm)
|
||||
.add(["mkdir", "-p", dest_dir], error_message="Failed to create destination directory")
|
||||
.add(
|
||||
["sh", "-c", f'tar {extract_flag} "$1" -C "$2" 2>/dev/null; exit $?', "sh", archive_path, dest_dir],
|
||||
error_message="Failed to extract tar archive",
|
||||
)
|
||||
.execute(timeout=self._DEFAULT_TIMEOUT_SECONDS, raise_on_error=True)
|
||||
)
|
||||
except PipelineExecutionError as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
return dest_dir
|
||||
|
||||
def tar(self, src: str = ".", *, include_base: bool = True, compress: bool = True) -> SandboxFile:
|
||||
"""Create a tar archive and return a handle to it.
|
||||
|
||||
Args:
|
||||
src: Source path to archive (file or directory)
|
||||
include_base: If True, include the base directory name in the archive
|
||||
compress: If True, create a gzipped tar archive (.tar.gz)
|
||||
|
||||
Returns:
|
||||
SandboxFile handle to the created archive
|
||||
"""
|
||||
src = self._normalize_path(src)
|
||||
extension = ".tar.gz" if compress else ".tar"
|
||||
out_path = f"/tmp/{uuid4().hex}{extension}"
|
||||
|
||||
create_flag = "-czf" if compress else "-cf"
|
||||
|
||||
try:
|
||||
if src in (".", ""):
|
||||
# Archive current directory contents
|
||||
execute(
|
||||
self.vm,
|
||||
["tar", create_flag, out_path, "-C", ".", "."],
|
||||
timeout=self._DEFAULT_TIMEOUT_SECONDS,
|
||||
error_message="Failed to create tar archive",
|
||||
)
|
||||
elif include_base:
|
||||
# Archive with base directory name included
|
||||
parent_dir = posixpath.dirname(src) or "."
|
||||
base_name = posixpath.basename(src)
|
||||
execute(
|
||||
self.vm,
|
||||
["tar", create_flag, out_path, "-C", parent_dir, base_name],
|
||||
timeout=self._DEFAULT_TIMEOUT_SECONDS,
|
||||
error_message="Failed to create tar archive",
|
||||
)
|
||||
else:
|
||||
# Archive contents without base directory name
|
||||
execute(
|
||||
self.vm,
|
||||
["tar", create_flag, out_path, "-C", src, "."],
|
||||
timeout=self._DEFAULT_TIMEOUT_SECONDS,
|
||||
error_message="Failed to create tar archive",
|
||||
)
|
||||
except CommandExecutionError as exc:
|
||||
raise RuntimeError(str(exc)) from exc
|
||||
|
||||
return SandboxFile(path=out_path)
|
||||
41
api/dify_graph/entities/tool_entities.py
Normal file
41
api/dify_graph/entities/tool_entities.py
Normal file
@ -0,0 +1,41 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphon.file import File
|
||||
|
||||
|
||||
class ToolResultStatus(StrEnum):
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str | None = Field(default=None, description="Unique identifier for this tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool being called")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[str] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
|
||||
icon: str | dict[str, Any] | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict[str, Any] | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
provider: str | None = Field(default=None, description="Tool provider identifier")
|
||||
|
||||
|
||||
class ToolCallResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier for the tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[File] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
|
||||
929
api/dify_graph/nodes/agent/agent_node.py
Normal file
929
api/dify_graph/nodes/agent/agent_node.py
Normal file
@ -0,0 +1,929 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryMode
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.agent.exceptions import (
|
||||
AgentInputTypeError,
|
||||
AgentInvocationError,
|
||||
AgentMessageTransformError,
|
||||
AgentNodeError,
|
||||
AgentVariableNotFoundError,
|
||||
AgentVariableTypeError,
|
||||
ToolFileNotFoundError,
|
||||
)
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from graphon.file import File, FileTransferMethod
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from graphon.runtime import VariablePool
|
||||
from graphon.variables.segments import ArrayFileSegment, StringSegment
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models import ToolFile
|
||||
from models.model import Conversation
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
|
||||
class AgentNode(Node[AgentNodeData]):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
node_type = BuiltinNodeTypes.AGENT
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to get agent strategy: {str(e)}",
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
agent_parameters = strategy.get_parameters()
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
)
|
||||
parameters_for_log = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
credentials = self._generate_credentials(parameters=parameters)
|
||||
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Fetch memory for node memory saving
|
||||
memory = self._fetch_memory_for_save()
|
||||
|
||||
try:
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": self.node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
memory=memory,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(transform_error),
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_agent_parameters(
|
||||
self,
|
||||
*,
|
||||
agent_parameters: Sequence[AgentStrategyParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
strategy: PluginAgentStrategy,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
agent_parameters (Sequence[AgentParameter]): The list of agent parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (AgentNodeData): The data associated with the agent node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
case _:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
# But for backward compatibility with historical data
|
||||
# this version field judgment is still preserved here.
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
runtime_variable_pool = variable_pool
|
||||
dify_ctx = self.require_dify_context()
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
dify_ctx.tenant_id,
|
||||
dify_ctx.app_id,
|
||||
entity,
|
||||
dify_ctx.invoke_from,
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
ToolParameter.ToolParameterForm.FORM
|
||||
if tool_runtime_params.name in manual_input_params
|
||||
else tool_runtime_params.form
|
||||
)
|
||||
manual_input_value = {}
|
||||
if tool_runtime.entity.parameters:
|
||||
manual_input_value = {
|
||||
key: value for key, value in parameters.items() if key in manual_input_params
|
||||
}
|
||||
runtime_parameters = {
|
||||
**tool_runtime.runtime.runtime_parameters,
|
||||
**manual_input_value,
|
||||
}
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
model_instance, model_schema = self._fetch_model(value)
|
||||
# memory config
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
memory = self._fetch_memory(model_instance)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size or None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
]
|
||||
value["history_prompt_messages"] = history_prompt_messages
|
||||
if model_schema:
|
||||
# remove structured output feature to support old version agent plugin
|
||||
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
|
||||
value["entity"] = model_schema.model_dump(mode="json")
|
||||
else:
|
||||
value["entity"] = None
|
||||
result[parameter_name] = value
|
||||
|
||||
return result
|
||||
|
||||
def _generate_credentials(
|
||||
self,
|
||||
parameters: dict[str, Any],
|
||||
) -> InvokeCredentials:
|
||||
"""
|
||||
Generate credentials based on the given agent parameters.
|
||||
"""
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
credentials = InvokeCredentials()
|
||||
|
||||
# generate credentials for tools selector
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
if tool.get("credential_id"):
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
except ValidationError:
|
||||
continue
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AgentNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
typed_node_data = node_data
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
match input.type:
|
||||
case "mixed" | "constant":
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def agent_strategy_icon(self) -> str | None:
|
||||
"""
|
||||
Get agent strategy icon
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
dify_ctx = self.require_dify_context()
|
||||
plugins = manager.list_plugins(dify_ctx.tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory based on configuration mode.
|
||||
|
||||
Returns TokenBufferMemory for conversation mode (default),
|
||||
or NodeTokenBufferMemory for node mode (Chatflow only).
|
||||
"""
|
||||
node_data = self.node_data
|
||||
memory_config = node_data.memory
|
||||
|
||||
if not memory_config:
|
||||
return None
|
||||
|
||||
# get conversation id (required for both modes in Chatflow)
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
if memory_config.mode == MemoryMode.NODE:
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=self._node_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
|
||||
)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM, model=model_name
|
||||
)
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_instance, model_schema
|
||||
|
||||
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
|
||||
if model_schema.features:
|
||||
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
|
||||
try:
|
||||
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter MCP type tool
|
||||
:param strategy: plugin agent strategy
|
||||
:param tool: tool
|
||||
:return: filtered tool dict
|
||||
"""
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _fetch_memory_for_save(self) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory instance for saving node memory.
|
||||
This is a simplified version that doesn't require model_instance.
|
||||
"""
|
||||
from core.model_manager import ModelManager
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
node_data = self.node_data
|
||||
if not node_data.memory:
|
||||
return None
|
||||
|
||||
# Get conversation_id
|
||||
conversation_id_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_var, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_var.value
|
||||
|
||||
# Return appropriate memory type based on mode
|
||||
if node_data.memory.mode == MemoryMode.NODE:
|
||||
# For node memory, we need a model_instance for token counting
|
||||
# Use a simple default model for this purpose
|
||||
try:
|
||||
model_instance = ModelManager().get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=self._node_id,
|
||||
tenant_id=self.tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory doesn't need saving here
|
||||
return None
|
||||
|
||||
def _build_context(
|
||||
self,
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_query: str,
|
||||
assistant_response: str,
|
||||
agent_logs: list[AgentLogEvent],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from user query, tool calls, and assistant response.
|
||||
Format: user -> assistant(with tool_calls) -> tool -> assistant
|
||||
|
||||
The context includes:
|
||||
- Current user query (always present, may be empty)
|
||||
- Assistant message with tool_calls (if tools were called)
|
||||
- Tool results
|
||||
- Assistant's final response
|
||||
"""
|
||||
context_messages: list[PromptMessage] = []
|
||||
|
||||
# Always add user query (even if empty, to maintain conversation structure)
|
||||
context_messages.append(UserPromptMessage(content=user_query or ""))
|
||||
|
||||
# Extract actual tool calls from agent logs
|
||||
# Only include logs with label starting with "CALL " - these are real tool invocations
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_results: list[tuple[str, str, str]] = [] # (tool_call_id, tool_name, result)
|
||||
|
||||
for log in agent_logs:
|
||||
if log.status == "success" and log.label and log.label.startswith("CALL "):
|
||||
# Extract tool name from label (format: "CALL tool_name")
|
||||
tool_name = log.label[5:] # Remove "CALL " prefix
|
||||
tool_call_id = log.message_id
|
||||
|
||||
# Parse tool response from data
|
||||
data = log.data or {}
|
||||
tool_response = ""
|
||||
|
||||
# Try to extract the actual tool response
|
||||
if "tool_response" in data:
|
||||
tool_response = data["tool_response"]
|
||||
elif "output" in data:
|
||||
tool_response = data["output"]
|
||||
elif "result" in data:
|
||||
tool_response = data["result"]
|
||||
|
||||
if isinstance(tool_response, dict):
|
||||
tool_response = str(tool_response)
|
||||
|
||||
# Get tool input for arguments
|
||||
tool_input = data.get("tool_call_input", {}) or data.get("input", {})
|
||||
if isinstance(tool_input, dict):
|
||||
import json
|
||||
|
||||
tool_input_str = json.dumps(tool_input, ensure_ascii=False)
|
||||
else:
|
||||
tool_input_str = str(tool_input) if tool_input else ""
|
||||
|
||||
if tool_response:
|
||||
tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_name,
|
||||
arguments=tool_input_str,
|
||||
),
|
||||
)
|
||||
)
|
||||
tool_results.append((tool_call_id, tool_name, str(tool_response)))
|
||||
|
||||
# Add assistant message with tool_calls if there were tool calls
|
||||
if tool_calls:
|
||||
context_messages.append(AssistantPromptMessage(content="", tool_calls=tool_calls))
|
||||
|
||||
# Add tool result messages
|
||||
for tool_call_id, tool_name, result in tool_results:
|
||||
context_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=result,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Add final assistant response
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
|
||||
return context_messages
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
memory: BaseMemory | None = None,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json_list: list[dict | list] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == BuiltinNodeTypes.AGENT:
|
||||
if isinstance(message.message.json_object, dict):
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
else:
|
||||
msg_metadata = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
agent_execution_metadata = {}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise AgentVariableTypeError(
|
||||
"When 'stream' is True, 'variable_value' must be a string.",
|
||||
variable_name=variable_name,
|
||||
expected_type="str",
|
||||
actual_type=type(variable_value).__name__,
|
||||
)
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
# Validate that meta contains a 'file' key
|
||||
if "file" not in message.meta:
|
||||
raise AgentNodeError("File message is missing 'file' key in meta")
|
||||
|
||||
# Validate that the file is an instance of File
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
message_id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.message_id == agent_log.message_id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json_list:
|
||||
json_output.extend(json_list)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
# Send final chunk events for all streamed outputs
|
||||
# Final chunk for text stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Final chunks for any streamed variables
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Get user query from parameters for building context
|
||||
user_query = parameters_for_log.get("query", "")
|
||||
|
||||
# Build context from history, user query, tool calls and assistant response
|
||||
context = self._build_context(parameters_for_log, user_query, text, agent_logs)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"text": text,
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
"context": context,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
@ -90,9 +90,9 @@ def init_app(app: DifyApp):
|
||||
app.register_blueprint(inner_api_bp)
|
||||
app.register_blueprint(mcp_bp)
|
||||
|
||||
# TODO: enable after full sandbox integration
|
||||
# from controllers.cli_api import bp as cli_api_bp
|
||||
# app.register_blueprint(cli_api_bp)
|
||||
from controllers.cli_api import bp as cli_api_bp
|
||||
|
||||
app.register_blueprint(cli_api_bp)
|
||||
|
||||
# Register trigger blueprint with CORS for webhook calls
|
||||
_apply_cors_once(
|
||||
|
||||
5
api/extensions/ext_socketio.py
Normal file
5
api/extensions/ext_socketio.py
Normal file
@ -0,0 +1,5 @@
|
||||
import socketio # type: ignore[reportMissingTypeStubs]
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
sio = socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS)
|
||||
17
api/fields/online_user_fields.py
Normal file
17
api/fields/online_user_fields.py
Normal file
@ -0,0 +1,17 @@
|
||||
from flask_restx import fields
|
||||
|
||||
online_user_partial_fields = {
|
||||
"user_id": fields.String,
|
||||
"username": fields.String,
|
||||
"avatar": fields.String,
|
||||
"sid": fields.String,
|
||||
}
|
||||
|
||||
workflow_online_users_fields = {
|
||||
"workflow_id": fields.String,
|
||||
"users": fields.List(fields.Nested(online_user_partial_fields)),
|
||||
}
|
||||
|
||||
online_user_list_fields = {
|
||||
"data": fields.List(fields.Nested(workflow_online_users_fields)),
|
||||
}
|
||||
96
api/fields/workflow_comment_fields.py
Normal file
96
api/fields/workflow_comment_fields.py
Normal file
@ -0,0 +1,96 @@
|
||||
from flask_restx import fields
|
||||
|
||||
from libs.helper import AvatarUrlField, TimestampField
|
||||
|
||||
# basic account fields for comments
|
||||
account_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
"avatar_url": AvatarUrlField,
|
||||
}
|
||||
|
||||
# Comment mention fields
|
||||
workflow_comment_mention_fields = {
|
||||
"mentioned_user_id": fields.String,
|
||||
"mentioned_user_account": fields.Nested(account_fields, allow_null=True),
|
||||
"reply_id": fields.String,
|
||||
}
|
||||
|
||||
# Comment reply fields
|
||||
workflow_comment_reply_fields = {
|
||||
"id": fields.String,
|
||||
"content": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
# Basic comment fields (for list views)
|
||||
workflow_comment_basic_fields = {
|
||||
"id": fields.String,
|
||||
"position_x": fields.Float,
|
||||
"position_y": fields.Float,
|
||||
"content": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"resolved": fields.Boolean,
|
||||
"resolved_at": TimestampField,
|
||||
"resolved_by": fields.String,
|
||||
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"reply_count": fields.Integer,
|
||||
"mention_count": fields.Integer,
|
||||
"participants": fields.List(fields.Nested(account_fields)),
|
||||
}
|
||||
|
||||
# Detailed comment fields (for single comment view)
|
||||
workflow_comment_detail_fields = {
|
||||
"id": fields.String,
|
||||
"position_x": fields.Float,
|
||||
"position_y": fields.Float,
|
||||
"content": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"resolved": fields.Boolean,
|
||||
"resolved_at": TimestampField,
|
||||
"resolved_by": fields.String,
|
||||
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"replies": fields.List(fields.Nested(workflow_comment_reply_fields)),
|
||||
"mentions": fields.List(fields.Nested(workflow_comment_mention_fields)),
|
||||
}
|
||||
|
||||
# Comment creation response fields (simplified)
|
||||
workflow_comment_create_fields = {
|
||||
"id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
# Comment update response fields (simplified)
|
||||
workflow_comment_update_fields = {
|
||||
"id": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
# Comment resolve response fields
|
||||
workflow_comment_resolve_fields = {
|
||||
"id": fields.String,
|
||||
"resolved": fields.Boolean,
|
||||
"resolved_at": TimestampField,
|
||||
"resolved_by": fields.String,
|
||||
}
|
||||
|
||||
# Reply creation response fields (simplified)
|
||||
workflow_comment_reply_create_fields = {
|
||||
"id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
# Reply update response fields
|
||||
workflow_comment_reply_update_fields = {
|
||||
"id": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
@ -0,0 +1,143 @@
|
||||
"""Add sandbox providers, app assets, and LLM detail tables.
|
||||
|
||||
Revision ID: aab323465866
|
||||
Revises: f55813ffe2c8
|
||||
Create Date: 2026-02-09 10:31:05.062722
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "aab323465866"
|
||||
down_revision = "c3df22613c99"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _get_ssh_config_from_env() -> dict[str, str]:
|
||||
"""Build SSH sandbox config from environment variables.
|
||||
|
||||
Defaults are chosen so that:
|
||||
- All-in-one Docker Compose (api inside the network): agentbox:22
|
||||
- Middleware / local dev (api on the host): 127.0.0.1:2222
|
||||
|
||||
The env vars (SSH_SANDBOX_*) are documented in api/.env.example.
|
||||
"""
|
||||
return {
|
||||
"ssh_host": os.environ.get("SSH_SANDBOX_HOST", "agentbox"),
|
||||
"ssh_port": os.environ.get("SSH_SANDBOX_PORT", "22"),
|
||||
"ssh_username": os.environ.get("SSH_SANDBOX_USERNAME", "agentbox"),
|
||||
"ssh_password": os.environ.get("SSH_SANDBOX_PASSWORD", "agentbox"),
|
||||
"base_working_path": os.environ.get("SSH_SANDBOX_BASE_WORKING_PATH", "/workspace/sandboxes"),
|
||||
}
|
||||
|
||||
|
||||
def upgrade():
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
|
||||
op.create_table(
|
||||
"sandbox_provider_system_config",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("provider_type", sa.String(length=50), nullable=False, comment="e2b, docker, local, ssh"),
|
||||
sa.Column("encrypted_config", models.types.LongText(), nullable=False, comment="Encrypted config JSON"),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="sandbox_provider_system_config_pkey"),
|
||||
sa.UniqueConstraint("provider_type", name="unique_sandbox_provider_system_config_type"),
|
||||
)
|
||||
op.create_table(
|
||||
"sandbox_providers",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("provider_type", sa.String(length=50), nullable=False, comment="e2b, docker, local, ssh"),
|
||||
sa.Column("configure_type", sa.String(length=20), server_default="user", nullable=False),
|
||||
sa.Column("encrypted_config", models.types.LongText(), nullable=False, comment="Encrypted config JSON"),
|
||||
sa.Column("is_active", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="sandbox_provider_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "provider_type", "configure_type", name="unique_sandbox_provider_tenant_type"),
|
||||
)
|
||||
with op.batch_alter_table("sandbox_providers", schema=None) as batch_op:
|
||||
batch_op.create_index("idx_sandbox_providers_tenant_active", ["tenant_id", "is_active"], unique=False)
|
||||
batch_op.create_index("idx_sandbox_providers_tenant_id", ["tenant_id"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"llm_generation_details",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("message_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("node_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("reasoning_content", models.types.LongText(), nullable=True),
|
||||
sa.Column("tool_calls", models.types.LongText(), nullable=True),
|
||||
sa.Column("sequence", models.types.LongText(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.CheckConstraint(
|
||||
"(message_id IS NOT NULL AND workflow_run_id IS NULL AND node_id IS NULL) OR (message_id IS NULL AND workflow_run_id IS NOT NULL AND node_id IS NOT NULL)",
|
||||
name=op.f("llm_generation_details_ck_llm_generation_detail_assoc_mode_check"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name="llm_generation_detail_pkey"),
|
||||
sa.UniqueConstraint("message_id", name=op.f("llm_generation_details_message_id_key")),
|
||||
)
|
||||
with op.batch_alter_table("llm_generation_details", schema=None) as batch_op:
|
||||
batch_op.create_index("idx_llm_generation_detail_message", ["message_id"], unique=False)
|
||||
batch_op.create_index("idx_llm_generation_detail_workflow", ["workflow_run_id", "node_id"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"app_assets",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("version", sa.String(length=255), nullable=False),
|
||||
sa.Column("asset_tree", models.types.LongText(), nullable=False),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="app_assets_pkey"),
|
||||
)
|
||||
with op.batch_alter_table("app_assets", schema=None) as batch_op:
|
||||
batch_op.create_index("app_assets_version_idx", ["tenant_id", "app_id", "version"], unique=False)
|
||||
|
||||
# Only seed a default SSH system provider for self-hosted deployments.
|
||||
# CLOUD editions manage sandbox providers through admin tooling.
|
||||
edition = os.environ.get("EDITION", "SELF_HOSTED")
|
||||
if edition == "SELF_HOSTED":
|
||||
ssh_config = _get_ssh_config_from_env()
|
||||
encrypted_config = encrypt_system_params(ssh_config)
|
||||
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO sandbox_provider_system_config
|
||||
(id, provider_type, encrypted_config, created_at, updated_at)
|
||||
VALUES (:id, :provider_type, :encrypted_config, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT (provider_type) DO NOTHING
|
||||
"""
|
||||
).bindparams(
|
||||
id=str(uuid4()),
|
||||
provider_type="ssh",
|
||||
encrypted_config=encrypted_config,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("app_assets")
|
||||
op.drop_table("llm_generation_details")
|
||||
|
||||
with op.batch_alter_table("sandbox_providers", schema=None) as batch_op:
|
||||
batch_op.drop_index("idx_sandbox_providers_tenant_id")
|
||||
batch_op.drop_index("idx_sandbox_providers_tenant_active")
|
||||
|
||||
op.drop_table("sandbox_providers")
|
||||
op.drop_table("sandbox_provider_system_config")
|
||||
@ -0,0 +1,109 @@
|
||||
"""Add workflow comments table
|
||||
|
||||
Revision ID: 227822d22895
|
||||
Revises: aab323465866
|
||||
Create Date: 2026-02-09 17:26:15.255980
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "227822d22895"
|
||||
down_revision = "aab323465866"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"workflow_comments",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("position_x", sa.Float(), nullable=False),
|
||||
sa.Column("position_y", sa.Float(), nullable=False),
|
||||
sa.Column("content", sa.Text(), nullable=False),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("resolved", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||
sa.Column("resolved_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("resolved_by", models.types.StringUUID(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
|
||||
)
|
||||
with op.batch_alter_table("workflow_comments", schema=None) as batch_op:
|
||||
batch_op.create_index("workflow_comments_app_idx", ["tenant_id", "app_id"], unique=False)
|
||||
batch_op.create_index("workflow_comments_created_at_idx", ["created_at"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"workflow_comment_replies",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("comment_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("content", sa.Text(), nullable=False),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["comment_id"],
|
||||
["workflow_comments.id"],
|
||||
name=op.f("workflow_comment_replies_comment_id_fkey"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
|
||||
)
|
||||
with op.batch_alter_table("workflow_comment_replies", schema=None) as batch_op:
|
||||
batch_op.create_index("comment_replies_comment_idx", ["comment_id"], unique=False)
|
||||
batch_op.create_index("comment_replies_created_at_idx", ["created_at"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"workflow_comment_mentions",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("comment_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("reply_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("mentioned_user_id", models.types.StringUUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["comment_id"],
|
||||
["workflow_comments.id"],
|
||||
name=op.f("workflow_comment_mentions_comment_id_fkey"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["reply_id"],
|
||||
["workflow_comment_replies.id"],
|
||||
name=op.f("workflow_comment_mentions_reply_id_fkey"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
|
||||
)
|
||||
with op.batch_alter_table("workflow_comment_mentions", schema=None) as batch_op:
|
||||
batch_op.create_index("comment_mentions_comment_idx", ["comment_id"], unique=False)
|
||||
batch_op.create_index("comment_mentions_reply_idx", ["reply_id"], unique=False)
|
||||
batch_op.create_index("comment_mentions_user_idx", ["mentioned_user_id"], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("workflow_comment_mentions", schema=None) as batch_op:
|
||||
batch_op.drop_index("comment_mentions_user_idx")
|
||||
batch_op.drop_index("comment_mentions_reply_idx")
|
||||
batch_op.drop_index("comment_mentions_comment_idx")
|
||||
|
||||
op.drop_table("workflow_comment_mentions")
|
||||
with op.batch_alter_table("workflow_comment_replies", schema=None) as batch_op:
|
||||
batch_op.drop_index("comment_replies_created_at_idx")
|
||||
batch_op.drop_index("comment_replies_comment_idx")
|
||||
|
||||
op.drop_table("workflow_comment_replies")
|
||||
with op.batch_alter_table("workflow_comments", schema=None) as batch_op:
|
||||
batch_op.drop_index("workflow_comments_created_at_idx")
|
||||
batch_op.drop_index("workflow_comments_app_idx")
|
||||
|
||||
op.drop_table("workflow_comments")
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,40 @@
|
||||
"""Add app_asset_contents table for inline content caching.
|
||||
|
||||
Revision ID: 5ee0aa981887
|
||||
Revises: aab323465866
|
||||
Create Date: 2026-03-09 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5ee0aa981887"
|
||||
down_revision = "6b5f9f8b1a2c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"app_asset_contents",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("node_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("content", sa.Text(), nullable=False, server_default=""),
|
||||
sa.Column("size", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa.Column("updated_at", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa.PrimaryKeyConstraint("id", name="app_asset_contents_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "app_id", "node_id", name="uq_asset_content_node"),
|
||||
)
|
||||
op.create_index("idx_asset_content_app", "app_asset_contents", ["tenant_id", "app_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("idx_asset_content_app", table_name="app_asset_contents")
|
||||
op.drop_table("app_asset_contents")
|
||||
89
api/models/app_asset.py
Normal file
89
api/models/app_asset.py
Normal file
@ -0,0 +1,89 @@
|
||||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, Integer, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
|
||||
from .base import Base
|
||||
from .types import LongText, StringUUID
|
||||
|
||||
|
||||
class AppAssets(Base):
|
||||
__tablename__ = "app_assets"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="app_assets_pkey"),
|
||||
sa.Index("app_assets_version_idx", "tenant_id", "app_id", "version"),
|
||||
)
|
||||
|
||||
VERSION_DRAFT = "draft"
|
||||
VERSION_PUBLISHED = "published"
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
version: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
_asset_tree: Mapped[str] = mapped_column("asset_tree", LongText, nullable=False, default='{"nodes":[]}')
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by: Mapped[str | None] = mapped_column(StringUUID)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=func.current_timestamp(),
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
||||
@property
|
||||
def asset_tree(self) -> AppAssetFileTree:
|
||||
if not self._asset_tree:
|
||||
return AppAssetFileTree()
|
||||
return AppAssetFileTree.model_validate_json(self._asset_tree)
|
||||
|
||||
@asset_tree.setter
|
||||
def asset_tree(self, value: AppAssetFileTree) -> None:
|
||||
self._asset_tree = value.model_dump_json()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AppAssets(id={self.id}, app_id={self.app_id}, version={self.version})>"
|
||||
|
||||
|
||||
class AppAssetContent(Base):
|
||||
"""Inline content cache for app asset draft files.
|
||||
|
||||
Acts as a read-through cache for S3: text-like asset content is dual-written
|
||||
here on save and read from DB first (falling back to S3 on miss with sync backfill).
|
||||
Keyed by (tenant_id, app_id, node_id) — stores only the current draft content,
|
||||
not published snapshots.
|
||||
|
||||
See core/app_assets/content_accessor.py for the accessor abstraction that
|
||||
manages the DB/S3 read-through and dual-write logic.
|
||||
"""
|
||||
|
||||
__tablename__ = "app_asset_contents"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="app_asset_contents_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "app_id", "node_id", name="uq_asset_content_node"),
|
||||
sa.Index("idx_asset_content_app", "tenant_id", "app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
node_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
content: Mapped[str] = mapped_column(LongText, nullable=False, default="")
|
||||
size: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=func.current_timestamp(),
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AppAssetContent(id={self.id}, node_id={self.node_id})>"
|
||||
210
api/models/comment.py
Normal file
210
api/models/comment.py
Normal file
@ -0,0 +1,210 @@
|
||||
"""Workflow comment models."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Index, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class WorkflowComment(Base):
|
||||
"""Workflow comment model for canvas commenting functionality.
|
||||
|
||||
Comments are associated with apps rather than specific workflow versions,
|
||||
since an app has only one draft workflow at a time and comments should persist
|
||||
across workflow version changes.
|
||||
|
||||
Attributes:
|
||||
id: Comment ID
|
||||
tenant_id: Workspace ID
|
||||
app_id: App ID (primary association, comments belong to apps)
|
||||
position_x: X coordinate on canvas
|
||||
position_y: Y coordinate on canvas
|
||||
content: Comment content
|
||||
created_by: Creator account ID
|
||||
created_at: Creation time
|
||||
updated_at: Last update time
|
||||
resolved: Whether comment is resolved
|
||||
resolved_at: Resolution time
|
||||
resolved_by: Resolver account ID
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_comments"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
|
||||
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
|
||||
Index("workflow_comments_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position_x: Mapped[float] = mapped_column(db.Float)
|
||||
position_y: Mapped[float] = mapped_column(db.Float)
|
||||
content: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime)
|
||||
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
|
||||
|
||||
# Relationships
|
||||
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
|
||||
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
|
||||
)
|
||||
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
|
||||
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
"""Get creator account."""
|
||||
if hasattr(self, "_created_by_account_cache"):
|
||||
return self._created_by_account_cache
|
||||
return db.session.get(Account, self.created_by)
|
||||
|
||||
def cache_created_by_account(self, account: Account | None) -> None:
|
||||
"""Cache creator account to avoid extra queries."""
|
||||
self._created_by_account_cache = account
|
||||
|
||||
@property
|
||||
def resolved_by_account(self):
|
||||
"""Get resolver account."""
|
||||
if hasattr(self, "_resolved_by_account_cache"):
|
||||
return self._resolved_by_account_cache
|
||||
if self.resolved_by:
|
||||
return db.session.get(Account, self.resolved_by)
|
||||
return None
|
||||
|
||||
def cache_resolved_by_account(self, account: Account | None) -> None:
|
||||
"""Cache resolver account to avoid extra queries."""
|
||||
self._resolved_by_account_cache = account
|
||||
|
||||
@property
|
||||
def reply_count(self):
|
||||
"""Get reply count."""
|
||||
return len(self.replies)
|
||||
|
||||
@property
|
||||
def mention_count(self):
|
||||
"""Get mention count."""
|
||||
return len(self.mentions)
|
||||
|
||||
@property
|
||||
def participants(self):
|
||||
"""Get all participants (creator + repliers + mentioned users)."""
|
||||
participant_ids = set()
|
||||
|
||||
# Add comment creator
|
||||
participant_ids.add(self.created_by)
|
||||
|
||||
# Add reply creators
|
||||
participant_ids.update(reply.created_by for reply in self.replies)
|
||||
|
||||
# Add mentioned users
|
||||
participant_ids.update(mention.mentioned_user_id for mention in self.mentions)
|
||||
|
||||
# Get account objects
|
||||
participants = []
|
||||
for user_id in participant_ids:
|
||||
account = db.session.get(Account, user_id)
|
||||
if account:
|
||||
participants.append(account)
|
||||
|
||||
return participants
|
||||
|
||||
|
||||
class WorkflowCommentReply(Base):
|
||||
"""Workflow comment reply model.
|
||||
|
||||
Attributes:
|
||||
id: Reply ID
|
||||
comment_id: Parent comment ID
|
||||
content: Reply content
|
||||
created_by: Creator account ID
|
||||
created_at: Creation time
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_comment_replies"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
|
||||
Index("comment_replies_comment_idx", "comment_id"),
|
||||
Index("comment_replies_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
content: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
"""Get creator account."""
|
||||
if hasattr(self, "_created_by_account_cache"):
|
||||
return self._created_by_account_cache
|
||||
return db.session.get(Account, self.created_by)
|
||||
|
||||
def cache_created_by_account(self, account: Account | None) -> None:
|
||||
"""Cache creator account to avoid extra queries."""
|
||||
self._created_by_account_cache = account
|
||||
|
||||
|
||||
class WorkflowCommentMention(Base):
|
||||
"""Workflow comment mention model.
|
||||
|
||||
Mentions are only for internal accounts since end users
|
||||
cannot access workflow canvas and commenting features.
|
||||
|
||||
Attributes:
|
||||
id: Mention ID
|
||||
comment_id: Parent comment ID
|
||||
mentioned_user_id: Mentioned account ID
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_comment_mentions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
|
||||
Index("comment_mentions_comment_idx", "comment_id"),
|
||||
Index("comment_mentions_reply_idx", "reply_id"),
|
||||
Index("comment_mentions_user_idx", "mentioned_user_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
reply_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
|
||||
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
|
||||
|
||||
@property
|
||||
def mentioned_user_account(self):
|
||||
"""Get mentioned account."""
|
||||
if hasattr(self, "_mentioned_user_account_cache"):
|
||||
return self._mentioned_user_account_cache
|
||||
return db.session.get(Account, self.mentioned_user_id)
|
||||
|
||||
def cache_mentioned_user_account(self, account: Account | None) -> None:
|
||||
"""Cache mentioned account to avoid extra queries."""
|
||||
self._mentioned_user_account_cache = account
|
||||
80
api/models/sandbox.py
Normal file
80
api/models/sandbox.py
Normal file
@ -0,0 +1,80 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import TypeBase
|
||||
from .types import LongText, StringUUID
|
||||
|
||||
|
||||
class SandboxProviderSystemConfig(TypeBase):
|
||||
"""
|
||||
System-level sandbox provider configuration.
|
||||
Stores default configuration for each provider type.
|
||||
"""
|
||||
|
||||
__tablename__ = "sandbox_provider_system_config"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="sandbox_provider_system_config_pkey"),
|
||||
sa.UniqueConstraint("provider_type", name="unique_sandbox_provider_system_config_type"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
provider_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="e2b, docker, local, ssh")
|
||||
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False, comment="Encrypted config JSON")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
server_onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
|
||||
class SandboxProvider(TypeBase):
|
||||
"""
|
||||
Tenant-level sandbox provider configuration.
|
||||
Each tenant can have one configuration per provider type.
|
||||
Only one provider can be active at a time per tenant.
|
||||
"""
|
||||
|
||||
__tablename__ = "sandbox_providers"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="sandbox_provider_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "provider_type", "configure_type", name="unique_sandbox_provider_tenant_type"),
|
||||
sa.Index("idx_sandbox_providers_tenant_id", "tenant_id"),
|
||||
sa.Index("idx_sandbox_providers_tenant_active", "tenant_id", "is_active"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="e2b, docker, local, ssh")
|
||||
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False, comment="Encrypted config JSON")
|
||||
configure_type: Mapped[str] = mapped_column(String(20), nullable=False, server_default="user", default="user")
|
||||
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
server_onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def config(self) -> Mapping[str, Any]:
|
||||
return cast(Mapping[str, Any], json.loads(self.encrypted_config or "{}"))
|
||||
0
api/models/workflow_comment.py
Normal file
0
api/models/workflow_comment.py
Normal file
26
api/models/workflow_features.py
Normal file
26
api/models/workflow_features.py
Normal file
@ -0,0 +1,26 @@
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class WorkflowFeatures(StrEnum):
|
||||
SANDBOX = "sandbox"
|
||||
SPEECH_TO_TEXT = "speech_to_text"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
RETRIEVER_RESOURCE = "retriever_resource"
|
||||
SENSITIVE_WORD_AVOIDANCE = "sensitive_word_avoidance"
|
||||
FILE_UPLOAD = "file_upload"
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER = "suggested_questions_after_answer"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WorkflowFeature:
|
||||
enabled: bool
|
||||
config: Mapping[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any] | None) -> "WorkflowFeature":
|
||||
if data is None or not isinstance(data, dict):
|
||||
return cls(enabled=False, config={})
|
||||
return cls(enabled=bool(data.get("enabled", False)), config=data)
|
||||
226
api/repositories/workflow_collaboration_repository.py
Normal file
226
api/repositories/workflow_collaboration_repository.py
Normal file
@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TypedDict
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
SESSION_STATE_TTL_SECONDS = 3600
|
||||
WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:"
|
||||
WORKFLOW_LEADER_PREFIX = "workflow_leader:"
|
||||
WORKFLOW_SKILL_LEADER_PREFIX = "workflow_skill_leader:"
|
||||
WS_SID_MAP_PREFIX = "ws_sid_map:"
|
||||
|
||||
|
||||
class WorkflowSessionInfo(TypedDict):
|
||||
user_id: str
|
||||
username: str
|
||||
avatar: str | None
|
||||
sid: str
|
||||
connected_at: int
|
||||
graph_active: bool
|
||||
active_skill_file_id: str | None
|
||||
|
||||
|
||||
class SidMapping(TypedDict):
|
||||
workflow_id: str
|
||||
user_id: str
|
||||
|
||||
|
||||
class WorkflowCollaborationRepository:
|
||||
def __init__(self) -> None:
|
||||
self._redis = redis_client
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(redis_client={self._redis})"
|
||||
|
||||
@staticmethod
|
||||
def workflow_key(workflow_id: str) -> str:
|
||||
return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}"
|
||||
|
||||
@staticmethod
|
||||
def leader_key(workflow_id: str) -> str:
|
||||
return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}"
|
||||
|
||||
@staticmethod
|
||||
def skill_leader_key(workflow_id: str, file_id: str) -> str:
|
||||
return f"{WORKFLOW_SKILL_LEADER_PREFIX}{workflow_id}:{file_id}"
|
||||
|
||||
@staticmethod
|
||||
def sid_key(sid: str) -> str:
|
||||
return f"{WS_SID_MAP_PREFIX}{sid}"
|
||||
|
||||
@staticmethod
|
||||
def _decode(value: str | bytes | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
return value
|
||||
|
||||
def refresh_session_state(self, workflow_id: str, sid: str) -> None:
|
||||
workflow_key = self.workflow_key(workflow_id)
|
||||
sid_key = self.sid_key(sid)
|
||||
if self._redis.exists(workflow_key):
|
||||
self._redis.expire(workflow_key, SESSION_STATE_TTL_SECONDS)
|
||||
if self._redis.exists(sid_key):
|
||||
self._redis.expire(sid_key, SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
def set_session_info(self, workflow_id: str, session_info: WorkflowSessionInfo) -> None:
|
||||
workflow_key = self.workflow_key(workflow_id)
|
||||
self._redis.hset(workflow_key, session_info["sid"], json.dumps(session_info))
|
||||
self._redis.set(
|
||||
self.sid_key(session_info["sid"]),
|
||||
json.dumps({"workflow_id": workflow_id, "user_id": session_info["user_id"]}),
|
||||
ex=SESSION_STATE_TTL_SECONDS,
|
||||
)
|
||||
self.refresh_session_state(workflow_id, session_info["sid"])
|
||||
|
||||
def get_session_info(self, workflow_id: str, sid: str) -> WorkflowSessionInfo | None:
|
||||
raw = self._redis.hget(self.workflow_key(workflow_id), sid)
|
||||
value = self._decode(raw)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
session_info = json.loads(value)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
if not isinstance(session_info, dict):
|
||||
return None
|
||||
if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"user_id": str(session_info["user_id"]),
|
||||
"username": str(session_info["username"]),
|
||||
"avatar": session_info.get("avatar"),
|
||||
"sid": str(session_info["sid"]),
|
||||
"connected_at": int(session_info.get("connected_at") or 0),
|
||||
"graph_active": bool(session_info.get("graph_active")),
|
||||
"active_skill_file_id": session_info.get("active_skill_file_id"),
|
||||
}
|
||||
|
||||
def set_graph_active(self, workflow_id: str, sid: str, active: bool) -> None:
|
||||
session_info = self.get_session_info(workflow_id, sid)
|
||||
if not session_info:
|
||||
return
|
||||
session_info["graph_active"] = bool(active)
|
||||
self._redis.hset(self.workflow_key(workflow_id), sid, json.dumps(session_info))
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
|
||||
def is_graph_active(self, workflow_id: str, sid: str) -> bool:
|
||||
session_info = self.get_session_info(workflow_id, sid)
|
||||
if not session_info:
|
||||
return False
|
||||
return bool(session_info.get("graph_active") or False)
|
||||
|
||||
def set_active_skill_file(self, workflow_id: str, sid: str, file_id: str | None) -> None:
|
||||
session_info = self.get_session_info(workflow_id, sid)
|
||||
if not session_info:
|
||||
return
|
||||
session_info["active_skill_file_id"] = file_id
|
||||
self._redis.hset(self.workflow_key(workflow_id), sid, json.dumps(session_info))
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
|
||||
def get_active_skill_file_id(self, workflow_id: str, sid: str) -> str | None:
|
||||
session_info = self.get_session_info(workflow_id, sid)
|
||||
if not session_info:
|
||||
return None
|
||||
return session_info.get("active_skill_file_id")
|
||||
|
||||
def get_sid_mapping(self, sid: str) -> SidMapping | None:
|
||||
raw = self._redis.get(self.sid_key(sid))
|
||||
if not raw:
|
||||
return None
|
||||
value = self._decode(raw)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
def delete_session(self, workflow_id: str, sid: str) -> None:
|
||||
self._redis.hdel(self.workflow_key(workflow_id), sid)
|
||||
self._redis.delete(self.sid_key(sid))
|
||||
|
||||
def session_exists(self, workflow_id: str, sid: str) -> bool:
|
||||
return bool(self._redis.hexists(self.workflow_key(workflow_id), sid))
|
||||
|
||||
def sid_mapping_exists(self, sid: str) -> bool:
|
||||
return bool(self._redis.exists(self.sid_key(sid)))
|
||||
|
||||
def get_session_sids(self, workflow_id: str) -> list[str]:
|
||||
raw_sids = self._redis.hkeys(self.workflow_key(workflow_id))
|
||||
decoded_sids: list[str] = []
|
||||
for sid in raw_sids:
|
||||
decoded = self._decode(sid)
|
||||
if decoded:
|
||||
decoded_sids.append(decoded)
|
||||
return decoded_sids
|
||||
|
||||
def list_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]:
|
||||
sessions_json = self._redis.hgetall(self.workflow_key(workflow_id))
|
||||
users: list[WorkflowSessionInfo] = []
|
||||
|
||||
for session_info_json in sessions_json.values():
|
||||
value = self._decode(session_info_json)
|
||||
if not value:
|
||||
continue
|
||||
try:
|
||||
session_info = json.loads(value)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
continue
|
||||
|
||||
if not isinstance(session_info, dict):
|
||||
continue
|
||||
if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info:
|
||||
continue
|
||||
|
||||
users.append(
|
||||
{
|
||||
"user_id": str(session_info["user_id"]),
|
||||
"username": str(session_info["username"]),
|
||||
"avatar": session_info.get("avatar"),
|
||||
"sid": str(session_info["sid"]),
|
||||
"connected_at": int(session_info.get("connected_at") or 0),
|
||||
"graph_active": bool(session_info.get("graph_active")),
|
||||
"active_skill_file_id": session_info.get("active_skill_file_id"),
|
||||
}
|
||||
)
|
||||
|
||||
return users
|
||||
|
||||
def get_current_leader(self, workflow_id: str) -> str | None:
|
||||
raw = self._redis.get(self.leader_key(workflow_id))
|
||||
return self._decode(raw)
|
||||
|
||||
def get_skill_leader(self, workflow_id: str, file_id: str) -> str | None:
|
||||
raw = self._redis.get(self.skill_leader_key(workflow_id, file_id))
|
||||
return self._decode(raw)
|
||||
|
||||
def set_leader_if_absent(self, workflow_id: str, sid: str) -> bool:
|
||||
return bool(self._redis.set(self.leader_key(workflow_id), sid, nx=True, ex=SESSION_STATE_TTL_SECONDS))
|
||||
|
||||
def set_leader(self, workflow_id: str, sid: str) -> None:
|
||||
self._redis.set(self.leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
def set_skill_leader(self, workflow_id: str, file_id: str, sid: str) -> None:
|
||||
self._redis.set(self.skill_leader_key(workflow_id, file_id), sid, ex=SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
def delete_leader(self, workflow_id: str) -> None:
|
||||
self._redis.delete(self.leader_key(workflow_id))
|
||||
|
||||
def delete_skill_leader(self, workflow_id: str, file_id: str) -> None:
|
||||
self._redis.delete(self.skill_leader_key(workflow_id, file_id))
|
||||
|
||||
def expire_leader(self, workflow_id: str) -> None:
|
||||
self._redis.expire(self.leader_key(workflow_id), SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
def expire_skill_leader(self, workflow_id: str, file_id: str) -> None:
|
||||
self._redis.expire(self.skill_leader_key(workflow_id, file_id), SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
def get_active_skill_session_sids(self, workflow_id: str, file_id: str) -> list[str]:
|
||||
sessions = self.list_sessions(workflow_id)
|
||||
return [session["sid"] for session in sessions if session.get("active_skill_file_id") == file_id]
|
||||
195
api/services/app_asset_package_service.py
Normal file
195
api/services/app_asset_package_service.py
Normal file
@ -0,0 +1,195 @@
|
||||
"""Service for packaging and publishing app assets.
|
||||
|
||||
This service handles operations that require core.zip_sandbox,
|
||||
separated from AppAssetService to avoid circular imports.
|
||||
|
||||
Dependency flow:
|
||||
core/* -> AppAssetPackageService -> AppAssetService
|
||||
(core modules can import this service without circular dependency)
|
||||
|
||||
Inline content optimisation:
|
||||
``AssetItem`` objects returned by the build pipeline may carry an
|
||||
in-process *content* field (e.g. resolved ``.md`` skill documents).
|
||||
``AppAssetService.to_download_items()`` converts these into unified
|
||||
``SandboxDownloadItem`` instances, and ``ZipSandbox.download_items()``
|
||||
handles both inline and remote items natively.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.app_assets.builder import AssetBuildPipeline, BuildContext
|
||||
from core.app_assets.builder.file_builder import FileBuilder
|
||||
from core.app_assets.builder.skill_builder import SkillBuilder
|
||||
from core.app_assets.entities.assets import AssetItem
|
||||
from core.app_assets.storage import AssetPaths
|
||||
from core.zip_sandbox import ZipSandbox
|
||||
from models.app_asset import AppAssets
|
||||
from models.model import App
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppAssetPackageService:
|
||||
"""Service for packaging and publishing app assets.
|
||||
|
||||
This service is designed to be imported by core/* modules without
|
||||
causing circular imports. It depends on AppAssetService for basic
|
||||
asset operations but provides the packaging/publishing functionality
|
||||
that requires core.zip_sandbox.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_app_assets(tenant_id: str, assets_id: str) -> AppAssets:
|
||||
"""Get app assets by tenant_id and assets_id.
|
||||
|
||||
This is a read-only operation that doesn't require AppAssetService.
|
||||
"""
|
||||
from extensions.ext_database import db
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_assets = (
|
||||
session.query(AppAssets)
|
||||
.filter(
|
||||
AppAssets.tenant_id == tenant_id,
|
||||
AppAssets.id == assets_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not app_assets:
|
||||
raise ValueError(f"App assets not found for tenant_id={tenant_id}, assets_id={assets_id}")
|
||||
|
||||
return app_assets
|
||||
|
||||
@staticmethod
|
||||
def get_draft_asset_items(tenant_id: str, app_id: str, file_tree: AppAssetFileTree) -> list[AssetItem]:
|
||||
"""Convert file tree to asset items for packaging."""
|
||||
files = file_tree.walk_files()
|
||||
return [
|
||||
AssetItem(
|
||||
asset_id=f.id,
|
||||
path=file_tree.get_path(f.id),
|
||||
file_name=f.name,
|
||||
extension=f.extension,
|
||||
storage_key=AssetPaths.draft(tenant_id, app_id, f.id),
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def package_and_upload(
|
||||
*,
|
||||
assets: list[AssetItem],
|
||||
upload_url: str,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
storage_key: str = "",
|
||||
) -> None:
|
||||
"""Package assets into a ZIP and upload directly to the given URL.
|
||||
|
||||
Uses ``AppAssetService.to_download_items()`` to convert assets
|
||||
into unified download items, then ``ZipSandbox.download_items()``
|
||||
handles both inline content and remote presigned URLs natively.
|
||||
|
||||
When *assets* is empty an empty ZIP is written directly to storage
|
||||
using *storage_key*, bypassing the HTTP ticket URL.
|
||||
"""
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
if not assets:
|
||||
import io
|
||||
import zipfile
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w"):
|
||||
pass
|
||||
buf.seek(0)
|
||||
|
||||
# Write directly to storage instead of going through the HTTP
|
||||
# ticket URL. The ticket URL (FILES_API_URL) is designed for
|
||||
# sandbox containers (agentbox) and is not routable from the api
|
||||
# container in standard Docker Compose deployments.
|
||||
if storage_key:
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
storage.save(storage_key, buf.getvalue())
|
||||
else:
|
||||
import requests
|
||||
|
||||
requests.put(upload_url, data=buf.getvalue(), timeout=30)
|
||||
return
|
||||
|
||||
download_items = AppAssetService.to_download_items(assets)
|
||||
|
||||
with ZipSandbox(tenant_id=tenant_id, user_id=user_id, app_id="asset-packager") as zs:
|
||||
zs.download_items(download_items)
|
||||
archive = zs.zip()
|
||||
zs.upload(archive, upload_url)
|
||||
|
||||
@staticmethod
|
||||
def publish(session: Session, app_model: App, account_id: str, workflow_id: str) -> AppAssets:
|
||||
"""Publish app assets for a workflow.
|
||||
|
||||
Creates a versioned copy of draft assets and packages them for
|
||||
runtime use. The build ZIP contains resolved ``.md`` content
|
||||
(inline from ``SkillBuilder``) and raw draft content for all
|
||||
other files. A separate source ZIP snapshots the raw drafts for
|
||||
later export.
|
||||
"""
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
tenant_id = app_model.tenant_id
|
||||
app_id = app_model.id
|
||||
|
||||
assets = AppAssetService.get_or_create_assets(session, app_model, account_id)
|
||||
tree = assets.asset_tree
|
||||
|
||||
publish_id = str(uuid4())
|
||||
|
||||
published = AppAssets(
|
||||
id=publish_id,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
version=workflow_id,
|
||||
created_by=account_id,
|
||||
)
|
||||
published.asset_tree = tree
|
||||
session.add(published)
|
||||
session.flush()
|
||||
|
||||
asset_storage = AppAssetService.get_storage()
|
||||
accessor = AppAssetService.get_accessor(tenant_id, app_id)
|
||||
build_pipeline = AssetBuildPipeline([SkillBuilder(accessor=accessor), FileBuilder()])
|
||||
ctx = BuildContext(tenant_id=tenant_id, app_id=app_id, build_id=publish_id)
|
||||
built_assets = build_pipeline.build_all(tree, ctx)
|
||||
|
||||
# Runtime ZIP: resolved .md (inline) + raw draft (remote).
|
||||
runtime_zip_key = AssetPaths.build_zip(tenant_id, app_id, publish_id)
|
||||
runtime_upload_url = asset_storage.get_upload_url(runtime_zip_key)
|
||||
AppAssetPackageService.package_and_upload(
|
||||
assets=built_assets,
|
||||
upload_url=runtime_upload_url,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=account_id,
|
||||
storage_key=runtime_zip_key,
|
||||
)
|
||||
|
||||
# Source ZIP: all raw draft content (for export/restore).
|
||||
source_items = AppAssetService.get_draft_assets(tenant_id, app_id)
|
||||
source_key = AssetPaths.source_zip(tenant_id, app_id, workflow_id)
|
||||
source_upload_url = asset_storage.get_upload_url(source_key)
|
||||
AppAssetPackageService.package_and_upload(
|
||||
assets=source_items,
|
||||
upload_url=source_upload_url,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=account_id,
|
||||
storage_key=source_key,
|
||||
)
|
||||
|
||||
return published
|
||||
443
api/services/app_runtime_upgrade_service.py
Normal file
443
api/services/app_runtime_upgrade_service.py
Normal file
@ -0,0 +1,443 @@
|
||||
"""Service for upgrading Classic runtime apps to Sandboxed runtime via clone-and-convert.
|
||||
|
||||
The upgrade flow:
|
||||
1. Clone the source app via DSL export/import
|
||||
2. On the cloned app's draft workflow, convert Agent nodes to LLM nodes
|
||||
3. Rewrite variable references for all LLM nodes (old output names → new generation-based names)
|
||||
4. Enable sandbox feature flag
|
||||
|
||||
The original app is never modified; the user gets a new sandboxed copy.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import App, Workflow
|
||||
from models.workflow_features import WorkflowFeatures
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_VAR_REWRITES: dict[str, list[str]] = {
|
||||
"text": ["generation", "content"],
|
||||
"reasoning_content": ["generation", "reasoning_content"],
|
||||
}
|
||||
|
||||
_PASSTHROUGH_KEYS = (
|
||||
"version",
|
||||
"error_strategy",
|
||||
"default_value",
|
||||
"retry_config",
|
||||
"parent_node_id",
|
||||
"isInLoop",
|
||||
"loop_id",
|
||||
"isInIteration",
|
||||
"iteration_id",
|
||||
)
|
||||
|
||||
|
||||
class AppRuntimeUpgradeService:
|
||||
"""Upgrades a Classic-runtime app to Sandboxed runtime by cloning and converting.
|
||||
|
||||
Holds an active SQLAlchemy session; the caller is responsible for commit/rollback.
|
||||
"""
|
||||
|
||||
session: Session
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def upgrade(self, app_model: App, account: Any) -> dict[str, Any]:
|
||||
"""Clone *app_model* and upgrade the clone to sandboxed runtime.
|
||||
|
||||
Returns:
|
||||
dict with keys: result, new_app_id, converted_agents, skipped_agents.
|
||||
"""
|
||||
workflow = self._get_draft_workflow(app_model)
|
||||
if not workflow:
|
||||
return {"result": "no_draft"}
|
||||
|
||||
if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled:
|
||||
return {"result": "already_sandboxed"}
|
||||
|
||||
new_app = self._clone_app(app_model, account)
|
||||
new_workflow = self._get_draft_workflow(new_app)
|
||||
if not new_workflow:
|
||||
return {"result": "no_draft"}
|
||||
|
||||
graph = json.loads(new_workflow.graph) if new_workflow.graph else {}
|
||||
nodes = graph.get("nodes", [])
|
||||
|
||||
converted, skipped = _convert_agent_nodes(nodes)
|
||||
_enable_computer_use_for_existing_llm_nodes(nodes)
|
||||
|
||||
llm_node_ids = {n["id"] for n in nodes if n.get("data", {}).get("type") == "llm"}
|
||||
_rewrite_variable_references(nodes, llm_node_ids)
|
||||
|
||||
new_workflow.graph = json.dumps(graph)
|
||||
|
||||
features = json.loads(new_workflow.features) if new_workflow.features else {}
|
||||
features.setdefault("sandbox", {})["enabled"] = True
|
||||
new_workflow.features = json.dumps(features)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"new_app_id": str(new_app.id),
|
||||
"converted_agents": converted,
|
||||
"skipped_agents": skipped,
|
||||
}
|
||||
|
||||
def _get_draft_workflow(self, app_model: App) -> Workflow | None:
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
return self.session.scalar(stmt)
|
||||
|
||||
def _clone_app(self, app_model: App, account: Any) -> App:
|
||||
dsl_service = AppDslService(self.session)
|
||||
yaml_content = dsl_service.export_dsl(app_model=app_model, include_secret=True)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=yaml_content,
|
||||
name=f"{app_model.name} (Sandboxed)",
|
||||
)
|
||||
stmt = select(App).where(App.id == result.app_id)
|
||||
new_app = self.session.scalar(stmt)
|
||||
if not new_app:
|
||||
raise RuntimeError(f"Cloned app not found: {result.app_id}")
|
||||
return new_app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure conversion functions (no DB access)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _convert_agent_nodes(nodes: list[dict[str, Any]]) -> tuple[int, int]:
|
||||
"""Convert Agent nodes to LLM nodes in-place. Returns (converted_count, skipped_count)."""
|
||||
converted = 0
|
||||
|
||||
for node in nodes:
|
||||
data = node.get("data", {})
|
||||
if data.get("type") != "agent":
|
||||
continue
|
||||
|
||||
node_id = node.get("id", "?")
|
||||
node["data"] = _agent_data_to_llm_data(data)
|
||||
logger.info("Converted agent node %s to LLM", node_id)
|
||||
converted += 1
|
||||
|
||||
return converted, 0
|
||||
|
||||
|
||||
def _agent_data_to_llm_data(agent_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Map an Agent node's data dict to an LLM node's data dict.
|
||||
|
||||
Always returns a valid LLM data dict. If the agent has no model selected,
|
||||
produces an empty LLM node with agent mode (computer_use) enabled.
|
||||
"""
|
||||
params = agent_data.get("agent_parameters") or {}
|
||||
|
||||
model_param = params.get("model", {}) if isinstance(params, dict) else {}
|
||||
model_value = model_param.get("value") if isinstance(model_param, dict) else None
|
||||
|
||||
if isinstance(model_value, dict) and model_value.get("provider") and model_value.get("model"):
|
||||
model_config = {
|
||||
"provider": model_value["provider"],
|
||||
"name": model_value["model"],
|
||||
"mode": model_value.get("mode", "chat"),
|
||||
"completion_params": model_value.get("completion_params", {}),
|
||||
}
|
||||
else:
|
||||
model_config = {"provider": "", "name": "", "mode": "chat", "completion_params": {}}
|
||||
|
||||
tools_param = params.get("tools", {})
|
||||
tools_value = tools_param.get("value", []) if isinstance(tools_param, dict) else []
|
||||
tools_meta, tool_settings = _convert_tools(tools_value if isinstance(tools_value, list) else [])
|
||||
|
||||
instruction_param = params.get("instruction", {})
|
||||
instruction = instruction_param.get("value", "") if isinstance(instruction_param, dict) else ""
|
||||
|
||||
query_param = params.get("query", {})
|
||||
query_value = query_param.get("value", "") if isinstance(query_param, dict) else ""
|
||||
|
||||
has_tools = bool(tools_meta)
|
||||
prompt_template = _build_prompt_template(
|
||||
instruction,
|
||||
query_value,
|
||||
skill=has_tools,
|
||||
tools=tools_value if has_tools else None,
|
||||
)
|
||||
|
||||
max_iter_param = params.get("maximum_iterations", {})
|
||||
max_iterations = max_iter_param.get("value", 100) if isinstance(max_iter_param, dict) else 100
|
||||
|
||||
context_config = _extract_context(params)
|
||||
vision_config = _extract_vision(params)
|
||||
|
||||
llm_data: dict[str, Any] = {
|
||||
"type": "llm",
|
||||
"title": agent_data.get("title", "LLM"),
|
||||
"desc": agent_data.get("desc", ""),
|
||||
"model": model_config,
|
||||
"prompt_template": prompt_template,
|
||||
"prompt_config": {"jinja2_variables": []},
|
||||
"memory": agent_data.get("memory"),
|
||||
"context": context_config,
|
||||
"vision": vision_config,
|
||||
"computer_use": True,
|
||||
"structured_output_switch_on": False,
|
||||
"reasoning_format": "separated",
|
||||
"tools": tools_meta,
|
||||
"tool_settings": tool_settings,
|
||||
"max_iterations": max_iterations,
|
||||
}
|
||||
|
||||
for key in _PASSTHROUGH_KEYS:
|
||||
if key in agent_data:
|
||||
llm_data[key] = agent_data[key]
|
||||
|
||||
return llm_data
|
||||
|
||||
|
||||
def _extract_context(params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract context config from agent_parameters for LLM node format.
|
||||
|
||||
Agent stores context as a variable selector in agent_parameters.context.value,
|
||||
e.g. ["knowledge_retrieval_node_id", "result"]. Maps to LLM ContextConfig.
|
||||
"""
|
||||
if not isinstance(params, dict):
|
||||
return {"enabled": False}
|
||||
|
||||
ctx_param = params.get("context", {})
|
||||
ctx_value = ctx_param.get("value") if isinstance(ctx_param, dict) else None
|
||||
|
||||
if isinstance(ctx_value, list) and len(ctx_value) >= 2 and all(isinstance(s, str) for s in ctx_value):
|
||||
return {"enabled": True, "variable_selector": ctx_value}
|
||||
|
||||
return {"enabled": False}
|
||||
|
||||
|
||||
def _extract_vision(params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract vision config from agent_parameters for LLM node format."""
|
||||
if not isinstance(params, dict):
|
||||
return {"enabled": False}
|
||||
|
||||
vision_param = params.get("vision", {})
|
||||
vision_value = vision_param.get("value") if isinstance(vision_param, dict) else None
|
||||
|
||||
if isinstance(vision_value, dict) and vision_value.get("enabled"):
|
||||
return vision_value
|
||||
|
||||
if isinstance(vision_value, bool) and vision_value:
|
||||
return {"enabled": True}
|
||||
|
||||
return {"enabled": False}
|
||||
|
||||
|
||||
def _enable_computer_use_for_existing_llm_nodes(nodes: list[dict[str, Any]]) -> None:
|
||||
"""Enable computer_use for existing LLM nodes that have tools configured.
|
||||
|
||||
After upgrade, the sandbox runtime requires computer_use=true for tool calling.
|
||||
Existing LLM nodes from classic mode may have tools but computer_use=false.
|
||||
"""
|
||||
for node in nodes:
|
||||
data = node.get("data", {})
|
||||
if data.get("type") != "llm":
|
||||
continue
|
||||
|
||||
tools = data.get("tools", [])
|
||||
if tools and not data.get("computer_use"):
|
||||
data["computer_use"] = True
|
||||
logger.info("Enabled computer_use for LLM node %s with %d tools", node.get("id", "?"), len(tools))
|
||||
|
||||
|
||||
def _convert_tools(
|
||||
tools_input: list[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""Convert agent tool dicts to (ToolMetadata[], ToolSetting[]).
|
||||
|
||||
Agent tools in graph JSON already use provider_name/settings/parameters —
|
||||
the same field names as LLM ToolMetadata. We pass them through with defaults
|
||||
for any missing fields.
|
||||
"""
|
||||
tools_meta: list[dict[str, Any]] = []
|
||||
tool_settings: list[dict[str, Any]] = []
|
||||
|
||||
for ts in tools_input:
|
||||
if not isinstance(ts, dict):
|
||||
continue
|
||||
|
||||
provider_name = ts.get("provider_name", "")
|
||||
tool_name = ts.get("tool_name", "")
|
||||
tool_type = ts.get("type", "builtin")
|
||||
|
||||
tools_meta.append(
|
||||
{
|
||||
"enabled": True,
|
||||
"type": tool_type,
|
||||
"provider_name": provider_name,
|
||||
"tool_name": tool_name,
|
||||
"plugin_unique_identifier": ts.get("plugin_unique_identifier"),
|
||||
"credential_id": ts.get("credential_id"),
|
||||
"parameters": ts.get("parameters", {}),
|
||||
"settings": ts.get("settings", {}) or ts.get("tool_configuration", {}),
|
||||
"extra": ts.get("extra", {}),
|
||||
}
|
||||
)
|
||||
|
||||
tool_settings.append(
|
||||
{
|
||||
"type": tool_type,
|
||||
"provider": provider_name,
|
||||
"tool_name": tool_name,
|
||||
"enabled": True,
|
||||
}
|
||||
)
|
||||
|
||||
return tools_meta, tool_settings
|
||||
|
||||
|
||||
def _build_prompt_template(
|
||||
instruction: Any,
|
||||
query: Any,
|
||||
*,
|
||||
skill: bool = False,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build LLM prompt_template from Agent instruction and query values.
|
||||
|
||||
When *skill* is True each message gets ``"skill": True`` so the sandbox
|
||||
engine treats the prompt as a skill document.
|
||||
|
||||
When *tools* is provided, tool reference placeholders
|
||||
(``§[tool].[provider].[name].[uuid]§``) are appended to the system
|
||||
message and the corresponding ``ToolReference`` entries are placed in the
|
||||
message's ``metadata.tools`` dict so the skill assembler can resolve them.
|
||||
Tools from the same provider are grouped into a single token list.
|
||||
"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
|
||||
system_text = instruction if isinstance(instruction, str) else (str(instruction) if instruction else "")
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
if tools:
|
||||
tool_refs: dict[str, dict[str, Any]] = {}
|
||||
provider_groups: dict[str, list[str]] = {}
|
||||
for ts in tools:
|
||||
if not isinstance(ts, dict):
|
||||
continue
|
||||
tool_uuid = str(uuid.uuid4())
|
||||
provider_id = ts.get("provider_name", "")
|
||||
tool_name = ts.get("tool_name", "")
|
||||
tool_type = ts.get("type", "builtin")
|
||||
|
||||
token = f"§[tool].[{provider_id}].[{tool_name}].[{tool_uuid}]§"
|
||||
provider_groups.setdefault(provider_id, []).append(token)
|
||||
tool_refs[tool_uuid] = {
|
||||
"type": tool_type,
|
||||
"configuration": {"fields": []},
|
||||
"enabled": True,
|
||||
**({"credential_id": ts.get("credential_id")} if ts.get("credential_id") else {}),
|
||||
}
|
||||
|
||||
if provider_groups:
|
||||
group_texts: list[str] = []
|
||||
for tokens in provider_groups.values():
|
||||
if len(tokens) == 1:
|
||||
group_texts.append(tokens[0])
|
||||
else:
|
||||
group_texts.append("[" + ",".join(tokens) + "]")
|
||||
all_tools_text = " ".join(group_texts)
|
||||
system_text = f"{system_text}\n\n{all_tools_text}" if system_text else all_tools_text
|
||||
metadata = {"tools": tool_refs, "files": []}
|
||||
|
||||
if system_text:
|
||||
msg: dict[str, Any] = {"role": "system", "text": system_text, "skill": skill}
|
||||
if metadata:
|
||||
msg["metadata"] = metadata
|
||||
messages.append(msg)
|
||||
|
||||
if isinstance(query, list) and len(query) >= 2:
|
||||
template_ref = "{{#" + ".".join(str(s) for s in query) + "#}}"
|
||||
messages.append({"role": "user", "text": template_ref, "skill": skill})
|
||||
elif query:
|
||||
messages.append({"role": "user", "text": str(query), "skill": skill})
|
||||
|
||||
if not messages:
|
||||
messages.append({"role": "user", "text": "", "skill": skill})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _rewrite_variable_references(nodes: list[dict[str, Any]], llm_ids: set[str]) -> None:
|
||||
"""Recursively walk all node data and rewrite variable references for LLM nodes.
|
||||
|
||||
Handles two forms:
|
||||
- Structured selectors: [node_id, "text"] → [node_id, "generation", "content"]
|
||||
- Template strings: {{#node_id.text#}} → {{#node_id.generation.content#}}
|
||||
"""
|
||||
if not llm_ids:
|
||||
return
|
||||
|
||||
escaped_ids = [re.escape(nid) for nid in llm_ids]
|
||||
patterns: list[tuple[re.Pattern[str], str]] = []
|
||||
for old_name, new_path in _VAR_REWRITES.items():
|
||||
pattern = re.compile(r"\{\{#(" + "|".join(escaped_ids) + r")\." + re.escape(old_name) + r"#\}\}")
|
||||
replacement = r"{{#\1." + ".".join(new_path) + r"#}}"
|
||||
patterns.append((pattern, replacement))
|
||||
|
||||
for node in nodes:
|
||||
data = node.get("data", {})
|
||||
_walk_and_rewrite(data, llm_ids, patterns)
|
||||
|
||||
|
||||
def _walk_and_rewrite(
|
||||
obj: Any,
|
||||
llm_ids: set[str],
|
||||
template_patterns: list[tuple[re.Pattern[str], str]],
|
||||
) -> Any:
|
||||
"""Recursively rewrite variable references in a nested data structure."""
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
obj[key] = _walk_and_rewrite(value, llm_ids, template_patterns)
|
||||
return obj
|
||||
|
||||
if isinstance(obj, list):
|
||||
if _is_variable_selector(obj, llm_ids):
|
||||
return _rewrite_selector(obj)
|
||||
for i, item in enumerate(obj):
|
||||
obj[i] = _walk_and_rewrite(item, llm_ids, template_patterns)
|
||||
return obj
|
||||
|
||||
if isinstance(obj, str):
|
||||
for pattern, replacement in template_patterns:
|
||||
obj = pattern.sub(replacement, obj)
|
||||
return obj
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def _is_variable_selector(lst: list, llm_ids: set[str]) -> bool:
|
||||
"""Check if a list is a structured variable selector pointing to an LLM node output."""
|
||||
if len(lst) < 2:
|
||||
return False
|
||||
if not all(isinstance(s, str) for s in lst):
|
||||
return False
|
||||
return lst[0] in llm_ids and lst[1] in _VAR_REWRITES
|
||||
|
||||
|
||||
def _rewrite_selector(selector: list[str]) -> list[str]:
|
||||
"""Rewrite [node_id, "text"] → [node_id, "generation", "content"]."""
|
||||
old_field = selector[1]
|
||||
new_path = _VAR_REWRITES[old_field]
|
||||
return [selector[0]] + new_path + selector[2:]
|
||||
103
api/services/asset_content_service.py
Normal file
103
api/services/asset_content_service.py
Normal file
@ -0,0 +1,103 @@
|
||||
"""Service for the app_asset_contents table.
|
||||
|
||||
Provides single-node and batch DB operations for the inline content cache.
|
||||
All methods are static and open their own short-lived sessions.
|
||||
|
||||
Collaborators:
|
||||
- models.app_asset.AppAssetContent (SQLAlchemy model)
|
||||
- core.app_assets.accessor (accessor abstraction that calls this service)
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.app_asset import AppAssetContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssetContentService:
|
||||
"""DB operations for the inline asset content cache.
|
||||
|
||||
All methods are static. All queries are scoped by tenant_id + app_id.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get(tenant_id: str, app_id: str, node_id: str) -> str | None:
|
||||
"""Get cached content for a single node. Returns None on miss."""
|
||||
with Session(db.engine) as session:
|
||||
return session.execute(
|
||||
select(AppAssetContent.content).where(
|
||||
AppAssetContent.tenant_id == tenant_id,
|
||||
AppAssetContent.app_id == app_id,
|
||||
AppAssetContent.node_id == node_id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
def get_many(tenant_id: str, app_id: str, node_ids: list[str]) -> dict[str, str]:
|
||||
"""Batch get. Returns {node_id: content} for hits only."""
|
||||
if not node_ids:
|
||||
return {}
|
||||
with Session(db.engine) as session:
|
||||
rows = session.execute(
|
||||
select(AppAssetContent.node_id, AppAssetContent.content).where(
|
||||
AppAssetContent.tenant_id == tenant_id,
|
||||
AppAssetContent.app_id == app_id,
|
||||
AppAssetContent.node_id.in_(node_ids),
|
||||
)
|
||||
).all()
|
||||
return {row.node_id: row.content for row in rows}
|
||||
|
||||
@staticmethod
|
||||
def upsert(tenant_id: str, app_id: str, node_id: str, content: str, size: int) -> None:
|
||||
"""Insert or update inline content for a single node."""
|
||||
with Session(db.engine) as session:
|
||||
stmt = pg_insert(AppAssetContent).values(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
node_id=node_id,
|
||||
content=content,
|
||||
size=size,
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
constraint="uq_asset_content_node",
|
||||
set_={
|
||||
"content": stmt.excluded.content,
|
||||
"size": stmt.excluded.size,
|
||||
},
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def delete(tenant_id: str, app_id: str, node_id: str) -> None:
|
||||
"""Delete cached content for a single node."""
|
||||
with Session(db.engine) as session:
|
||||
session.execute(
|
||||
delete(AppAssetContent).where(
|
||||
AppAssetContent.tenant_id == tenant_id,
|
||||
AppAssetContent.app_id == app_id,
|
||||
AppAssetContent.node_id == node_id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def delete_many(tenant_id: str, app_id: str, node_ids: list[str]) -> None:
|
||||
"""Delete cached content for multiple nodes."""
|
||||
if not node_ids:
|
||||
return
|
||||
with Session(db.engine) as session:
|
||||
session.execute(
|
||||
delete(AppAssetContent).where(
|
||||
AppAssetContent.tenant_id == tenant_id,
|
||||
AppAssetContent.app_id == app_id,
|
||||
AppAssetContent.node_id.in_(node_ids),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
17
api/services/errors/app_asset.py
Normal file
17
api/services/errors/app_asset.py
Normal file
@ -0,0 +1,17 @@
|
||||
from .base import BaseServiceError
|
||||
|
||||
|
||||
class AppAssetNodeNotFoundError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class AppAssetParentNotFoundError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class AppAssetPathConflictError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class AppAssetNodeTooLargeError(BaseServiceError):
|
||||
pass
|
||||
37
api/services/llm_generation_service.py
Normal file
37
api/services/llm_generation_service.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""
|
||||
LLM Generation Detail Service.
|
||||
|
||||
Provides methods to query and attach generation details to workflow node executions
|
||||
and messages, avoiding N+1 query problems.
|
||||
"""
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.llm_generation_entities import LLMGenerationDetailData
|
||||
from models import LLMGenerationDetail
|
||||
|
||||
|
||||
class LLMGenerationService:
|
||||
"""Service for handling LLM generation details."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
|
||||
def get_generation_detail_for_message(self, message_id: str) -> LLMGenerationDetailData | None:
|
||||
"""Query generation detail for a specific message."""
|
||||
stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id == message_id)
|
||||
detail = self._session.scalars(stmt).first()
|
||||
return detail.to_domain_model() if detail else None
|
||||
|
||||
def get_generation_details_for_messages(
|
||||
self,
|
||||
message_ids: list[str],
|
||||
) -> dict[str, LLMGenerationDetailData]:
|
||||
"""Batch query generation details for multiple messages."""
|
||||
if not message_ids:
|
||||
return {}
|
||||
|
||||
stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id.in_(message_ids))
|
||||
details = self._session.scalars(stmt).all()
|
||||
return {detail.message_id: detail.to_domain_model() for detail in details if detail.message_id}
|
||||
204
api/services/skill_service.py
Normal file
204
api/services/skill_service.py
Normal file
@ -0,0 +1,204 @@
|
||||
"""Service for extracting tool dependencies from LLM node skill prompts.
|
||||
|
||||
Two public entry points:
|
||||
|
||||
- ``extract_tool_dependencies`` — takes raw node data from the client,
|
||||
real-time builds a ``SkillBundle`` from current draft ``.md`` assets,
|
||||
and resolves transitive tool dependencies. Used by the per-node POST
|
||||
endpoint.
|
||||
- ``get_workflow_skills`` — scans all LLM nodes in a persisted draft
|
||||
workflow and returns per-node skill info. Uses a cached bundle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from functools import reduce
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
|
||||
from core.sandbox.entities.config import AppAssets
|
||||
from core.skill.assembler import SkillBundleAssembler, SkillDocumentAssembler
|
||||
from core.skill.entities.skill_bundle import SkillBundle
|
||||
from core.skill.entities.skill_document import SkillDocument
|
||||
from core.skill.entities.skill_metadata import SkillMetadata
|
||||
from core.skill.entities.tool_dependencies import ToolDependencies, ToolDependency
|
||||
from core.skill.skill_manager import SkillManager
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from models.model import App
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillService:
|
||||
"""Service for managing and retrieving skill information from workflows."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Per-node: client sends node data, server builds bundle in real-time
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def extract_tool_dependencies(
|
||||
app: App,
|
||||
node_data: Mapping[str, Any],
|
||||
user_id: str,
|
||||
) -> list[ToolDependency]:
|
||||
"""Extract tool dependencies from an LLM node's skill prompts.
|
||||
|
||||
Builds a fresh ``SkillBundle`` from current draft ``.md`` assets
|
||||
every time — no cached bundle is used. The caller supplies the
|
||||
full node ``data`` dict directly (not a ``node_id``).
|
||||
|
||||
Returns an empty list when the node has no skill prompts or when
|
||||
no draft assets exist.
|
||||
"""
|
||||
if node_data.get("type", "") != BuiltinNodeTypes.LLM:
|
||||
return []
|
||||
|
||||
if not SkillService._has_skill(node_data):
|
||||
return []
|
||||
|
||||
bundle = SkillService._build_bundle(app, user_id)
|
||||
if bundle is None:
|
||||
return []
|
||||
|
||||
return SkillService._resolve_prompt_dependencies(node_data, bundle)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _has_skill(node_data: Mapping[str, Any]) -> bool:
|
||||
"""Check if node has any skill prompts."""
|
||||
prompt_template_raw = node_data.get("prompt_template", [])
|
||||
if isinstance(prompt_template_raw, list):
|
||||
for prompt_item in cast(list[object], prompt_template_raw):
|
||||
if isinstance(prompt_item, dict) and prompt_item.get("skill", False):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _build_bundle(app: App, user_id: str) -> SkillBundle | None:
|
||||
"""Real-time build a SkillBundle from current draft .md assets.
|
||||
|
||||
Reads all ``.md`` nodes from the draft file tree, bulk-loads
|
||||
their content from the DB cache, parses into ``SkillDocument``
|
||||
objects, and assembles a full bundle with transitive dependency
|
||||
resolution.
|
||||
|
||||
The bundle is **not** persisted — it is built fresh for each
|
||||
request so the response always reflects the latest draft state.
|
||||
"""
|
||||
assets = AppAssetService.get_assets(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
user_id=user_id,
|
||||
is_draft=True,
|
||||
)
|
||||
if not assets:
|
||||
return None
|
||||
|
||||
file_tree: AppAssetFileTree = assets.asset_tree
|
||||
if file_tree.empty():
|
||||
return SkillBundle(assets_id=assets.id, asset_tree=file_tree)
|
||||
|
||||
# Collect all .md file nodes from the tree.
|
||||
md_nodes: list[AppAssetNode] = [n for n in file_tree.walk_files() if n.extension == "md"]
|
||||
if not md_nodes:
|
||||
return SkillBundle(assets_id=assets.id, asset_tree=file_tree)
|
||||
|
||||
# Bulk-load content from DB (with S3 fallback).
|
||||
accessor = AppAssetService.get_accessor(app.tenant_id, app.id)
|
||||
raw_contents = accessor.bulk_load(md_nodes)
|
||||
|
||||
# Parse into SkillDocuments.
|
||||
documents: dict[str, SkillDocument] = {}
|
||||
for node in md_nodes:
|
||||
raw = raw_contents.get(node.id)
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
data = {"skill_id": node.id, **json.loads(raw)}
|
||||
documents[node.id] = SkillDocument.model_validate(data)
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
logger.warning("Skipping unparseable skill document node_id=%s", node.id)
|
||||
continue
|
||||
|
||||
return SkillBundleAssembler(file_tree).assemble_bundle(documents, assets.id)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_prompt_dependencies(
|
||||
node_data: Mapping[str, Any],
|
||||
bundle: SkillBundle,
|
||||
) -> list[ToolDependency]:
|
||||
"""Resolve tool dependencies from skill prompts against a bundle."""
|
||||
assembler = SkillDocumentAssembler(bundle)
|
||||
tool_deps_list: list[ToolDependencies] = []
|
||||
|
||||
prompt_template_raw = node_data.get("prompt_template", [])
|
||||
if not isinstance(prompt_template_raw, list):
|
||||
return []
|
||||
|
||||
for prompt_item in cast(list[object], prompt_template_raw):
|
||||
if not isinstance(prompt_item, dict):
|
||||
continue
|
||||
prompt = cast(dict[str, Any], prompt_item)
|
||||
if not prompt.get("skill", False):
|
||||
continue
|
||||
|
||||
text_raw = prompt.get("text", "")
|
||||
text = text_raw if isinstance(text_raw, str) else str(text_raw)
|
||||
|
||||
metadata_obj: object = prompt.get("metadata")
|
||||
metadata = cast(dict[str, Any], metadata_obj) if isinstance(metadata_obj, dict) else {}
|
||||
|
||||
skill_entry = assembler.assemble_document(
|
||||
document=SkillDocument(
|
||||
skill_id="anonymous",
|
||||
content=text,
|
||||
metadata=SkillMetadata.model_validate(metadata),
|
||||
),
|
||||
base_path=AppAssets.PATH,
|
||||
)
|
||||
tool_deps_list.append(skill_entry.dependance.tools)
|
||||
|
||||
if not tool_deps_list:
|
||||
return []
|
||||
|
||||
merged = reduce(lambda x, y: x.merge(y), tool_deps_list)
|
||||
return merged.dependencies
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_dependencies_cached(
|
||||
app: App,
|
||||
node_data: Mapping[str, Any],
|
||||
user_id: str,
|
||||
) -> list[ToolDependency]:
|
||||
"""Extract tool dependencies using a cached SkillBundle.
|
||||
|
||||
Used by ``get_workflow_skills`` for the whole-workflow endpoint.
|
||||
"""
|
||||
assets = AppAssetService.get_assets(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
user_id=user_id,
|
||||
is_draft=True,
|
||||
)
|
||||
if not assets:
|
||||
return []
|
||||
|
||||
try:
|
||||
bundle = SkillManager.load_bundle(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
assets_id=assets.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to load cached skill bundle for app_id=%s", app.id, exc_info=True)
|
||||
return []
|
||||
|
||||
return SkillService._resolve_prompt_dependencies(node_data, bundle)
|
||||
153
api/services/storage_ticket_service.py
Normal file
153
api/services/storage_ticket_service.py
Normal file
@ -0,0 +1,153 @@
|
||||
"""Storage ticket service for generating opaque download/upload URLs.
|
||||
|
||||
This service provides a ticket-based approach for file access. Instead of exposing
|
||||
the real storage key in URLs, it generates a random UUID token and stores the mapping
|
||||
in Redis with a TTL.
|
||||
|
||||
Usage:
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
# Generate a download ticket
|
||||
url = StorageTicketService.create_download_url("path/to/file.txt", expires_in=300)
|
||||
|
||||
# Generate an upload ticket
|
||||
url = StorageTicketService.create_upload_url("path/to/file.txt", expires_in=300, max_bytes=10*1024*1024)
|
||||
|
||||
URL format:
|
||||
{FILES_API_URL}/files/storage-files/{token}
|
||||
|
||||
The token is validated by looking up the Redis key, which contains:
|
||||
- op: "download" or "upload"
|
||||
- storage_key: the real storage path
|
||||
- max_bytes: (upload only) maximum allowed upload size
|
||||
- filename: suggested filename for Content-Disposition header
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TICKET_KEY_PREFIX = "storage_files"
|
||||
DEFAULT_DOWNLOAD_TTL = 300 # 5 minutes
|
||||
DEFAULT_UPLOAD_TTL = 300 # 5 minutes
|
||||
DEFAULT_MAX_UPLOAD_BYTES = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
|
||||
class StorageTicket(BaseModel):
|
||||
"""Represents a storage access ticket."""
|
||||
|
||||
op: Literal["download", "upload"]
|
||||
storage_key: str
|
||||
max_bytes: int | None = None # upload only
|
||||
filename: str | None = None # suggested filename for download
|
||||
|
||||
|
||||
class StorageTicketService:
|
||||
"""Service for creating and validating storage access tickets."""
|
||||
|
||||
@classmethod
|
||||
def create_download_url(
|
||||
cls,
|
||||
storage_key: str,
|
||||
*,
|
||||
expires_in: int = DEFAULT_DOWNLOAD_TTL,
|
||||
filename: str | None = None,
|
||||
) -> str:
|
||||
"""Create a download ticket and return the URL.
|
||||
|
||||
Args:
|
||||
storage_key: The real storage path
|
||||
expires_in: TTL in seconds (default 300)
|
||||
filename: Suggested filename for Content-Disposition header
|
||||
|
||||
Returns:
|
||||
Full URL with token
|
||||
"""
|
||||
if filename is None:
|
||||
filename = storage_key.rsplit("/", 1)[-1]
|
||||
|
||||
ticket = StorageTicket(op="download", storage_key=storage_key, filename=filename)
|
||||
token = cls._store_ticket(ticket, expires_in)
|
||||
return cls._build_url(token)
|
||||
|
||||
@classmethod
|
||||
def create_upload_url(
|
||||
cls,
|
||||
storage_key: str,
|
||||
*,
|
||||
expires_in: int = DEFAULT_UPLOAD_TTL,
|
||||
max_bytes: int = DEFAULT_MAX_UPLOAD_BYTES,
|
||||
) -> str:
|
||||
"""Create an upload ticket and return the URL.
|
||||
|
||||
Args:
|
||||
storage_key: The real storage path
|
||||
expires_in: TTL in seconds (default 300)
|
||||
max_bytes: Maximum allowed upload size in bytes
|
||||
|
||||
Returns:
|
||||
Full URL with token
|
||||
"""
|
||||
ticket = StorageTicket(op="upload", storage_key=storage_key, max_bytes=max_bytes)
|
||||
token = cls._store_ticket(ticket, expires_in)
|
||||
return cls._build_url(token)
|
||||
|
||||
@classmethod
|
||||
def get_ticket(cls, token: str) -> StorageTicket | None:
|
||||
"""Retrieve a ticket by token.
|
||||
|
||||
Args:
|
||||
token: The UUID token from the URL
|
||||
|
||||
Returns:
|
||||
StorageTicket if found and valid, None otherwise
|
||||
"""
|
||||
key = cls._ticket_key(token)
|
||||
try:
|
||||
data = redis_client.get(key)
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
return StorageTicket.model_validate_json(data)
|
||||
except Exception:
|
||||
logger.warning("Failed to retrieve storage ticket: %s", token, exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _store_ticket(cls, ticket: StorageTicket, ttl: int) -> str:
|
||||
"""Store a ticket in Redis and return the token."""
|
||||
token = str(uuid4())
|
||||
key = cls._ticket_key(token)
|
||||
value = ticket.model_dump_json()
|
||||
redis_client.setex(key, ttl, value)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def _ticket_key(cls, token: str) -> str:
|
||||
"""Generate Redis key for a token."""
|
||||
return f"{TICKET_KEY_PREFIX}:{token}"
|
||||
|
||||
@classmethod
|
||||
def _build_url(cls, token: str) -> str:
|
||||
"""Build the full URL for a token.
|
||||
|
||||
FILES_API_URL is dedicated to sandbox runtime file access (agentbox/e2b/etc.).
|
||||
This endpoint must be routable from the runtime environment.
|
||||
"""
|
||||
base_url = dify_config.FILES_API_URL.strip()
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
"FILES_API_URL is required for sandbox runtime file access. "
|
||||
"Set FILES_API_URL to a URL reachable by your sandbox runtime. "
|
||||
"For public sandbox environments (e.g. e2b), use a public domain or IP."
|
||||
)
|
||||
base_url = base_url.rstrip("/")
|
||||
return f"{base_url}/files/storage-files/{token}"
|
||||
157
api/services/workflow/nested_node_graph_service.py
Normal file
157
api/services/workflow/nested_node_graph_service.py
Normal file
@ -0,0 +1,157 @@
|
||||
"""
|
||||
Service for generating Nested Node LLM graph structures.
|
||||
|
||||
This service creates graph structures containing LLM nodes configured for
|
||||
extracting values from list[PromptMessage] variables.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.model_runtime.entities import LLMMode
|
||||
from services.model_provider_service import ModelProviderService
|
||||
from services.workflow.entities import NestedNodeGraphRequest, NestedNodeGraphResponse, NestedNodeParameterSchema
|
||||
|
||||
|
||||
class NestedNodeGraphService:
|
||||
"""Service for generating Nested Node LLM graph structures."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
|
||||
def generate_nested_node_id(self, node_id: str, parameter_name: str) -> str:
|
||||
"""Generate nested node ID following the naming convention.
|
||||
|
||||
Format: {node_id}_ext_{parameter_name}
|
||||
"""
|
||||
return f"{node_id}_ext_{parameter_name}"
|
||||
|
||||
def generate_nested_node_graph(self, tenant_id: str, request: NestedNodeGraphRequest) -> NestedNodeGraphResponse:
|
||||
"""Generate a complete graph structure containing a Nested Node LLM node.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID for fetching default model config
|
||||
request: The nested node graph generation request
|
||||
|
||||
Returns:
|
||||
Complete graph structure with nodes, edges, and viewport
|
||||
"""
|
||||
node_id = self.generate_nested_node_id(request.parent_node_id, request.parameter_key)
|
||||
model_config = self._get_default_model_config(tenant_id)
|
||||
node = self._build_nested_node_llm_node(
|
||||
node_id=node_id,
|
||||
parent_node_id=request.parent_node_id,
|
||||
context_source=request.context_source,
|
||||
parameter_schema=request.parameter_schema,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
graph = {
|
||||
"nodes": [node],
|
||||
"edges": [],
|
||||
"viewport": {},
|
||||
}
|
||||
|
||||
return NestedNodeGraphResponse(graph=graph)
|
||||
|
||||
def _get_default_model_config(self, tenant_id: str) -> dict[str, Any]:
|
||||
"""Get the default LLM model configuration for the tenant."""
|
||||
model_provider_service = ModelProviderService()
|
||||
default_model = model_provider_service.get_default_model_of_model_type(
|
||||
tenant_id=tenant_id,
|
||||
model_type="llm",
|
||||
)
|
||||
|
||||
if default_model:
|
||||
return {
|
||||
"provider": default_model.provider.provider,
|
||||
"name": default_model.model,
|
||||
"mode": LLMMode.CHAT.value,
|
||||
"completion_params": {},
|
||||
}
|
||||
|
||||
# Fallback to empty config if no default model is configured
|
||||
return {
|
||||
"provider": "",
|
||||
"name": "",
|
||||
"mode": LLMMode.CHAT.value,
|
||||
"completion_params": {},
|
||||
}
|
||||
|
||||
def _build_nested_node_llm_node(
|
||||
self,
|
||||
*,
|
||||
node_id: str,
|
||||
parent_node_id: str,
|
||||
context_source: list[str],
|
||||
parameter_schema: NestedNodeParameterSchema,
|
||||
model_config: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Build the Nested Node LLM node structure.
|
||||
|
||||
The node uses:
|
||||
- $context in prompt_template to reference the PromptMessage list
|
||||
- structured_output for extracting the specific parameter
|
||||
- parent_node_id to associate with the parent node
|
||||
"""
|
||||
prompt_template = [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "Extract the required parameter value from the conversation context above.",
|
||||
"skill": False,
|
||||
},
|
||||
{"$context": context_source},
|
||||
{"role": "user", "text": "", "skill": False},
|
||||
]
|
||||
|
||||
structured_output = {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
parameter_schema.name: {
|
||||
"type": parameter_schema.type,
|
||||
"description": parameter_schema.description,
|
||||
}
|
||||
},
|
||||
"required": [parameter_schema.name],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"id": node_id,
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"type": BuiltinNodeTypes.LLM,
|
||||
# BaseNodeData fields
|
||||
"title": f"NestedNode: {parameter_schema.name}",
|
||||
"desc": f"Extract {parameter_schema.name} from conversation context",
|
||||
"version": "1",
|
||||
"error_strategy": None,
|
||||
"default_value": None,
|
||||
"retry_config": {"max_retries": 0},
|
||||
"parent_node_id": parent_node_id,
|
||||
# LLMNodeData fields
|
||||
"model": model_config,
|
||||
"prompt_template": prompt_template,
|
||||
"prompt_config": {"jinja2_variables": []},
|
||||
"memory": None,
|
||||
"context": {
|
||||
"enabled": False,
|
||||
"variable_selector": None,
|
||||
},
|
||||
"vision": {
|
||||
"enabled": False,
|
||||
"configs": {
|
||||
"variable_selector": ["sys", "files"],
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
"structured_output_enabled": True,
|
||||
"structured_output": structured_output,
|
||||
"computer_use": False,
|
||||
"tool_settings": [],
|
||||
},
|
||||
}
|
||||
391
api/services/workflow_collaboration_service.py
Normal file
391
api/services/workflow_collaboration_service.py
Normal file
@ -0,0 +1,391 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
|
||||
from models.account import Account
|
||||
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository, WorkflowSessionInfo
|
||||
|
||||
|
||||
class WorkflowCollaborationService:
|
||||
def __init__(self, repository: WorkflowCollaborationRepository, socketio) -> None:
|
||||
self._repository = repository
|
||||
self._socketio = socketio
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(repository={self._repository})"
|
||||
|
||||
def save_session(self, sid: str, user: Account) -> None:
|
||||
self._socketio.save_session(
|
||||
sid,
|
||||
{
|
||||
"user_id": user.id,
|
||||
"username": user.name,
|
||||
"avatar": user.avatar,
|
||||
},
|
||||
)
|
||||
|
||||
def register_session(self, workflow_id: str, sid: str) -> tuple[str, bool] | None:
|
||||
session = self._socketio.get_session(sid)
|
||||
user_id = session.get("user_id")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
session_info: WorkflowSessionInfo = {
|
||||
"user_id": str(user_id),
|
||||
"username": str(session.get("username", "Unknown")),
|
||||
"avatar": session.get("avatar"),
|
||||
"sid": sid,
|
||||
"connected_at": int(time.time()),
|
||||
"graph_active": True,
|
||||
"active_skill_file_id": None,
|
||||
}
|
||||
|
||||
self._repository.set_session_info(workflow_id, session_info)
|
||||
|
||||
leader_sid = self.get_or_set_leader(workflow_id, sid)
|
||||
is_leader = leader_sid == sid if leader_sid else False
|
||||
|
||||
self._socketio.enter_room(sid, workflow_id)
|
||||
self.broadcast_online_users(workflow_id)
|
||||
|
||||
self._socketio.emit("status", {"isLeader": is_leader}, room=sid)
|
||||
|
||||
return str(user_id), is_leader
|
||||
|
||||
def disconnect_session(self, sid: str) -> None:
|
||||
mapping = self._repository.get_sid_mapping(sid)
|
||||
if not mapping:
|
||||
return
|
||||
|
||||
workflow_id = mapping["workflow_id"]
|
||||
active_skill_file_id = self._repository.get_active_skill_file_id(workflow_id, sid)
|
||||
self._repository.delete_session(workflow_id, sid)
|
||||
|
||||
self.handle_leader_disconnect(workflow_id, sid)
|
||||
if active_skill_file_id:
|
||||
self.handle_skill_leader_disconnect(workflow_id, active_skill_file_id, sid)
|
||||
self.broadcast_online_users(workflow_id)
|
||||
|
||||
def relay_collaboration_event(self, sid: str, data: Mapping[str, object]) -> tuple[dict[str, str], int]:
|
||||
mapping = self._repository.get_sid_mapping(sid)
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
workflow_id = mapping["workflow_id"]
|
||||
user_id = mapping["user_id"]
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
|
||||
event_type = data.get("type")
|
||||
event_data = data.get("data")
|
||||
timestamp = data.get("timestamp", int(time.time()))
|
||||
|
||||
if not event_type:
|
||||
return {"msg": "invalid event type"}, 400
|
||||
|
||||
if event_type == "graph_view_active":
|
||||
is_active = False
|
||||
if isinstance(event_data, dict):
|
||||
is_active = bool(event_data.get("active") or False)
|
||||
self._repository.set_graph_active(workflow_id, sid, is_active)
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
self.broadcast_online_users(workflow_id)
|
||||
return {"msg": "graph_view_active_updated"}, 200
|
||||
|
||||
if event_type == "skill_file_active":
|
||||
file_id = None
|
||||
is_active = False
|
||||
if isinstance(event_data, dict):
|
||||
file_id = event_data.get("file_id")
|
||||
is_active = bool(event_data.get("active") or False)
|
||||
|
||||
if not file_id or not isinstance(file_id, str):
|
||||
return {"msg": "invalid skill_file_active payload"}, 400
|
||||
|
||||
previous_file_id = self._repository.get_active_skill_file_id(workflow_id, sid)
|
||||
next_file_id = file_id if is_active else None
|
||||
|
||||
if previous_file_id == next_file_id:
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
return {"msg": "skill_file_active_unchanged"}, 200
|
||||
|
||||
self._repository.set_active_skill_file(workflow_id, sid, next_file_id)
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
|
||||
if previous_file_id:
|
||||
self._ensure_skill_leader(workflow_id, previous_file_id)
|
||||
if next_file_id:
|
||||
self._ensure_skill_leader(workflow_id, next_file_id, preferred_sid=sid)
|
||||
|
||||
return {"msg": "skill_file_active_updated"}, 200
|
||||
|
||||
if event_type == "sync_request":
|
||||
leader_sid = self._repository.get_current_leader(workflow_id)
|
||||
if leader_sid and (
|
||||
self.is_session_active(workflow_id, leader_sid)
|
||||
and self._repository.is_graph_active(workflow_id, leader_sid)
|
||||
):
|
||||
target_sid = leader_sid
|
||||
else:
|
||||
if leader_sid:
|
||||
self._repository.delete_leader(workflow_id)
|
||||
target_sid = self._select_graph_leader(workflow_id, preferred_sid=sid)
|
||||
if target_sid:
|
||||
self._repository.set_leader(workflow_id, target_sid)
|
||||
self.broadcast_leader_change(workflow_id, target_sid)
|
||||
if not target_sid:
|
||||
return {"msg": "no_active_leader"}, 200
|
||||
|
||||
self._socketio.emit(
|
||||
"collaboration_update",
|
||||
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
|
||||
room=target_sid,
|
||||
)
|
||||
return {"msg": "sync_request_forwarded"}, 200
|
||||
|
||||
self._socketio.emit(
|
||||
"collaboration_update",
|
||||
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
|
||||
room=workflow_id,
|
||||
skip_sid=sid,
|
||||
)
|
||||
|
||||
return {"msg": "event_broadcasted"}, 200
|
||||
|
||||
def relay_graph_event(self, sid: str, data: object) -> tuple[dict[str, str], int]:
|
||||
mapping = self._repository.get_sid_mapping(sid)
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
workflow_id = mapping["workflow_id"]
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
|
||||
self._socketio.emit("graph_update", data, room=workflow_id, skip_sid=sid)
|
||||
|
||||
return {"msg": "graph_update_broadcasted"}, 200
|
||||
|
||||
def relay_skill_event(self, sid: str, data: object) -> tuple[dict[str, str], int]:
|
||||
mapping = self._repository.get_sid_mapping(sid)
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
workflow_id = mapping["workflow_id"]
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
|
||||
self._socketio.emit("skill_update", data, room=workflow_id, skip_sid=sid)
|
||||
|
||||
return {"msg": "skill_update_broadcasted"}, 200
|
||||
|
||||
def get_or_set_leader(self, workflow_id: str, sid: str) -> str | None:
|
||||
current_leader = self._repository.get_current_leader(workflow_id)
|
||||
|
||||
if current_leader:
|
||||
if self.is_session_active(workflow_id, current_leader) and self._repository.is_graph_active(
|
||||
workflow_id, current_leader
|
||||
):
|
||||
return current_leader
|
||||
self._repository.delete_session(workflow_id, current_leader)
|
||||
self._repository.delete_leader(workflow_id)
|
||||
|
||||
new_leader_sid = self._select_graph_leader(workflow_id, preferred_sid=sid)
|
||||
if not new_leader_sid:
|
||||
return None
|
||||
|
||||
was_set = self._repository.set_leader_if_absent(workflow_id, new_leader_sid)
|
||||
|
||||
if was_set:
|
||||
if current_leader:
|
||||
self.broadcast_leader_change(workflow_id, new_leader_sid)
|
||||
return new_leader_sid
|
||||
|
||||
current_leader = self._repository.get_current_leader(workflow_id)
|
||||
if current_leader:
|
||||
return current_leader
|
||||
|
||||
return new_leader_sid
|
||||
|
||||
def handle_leader_disconnect(self, workflow_id: str, disconnected_sid: str) -> None:
|
||||
current_leader = self._repository.get_current_leader(workflow_id)
|
||||
if not current_leader:
|
||||
return
|
||||
|
||||
if current_leader != disconnected_sid:
|
||||
return
|
||||
|
||||
new_leader_sid = self._select_graph_leader(workflow_id)
|
||||
if new_leader_sid:
|
||||
self._repository.set_leader(workflow_id, new_leader_sid)
|
||||
self.broadcast_leader_change(workflow_id, new_leader_sid)
|
||||
else:
|
||||
self._repository.delete_leader(workflow_id)
|
||||
self.broadcast_leader_change(workflow_id, None)
|
||||
|
||||
def handle_skill_leader_disconnect(self, workflow_id: str, file_id: str, disconnected_sid: str) -> None:
|
||||
current_leader = self._repository.get_skill_leader(workflow_id, file_id)
|
||||
if not current_leader:
|
||||
return
|
||||
|
||||
if current_leader != disconnected_sid:
|
||||
return
|
||||
|
||||
new_leader_sid = self._select_skill_leader(workflow_id, file_id)
|
||||
if new_leader_sid:
|
||||
self._repository.set_skill_leader(workflow_id, file_id, new_leader_sid)
|
||||
self.broadcast_skill_leader_change(workflow_id, file_id, new_leader_sid)
|
||||
else:
|
||||
self._repository.delete_skill_leader(workflow_id, file_id)
|
||||
self.broadcast_skill_leader_change(workflow_id, file_id, None)
|
||||
|
||||
def broadcast_leader_change(self, workflow_id: str, new_leader_sid: str | None) -> None:
|
||||
for sid in self._repository.get_session_sids(workflow_id):
|
||||
try:
|
||||
is_leader = new_leader_sid is not None and sid == new_leader_sid
|
||||
self._socketio.emit("status", {"isLeader": is_leader}, room=sid)
|
||||
except Exception:
|
||||
logging.exception("Failed to emit leader status to session %s", sid)
|
||||
|
||||
def broadcast_skill_leader_change(self, workflow_id: str, file_id: str, new_leader_sid: str | None) -> None:
|
||||
for sid in self._repository.get_session_sids(workflow_id):
|
||||
try:
|
||||
is_leader = new_leader_sid is not None and sid == new_leader_sid
|
||||
self._socketio.emit("skill_status", {"file_id": file_id, "isLeader": is_leader}, room=sid)
|
||||
except Exception:
|
||||
logging.exception("Failed to emit skill leader status to session %s", sid)
|
||||
|
||||
def get_current_leader(self, workflow_id: str) -> str | None:
|
||||
return self._repository.get_current_leader(workflow_id)
|
||||
|
||||
def _prune_inactive_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]:
|
||||
"""Remove inactive sessions from storage and return active sessions only."""
|
||||
sessions = self._repository.list_sessions(workflow_id)
|
||||
if not sessions:
|
||||
return []
|
||||
|
||||
active_sessions: list[WorkflowSessionInfo] = []
|
||||
stale_sids: list[str] = []
|
||||
for session in sessions:
|
||||
sid = session["sid"]
|
||||
if self.is_session_active(workflow_id, sid):
|
||||
active_sessions.append(session)
|
||||
else:
|
||||
stale_sids.append(sid)
|
||||
|
||||
for sid in stale_sids:
|
||||
self._repository.delete_session(workflow_id, sid)
|
||||
|
||||
return active_sessions
|
||||
|
||||
def broadcast_online_users(self, workflow_id: str) -> None:
|
||||
users = self._prune_inactive_sessions(workflow_id)
|
||||
users.sort(key=lambda x: x.get("connected_at") or 0)
|
||||
|
||||
leader_sid = self.get_current_leader(workflow_id)
|
||||
previous_leader = leader_sid
|
||||
active_sids = {user["sid"] for user in users}
|
||||
if leader_sid and leader_sid not in active_sids:
|
||||
self._repository.delete_leader(workflow_id)
|
||||
leader_sid = None
|
||||
|
||||
if not leader_sid and users:
|
||||
leader_sid = self._select_graph_leader(workflow_id)
|
||||
if leader_sid:
|
||||
self._repository.set_leader(workflow_id, leader_sid)
|
||||
|
||||
if leader_sid != previous_leader:
|
||||
self.broadcast_leader_change(workflow_id, leader_sid)
|
||||
|
||||
self._socketio.emit(
|
||||
"online_users",
|
||||
{"workflow_id": workflow_id, "users": users, "leader": leader_sid},
|
||||
room=workflow_id,
|
||||
)
|
||||
|
||||
def refresh_session_state(self, workflow_id: str, sid: str) -> None:
|
||||
self._repository.refresh_session_state(workflow_id, sid)
|
||||
self._ensure_leader(workflow_id, sid)
|
||||
active_skill_file_id = self._repository.get_active_skill_file_id(workflow_id, sid)
|
||||
if active_skill_file_id:
|
||||
self._ensure_skill_leader(workflow_id, active_skill_file_id, preferred_sid=sid)
|
||||
|
||||
def _ensure_leader(self, workflow_id: str, sid: str) -> None:
|
||||
current_leader = self._repository.get_current_leader(workflow_id)
|
||||
if (
|
||||
current_leader
|
||||
and self.is_session_active(workflow_id, current_leader)
|
||||
and self._repository.is_graph_active(workflow_id, current_leader)
|
||||
):
|
||||
self._repository.expire_leader(workflow_id)
|
||||
return
|
||||
|
||||
if current_leader:
|
||||
self._repository.delete_leader(workflow_id)
|
||||
|
||||
new_leader_sid = self._select_graph_leader(workflow_id, preferred_sid=sid)
|
||||
if not new_leader_sid:
|
||||
self.broadcast_leader_change(workflow_id, None)
|
||||
return
|
||||
|
||||
self._repository.set_leader(workflow_id, new_leader_sid)
|
||||
self.broadcast_leader_change(workflow_id, new_leader_sid)
|
||||
|
||||
def _ensure_skill_leader(self, workflow_id: str, file_id: str, preferred_sid: str | None = None) -> None:
|
||||
current_leader = self._repository.get_skill_leader(workflow_id, file_id)
|
||||
active_sids = self._repository.get_active_skill_session_sids(workflow_id, file_id)
|
||||
if current_leader and self.is_session_active(workflow_id, current_leader):
|
||||
if current_leader in active_sids or not active_sids:
|
||||
self._repository.expire_skill_leader(workflow_id, file_id)
|
||||
return
|
||||
|
||||
if current_leader:
|
||||
self._repository.delete_skill_leader(workflow_id, file_id)
|
||||
|
||||
new_leader_sid = self._select_skill_leader(workflow_id, file_id, preferred_sid=preferred_sid)
|
||||
if not new_leader_sid:
|
||||
self.broadcast_skill_leader_change(workflow_id, file_id, None)
|
||||
return
|
||||
|
||||
self._repository.set_skill_leader(workflow_id, file_id, new_leader_sid)
|
||||
self.broadcast_skill_leader_change(workflow_id, file_id, new_leader_sid)
|
||||
|
||||
def _select_graph_leader(self, workflow_id: str, preferred_sid: str | None = None) -> str | None:
|
||||
session_sids = [
|
||||
session["sid"]
|
||||
for session in self._repository.list_sessions(workflow_id)
|
||||
if session.get("graph_active") and self.is_session_active(workflow_id, session["sid"])
|
||||
]
|
||||
if not session_sids:
|
||||
return None
|
||||
if preferred_sid and preferred_sid in session_sids:
|
||||
return preferred_sid
|
||||
return session_sids[0]
|
||||
|
||||
def _select_skill_leader(self, workflow_id: str, file_id: str, preferred_sid: str | None = None) -> str | None:
|
||||
session_sids = [
|
||||
sid
|
||||
for sid in self._repository.get_active_skill_session_sids(workflow_id, file_id)
|
||||
if self.is_session_active(workflow_id, sid)
|
||||
]
|
||||
if not session_sids:
|
||||
return None
|
||||
if preferred_sid and preferred_sid in session_sids:
|
||||
return preferred_sid
|
||||
return session_sids[0]
|
||||
|
||||
def is_session_active(self, workflow_id: str, sid: str) -> bool:
|
||||
if not sid:
|
||||
return False
|
||||
|
||||
try:
|
||||
if not self._socketio.manager.is_connected(sid, "/"):
|
||||
return False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
if not self._repository.session_exists(workflow_id, sid):
|
||||
return False
|
||||
|
||||
if not self._repository.sid_mapping_exists(sid):
|
||||
return False
|
||||
|
||||
return True
|
||||
468
api/services/workflow_comment_service.py
Normal file
468
api/services/workflow_comment_service.py
Normal file
@ -0,0 +1,468 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import uuid_value
|
||||
from models import App, TenantAccountJoin, WorkflowComment, WorkflowCommentMention, WorkflowCommentReply
|
||||
from models.account import Account
|
||||
from tasks.mail_workflow_comment_task import send_workflow_comment_mention_email_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowCommentService:
|
||||
"""Service for managing workflow comments."""
|
||||
|
||||
@staticmethod
|
||||
def _validate_content(content: str) -> None:
|
||||
if len(content.strip()) == 0:
|
||||
raise ValueError("Comment content cannot be empty")
|
||||
|
||||
if len(content) > 1000:
|
||||
raise ValueError("Comment content cannot exceed 1000 characters")
|
||||
|
||||
@staticmethod
|
||||
def _filter_valid_mentioned_user_ids(mentioned_user_ids: Sequence[str]) -> list[str]:
|
||||
"""Return deduplicated UUID user IDs in the order provided."""
|
||||
unique_user_ids: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for user_id in mentioned_user_ids:
|
||||
if not isinstance(user_id, str):
|
||||
continue
|
||||
if not uuid_value(user_id):
|
||||
continue
|
||||
if user_id in seen:
|
||||
continue
|
||||
seen.add(user_id)
|
||||
unique_user_ids.append(user_id)
|
||||
return unique_user_ids
|
||||
|
||||
@staticmethod
|
||||
def _format_comment_excerpt(content: str, max_length: int = 200) -> str:
|
||||
"""Trim comment content for email display."""
|
||||
trimmed = content.strip()
|
||||
if len(trimmed) <= max_length:
|
||||
return trimmed
|
||||
if max_length <= 3:
|
||||
return trimmed[:max_length]
|
||||
return f"{trimmed[: max_length - 3].rstrip()}..."
|
||||
|
||||
@staticmethod
|
||||
def _build_mention_email_payloads(
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
mentioner_id: str,
|
||||
mentioned_user_ids: Sequence[str],
|
||||
content: str,
|
||||
) -> list[dict[str, str]]:
|
||||
"""Prepare email payloads for mentioned users, including the workflow app link."""
|
||||
if not mentioned_user_ids:
|
||||
return []
|
||||
|
||||
candidate_user_ids = [user_id for user_id in mentioned_user_ids if user_id != mentioner_id]
|
||||
if not candidate_user_ids:
|
||||
return []
|
||||
|
||||
app_name = session.scalar(select(App.name).where(App.id == app_id, App.tenant_id == tenant_id)) or "Dify app"
|
||||
commenter_name = session.scalar(select(Account.name).where(Account.id == mentioner_id)) or "Dify user"
|
||||
comment_excerpt = WorkflowCommentService._format_comment_excerpt(content)
|
||||
base_url = dify_config.CONSOLE_WEB_URL.rstrip("/")
|
||||
app_url = f"{base_url}/app/{app_id}/workflow"
|
||||
|
||||
accounts = session.scalars(
|
||||
select(Account)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id, Account.id.in_(candidate_user_ids))
|
||||
).all()
|
||||
|
||||
payloads: list[dict[str, str]] = []
|
||||
for account in accounts:
|
||||
payloads.append(
|
||||
{
|
||||
"language": account.interface_language or "en-US",
|
||||
"to": account.email,
|
||||
"mentioned_name": account.name or account.email,
|
||||
"commenter_name": commenter_name,
|
||||
"app_name": app_name,
|
||||
"comment_content": comment_excerpt,
|
||||
"app_url": app_url,
|
||||
}
|
||||
)
|
||||
return payloads
|
||||
|
||||
@staticmethod
|
||||
def _dispatch_mention_emails(payloads: Sequence[dict[str, str]]) -> None:
|
||||
"""Enqueue mention notification emails."""
|
||||
for payload in payloads:
|
||||
send_workflow_comment_mention_email_task.delay(**payload)
|
||||
|
||||
@staticmethod
|
||||
def get_comments(tenant_id: str, app_id: str) -> Sequence[WorkflowComment]:
|
||||
"""Get all comments for a workflow."""
|
||||
with Session(db.engine) as session:
|
||||
# Get all comments with eager loading
|
||||
stmt = (
|
||||
select(WorkflowComment)
|
||||
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
|
||||
.where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id)
|
||||
.order_by(desc(WorkflowComment.created_at))
|
||||
)
|
||||
|
||||
comments = session.scalars(stmt).all()
|
||||
|
||||
# Batch preload all Account objects to avoid N+1 queries
|
||||
WorkflowCommentService._preload_accounts(session, comments)
|
||||
|
||||
return comments
|
||||
|
||||
@staticmethod
|
||||
def _preload_accounts(session: Session, comments: Sequence[WorkflowComment]) -> None:
|
||||
"""Batch preload Account objects for comments, replies, and mentions."""
|
||||
# Collect all user IDs
|
||||
user_ids: set[str] = set()
|
||||
for comment in comments:
|
||||
user_ids.add(comment.created_by)
|
||||
if comment.resolved_by:
|
||||
user_ids.add(comment.resolved_by)
|
||||
user_ids.update(reply.created_by for reply in comment.replies)
|
||||
user_ids.update(mention.mentioned_user_id for mention in comment.mentions)
|
||||
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
# Batch query all accounts
|
||||
accounts = session.scalars(select(Account).where(Account.id.in_(user_ids))).all()
|
||||
account_map = {str(account.id): account for account in accounts}
|
||||
|
||||
# Cache accounts on objects
|
||||
for comment in comments:
|
||||
comment.cache_created_by_account(account_map.get(comment.created_by))
|
||||
comment.cache_resolved_by_account(account_map.get(comment.resolved_by) if comment.resolved_by else None)
|
||||
for reply in comment.replies:
|
||||
reply.cache_created_by_account(account_map.get(reply.created_by))
|
||||
for mention in comment.mentions:
|
||||
mention.cache_mentioned_user_account(account_map.get(mention.mentioned_user_id))
|
||||
|
||||
@staticmethod
|
||||
def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session | None = None) -> WorkflowComment:
|
||||
"""Get a specific comment."""
|
||||
|
||||
def _get_comment(session: Session) -> WorkflowComment:
|
||||
stmt = (
|
||||
select(WorkflowComment)
|
||||
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
|
||||
.where(
|
||||
WorkflowComment.id == comment_id,
|
||||
WorkflowComment.tenant_id == tenant_id,
|
||||
WorkflowComment.app_id == app_id,
|
||||
)
|
||||
)
|
||||
comment = session.scalar(stmt)
|
||||
|
||||
if not comment:
|
||||
raise NotFound("Comment not found")
|
||||
|
||||
# Preload accounts to avoid N+1 queries
|
||||
WorkflowCommentService._preload_accounts(session, [comment])
|
||||
|
||||
return comment
|
||||
|
||||
if session is not None:
|
||||
return _get_comment(session)
|
||||
else:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return _get_comment(session)
|
||||
|
||||
@staticmethod
|
||||
def create_comment(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
created_by: str,
|
||||
content: str,
|
||||
position_x: float,
|
||||
position_y: float,
|
||||
mentioned_user_ids: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Create a new workflow comment and send mention notification emails."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
comment = WorkflowComment(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
position_x=position_x,
|
||||
position_y=position_y,
|
||||
content=content,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
session.add(comment)
|
||||
session.flush() # Get the comment ID for mentions
|
||||
|
||||
# Create mentions if specified
|
||||
mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or [])
|
||||
for user_id in mentioned_user_ids:
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=comment.id,
|
||||
reply_id=None, # This is a comment mention, not reply mention
|
||||
mentioned_user_id=user_id,
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
mention_email_payloads = WorkflowCommentService._build_mention_email_payloads(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
mentioner_id=created_by,
|
||||
mentioned_user_ids=mentioned_user_ids,
|
||||
content=content,
|
||||
)
|
||||
|
||||
session.commit()
|
||||
WorkflowCommentService._dispatch_mention_emails(mention_email_payloads)
|
||||
|
||||
# Return only what we need - id and created_at
|
||||
return {"id": comment.id, "created_at": comment.created_at}
|
||||
|
||||
@staticmethod
|
||||
def update_comment(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
comment_id: str,
|
||||
user_id: str,
|
||||
content: str,
|
||||
position_x: float | None = None,
|
||||
position_y: float | None = None,
|
||||
mentioned_user_ids: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Update a workflow comment and notify newly mentioned users."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get comment with validation
|
||||
stmt = select(WorkflowComment).where(
|
||||
WorkflowComment.id == comment_id,
|
||||
WorkflowComment.tenant_id == tenant_id,
|
||||
WorkflowComment.app_id == app_id,
|
||||
)
|
||||
comment = session.scalar(stmt)
|
||||
|
||||
if not comment:
|
||||
raise NotFound("Comment not found")
|
||||
|
||||
# Only the creator can update the comment
|
||||
if comment.created_by != user_id:
|
||||
raise Forbidden("Only the comment creator can update it")
|
||||
|
||||
# Update comment fields
|
||||
comment.content = content
|
||||
if position_x is not None:
|
||||
comment.position_x = position_x
|
||||
if position_y is not None:
|
||||
comment.position_y = position_y
|
||||
|
||||
# Update mentions - first remove existing mentions for this comment only (not replies)
|
||||
existing_mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(
|
||||
WorkflowCommentMention.comment_id == comment.id,
|
||||
WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions
|
||||
)
|
||||
).all()
|
||||
existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions}
|
||||
for mention in existing_mentions:
|
||||
session.delete(mention)
|
||||
|
||||
# Add new mentions
|
||||
mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or [])
|
||||
new_mentioned_user_ids = [
|
||||
user_id for user_id in mentioned_user_ids if user_id not in existing_mentioned_user_ids
|
||||
]
|
||||
for user_id_str in mentioned_user_ids:
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=comment.id,
|
||||
reply_id=None, # This is a comment mention
|
||||
mentioned_user_id=user_id_str,
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
mention_email_payloads = WorkflowCommentService._build_mention_email_payloads(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
mentioner_id=user_id,
|
||||
mentioned_user_ids=new_mentioned_user_ids,
|
||||
content=content,
|
||||
)
|
||||
|
||||
session.commit()
|
||||
WorkflowCommentService._dispatch_mention_emails(mention_email_payloads)
|
||||
|
||||
return {"id": comment.id, "updated_at": comment.updated_at}
|
||||
|
||||
@staticmethod
|
||||
def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None:
|
||||
"""Delete a workflow comment."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
|
||||
|
||||
# Only the creator can delete the comment
|
||||
if comment.created_by != user_id:
|
||||
raise Forbidden("Only the comment creator can delete it")
|
||||
|
||||
# Delete associated mentions (both comment and reply mentions)
|
||||
mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id)
|
||||
).all()
|
||||
for mention in mentions:
|
||||
session.delete(mention)
|
||||
|
||||
# Delete associated replies
|
||||
replies = session.scalars(
|
||||
select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id)
|
||||
).all()
|
||||
for reply in replies:
|
||||
session.delete(reply)
|
||||
|
||||
session.delete(comment)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment:
|
||||
"""Resolve a workflow comment."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
|
||||
if comment.resolved:
|
||||
return comment
|
||||
|
||||
comment.resolved = True
|
||||
comment.resolved_at = naive_utc_now()
|
||||
comment.resolved_by = user_id
|
||||
session.commit()
|
||||
|
||||
return comment
|
||||
|
||||
@staticmethod
|
||||
def create_reply(
|
||||
comment_id: str, content: str, created_by: str, mentioned_user_ids: list[str] | None = None
|
||||
) -> dict:
|
||||
"""Add a reply to a workflow comment and notify mentioned users."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Check if comment exists
|
||||
comment = session.get(WorkflowComment, comment_id)
|
||||
if not comment:
|
||||
raise NotFound("Comment not found")
|
||||
|
||||
reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by)
|
||||
|
||||
session.add(reply)
|
||||
session.flush() # Get the reply ID for mentions
|
||||
|
||||
# Create mentions if specified
|
||||
mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or [])
|
||||
for user_id in mentioned_user_ids:
|
||||
# Create mention linking to specific reply
|
||||
mention = WorkflowCommentMention(comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id)
|
||||
session.add(mention)
|
||||
|
||||
mention_email_payloads = WorkflowCommentService._build_mention_email_payloads(
|
||||
session=session,
|
||||
tenant_id=comment.tenant_id,
|
||||
app_id=comment.app_id,
|
||||
mentioner_id=created_by,
|
||||
mentioned_user_ids=mentioned_user_ids,
|
||||
content=content,
|
||||
)
|
||||
|
||||
session.commit()
|
||||
WorkflowCommentService._dispatch_mention_emails(mention_email_payloads)
|
||||
|
||||
return {"id": reply.id, "created_at": reply.created_at}
|
||||
|
||||
@staticmethod
|
||||
def update_reply(reply_id: str, user_id: str, content: str, mentioned_user_ids: list[str] | None = None) -> dict:
|
||||
"""Update a comment reply and notify newly mentioned users."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
reply = session.get(WorkflowCommentReply, reply_id)
|
||||
if not reply:
|
||||
raise NotFound("Reply not found")
|
||||
|
||||
# Only the creator can update the reply
|
||||
if reply.created_by != user_id:
|
||||
raise Forbidden("Only the reply creator can update it")
|
||||
|
||||
reply.content = content
|
||||
|
||||
# Update mentions - first remove existing mentions for this reply
|
||||
existing_mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id)
|
||||
).all()
|
||||
existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions}
|
||||
for mention in existing_mentions:
|
||||
session.delete(mention)
|
||||
|
||||
# Add mentions
|
||||
mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or [])
|
||||
new_mentioned_user_ids = [
|
||||
user_id for user_id in mentioned_user_ids if user_id not in existing_mentioned_user_ids
|
||||
]
|
||||
for user_id_str in mentioned_user_ids:
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
mention_email_payloads: list[dict[str, str]] = []
|
||||
comment = session.get(WorkflowComment, reply.comment_id)
|
||||
if comment:
|
||||
mention_email_payloads = WorkflowCommentService._build_mention_email_payloads(
|
||||
session=session,
|
||||
tenant_id=comment.tenant_id,
|
||||
app_id=comment.app_id,
|
||||
mentioner_id=user_id,
|
||||
mentioned_user_ids=new_mentioned_user_ids,
|
||||
content=content,
|
||||
)
|
||||
|
||||
session.commit()
|
||||
session.refresh(reply) # Refresh to get updated timestamp
|
||||
WorkflowCommentService._dispatch_mention_emails(mention_email_payloads)
|
||||
|
||||
return {"id": reply.id, "updated_at": reply.updated_at}
|
||||
|
||||
@staticmethod
|
||||
def delete_reply(reply_id: str, user_id: str) -> None:
|
||||
"""Delete a comment reply."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
reply = session.get(WorkflowCommentReply, reply_id)
|
||||
if not reply:
|
||||
raise NotFound("Reply not found")
|
||||
|
||||
# Only the creator can delete the reply
|
||||
if reply.created_by != user_id:
|
||||
raise Forbidden("Only the reply creator can delete it")
|
||||
|
||||
# Delete associated mentions first
|
||||
mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id)
|
||||
).all()
|
||||
for mention in mentions:
|
||||
session.delete(mention)
|
||||
|
||||
session.delete(reply)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment:
|
||||
"""Validate that a comment belongs to the specified tenant and app."""
|
||||
return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id)
|
||||
65
api/tasks/mail_workflow_comment_task.py
Normal file
65
api/tasks/mail_workflow_comment_task.py
Normal file
@ -0,0 +1,65 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_workflow_comment_mention_email_task(
|
||||
language: str,
|
||||
to: str,
|
||||
mentioned_name: str,
|
||||
commenter_name: str,
|
||||
app_name: str,
|
||||
comment_content: str,
|
||||
app_url: str,
|
||||
):
|
||||
"""
|
||||
Send workflow comment mention email with internationalization support.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
mentioned_name: Name of the mentioned user
|
||||
commenter_name: Name of the comment author
|
||||
app_name: Name of the app where the comment was made
|
||||
comment_content: Comment content excerpt
|
||||
app_url: Link to the app workflow page
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start workflow comment mention mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.WORKFLOW_COMMENT_MENTION,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"mentioned_name": mentioned_name,
|
||||
"commenter_name": commenter_name,
|
||||
"app_name": app_name,
|
||||
"comment_content": comment_content,
|
||||
"app_url": app_url,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Send workflow comment mention mail to {to} succeeded: latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("workflow comment mention email to %s failed", to)
|
||||
Loading…
Reference in New Issue
Block a user