diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template index c3e2c50c52..2611b75c6c 100644 --- a/.vscode/launch.json.template +++ b/.vscode/launch.json.template @@ -2,21 +2,10 @@ "version": "0.2.0", "configurations": [ { - "name": "Python: Flask API", + "name": "Python: API (gevent)", "type": "debugpy", "request": "launch", - "module": "flask", - "env": { - "FLASK_APP": "app.py", - "FLASK_ENV": "development" - }, - "args": [ - "run", - "--host=0.0.0.0", - "--port=5001", - "--no-debugger", - "--no-reload" - ], + "program": "${workspaceFolder}/api/app.py", "jinja": true, "justMyCode": true, "cwd": "${workspaceFolder}/api", diff --git a/api/.env.example b/api/.env.example index beb820e797..7455d4a0e9 100644 --- a/api/.env.example +++ b/api/.env.example @@ -33,6 +33,9 @@ TRIGGER_URL=http://localhost:5001 # The time in seconds after the signature is rejected FILES_ACCESS_TIMEOUT=300 +# Collaboration mode toggle +ENABLE_COLLABORATION_MODE=false + # Access token expiration time in minutes ACCESS_TOKEN_EXPIRE_MINUTES=60 diff --git a/api/.vscode/launch.json.example b/api/.vscode/launch.json.example index 6bdfa2c039..1001559176 100644 --- a/api/.vscode/launch.json.example +++ b/api/.vscode/launch.json.example @@ -3,29 +3,21 @@ "compounds": [ { "name": "Launch Flask and Celery", - "configurations": ["Python: Flask", "Python: Celery"] + "configurations": ["Python: API (gevent)", "Python: Celery"] } ], "configurations": [ { - "name": "Python: Flask", - "consoleName": "Flask", + "name": "Python: API (gevent)", + "consoleName": "API", "type": "debugpy", "request": "launch", "python": "${workspaceFolder}/.venv/bin/python", "cwd": "${workspaceFolder}", "envFile": ".env", - "module": "flask", + "program": "${workspaceFolder}/app.py", "justMyCode": true, - "jinja": true, - "env": { - "FLASK_APP": "app.py", - "GEVENT_SUPPORT": "True" - }, - "args": [ - "run", - "--port=5001" - ] + "jinja": true }, { "name": "Python: Celery", diff --git a/api/app.py b/api/app.py index c018c8a045..e53b037be5 100644 --- a/api/app.py +++ b/api/app.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import sys from typing import TYPE_CHECKING, cast @@ -9,17 +10,35 @@ if TYPE_CHECKING: celery: Celery +HOST = "0.0.0.0" +PORT = 5001 +logger = logging.getLogger(__name__) + + def is_db_command() -> bool: if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db": return True return False +def log_startup_banner(host: str, port: int) -> None: + debugger_attached = sys.gettrace() is not None + logger.info("Serving Dify API via gevent WebSocket server") + logger.info("Bound to http://%s:%s", host, port) + logger.info("Debugger attached: %s", "on" if debugger_attached else "off") + logger.info("Press CTRL+C to quit") + + # create app +flask_app = None +socketio_app = None + if is_db_command(): from app_factory import create_migrations_app app = create_migrations_app() + socketio_app = app + flask_app = app else: # Gunicorn and Celery handle monkey patching automatically in production by # specifying the `gevent` worker class. Manual monkey patching is not required here. @@ -30,8 +49,14 @@ else: from app_factory import create_app - app = create_app() + socketio_app, flask_app = create_app() + app = flask_app celery = cast("Celery", app.extensions["celery"]) if __name__ == "__main__": - app.run(host="0.0.0.0", port=5001) + from gevent import pywsgi + from geventwebsocket.handler import WebSocketHandler # type: ignore[reportMissingTypeStubs] + + log_startup_banner(HOST, PORT) + server = pywsgi.WSGIServer((HOST, PORT), socketio_app, handler_class=WebSocketHandler) + server.serve_forever() diff --git a/api/app_factory.py b/api/app_factory.py index 76838f9925..48e50ceae9 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -1,6 +1,7 @@ import logging import time +import socketio # type: ignore[reportMissingTypeStubs] from flask import request from opentelemetry.trace import get_current_span from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID @@ -10,6 +11,7 @@ from contexts.wrapper import RecyclableContextVar from controllers.console.error import UnauthorizedAndForceLogout from core.logging.context import init_request_context from dify_app import DifyApp +from extensions.ext_socketio import sio from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import LicenseStatus @@ -122,14 +124,18 @@ def create_flask_app_with_configs() -> DifyApp: return dify_app -def create_app() -> DifyApp: +def create_app() -> tuple[socketio.WSGIApp, DifyApp]: start_time = time.perf_counter() app = create_flask_app_with_configs() initialize_extensions(app) + + sio.app = app + socketio_app = socketio.WSGIApp(sio, app) + end_time = time.perf_counter() if dify_config.DEBUG: logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2)) - return app + return socketio_app, app def initialize_extensions(app: DifyApp): diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index d37cff63e9..ae49ae47d0 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1274,6 +1274,13 @@ class PositionConfig(BaseSettings): return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} +class CollaborationConfig(BaseSettings): + ENABLE_COLLABORATION_MODE: bool = Field( + description="Whether to enable collaboration mode features across the workspace", + default=False, + ) + + class LoginConfig(BaseSettings): ENABLE_EMAIL_CODE_LOGIN: bool = Field( description="whether to enable email code login", @@ -1399,6 +1406,7 @@ class FeatureConfig( WorkflowConfig, WorkflowNodeExecutionConfig, WorkspaceConfig, + CollaborationConfig, LoginConfig, AccountConfig, SwaggerUIConfig, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index d624b10b22..980e828945 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -65,6 +65,7 @@ from .app import ( statistic, workflow, workflow_app_log, + workflow_comment, workflow_draft_variable, workflow_run, workflow_statistic, @@ -116,6 +117,7 @@ from .explore import ( saved_message, trial, ) +from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport] # Import tag controllers from .tag import tags @@ -201,6 +203,7 @@ __all__ = [ "saved_message", "setup", "site", + "socketio_workflow", "spec", "statistic", "tags", @@ -211,6 +214,7 @@ __all__ = [ "website", "workflow", "workflow_app_log", + "workflow_comment", "workflow_draft_variable", "workflow_run", "workflow_statistic", diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 5e6ff87d62..cca4bbee1e 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -7,6 +7,7 @@ from flask import abort, request from flask_restx import Resource, fields, marshal, marshal_with from graphon.enums import NodeType from graphon.file import File +from graphon.file import helpers as file_helpers from graphon.graph_engine.manager import GraphEngineManager from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, ValidationError, field_validator @@ -39,6 +40,7 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields +from fields.online_user_fields import online_user_list_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields from libs import helper from libs.datetime_utils import naive_utc_now @@ -47,6 +49,7 @@ from libs.login import current_account_with_tenant, login_required from models import App from models.model import AppMode from models.workflow import Workflow +from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -57,6 +60,7 @@ _file_access_controller = DatabaseFileAccessController() LISTENING_RETRY_IN = 2000 DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" +MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS = 50 # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -150,6 +154,14 @@ class ConvertToWorkflowPayload(BaseModel): icon_background: str | None = None +class WorkflowFeaturesPayload(BaseModel): + features: dict[str, Any] = Field(..., description="Workflow feature configuration") + + +class WorkflowOnlineUsersQuery(BaseModel): + app_ids: str = Field(..., description="Comma-separated app IDs") + + class DraftWorkflowTriggerRunPayload(BaseModel): node_id: str @@ -173,6 +185,8 @@ reg(DefaultBlockConfigQuery) reg(ConvertToWorkflowPayload) reg(WorkflowListQuery) reg(WorkflowUpdatePayload) +reg(WorkflowFeaturesPayload) +reg(WorkflowOnlineUsersQuery) reg(DraftWorkflowTriggerRunPayload) reg(DraftWorkflowTriggerRunAllPayload) @@ -931,6 +945,32 @@ class ConvertToWorkflowApi(Resource): } +@console_ns.route("/apps//workflows/draft/features") +class WorkflowFeaturesApi(Resource): + """Update draft workflow features.""" + + @console_ns.expect(console_ns.models[WorkflowFeaturesPayload.__name__]) + @console_ns.doc("update_workflow_features") + @console_ns.doc(description="Update draft workflow features") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Workflow features updated successfully") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required + def post(self, app_model: App): + current_user, _ = current_account_with_tenant() + + args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {}) + features = args.features + + workflow_service = WorkflowService() + workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user) + + return {"result": "success"} + + @console_ns.route("/apps//workflows") class PublishedAllWorkflowApi(Resource): @console_ns.expect(console_ns.models[WorkflowListQuery.__name__]) @@ -1340,3 +1380,62 @@ class DraftWorkflowTriggerRunAllApi(Resource): "status": "error", } ), 400 + + +@console_ns.route("/apps/workflows/online-users") +class WorkflowOnlineUsersApi(Resource): + @console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__]) + @console_ns.doc("get_workflow_online_users") + @console_ns.doc(description="Get workflow online users") + @setup_required + @login_required + @account_initialization_required + @marshal_with(online_user_list_fields) + def get(self): + args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + + app_ids = list(dict.fromkeys(app_id.strip() for app_id in args.app_ids.split(",") if app_id.strip())) + if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS: + raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS} app_ids are allowed per request.") + + if not app_ids: + return {"data": []} + + _, current_tenant_id = current_account_with_tenant() + workflow_service = WorkflowService() + accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id) + + results = [] + for app_id in app_ids: + if app_id not in accessible_app_ids: + continue + + users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}") + + users = [] + for _, user_info_json in users_json.items(): + try: + user_info = json.loads(user_info_json) + except Exception: + continue + + if not isinstance(user_info, dict): + continue + + avatar = user_info.get("avatar") + if isinstance(avatar, str) and avatar and not avatar.startswith(("http://", "https://")): + try: + user_info["avatar"] = file_helpers.get_signed_file_url(avatar) + except Exception as exc: + logger.warning( + "Failed to sign workflow online user avatar; using original value. " + "app_id=%s avatar=%s error=%s", + app_id, + avatar, + exc, + ) + + users.append(user_info) + results.append({"app_id": app_id, "users": users}) + + return {"data": results} diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py new file mode 100644 index 0000000000..e7c3e982a6 --- /dev/null +++ b/api/controllers/console/app/workflow_comment.py @@ -0,0 +1,335 @@ +import logging + +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field, TypeAdapter + +from controllers.common.schema import register_schema_models +from controllers.console import console_ns +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from fields.member_fields import AccountWithRole +from fields.workflow_comment_fields import ( + workflow_comment_basic_fields, + workflow_comment_create_fields, + workflow_comment_detail_fields, + workflow_comment_reply_create_fields, + workflow_comment_reply_update_fields, + workflow_comment_resolve_fields, + workflow_comment_update_fields, +) +from libs.login import current_user, login_required +from models import App +from services.account_service import TenantService +from services.workflow_comment_service import WorkflowCommentService + +logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class WorkflowCommentCreatePayload(BaseModel): + content: str = Field(..., description="Comment content") + position_x: float = Field(..., description="Comment X position") + position_y: float = Field(..., description="Comment Y position") + mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs") + + +class WorkflowCommentUpdatePayload(BaseModel): + content: str = Field(..., description="Comment content") + position_x: float | None = Field(default=None, description="Comment X position") + position_y: float | None = Field(default=None, description="Comment Y position") + mentioned_user_ids: list[str] | None = Field( + default=None, + description="Mentioned user IDs. Omit to keep existing mentions.", + ) + + +class WorkflowCommentReplyPayload(BaseModel): + content: str = Field(..., description="Reply content") + mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs") + + +class WorkflowCommentMentionUsersPayload(BaseModel): + users: list[AccountWithRole] + + +for model in ( + WorkflowCommentCreatePayload, + WorkflowCommentUpdatePayload, + WorkflowCommentReplyPayload, +): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models(console_ns, AccountWithRole, WorkflowCommentMentionUsersPayload) + +workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields) +workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields) +workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields) +workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields) +workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields) +workflow_comment_reply_create_model = console_ns.model( + "WorkflowCommentReplyCreate", workflow_comment_reply_create_fields +) +workflow_comment_reply_update_model = console_ns.model( + "WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields +) + + +@console_ns.route("/apps//workflow/comments") +class WorkflowCommentListApi(Resource): + """API for listing and creating workflow comments.""" + + @console_ns.doc("list_workflow_comments") + @console_ns.doc(description="Get all comments for a workflow") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_basic_model, envelope="data") + def get(self, app_model: App): + """Get all comments for a workflow.""" + comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id) + + return comments + + @console_ns.doc("create_workflow_comment") + @console_ns.doc(description="Create a new workflow comment") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__]) + @console_ns.response(201, "Comment created successfully", workflow_comment_create_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_create_model) + @edit_permission_required + def post(self, app_model: App): + """Create a new workflow comment.""" + payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.create_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + created_by=current_user.id, + content=payload.content, + position_x=payload.position_x, + position_y=payload.position_y, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result, 201 + + +@console_ns.route("/apps//workflow/comments/") +class WorkflowCommentDetailApi(Resource): + """API for managing individual workflow comments.""" + + @console_ns.doc("get_workflow_comment") + @console_ns.doc(description="Get a specific workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_detail_model) + def get(self, app_model: App, comment_id: str): + """Get a specific workflow comment.""" + comment = WorkflowCommentService.get_comment( + tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id + ) + + return comment + + @console_ns.doc("update_workflow_comment") + @console_ns.doc(description="Update a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__]) + @console_ns.response(200, "Comment updated successfully", workflow_comment_update_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_update_model) + @edit_permission_required + def put(self, app_model: App, comment_id: str): + """Update a workflow comment.""" + payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.update_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + content=payload.content, + position_x=payload.position_x, + position_y=payload.position_y, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result + + @console_ns.doc("delete_workflow_comment") + @console_ns.doc(description="Delete a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(204, "Comment deleted successfully") + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @edit_permission_required + def delete(self, app_model: App, comment_id: str): + """Delete a workflow comment.""" + WorkflowCommentService.delete_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + ) + + return {"result": "success"}, 204 + + +@console_ns.route("/apps//workflow/comments//resolve") +class WorkflowCommentResolveApi(Resource): + """API for resolving and reopening workflow comments.""" + + @console_ns.doc("resolve_workflow_comment") + @console_ns.doc(description="Resolve a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_resolve_model) + @edit_permission_required + def post(self, app_model: App, comment_id: str): + """Resolve a workflow comment.""" + comment = WorkflowCommentService.resolve_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + ) + + return comment + + +@console_ns.route("/apps//workflow/comments//replies") +class WorkflowCommentReplyApi(Resource): + """API for managing comment replies.""" + + @console_ns.doc("create_workflow_comment_reply") + @console_ns.doc(description="Add a reply to a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__]) + @console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_reply_create_model) + @edit_permission_required + def post(self, app_model: App, comment_id: str): + """Add a reply to a workflow comment.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.create_reply( + comment_id=comment_id, + content=payload.content, + created_by=current_user.id, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result, 201 + + +@console_ns.route("/apps//workflow/comments//replies/") +class WorkflowCommentReplyDetailApi(Resource): + """API for managing individual comment replies.""" + + @console_ns.doc("update_workflow_comment_reply") + @console_ns.doc(description="Update a comment reply") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__]) + @console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_reply_update_model) + @edit_permission_required + def put(self, app_model: App, comment_id: str, reply_id: str): + """Update a comment reply.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {}) + + reply = WorkflowCommentService.update_reply( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + reply_id=reply_id, + user_id=current_user.id, + content=payload.content, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return reply + + @console_ns.doc("delete_workflow_comment_reply") + @console_ns.doc(description="Delete a comment reply") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"}) + @console_ns.response(204, "Reply deleted successfully") + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @edit_permission_required + def delete(self, app_model: App, comment_id: str, reply_id: str): + """Delete a comment reply.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + WorkflowCommentService.delete_reply( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + reply_id=reply_id, + user_id=current_user.id, + ) + + return {"result": "success"}, 204 + + +@console_ns.route("/apps//workflow/comments/mention-users") +class WorkflowCommentMentionUsersApi(Resource): + """API for getting mentionable users for workflow comments.""" + + @console_ns.doc("workflow_comment_mention_users") + @console_ns.doc(description="Get all users in current tenant for mentions") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response( + 200, "Mentionable users retrieved successfully", console_ns.models[WorkflowCommentMentionUsersPayload.__name__] + ) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + def get(self, app_model: App): + """Get all users in current tenant for mentions.""" + members = TenantService.get_tenant_members(current_user.current_tenant) + users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = WorkflowCommentMentionUsersPayload(users=users) + return response.model_dump(mode="json"), 200 diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 657e794ac4..640189b070 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -22,6 +22,7 @@ from controllers.web.error import InvalidArgumentError, NotFoundError from core.app.file_access import DatabaseFileAccessController from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db +from factories import variable_factory from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type from libs.login import current_user, login_required @@ -45,6 +46,16 @@ class WorkflowDraftVariableUpdatePayload(BaseModel): value: Any | None = Field(default=None, description="Variable value") +class ConversationVariableUpdatePayload(BaseModel): + conversation_variables: list[dict[str, Any]] = Field( + ..., description="Conversation variables for the draft workflow" + ) + + +class EnvironmentVariableUpdatePayload(BaseModel): + environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow") + + console_ns.schema_model( WorkflowDraftVariableListQuery.__name__, WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), @@ -53,6 +64,14 @@ console_ns.schema_model( WorkflowDraftVariableUpdatePayload.__name__, WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), ) +console_ns.schema_model( + ConversationVariableUpdatePayload.__name__, + ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + EnvironmentVariableUpdatePayload.__name__, + EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) def _convert_values_to_json_serializable_object(value: Segment): @@ -510,6 +529,34 @@ class ConversationVariableCollectionApi(Resource): db.session.commit() return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) + @console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__]) + @console_ns.doc("update_conversation_variables") + @console_ns.doc(description="Update conversation variables for workflow draft") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Conversation variables updated successfully") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_app_model(mode=AppMode.ADVANCED_CHAT) + def post(self, app_model: App): + payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {}) + + workflow_service = WorkflowService() + + conversation_variables_list = payload.conversation_variables + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + + workflow_service.update_draft_workflow_conversation_variables( + app_model=app_model, + account=current_user, + conversation_variables=conversation_variables, + ) + + return {"result": "success"} + @console_ns.route("/apps//workflows/draft/system-variables") class SystemVariableCollectionApi(Resource): @@ -561,3 +608,31 @@ class EnvironmentVariableCollectionApi(Resource): ) return {"items": env_vars_list} + + @console_ns.expect(console_ns.models[EnvironmentVariableUpdatePayload.__name__]) + @console_ns.doc("update_environment_variables") + @console_ns.doc(description="Update environment variables for workflow draft") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Environment variables updated successfully") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {}) + + workflow_service = WorkflowService() + + environment_variables_list = payload.environment_variables + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + + workflow_service.update_draft_workflow_environment_variables( + app_model=app_model, + account=current_user, + environment_variables=environment_variables, + ) + + return {"result": "success"} diff --git a/api/controllers/console/socketio/__init__.py b/api/controllers/console/socketio/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/controllers/console/socketio/__init__.py @@ -0,0 +1 @@ + diff --git a/api/controllers/console/socketio/workflow.py b/api/controllers/console/socketio/workflow.py new file mode 100644 index 0000000000..b4f03593fd --- /dev/null +++ b/api/controllers/console/socketio/workflow.py @@ -0,0 +1,108 @@ +import logging +from collections.abc import Callable +from typing import cast + +from flask import Request as FlaskRequest + +from extensions.ext_socketio import sio +from libs.passport import PassportService +from libs.token import extract_access_token +from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository +from services.account_service import AccountService +from services.workflow_collaboration_service import WorkflowCollaborationService + +repository = WorkflowCollaborationRepository() +collaboration_service = WorkflowCollaborationService(repository, sio) + + +def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]: + return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event)) + + +@_sio_on("connect") +def socket_connect(sid, environ, auth): + """ + WebSocket connect event, do authentication here. + """ + try: + request_environ = FlaskRequest(environ) + token = extract_access_token(request_environ) + except Exception: + logging.exception("Failed to extract token") + token = None + + if not token: + logging.warning("Socket connect rejected: missing token (sid=%s)", sid) + return False + + try: + decoded = PassportService().verify(token) + user_id = decoded.get("user_id") + if not user_id: + logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid) + return False + + with sio.app.app_context(): + user = AccountService.load_logged_in_account(account_id=user_id) + if not user: + logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid) + return False + if not user.has_edit_permission: + logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid) + return False + + collaboration_service.save_socket_identity(sid, user) + return True + + except Exception: + logging.exception("Socket authentication failed") + return False + + +@_sio_on("user_connect") +def handle_user_connect(sid, data): + """ + Handle user connect event. Each session (tab) is treated as an independent collaborator. + """ + workflow_id = data.get("workflow_id") + if not workflow_id: + return {"msg": "workflow_id is required"}, 400 + + result = collaboration_service.authorize_and_join_workflow_room(workflow_id, sid) + if not result: + return {"msg": "unauthorized"}, 401 + + user_id, is_leader = result + return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader} + + +@_sio_on("disconnect") +def handle_disconnect(sid): + """ + Handle session disconnect event. Remove the specific session from online users. + """ + collaboration_service.disconnect_session(sid) + + +@_sio_on("collaboration_event") +def handle_collaboration_event(sid, data): + """ + Handle general collaboration events, include: + 1. mouse_move + 2. vars_and_features_update + 3. sync_request (ask leader to update graph) + 4. app_state_update + 5. mcp_server_update + 6. workflow_update + 7. comments_update + 8. node_panel_presence + """ + return collaboration_service.relay_collaboration_event(sid, data) + + +@_sio_on("graph_event") +def handle_graph_event(sid, data): + """ + Handle graph events - simple broadcast relay. + """ + return collaboration_service.relay_graph_event(sid, data) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 582c38052e..9de56acc4d 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -6,6 +6,7 @@ from typing import Any, Literal import pytz from flask import request from flask_restx import Resource +from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import select @@ -75,6 +76,10 @@ class AccountAvatarPayload(BaseModel): avatar: str +class AccountAvatarQuery(BaseModel): + avatar: str = Field(..., description="Avatar file ID") + + class AccountInterfaceLanguagePayload(BaseModel): interface_language: str @@ -160,6 +165,7 @@ def reg(cls: type[BaseModel]): reg(AccountInitPayload) reg(AccountNamePayload) reg(AccountAvatarPayload) +reg(AccountAvatarQuery) reg(AccountInterfaceLanguagePayload) reg(AccountInterfaceThemePayload) reg(AccountTimezonePayload) @@ -309,6 +315,18 @@ class AccountNameApi(Resource): @console_ns.route("/account/avatar") class AccountAvatarApi(Resource): + @console_ns.expect(console_ns.models[AccountAvatarQuery.__name__]) + @console_ns.doc("get_account_avatar") + @console_ns.doc(description="Get account avatar url") + @setup_required + @login_required + @account_initialization_required + def get(self): + args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + + avatar_url = file_helpers.get_signed_file_url(args.avatar) + return {"avatar_url": avatar_url} + @console_ns.expect(console_ns.models[AccountAvatarPayload.__name__]) @setup_required @login_required diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 6b904b7d0d..48533efe66 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -119,14 +119,16 @@ elif [[ "${MODE}" == "job" ]]; then else if [[ "${DEBUG}" == "true" ]]; then - exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug + export HOST=${DIFY_BIND_ADDRESS:-0.0.0.0} + export PORT=${DIFY_PORT:-5001} + exec python -m app else exec gunicorn \ --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ --workers ${SERVER_WORKER_AMOUNT:-1} \ - --worker-class ${SERVER_WORKER_CLASS:-gevent} \ + --worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \ --worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \ --timeout ${GUNICORN_TIMEOUT:-200} \ - app:app + app:socketio_app fi fi diff --git a/api/extensions/ext_socketio.py b/api/extensions/ext_socketio.py new file mode 100644 index 0000000000..5ed82bac8d --- /dev/null +++ b/api/extensions/ext_socketio.py @@ -0,0 +1,5 @@ +import socketio # type: ignore[reportMissingTypeStubs] + +from configs import dify_config + +sio = socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS) diff --git a/api/fields/online_user_fields.py b/api/fields/online_user_fields.py new file mode 100644 index 0000000000..bdbe19679c --- /dev/null +++ b/api/fields/online_user_fields.py @@ -0,0 +1,16 @@ +from flask_restx import fields + +online_user_partial_fields = { + "user_id": fields.String, + "username": fields.String, + "avatar": fields.String, +} + +workflow_online_users_fields = { + "app_id": fields.String, + "users": fields.List(fields.Nested(online_user_partial_fields)), +} + +online_user_list_fields = { + "data": fields.List(fields.Nested(workflow_online_users_fields)), +} diff --git a/api/fields/workflow_comment_fields.py b/api/fields/workflow_comment_fields.py new file mode 100644 index 0000000000..c708dd3460 --- /dev/null +++ b/api/fields/workflow_comment_fields.py @@ -0,0 +1,96 @@ +from flask_restx import fields + +from libs.helper import AvatarUrlField, TimestampField + +# basic account fields for comments +account_fields = { + "id": fields.String, + "name": fields.String, + "email": fields.String, + "avatar_url": AvatarUrlField, +} + +# Comment mention fields +workflow_comment_mention_fields = { + "mentioned_user_id": fields.String, + "mentioned_user_account": fields.Nested(account_fields, allow_null=True), + "reply_id": fields.String, +} + +# Comment reply fields +workflow_comment_reply_fields = { + "id": fields.String, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, +} + +# Basic comment fields (for list views) +workflow_comment_basic_fields = { + "id": fields.String, + "position_x": fields.Float, + "position_y": fields.Float, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, + "updated_at": TimestampField, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, + "resolved_by_account": fields.Nested(account_fields, allow_null=True), + "reply_count": fields.Integer, + "mention_count": fields.Integer, + "participants": fields.List(fields.Nested(account_fields)), +} + +# Detailed comment fields (for single comment view) +workflow_comment_detail_fields = { + "id": fields.String, + "position_x": fields.Float, + "position_y": fields.Float, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, + "updated_at": TimestampField, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, + "resolved_by_account": fields.Nested(account_fields, allow_null=True), + "replies": fields.List(fields.Nested(workflow_comment_reply_fields)), + "mentions": fields.List(fields.Nested(workflow_comment_mention_fields)), +} + +# Comment creation response fields (simplified) +workflow_comment_create_fields = { + "id": fields.String, + "created_at": TimestampField, +} + +# Comment update response fields (simplified) +workflow_comment_update_fields = { + "id": fields.String, + "updated_at": TimestampField, +} + +# Comment resolve response fields +workflow_comment_resolve_fields = { + "id": fields.String, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, +} + +# Reply creation response fields (simplified) +workflow_comment_reply_create_fields = { + "id": fields.String, + "created_at": TimestampField, +} + +# Reply update response fields +workflow_comment_reply_update_fields = { + "id": fields.String, + "updated_at": TimestampField, +} diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index 0828cf80bf..1519f07bb1 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -37,6 +37,7 @@ class EmailType(StrEnum): ENTERPRISE_CUSTOM = auto() QUEUE_MONITOR_ALERT = auto() DOCUMENT_CLEAN_NOTIFY = auto() + WORKFLOW_COMMENT_MENTION = auto() EMAIL_REGISTER = auto() EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto() RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto() @@ -453,6 +454,18 @@ def create_default_email_config() -> EmailI18nConfig: branded_template_path="clean_document_job_mail_template_zh-CN.html", ), }, + EmailType.WORKFLOW_COMMENT_MENTION: { + EmailLanguage.EN_US: EmailTemplate( + subject="You were mentioned in a workflow comment", + template_path="workflow_comment_mention_template_en-US.html", + branded_template_path="without-brand/workflow_comment_mention_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="你在工作流评论中被提及", + template_path="workflow_comment_mention_template_zh-CN.html", + branded_template_path="without-brand/workflow_comment_mention_template_zh-CN.html", + ), + }, EmailType.TRIGGER_EVENTS_LIMIT_SANDBOX: { EmailLanguage.EN_US: EmailTemplate( subject="You’ve reached your Sandbox Trigger Events limit", diff --git a/api/migrations/versions/2026_04_15_1726-227822d22895_add_workflow_comments_table.py b/api/migrations/versions/2026_04_15_1726-227822d22895_add_workflow_comments_table.py new file mode 100644 index 0000000000..0548c932b5 --- /dev/null +++ b/api/migrations/versions/2026_04_15_1726-227822d22895_add_workflow_comments_table.py @@ -0,0 +1,90 @@ +"""Add workflow comments table + +Revision ID: 227822d22895 +Revises: 8574b23a38fd +Create Date: 2025-08-22 17:26:15.255980 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '227822d22895' +down_revision = '8574b23a38fd' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow_comments', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('position_x', sa.Float(), nullable=False), + sa.Column('position_y', sa.Float(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('resolved', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('resolved_at', sa.DateTime(), nullable=True), + sa.Column('resolved_by', models.types.StringUUID(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_comments_pkey') + ) + with op.batch_alter_table('workflow_comments', schema=None) as batch_op: + batch_op.create_index('workflow_comments_app_idx', ['tenant_id', 'app_id'], unique=False) + batch_op.create_index('workflow_comments_created_at_idx', ['created_at'], unique=False) + + op.create_table('workflow_comment_replies', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('comment_id', models.types.StringUUID(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_replies_comment_id_fkey'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name='workflow_comment_replies_pkey') + ) + with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op: + batch_op.create_index('comment_replies_comment_idx', ['comment_id'], unique=False) + batch_op.create_index('comment_replies_created_at_idx', ['created_at'], unique=False) + + op.create_table('workflow_comment_mentions', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('comment_id', models.types.StringUUID(), nullable=False), + sa.Column('reply_id', models.types.StringUUID(), nullable=True), + sa.Column('mentioned_user_id', models.types.StringUUID(), nullable=False), + sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_mentions_comment_id_fkey'), ondelete='CASCADE'), + sa.ForeignKeyConstraint(['reply_id'], ['workflow_comment_replies.id'], name=op.f('workflow_comment_mentions_reply_id_fkey'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name='workflow_comment_mentions_pkey') + ) + with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op: + batch_op.create_index('comment_mentions_comment_idx', ['comment_id'], unique=False) + batch_op.create_index('comment_mentions_reply_idx', ['reply_id'], unique=False) + batch_op.create_index('comment_mentions_user_idx', ['mentioned_user_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op: + batch_op.drop_index('comment_mentions_user_idx') + batch_op.drop_index('comment_mentions_reply_idx') + batch_op.drop_index('comment_mentions_comment_idx') + + op.drop_table('workflow_comment_mentions') + with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op: + batch_op.drop_index('comment_replies_created_at_idx') + batch_op.drop_index('comment_replies_comment_idx') + + op.drop_table('workflow_comment_replies') + with op.batch_alter_table('workflow_comments', schema=None) as batch_op: + batch_op.drop_index('workflow_comments_created_at_idx') + batch_op.drop_index('workflow_comments_app_idx') + + op.drop_table('workflow_comments') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index fcae07f948..85be9ca3bd 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -9,6 +9,11 @@ from .account import ( TenantStatus, ) from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from .comment import ( + WorkflowComment, + WorkflowCommentMention, + WorkflowCommentReply, +) from .dataset import ( AppDatasetJoin, Dataset, @@ -208,6 +213,9 @@ __all__ = [ "WorkflowAppLog", "WorkflowAppLogCreatedFrom", "WorkflowArchiveLog", + "WorkflowComment", + "WorkflowCommentMention", + "WorkflowCommentReply", "WorkflowNodeExecutionModel", "WorkflowNodeExecutionOffload", "WorkflowNodeExecutionTriggeredFrom", diff --git a/api/models/comment.py b/api/models/comment.py new file mode 100644 index 0000000000..308339e6f6 --- /dev/null +++ b/api/models/comment.py @@ -0,0 +1,218 @@ +"""Workflow comment models.""" + +from datetime import datetime +from typing import Optional + +from sqlalchemy import Index, func +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .account import Account +from .base import Base +from .engine import db +from .types import StringUUID + + +class WorkflowComment(Base): + """Workflow comment model for canvas commenting functionality. + + Comments are associated with apps rather than specific workflow versions, + since an app has only one draft workflow at a time and comments should persist + across workflow version changes. + + Attributes: + id: Comment ID + tenant_id: Workspace ID + app_id: App ID (primary association, comments belong to apps) + position_x: X coordinate on canvas + position_y: Y coordinate on canvas + content: Comment content + created_by: Creator account ID + created_at: Creation time + updated_at: Last update time + resolved: Whether comment is resolved + resolved_at: Resolution time + resolved_by: Resolver account ID + """ + + __tablename__ = "workflow_comments" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"), + Index("workflow_comments_app_idx", "tenant_id", "app_id"), + Index("workflow_comments_created_at_idx", "created_at"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + position_x: Mapped[float] = mapped_column(db.Float) + position_y: Mapped[float] = mapped_column(db.Float) + content: Mapped[str] = mapped_column(db.Text, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime) + resolved_by: Mapped[str | None] = mapped_column(StringUUID) + + # Relationships + replies: Mapped[list["WorkflowCommentReply"]] = relationship( + "WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan" + ) + mentions: Mapped[list["WorkflowCommentMention"]] = relationship( + "WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan" + ) + + @property + def created_by_account(self): + """Get creator account.""" + if hasattr(self, "_created_by_account_cache"): + return self._created_by_account_cache + return db.session.get(Account, self.created_by) + + def cache_created_by_account(self, account: Account | None) -> None: + """Cache creator account to avoid extra queries.""" + self._created_by_account_cache = account + + @property + def resolved_by_account(self): + """Get resolver account.""" + if hasattr(self, "_resolved_by_account_cache"): + return self._resolved_by_account_cache + if self.resolved_by: + return db.session.get(Account, self.resolved_by) + return None + + def cache_resolved_by_account(self, account: Account | None) -> None: + """Cache resolver account to avoid extra queries.""" + self._resolved_by_account_cache = account + + @property + def reply_count(self): + """Get reply count.""" + return len(self.replies) + + @property + def mention_count(self): + """Get mention count.""" + return len(self.mentions) + + @property + def participants(self): + """Get all participants (creator + repliers + mentioned users).""" + participant_ids: set[str] = set() + participants: list[Account] = [] + + # Use account properties to reuse preloaded caches and avoid hidden N+1. + if self.created_by not in participant_ids: + participant_ids.add(self.created_by) + created_by_account = self.created_by_account + if created_by_account: + participants.append(created_by_account) + + for reply in self.replies: + if reply.created_by in participant_ids: + continue + participant_ids.add(reply.created_by) + reply_account = reply.created_by_account + if reply_account: + participants.append(reply_account) + + for mention in self.mentions: + if mention.mentioned_user_id in participant_ids: + continue + participant_ids.add(mention.mentioned_user_id) + mentioned_account = mention.mentioned_user_account + if mentioned_account: + participants.append(mentioned_account) + + return participants + + +class WorkflowCommentReply(Base): + """Workflow comment reply model. + + Attributes: + id: Reply ID + comment_id: Parent comment ID + content: Reply content + created_by: Creator account ID + created_at: Creation time + """ + + __tablename__ = "workflow_comment_replies" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"), + Index("comment_replies_comment_idx", "comment_id"), + Index("comment_replies_created_at_idx", "created_at"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + comment_id: Mapped[str] = mapped_column( + StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + ) + content: Mapped[str] = mapped_column(db.Text, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + # Relationships + comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies") + + @property + def created_by_account(self): + """Get creator account.""" + if hasattr(self, "_created_by_account_cache"): + return self._created_by_account_cache + return db.session.get(Account, self.created_by) + + def cache_created_by_account(self, account: Account | None) -> None: + """Cache creator account to avoid extra queries.""" + self._created_by_account_cache = account + + +class WorkflowCommentMention(Base): + """Workflow comment mention model. + + Mentions are only for internal accounts since end users + cannot access workflow canvas and commenting features. + + Attributes: + id: Mention ID + comment_id: Parent comment ID + mentioned_user_id: Mentioned account ID + """ + + __tablename__ = "workflow_comment_mentions" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"), + Index("comment_mentions_comment_idx", "comment_id"), + Index("comment_mentions_reply_idx", "reply_id"), + Index("comment_mentions_user_idx", "mentioned_user_id"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()")) + comment_id: Mapped[str] = mapped_column( + StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + ) + reply_id: Mapped[str | None] = mapped_column( + StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True + ) + mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + # Relationships + comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions") + reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply") + + @property + def mentioned_user_account(self): + """Get mentioned account.""" + if hasattr(self, "_mentioned_user_account_cache"): + return self._mentioned_user_account_cache + return db.session.get(Account, self.mentioned_user_id) + + def cache_mentioned_user_account(self, account: Account | None) -> None: + """Cache mentioned account to avoid extra queries.""" + self._mentioned_user_account_cache = account diff --git a/api/models/workflow.py b/api/models/workflow.py index 63abf8c3b6..d688043920 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -490,7 +490,7 @@ class Workflow(Base): # bug :return: hash """ - entity = {"graph": self.graph_dict, "features": self.features_dict} + entity = {"graph": self.graph_dict} return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) diff --git a/api/pyproject.toml b/api/pyproject.toml index 76b6cf8d99..80f1b42a73 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -11,13 +11,16 @@ dependencies = [ "croniter>=6.2.2", "flask-cors>=6.0.2", "gevent>=26.4.0", + "gevent-websocket>=0.10.1", "gmpy2>=2.3.0", "google-api-python-client>=2.194.0", "gunicorn>=25.3.0", "psycogreen>=1.0.2", "psycopg2-binary>=2.9.11", + "python-socketio>=5.13.0", "redis[hiredis]>=7.4.0", "sendgrid>=6.12.5", + "sseclient-py>=1.8.0", # Stable: production-proven, cap below the next major "aliyun-log-python-sdk>=0.9.44,<1.0.0", @@ -166,7 +169,6 @@ dev = [ "celery-types>=0.23.0", "mypy>=1.20.1", # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. - "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", "pyrefly>=0.60.0", diff --git a/api/repositories/workflow_collaboration_repository.py b/api/repositories/workflow_collaboration_repository.py new file mode 100644 index 0000000000..000f80496d --- /dev/null +++ b/api/repositories/workflow_collaboration_repository.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import json +from typing import TypedDict + +from extensions.ext_redis import redis_client + +SESSION_STATE_TTL_SECONDS = 3600 +WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:" +WORKFLOW_LEADER_PREFIX = "workflow_leader:" +WS_SID_MAP_PREFIX = "ws_sid_map:" + + +class WorkflowSessionInfo(TypedDict): + user_id: str + username: str + avatar: str | None + sid: str + connected_at: int + + +class SidMapping(TypedDict): + workflow_id: str + user_id: str + + +class WorkflowCollaborationRepository: + def __init__(self) -> None: + self._redis = redis_client + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(redis_client={self._redis})" + + @staticmethod + def workflow_key(workflow_id: str) -> str: + return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}" + + @staticmethod + def leader_key(workflow_id: str) -> str: + return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}" + + @staticmethod + def sid_key(sid: str) -> str: + return f"{WS_SID_MAP_PREFIX}{sid}" + + @staticmethod + def _decode(value: str | bytes | None) -> str | None: + if value is None: + return None + if isinstance(value, bytes): + return value.decode("utf-8") + return value + + def refresh_session_state(self, workflow_id: str, sid: str) -> None: + workflow_key = self.workflow_key(workflow_id) + sid_key = self.sid_key(sid) + if self._redis.exists(workflow_key): + self._redis.expire(workflow_key, SESSION_STATE_TTL_SECONDS) + if self._redis.exists(sid_key): + self._redis.expire(sid_key, SESSION_STATE_TTL_SECONDS) + + def set_session_info(self, workflow_id: str, session_info: WorkflowSessionInfo) -> None: + workflow_key = self.workflow_key(workflow_id) + self._redis.hset(workflow_key, session_info["sid"], json.dumps(session_info)) + self._redis.set( + self.sid_key(session_info["sid"]), + json.dumps({"workflow_id": workflow_id, "user_id": session_info["user_id"]}), + ex=SESSION_STATE_TTL_SECONDS, + ) + self.refresh_session_state(workflow_id, session_info["sid"]) + + def get_sid_mapping(self, sid: str) -> SidMapping | None: + raw = self._redis.get(self.sid_key(sid)) + if not raw: + return None + value = self._decode(raw) + if not value: + return None + try: + return json.loads(value) + except (TypeError, json.JSONDecodeError): + return None + + def delete_session(self, workflow_id: str, sid: str) -> None: + self._redis.hdel(self.workflow_key(workflow_id), sid) + self._redis.delete(self.sid_key(sid)) + + def session_exists(self, workflow_id: str, sid: str) -> bool: + return bool(self._redis.hexists(self.workflow_key(workflow_id), sid)) + + def sid_mapping_exists(self, sid: str) -> bool: + return bool(self._redis.exists(self.sid_key(sid))) + + def get_session_sids(self, workflow_id: str) -> list[str]: + raw_sids = self._redis.hkeys(self.workflow_key(workflow_id)) + decoded_sids: list[str] = [] + for sid in raw_sids: + decoded = self._decode(sid) + if decoded: + decoded_sids.append(decoded) + return decoded_sids + + def list_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]: + sessions_json = self._redis.hgetall(self.workflow_key(workflow_id)) + users: list[WorkflowSessionInfo] = [] + + for session_info_json in sessions_json.values(): + value = self._decode(session_info_json) + if not value: + continue + try: + session_info = json.loads(value) + except (TypeError, json.JSONDecodeError): + continue + + if not isinstance(session_info, dict): + continue + if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info: + continue + + users.append( + { + "user_id": str(session_info["user_id"]), + "username": str(session_info["username"]), + "avatar": session_info.get("avatar"), + "sid": str(session_info["sid"]), + "connected_at": int(session_info.get("connected_at") or 0), + } + ) + + return users + + def get_current_leader(self, workflow_id: str) -> str | None: + raw = self._redis.get(self.leader_key(workflow_id)) + return self._decode(raw) + + def set_leader_if_absent(self, workflow_id: str, sid: str) -> bool: + return bool(self._redis.set(self.leader_key(workflow_id), sid, nx=True, ex=SESSION_STATE_TTL_SECONDS)) + + def set_leader(self, workflow_id: str, sid: str) -> None: + self._redis.set(self.leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS) + + def delete_leader(self, workflow_id: str) -> None: + self._redis.delete(self.leader_key(workflow_id)) + + def expire_leader(self, workflow_id: str) -> None: + self._redis.expire(self.leader_key(workflow_id), SESSION_STATE_TTL_SECONDS) diff --git a/api/services/feature_service.py b/api/services/feature_service.py index df653e0ba7..d0d3fbd66b 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -164,6 +164,7 @@ class SystemFeatureModel(BaseModel): enable_email_code_login: bool = False enable_email_password_login: bool = True enable_social_oauth_login: bool = False + enable_collaboration_mode: bool = False is_allow_register: bool = False is_allow_create_workspace: bool = False is_email_setup: bool = False @@ -244,6 +245,7 @@ class FeatureService: system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN + system_features.enable_collaboration_mode = dify_config.ENABLE_COLLABORATION_MODE system_features.is_allow_register = dify_config.ALLOW_REGISTER system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" diff --git a/api/services/workflow_collaboration_service.py b/api/services/workflow_collaboration_service.py new file mode 100644 index 0000000000..cf2f509052 --- /dev/null +++ b/api/services/workflow_collaboration_service.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import logging +import time +from collections.abc import Mapping + +from sqlalchemy import select + +from core.db.session_factory import session_factory +from models.account import Account +from models.model import App +from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository, WorkflowSessionInfo + +logger = logging.getLogger(__name__) + + +class WorkflowCollaborationService: + def __init__(self, repository: WorkflowCollaborationRepository, socketio) -> None: + self._repository = repository + self._socketio = socketio + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(repository={self._repository})" + + def save_socket_identity(self, sid: str, user: Account) -> None: + """Persist the authenticated console user on the raw socket session.""" + self._socketio.save_session( + sid, + { + "user_id": user.id, + "username": user.name, + "avatar": user.avatar, + "tenant_id": user.current_tenant_id, + }, + ) + + def authorize_and_join_workflow_room(self, workflow_id: str, sid: str) -> tuple[str, bool] | None: + """ + Join a collaboration room only after validating the socket session and tenant-scoped app access. + + The Socket.IO payload still calls the room key `workflow_id`, but the identifier is the workflow app's + `App.id`. Returning `None` lets the controller reject the join before any Redis or room state is created. + """ + session = self._socketio.get_session(sid) + user_id = session.get("user_id") + tenant_id = session.get("tenant_id") + if not user_id or not tenant_id: + return None + + if not self._can_access_workflow(workflow_id, str(tenant_id)): + logger.warning( + "Workflow collaboration join rejected: workflow_id=%s tenant_id=%s user_id=%s sid=%s", + workflow_id, + tenant_id, + user_id, + sid, + ) + return None + + session_info: WorkflowSessionInfo = { + "user_id": str(user_id), + "username": str(session.get("username", "Unknown")), + "avatar": session.get("avatar"), + "sid": sid, + "connected_at": int(time.time()), + } + + self._repository.set_session_info(workflow_id, session_info) + + leader_sid = self.get_or_set_leader(workflow_id, sid) + is_leader = leader_sid == sid + + self._socketio.enter_room(sid, workflow_id) + self.broadcast_online_users(workflow_id) + + self._socketio.emit("status", {"isLeader": is_leader}, room=sid) + + return str(user_id), is_leader + + def _can_access_workflow(self, workflow_id: str, tenant_id: str) -> bool: + """Check room access without relying on Flask's app-context-bound scoped session.""" + with session_factory.create_session() as session: + app_id = session.scalar(select(App.id).where(App.id == workflow_id, App.tenant_id == tenant_id).limit(1)) + return app_id is not None + + def disconnect_session(self, sid: str) -> None: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return + + workflow_id = mapping["workflow_id"] + self._repository.delete_session(workflow_id, sid) + + self.handle_leader_disconnect(workflow_id, sid) + self.broadcast_online_users(workflow_id) + + def relay_collaboration_event(self, sid: str, data: Mapping[str, object]) -> tuple[dict[str, str], int]: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return {"msg": "unauthorized"}, 401 + + workflow_id = mapping["workflow_id"] + user_id = mapping["user_id"] + self.refresh_session_state(workflow_id, sid) + + event_type = data.get("type") + event_data = data.get("data") + timestamp = data.get("timestamp", int(time.time())) + + if not event_type: + return {"msg": "invalid event type"}, 400 + + if event_type == "sync_request": + leader_sid = self._repository.get_current_leader(workflow_id) + target_sid: str | None + if leader_sid and self.is_session_active(workflow_id, leader_sid): + target_sid = leader_sid + else: + if leader_sid: + self._repository.delete_leader(workflow_id) + target_sid = self._select_graph_leader(workflow_id, preferred_sid=sid) + if target_sid: + self._repository.set_leader(workflow_id, target_sid) + self.broadcast_leader_change(workflow_id, target_sid) + + if not target_sid: + return {"msg": "no_active_leader"}, 200 + + self._socketio.emit( + "collaboration_update", + {"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp}, + room=target_sid, + ) + + return {"msg": "sync_request_forwarded"}, 200 + + self._socketio.emit( + "collaboration_update", + {"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp}, + room=workflow_id, + skip_sid=sid, + ) + + return {"msg": "event_broadcasted"}, 200 + + def relay_graph_event(self, sid: str, data: object) -> tuple[dict[str, str], int]: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return {"msg": "unauthorized"}, 401 + + workflow_id = mapping["workflow_id"] + self.refresh_session_state(workflow_id, sid) + + self._socketio.emit("graph_update", data, room=workflow_id, skip_sid=sid) + + return {"msg": "graph_update_broadcasted"}, 200 + + def get_or_set_leader(self, workflow_id: str, sid: str) -> str: + current_leader = self._repository.get_current_leader(workflow_id) + + if current_leader: + if self.is_session_active(workflow_id, current_leader): + return current_leader + self._repository.delete_session(workflow_id, current_leader) + self._repository.delete_leader(workflow_id) + + was_set = self._repository.set_leader_if_absent(workflow_id, sid) + + if was_set: + if current_leader: + self.broadcast_leader_change(workflow_id, sid) + return sid + + current_leader = self._repository.get_current_leader(workflow_id) + if current_leader: + return current_leader + + return sid + + def handle_leader_disconnect(self, workflow_id: str, disconnected_sid: str) -> None: + current_leader = self._repository.get_current_leader(workflow_id) + if not current_leader: + return + + if current_leader != disconnected_sid: + return + + new_leader_sid = self._select_graph_leader(workflow_id) + if new_leader_sid: + self._repository.set_leader(workflow_id, new_leader_sid) + self.broadcast_leader_change(workflow_id, new_leader_sid) + else: + self._repository.delete_leader(workflow_id) + + def broadcast_leader_change(self, workflow_id: str, new_leader_sid: str | None) -> None: + for sid in self._repository.get_session_sids(workflow_id): + try: + is_leader = new_leader_sid is not None and sid == new_leader_sid + self._socketio.emit("status", {"isLeader": is_leader}, room=sid) + except Exception: + logging.exception("Failed to emit leader status to session %s", sid) + + def get_current_leader(self, workflow_id: str) -> str | None: + return self._repository.get_current_leader(workflow_id) + + def _prune_inactive_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]: + """Remove inactive sessions from storage and return active sessions only.""" + sessions = self._repository.list_sessions(workflow_id) + if not sessions: + return [] + + active_sessions: list[WorkflowSessionInfo] = [] + stale_sids: list[str] = [] + for session in sessions: + sid = session["sid"] + if self.is_session_active(workflow_id, sid): + active_sessions.append(session) + else: + stale_sids.append(sid) + + for sid in stale_sids: + self._repository.delete_session(workflow_id, sid) + + return active_sessions + + def broadcast_online_users(self, workflow_id: str) -> None: + users = self._prune_inactive_sessions(workflow_id) + users.sort(key=lambda x: x.get("connected_at") or 0) + + leader_sid = self.get_current_leader(workflow_id) + previous_leader = leader_sid + active_sids = {user["sid"] for user in users} + if leader_sid and leader_sid not in active_sids: + self._repository.delete_leader(workflow_id) + leader_sid = None + + if not leader_sid and users: + leader_sid = self._select_graph_leader(workflow_id) + if leader_sid: + self._repository.set_leader(workflow_id, leader_sid) + + if leader_sid != previous_leader: + self.broadcast_leader_change(workflow_id, leader_sid) + + self._socketio.emit( + "online_users", + {"workflow_id": workflow_id, "users": users, "leader": leader_sid}, + room=workflow_id, + ) + + def refresh_session_state(self, workflow_id: str, sid: str) -> None: + self._repository.refresh_session_state(workflow_id, sid) + self._ensure_leader(workflow_id, sid) + + def _ensure_leader(self, workflow_id: str, sid: str) -> None: + current_leader = self._repository.get_current_leader(workflow_id) + if current_leader and self.is_session_active(workflow_id, current_leader): + self._repository.expire_leader(workflow_id) + return + + if current_leader: + self._repository.delete_leader(workflow_id) + + self._repository.set_leader(workflow_id, sid) + self.broadcast_leader_change(workflow_id, sid) + + def _select_graph_leader(self, workflow_id: str, preferred_sid: str | None = None) -> str | None: + session_sids = [ + session["sid"] + for session in self._repository.list_sessions(workflow_id) + if session.get("graph_active", True) and self.is_session_active(workflow_id, session["sid"]) + ] + if not session_sids: + return None + if preferred_sid and preferred_sid in session_sids: + return preferred_sid + return session_sids[0] + + def is_session_active(self, workflow_id: str, sid: str) -> bool: + if not sid: + return False + + try: + if not self._socketio.manager.is_connected(sid, "/"): + return False + except AttributeError: + return False + + if not self._repository.session_exists(workflow_id, sid): + return False + + if not self._repository.sid_mapping_exists(sid): + return False + + return True diff --git a/api/services/workflow_comment_service.py b/api/services/workflow_comment_service.py new file mode 100644 index 0000000000..ff47e4f253 --- /dev/null +++ b/api/services/workflow_comment_service.py @@ -0,0 +1,564 @@ +import logging +from collections.abc import Sequence + +from sqlalchemy import desc, select +from sqlalchemy.orm import Session, selectinload +from werkzeug.exceptions import Forbidden, NotFound + +from configs import dify_config +from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now +from libs.helper import uuid_value +from models import App, TenantAccountJoin, WorkflowComment, WorkflowCommentMention, WorkflowCommentReply +from models.account import Account +from tasks.mail_workflow_comment_task import send_workflow_comment_mention_email_task + +logger = logging.getLogger(__name__) + + +class WorkflowCommentService: + """Service for managing workflow comments.""" + + @staticmethod + def _validate_content(content: str) -> None: + if len(content.strip()) == 0: + raise ValueError("Comment content cannot be empty") + + if len(content) > 1000: + raise ValueError("Comment content cannot exceed 1000 characters") + + @staticmethod + def _filter_valid_mentioned_user_ids( + mentioned_user_ids: Sequence[str], *, session: Session, tenant_id: str + ) -> list[str]: + """Return deduplicated UUID user IDs that belong to the tenant, preserving input order.""" + unique_user_ids: list[str] = [] + seen: set[str] = set() + for user_id in mentioned_user_ids: + if not isinstance(user_id, str): + continue + if not uuid_value(user_id): + continue + if user_id in seen: + continue + seen.add(user_id) + unique_user_ids.append(user_id) + if not unique_user_ids: + return [] + + tenant_member_ids = { + str(account_id) + for account_id in session.scalars( + select(TenantAccountJoin.account_id).where( + TenantAccountJoin.tenant_id == tenant_id, + TenantAccountJoin.account_id.in_(unique_user_ids), + ) + ).all() + } + + return [user_id for user_id in unique_user_ids if user_id in tenant_member_ids] + + @staticmethod + def _format_comment_excerpt(content: str, max_length: int = 200) -> str: + """Trim comment content for email display.""" + trimmed = content.strip() + if len(trimmed) <= max_length: + return trimmed + if max_length <= 3: + return trimmed[:max_length] + return f"{trimmed[: max_length - 3].rstrip()}..." + + @staticmethod + def _build_mention_email_payloads( + session: Session, + tenant_id: str, + app_id: str, + mentioner_id: str, + mentioned_user_ids: Sequence[str], + content: str, + ) -> list[dict[str, str]]: + """Prepare email payloads for mentioned users, including workflow app link.""" + if not mentioned_user_ids: + return [] + + candidate_user_ids = [user_id for user_id in mentioned_user_ids if user_id != mentioner_id] + if not candidate_user_ids: + return [] + + app_name_value = session.scalar(select(App.name).where(App.id == app_id, App.tenant_id == tenant_id)) + app_name = app_name_value if isinstance(app_name_value, str) and app_name_value else "Dify app" + commenter_name_value = session.scalar(select(Account.name).where(Account.id == mentioner_id)) + commenter_name = ( + commenter_name_value if isinstance(commenter_name_value, str) and commenter_name_value else "Dify user" + ) + comment_excerpt = WorkflowCommentService._format_comment_excerpt(content) + base_url = dify_config.CONSOLE_WEB_URL.rstrip("/") + app_url = f"{base_url}/app/{app_id}/workflow" + + accounts = session.scalars( + select(Account) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where(TenantAccountJoin.tenant_id == tenant_id, Account.id.in_(candidate_user_ids)) + ).all() + + payloads: list[dict[str, str]] = [] + for account in accounts: + email = account.email + if not isinstance(email, str) or not email: + continue + mentioned_name = account.name if isinstance(account.name, str) and account.name else email + language = ( + account.interface_language + if isinstance(account.interface_language, str) and account.interface_language + else "en-US" + ) + payloads.append( + { + "language": language, + "to": email, + "mentioned_name": mentioned_name, + "commenter_name": commenter_name, + "app_name": app_name, + "comment_content": comment_excerpt, + "app_url": app_url, + } + ) + return payloads + + @staticmethod + def _dispatch_mention_emails(payloads: Sequence[dict[str, str]]) -> None: + """Enqueue mention notification emails.""" + for payload in payloads: + send_workflow_comment_mention_email_task.delay(**payload) + + @staticmethod + def get_comments(tenant_id: str, app_id: str) -> Sequence[WorkflowComment]: + """Get all comments for a workflow.""" + with Session(db.engine) as session: + # Get all comments with eager loading + stmt = ( + select(WorkflowComment) + .options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions)) + .where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id) + .order_by(desc(WorkflowComment.created_at)) + ) + + comments = session.scalars(stmt).all() + + # Batch preload all Account objects to avoid N+1 queries + WorkflowCommentService._preload_accounts(session, comments) + + return comments + + @staticmethod + def _preload_accounts(session: Session, comments: Sequence[WorkflowComment]) -> None: + """Batch preload Account objects for comments, replies, and mentions.""" + # Collect all user IDs + user_ids: set[str] = set() + for comment in comments: + user_ids.add(comment.created_by) + if comment.resolved_by: + user_ids.add(comment.resolved_by) + user_ids.update(reply.created_by for reply in comment.replies) + user_ids.update(mention.mentioned_user_id for mention in comment.mentions) + + if not user_ids: + return + + # Batch query all accounts + accounts = session.scalars(select(Account).where(Account.id.in_(user_ids))).all() + account_map = {str(account.id): account for account in accounts} + + # Cache accounts on objects + for comment in comments: + comment.cache_created_by_account(account_map.get(comment.created_by)) + comment.cache_resolved_by_account(account_map.get(comment.resolved_by) if comment.resolved_by else None) + for reply in comment.replies: + reply.cache_created_by_account(account_map.get(reply.created_by)) + for mention in comment.mentions: + mention.cache_mentioned_user_account(account_map.get(mention.mentioned_user_id)) + + @staticmethod + def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session | None = None) -> WorkflowComment: + """Get a specific comment.""" + + def _get_comment(session: Session) -> WorkflowComment: + stmt = ( + select(WorkflowComment) + .options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions)) + .where( + WorkflowComment.id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + ) + comment = session.scalar(stmt) + + if not comment: + raise NotFound("Comment not found") + + # Preload accounts to avoid N+1 queries + WorkflowCommentService._preload_accounts(session, [comment]) + + return comment + + if session is not None: + return _get_comment(session) + else: + with Session(db.engine, expire_on_commit=False) as session: + return _get_comment(session) + + @staticmethod + def create_comment( + tenant_id: str, + app_id: str, + created_by: str, + content: str, + position_x: float, + position_y: float, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Create a new workflow comment and send mention notification emails.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine) as session: + comment = WorkflowComment( + tenant_id=tenant_id, + app_id=app_id, + position_x=position_x, + position_y=position_y, + content=content, + created_by=created_by, + ) + + session.add(comment) + session.flush() # Get the comment ID for mentions + + # Create mentions if specified + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + mentioned_user_ids or [], + session=session, + tenant_id=tenant_id, + ) + for user_id in mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=comment.id, + reply_id=None, # This is a comment mention, not reply mention + mentioned_user_id=user_id, + ) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=tenant_id, + app_id=app_id, + mentioner_id=created_by, + mentioned_user_ids=mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + # Return only what we need - id and created_at + return {"id": comment.id, "created_at": comment.created_at} + + @staticmethod + def update_comment( + tenant_id: str, + app_id: str, + comment_id: str, + user_id: str, + content: str, + position_x: float | None = None, + position_y: float | None = None, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Update a workflow comment and notify newly mentioned users. + + `mentioned_user_ids=None` means "leave mentions unchanged". + Passing an explicit list replaces the existing comment mentions, including clearing them with `[]`. + """ + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + # Get comment with validation + stmt = select(WorkflowComment).where( + WorkflowComment.id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + comment = session.scalar(stmt) + + if not comment: + raise NotFound("Comment not found") + + # Only the creator can update the comment + if comment.created_by != user_id: + raise Forbidden("Only the comment creator can update it") + + # Update comment fields + comment.content = content + if position_x is not None: + comment.position_x = position_x + if position_y is not None: + comment.position_y = position_y + + mention_email_payloads: list[dict[str, str]] = [] + if mentioned_user_ids is not None: + # Replace comment mentions only when the client explicitly sends the mention list. + existing_mentions = session.scalars( + select(WorkflowCommentMention).where( + WorkflowCommentMention.comment_id == comment.id, + WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions + ) + ).all() + existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions} + for mention in existing_mentions: + session.delete(mention) + + filtered_mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + mentioned_user_ids, + session=session, + tenant_id=tenant_id, + ) + new_mentioned_user_ids = [ + mentioned_user_id + for mentioned_user_id in filtered_mentioned_user_ids + if mentioned_user_id not in existing_mentioned_user_ids + ] + for mentioned_user_id in filtered_mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=comment.id, + reply_id=None, # This is a comment mention + mentioned_user_id=mentioned_user_id, + ) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=tenant_id, + app_id=app_id, + mentioner_id=user_id, + mentioned_user_ids=new_mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": comment.id, "updated_at": comment.updated_at} + + @staticmethod + def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None: + """Delete a workflow comment.""" + with Session(db.engine, expire_on_commit=False) as session: + comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session) + + # Only the creator can delete the comment + if comment.created_by != user_id: + raise Forbidden("Only the comment creator can delete it") + + # Delete associated mentions (both comment and reply mentions) + mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id) + ).all() + for mention in mentions: + session.delete(mention) + + # Delete associated replies + replies = session.scalars( + select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id) + ).all() + for reply in replies: + session.delete(reply) + + session.delete(comment) + session.commit() + + @staticmethod + def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment: + """Resolve a workflow comment.""" + with Session(db.engine, expire_on_commit=False) as session: + comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session) + if comment.resolved: + return comment + + comment.resolved = True + comment.resolved_at = naive_utc_now() + comment.resolved_by = user_id + session.commit() + + return comment + + @staticmethod + def create_reply( + comment_id: str, content: str, created_by: str, mentioned_user_ids: list[str] | None = None + ) -> dict: + """Add a reply to a workflow comment and notify mentioned users.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + # Check if comment exists + comment = session.get(WorkflowComment, comment_id) + if not comment: + raise NotFound("Comment not found") + + reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by) + + session.add(reply) + session.flush() # Get the reply ID for mentions + + # Create mentions if specified + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + mentioned_user_ids or [], + session=session, + tenant_id=comment.tenant_id, + ) + for user_id in mentioned_user_ids: + # Create mention linking to specific reply + mention = WorkflowCommentMention(comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=comment.tenant_id, + app_id=comment.app_id, + mentioner_id=created_by, + mentioned_user_ids=mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": reply.id, "created_at": reply.created_at} + + @staticmethod + def _get_reply_in_comment_scope( + *, + session: Session, + tenant_id: str, + app_id: str, + comment_id: str, + reply_id: str, + ) -> WorkflowCommentReply: + """Get a reply scoped to tenant/app/comment to prevent cross-thread mutations.""" + stmt = ( + select(WorkflowCommentReply) + .join(WorkflowComment, WorkflowComment.id == WorkflowCommentReply.comment_id) + .where( + WorkflowCommentReply.id == reply_id, + WorkflowCommentReply.comment_id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + .limit(1) + ) + reply = session.scalar(stmt) + if not reply: + raise NotFound("Reply not found") + return reply + + @staticmethod + def update_reply( + tenant_id: str, + app_id: str, + comment_id: str, + reply_id: str, + user_id: str, + content: str, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Update a comment reply and notify newly mentioned users.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + reply = WorkflowCommentService._get_reply_in_comment_scope( + session=session, + tenant_id=tenant_id, + app_id=app_id, + comment_id=comment_id, + reply_id=reply_id, + ) + + # Only the creator can update the reply + if reply.created_by != user_id: + raise Forbidden("Only the reply creator can update it") + + reply.content = content + + # Update mentions - first remove existing mentions for this reply + existing_mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id) + ).all() + existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions} + for mention in existing_mentions: + session.delete(mention) + + # Add mentions + raw_mentioned_user_ids = mentioned_user_ids or [] + comment = session.get(WorkflowComment, reply.comment_id) + mentioned_user_ids = [] + if comment: + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + raw_mentioned_user_ids, + session=session, + tenant_id=comment.tenant_id, + ) + new_mentioned_user_ids = [ + user_id for user_id in mentioned_user_ids if user_id not in existing_mentioned_user_ids + ] + for user_id_str in mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str + ) + session.add(mention) + + mention_email_payloads: list[dict[str, str]] = [] + if comment: + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=comment.tenant_id, + app_id=comment.app_id, + mentioner_id=user_id, + mentioned_user_ids=new_mentioned_user_ids, + content=content, + ) + + session.commit() + session.refresh(reply) # Refresh to get updated timestamp + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": reply.id, "updated_at": reply.updated_at} + + @staticmethod + def delete_reply(tenant_id: str, app_id: str, comment_id: str, reply_id: str, user_id: str) -> None: + """Delete a comment reply.""" + with Session(db.engine, expire_on_commit=False) as session: + reply = WorkflowCommentService._get_reply_in_comment_scope( + session=session, + tenant_id=tenant_id, + app_id=app_id, + comment_id=comment_id, + reply_id=reply_id, + ) + + # Only the creator can delete the reply + if reply.created_by != user_id: + raise Forbidden("Only the reply creator can delete it") + + # Delete associated mentions first + mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id) + ).all() + for mention in mentions: + session.delete(mention) + + session.delete(reply) + session.commit() + + @staticmethod + def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment: + """Validate that a comment belongs to the specified tenant and app.""" + return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0e1864ce9a..55a93c33b1 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -199,6 +199,16 @@ class WorkflowService: return workflow + def get_accessible_app_ids(self, app_ids: Sequence[str], tenant_id: str) -> set[str]: + """ + Return app IDs that belong to the given tenant. + """ + if not app_ids: + return set() + + stmt = select(App.id).where(App.id.in_(app_ids), App.tenant_id == tenant_id) + return {str(app_id) for app_id in db.session.scalars(stmt).all()} + def get_all_published_workflow( self, *, @@ -296,6 +306,78 @@ class WorkflowService: # return draft workflow return workflow + def update_draft_workflow_environment_variables( + self, + *, + app_model: App, + environment_variables: Sequence[VariableBase], + account: Account, + ): + """ + Update draft workflow environment variables + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + workflow.environment_variables = environment_variables + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + + def update_draft_workflow_conversation_variables( + self, + *, + app_model: App, + conversation_variables: Sequence[VariableBase], + account: Account, + ): + """ + Update draft workflow conversation variables + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + workflow.conversation_variables = conversation_variables + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + + def update_draft_workflow_features( + self, + *, + app_model: App, + features: dict, + account: Account, + ): + """ + Update draft workflow features + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + # validate features structure + self.validate_features_structure(app_model=app_model, features=features) + + workflow.features = json.dumps(features) + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + def restore_published_workflow_to_draft( self, *, diff --git a/api/tasks/mail_workflow_comment_task.py b/api/tasks/mail_workflow_comment_task.py new file mode 100644 index 0000000000..36d51f0514 --- /dev/null +++ b/api/tasks/mail_workflow_comment_task.py @@ -0,0 +1,65 @@ +import logging +import time + +import click +from celery import shared_task + +from extensions.ext_mail import mail +from libs.email_i18n import EmailType, get_email_i18n_service + +logger = logging.getLogger(__name__) + + +@shared_task(queue="mail") +def send_workflow_comment_mention_email_task( + language: str, + to: str, + mentioned_name: str, + commenter_name: str, + app_name: str, + comment_content: str, + app_url: str, +): + """ + Send workflow comment mention email with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + mentioned_name: Name of the mentioned user + commenter_name: Name of the comment author + app_name: Name of the app where the comment was made + comment_content: Comment content excerpt + app_url: Link to the app workflow page + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start workflow comment mention mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.WORKFLOW_COMMENT_MENTION, + language_code=language, + to=to, + template_context={ + "to": to, + "mentioned_name": mentioned_name, + "commenter_name": commenter_name, + "app_name": app_name, + "comment_content": comment_content, + "app_url": app_url, + }, + ) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Send workflow comment mention mail to {to} succeeded: latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception: + logger.exception("workflow comment mention email to %s failed", to) diff --git a/api/templates/without-brand/workflow_comment_mention_template_en-US.html b/api/templates/without-brand/workflow_comment_mention_template_en-US.html new file mode 100644 index 0000000000..1ef8fe4e3f --- /dev/null +++ b/api/templates/without-brand/workflow_comment_mention_template_en-US.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

You were mentioned in a workflow comment

+
+

Hi {{ mentioned_name }},

+

{{ commenter_name }} mentioned you in {{ app_name }}.

+
+
+

{{ comment_content }}

+
+

Open {{ application_title }} to reply to the comment.

+
+ + + diff --git a/api/templates/without-brand/workflow_comment_mention_template_zh-CN.html b/api/templates/without-brand/workflow_comment_mention_template_zh-CN.html new file mode 100644 index 0000000000..8b9b2dbe71 --- /dev/null +++ b/api/templates/without-brand/workflow_comment_mention_template_zh-CN.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

你在工作流评论中被提及

+
+

你好,{{ mentioned_name }}:

+

{{ commenter_name }} 在 {{ app_name }} 中提及了你。

+
+
+

{{ comment_content }}

+
+

请在 {{ application_title }} 中查看并回复此评论。

+
+ + + diff --git a/api/templates/workflow_comment_mention_template_en-US.html b/api/templates/workflow_comment_mention_template_en-US.html new file mode 100644 index 0000000000..1ef8fe4e3f --- /dev/null +++ b/api/templates/workflow_comment_mention_template_en-US.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

You were mentioned in a workflow comment

+
+

Hi {{ mentioned_name }},

+

{{ commenter_name }} mentioned you in {{ app_name }}.

+
+
+

{{ comment_content }}

+
+

Open {{ application_title }} to reply to the comment.

+
+ + + diff --git a/api/templates/workflow_comment_mention_template_zh-CN.html b/api/templates/workflow_comment_mention_template_zh-CN.html new file mode 100644 index 0000000000..8b9b2dbe71 --- /dev/null +++ b/api/templates/workflow_comment_mention_template_zh-CN.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

你在工作流评论中被提及

+
+

你好,{{ mentioned_name }}:

+

{{ commenter_name }} 在 {{ app_name }} 中提及了你。

+
+
+

{{ comment_content }}

+
+

请在 {{ application_title }} 中查看并回复此评论。

+
+ + + diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index b2e8dda443..09078d196d 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -48,7 +48,7 @@ os.environ["OPENDAL_FS_ROOT"] = "/tmp/dify-storage" os.environ.setdefault("STORAGE_TYPE", "opendal") os.environ.setdefault("OPENDAL_SCHEME", "fs") -_CACHED_APP = create_app() +_SIO_APP, _CACHED_APP = create_app() @pytest.fixture(scope="session") diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index ef74893f07..66a25e5daf 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -369,7 +369,7 @@ def _create_app_with_containers() -> Flask: # Create and configure the Flask application logger.info("Initializing Flask application...") - app = create_app() + sio_app, app = create_app() logger.info("Flask application created successfully") # Initialize database schema diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py index b3e7dd2a59..315936d721 100644 --- a/api/tests/test_containers_integration_tests/services/test_feature_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -274,6 +274,7 @@ class TestFeatureService: mock_config.ENABLE_EMAIL_CODE_LOGIN = True mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ENABLE_COLLABORATION_MODE = True mock_config.ALLOW_REGISTER = False mock_config.ALLOW_CREATE_WORKSPACE = False mock_config.MAIL_TYPE = "smtp" @@ -298,6 +299,7 @@ class TestFeatureService: # Verify authentication settings assert result.enable_email_code_login is True assert result.enable_email_password_login is False + assert result.enable_collaboration_mode is True assert result.is_allow_register is False assert result.is_allow_create_workspace is False @@ -401,6 +403,7 @@ class TestFeatureService: mock_config.ENABLE_EMAIL_CODE_LOGIN = True mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ENABLE_COLLABORATION_MODE = False mock_config.ALLOW_REGISTER = True mock_config.ALLOW_CREATE_WORKSPACE = True mock_config.MAIL_TYPE = "smtp" @@ -422,6 +425,7 @@ class TestFeatureService: assert result.enable_email_code_login is True assert result.enable_email_password_login is True assert result.enable_social_oauth_login is False + assert result.enable_collaboration_mode is False assert result.is_allow_register is True assert result.is_allow_create_workspace is True assert result.is_email_setup is True diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index f32d0ef0ec..9f20886a81 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from datetime import datetime from types import SimpleNamespace from unittest.mock import Mock @@ -347,3 +348,87 @@ def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.Monk ): with pytest.raises(NotFound): handler(api, app_model=SimpleNamespace(id="app")) + + +def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: pytest.MonkeyPatch) -> None: + app_id_1 = "11111111-1111-1111-1111-111111111111" + app_id_2 = "22222222-2222-2222-2222-222222222222" + signed_avatar_url = "https://files.example.com/signed/avatar-1" + sign_avatar = Mock(return_value=signed_avatar_url) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(get_accessible_app_ids=lambda app_ids, tenant_id: {app_id_1}), + ) + monkeypatch.setattr(workflow_module.file_helpers, "get_signed_file_url", sign_avatar) + + workflow_module.redis_client.hgetall.side_effect = lambda key: ( + { + b"sid-1": json.dumps( + { + "user_id": "u-1", + "username": "Alice", + "avatar": "avatar-file-id", + "sid": "sid-1", + } + ) + } + if key == f"{workflow_module.WORKFLOW_ONLINE_USERS_PREFIX}{app_id_1}" + else {} + ) + + api = workflow_module.WorkflowOnlineUsersApi() + handler = _unwrap(api.get) + + with app.test_request_context( + f"/apps/workflows/online-users?app_ids={app_id_1},{app_id_2}", + method="GET", + ): + response = handler(api) + + assert response == { + "data": [ + { + "app_id": app_id_1, + "users": [ + { + "user_id": "u-1", + "username": "Alice", + "avatar": signed_avatar_url, + "sid": "sid-1", + } + ], + } + ] + } + workflow_module.redis_client.hgetall.assert_called_once_with( + f"{workflow_module.WORKFLOW_ONLINE_USERS_PREFIX}{app_id_1}" + ) + sign_avatar.assert_called_once_with("avatar-file-id") + + +def test_workflow_online_users_rejects_excessive_workflow_ids(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) + accessible_app_ids = Mock(return_value=set()) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(get_accessible_app_ids=accessible_app_ids), + ) + + excessive_ids = ",".join(f"wf-{index}" for index in range(workflow_module.MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS + 1)) + + api = workflow_module.WorkflowOnlineUsersApi() + handler = _unwrap(api.get) + + with app.test_request_context( + f"/apps/workflows/online-users?app_ids={excessive_ids}", + method="GET", + ): + with pytest.raises(HTTPException) as exc: + handler(api) + + assert exc.value.code == 400 + assert "Maximum" in exc.value.description + accessible_app_ids.assert_not_called() diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py new file mode 100644 index 0000000000..85afcf0e60 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +from contextlib import nullcontext +from dataclasses import dataclass +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +from controllers.console import console_ns +from controllers.console import wraps as console_wraps +from controllers.console.app import workflow_comment as workflow_comment_module +from controllers.console.app import wraps as app_wraps +from libs import login as login_lib +from models.account import Account, AccountStatus, TenantAccountRole + + +def _make_account(role: TenantAccountRole) -> Account: + account = Account(name="tester", email="tester@example.com") + account.status = AccountStatus.ACTIVE + account.role = role + account.id = "account-123" # type: ignore[assignment] + account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] + account._get_current_object = lambda: account # type: ignore[attr-defined] + return account + + +def _make_app() -> SimpleNamespace: + return SimpleNamespace(id="app-123", tenant_id="tenant-123", status="normal", mode="workflow") + + +def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None: + monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) + monkeypatch.setattr(login_lib, "current_user", account) + monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) + monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") + monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model) + monkeypatch.setattr(workflow_comment_module, "current_user", account) + + +def _patch_write_services(monkeypatch: pytest.MonkeyPatch) -> None: + for method_name in ( + "create_comment", + "update_comment", + "delete_comment", + "resolve_comment", + "validate_comment_access", + "create_reply", + "update_reply", + "delete_reply", + ): + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, method_name, MagicMock()) + + +def _patch_payload(payload: dict[str, object] | None): + if payload is None: + return nullcontext() + return patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ) + + +@dataclass(frozen=True) +class WriteCase: + resource_cls: type + method_name: str + path: str + kwargs: dict[str, str] + payload: dict[str, object] | None = None + + +@pytest.mark.parametrize( + "case", + [ + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentListApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments", + kwargs={"app_id": "app-123"}, + payload={"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentDetailApi, + method_name="put", + path="/console/api/apps/app-123/workflow/comments/comment-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + payload={"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentDetailApi, + method_name="delete", + path="/console/api/apps/app-123/workflow/comments/comment-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentResolveApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments/comment-1/resolve", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + payload={"content": "reply", "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyDetailApi, + method_name="put", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies/reply-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1", "reply_id": "reply-1"}, + payload={"content": "reply", "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyDetailApi, + method_name="delete", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies/reply-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1", "reply_id": "reply-1"}, + ), + ], +) +def test_write_endpoints_require_edit_permission(app: Flask, monkeypatch: pytest.MonkeyPatch, case: WriteCase) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.NORMAL) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + _patch_write_services(monkeypatch) + + with app.test_request_context(case.path, method=case.method_name.upper(), json=case.payload): + with _patch_payload(case.payload): + handler = getattr(case.resource_cls(), case.method_name) + with pytest.raises(Forbidden): + handler(**case.kwargs) + + +def test_create_comment_allows_editor(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.EDITOR) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + + create_comment_mock = MagicMock(return_value={"id": "comment-1"}) + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "create_comment", create_comment_mock) + payload = {"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []} + + with app.test_request_context("/console/api/apps/app-123/workflow/comments", method="POST", json=payload): + with _patch_payload(payload): + result = workflow_comment_module.WorkflowCommentListApi().post(app_id="app-123") + + if isinstance(result, tuple): + response = result[0] + else: + response = result + assert response["id"] == "comment-1" + create_comment_mock.assert_called_once_with( + tenant_id="tenant-123", + app_id="app-123", + created_by="account-123", + content="hello", + position_x=1.0, + position_y=2.0, + mentioned_user_ids=[], + ) + + +def test_update_comment_omits_mentions_when_payload_does_not_include_them( + app: Flask, monkeypatch: pytest.MonkeyPatch +) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.EDITOR) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + + update_comment_mock = MagicMock(return_value={"id": "comment-1", "updated_at": datetime(2024, 1, 1, 12, 0, 0)}) + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "update_comment", update_comment_mock) + payload = {"content": "hello", "position_x": 10.0, "position_y": 20.0} + + with app.test_request_context("/console/api/apps/app-123/workflow/comments/comment-1", method="PUT", json=payload): + with _patch_payload(payload): + workflow_comment_module.WorkflowCommentDetailApi().put(app_id="app-123", comment_id="comment-1") + + update_comment_mock.assert_called_once_with( + tenant_id="tenant-123", + app_id="app-123", + comment_id="comment-1", + user_id="account-123", + content="hello", + position_x=10.0, + position_y=20.0, + mentioned_user_ids=None, + ) diff --git a/api/tests/unit_tests/libs/test_email_i18n.py b/api/tests/unit_tests/libs/test_email_i18n.py index 962a36fe03..b4c0eaf7ee 100644 --- a/api/tests/unit_tests/libs/test_email_i18n.py +++ b/api/tests/unit_tests/libs/test_email_i18n.py @@ -503,6 +503,7 @@ class TestEmailI18nIntegration: EmailType.ACCOUNT_DELETION_VERIFICATION, EmailType.QUEUE_MONITOR_ALERT, EmailType.DOCUMENT_CLEAN_NOTIFY, + EmailType.WORKFLOW_COMMENT_MENTION, ] for email_type in expected_types: diff --git a/api/tests/unit_tests/models/test_comment_models.py b/api/tests/unit_tests/models/test_comment_models.py new file mode 100644 index 0000000000..277335cbef --- /dev/null +++ b/api/tests/unit_tests/models/test_comment_models.py @@ -0,0 +1,100 @@ +from unittest.mock import Mock, patch + +from models.comment import WorkflowComment, WorkflowCommentMention, WorkflowCommentReply + + +def test_workflow_comment_account_properties_and_cache() -> None: + comment = WorkflowComment(created_by="user-1", resolved_by="user-2", content="hello", position_x=1, position_y=2) + created_account = Mock(id="user-1") + resolved_account = Mock(id="user-2") + + with patch("models.comment.db.session.get", side_effect=[created_account, resolved_account]) as get_mock: + assert comment.created_by_account is created_account + assert comment.resolved_by_account is resolved_account + assert get_mock.call_count == 2 + + comment.cache_created_by_account(created_account) + comment.cache_resolved_by_account(resolved_account) + with patch("models.comment.db.session.get") as get_mock: + assert comment.created_by_account is created_account + assert comment.resolved_by_account is resolved_account + get_mock.assert_not_called() + + comment_without_resolver = WorkflowComment( + created_by="user-1", + resolved_by=None, + content="hello", + position_x=1, + position_y=2, + ) + with patch("models.comment.db.session.get") as get_mock: + assert comment_without_resolver.resolved_by_account is None + get_mock.assert_not_called() + + +def test_workflow_comment_counts_and_participants() -> None: + reply_1 = WorkflowCommentReply(comment_id="comment-1", content="reply-1", created_by="user-2") + reply_2 = WorkflowCommentReply(comment_id="comment-1", content="reply-2", created_by="user-2") + mention_1 = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-3") + mention_2 = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-4") + comment = WorkflowComment(created_by="user-1", resolved_by=None, content="hello", position_x=1, position_y=2) + comment.replies = [reply_1, reply_2] + comment.mentions = [mention_1, mention_2] + + account_1 = Mock(id="user-1") + account_2 = Mock(id="user-2") + account_3 = Mock(id="user-3") + account_map = { + "user-1": account_1, + "user-2": account_2, + "user-3": account_3, + "user-4": None, + } + + with patch("models.comment.db.session.get", side_effect=lambda _model, user_id: account_map[user_id]) as get_mock: + participants = comment.participants + + assert comment.reply_count == 2 + assert comment.mention_count == 2 + assert set(participants) == {account_1, account_2, account_3} + assert get_mock.call_count == 4 + + +def test_workflow_comment_participants_use_cached_accounts() -> None: + reply = WorkflowCommentReply(comment_id="comment-1", content="reply-1", created_by="user-2") + mention = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-3") + comment = WorkflowComment(created_by="user-1", resolved_by=None, content="hello", position_x=1, position_y=2) + comment.replies = [reply] + comment.mentions = [mention] + + account_1 = Mock(id="user-1") + account_2 = Mock(id="user-2") + account_3 = Mock(id="user-3") + comment.cache_created_by_account(account_1) + reply.cache_created_by_account(account_2) + mention.cache_mentioned_user_account(account_3) + + with patch("models.comment.db.session.get") as get_mock: + participants = comment.participants + + assert set(participants) == {account_1, account_2, account_3} + get_mock.assert_not_called() + + +def test_reply_and_mention_account_properties_and_cache() -> None: + reply = WorkflowCommentReply(comment_id="comment-1", content="reply", created_by="user-1") + mention = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-2") + reply_account = Mock(id="user-1") + mention_account = Mock(id="user-2") + + with patch("models.comment.db.session.get", side_effect=[reply_account, mention_account]) as get_mock: + assert reply.created_by_account is reply_account + assert mention.mentioned_user_account is mention_account + assert get_mock.call_count == 2 + + reply.cache_created_by_account(reply_account) + mention.cache_mentioned_user_account(mention_account) + with patch("models.comment.db.session.get") as get_mock: + assert reply.created_by_account is reply_account + assert mention.mentioned_user_account is mention_account + get_mock.assert_not_called() diff --git a/api/tests/unit_tests/repositories/test_workflow_collaboration_repository.py b/api/tests/unit_tests/repositories/test_workflow_collaboration_repository.py new file mode 100644 index 0000000000..1f47e8b692 --- /dev/null +++ b/api/tests/unit_tests/repositories/test_workflow_collaboration_repository.py @@ -0,0 +1,121 @@ +import json +from unittest.mock import Mock + +import pytest + +from repositories import workflow_collaboration_repository as repo_module +from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository + + +class TestWorkflowCollaborationRepository: + @pytest.fixture + def mock_redis(self, monkeypatch: pytest.MonkeyPatch) -> Mock: + mock_redis = Mock() + monkeypatch.setattr(repo_module, "redis_client", mock_redis) + return mock_redis + + def test_get_sid_mapping_returns_mapping(self, mock_redis: Mock) -> None: + # Arrange + mock_redis.get.return_value = b'{"workflow_id":"wf-1","user_id":"u-1"}' + repository = WorkflowCollaborationRepository() + + # Act + result = repository.get_sid_mapping("sid-1") + + # Assert + assert result == {"workflow_id": "wf-1", "user_id": "u-1"} + + def test_list_sessions_filters_invalid_entries(self, mock_redis: Mock) -> None: + # Arrange + mock_redis.hgetall.return_value = { + b"sid-1": b'{"user_id":"u-1","username":"Jane","sid":"sid-1","connected_at":2}', + b"sid-2": b'{"username":"Missing","sid":"sid-2"}', + b"sid-3": b"not-json", + } + repository = WorkflowCollaborationRepository() + + # Act + result = repository.list_sessions("wf-1") + + # Assert + assert result == [ + { + "user_id": "u-1", + "username": "Jane", + "avatar": None, + "sid": "sid-1", + "connected_at": 2, + } + ] + + def test_set_session_info_persists_payload(self, mock_redis: Mock) -> None: + # Arrange + mock_redis.exists.return_value = True + repository = WorkflowCollaborationRepository() + payload = { + "user_id": "u-1", + "username": "Jane", + "avatar": None, + "sid": "sid-1", + "connected_at": 1, + } + + # Act + repository.set_session_info("wf-1", payload) + + # Assert + assert mock_redis.hset.called + workflow_key, sid, session_json = mock_redis.hset.call_args.args + assert workflow_key == "workflow_online_users:wf-1" + assert sid == "sid-1" + assert json.loads(session_json)["user_id"] == "u-1" + assert mock_redis.set.called + + def test_refresh_session_state_expires_keys(self, mock_redis: Mock) -> None: + # Arrange + mock_redis.exists.return_value = True + repository = WorkflowCollaborationRepository() + + # Act + repository.refresh_session_state("wf-1", "sid-1") + + # Assert + assert mock_redis.expire.call_count == 2 + + def test_get_current_leader_decodes_bytes(self, mock_redis: Mock) -> None: + # Arrange + mock_redis.get.return_value = b"sid-1" + repository = WorkflowCollaborationRepository() + + # Act + result = repository.get_current_leader("wf-1") + + # Assert + assert result == "sid-1" + + def test_set_leader_if_absent_uses_nx(self, mock_redis: Mock) -> None: + # Arrange + mock_redis.set.return_value = True + repository = WorkflowCollaborationRepository() + + # Act + result = repository.set_leader_if_absent("wf-1", "sid-1") + + # Assert + assert result is True + _key, _value = mock_redis.set.call_args.args + assert _key == "workflow_leader:wf-1" + assert _value == "sid-1" + assert mock_redis.set.call_args.kwargs["nx"] is True + assert "ex" in mock_redis.set.call_args.kwargs + + def test_get_session_sids_decodes(self, mock_redis: Mock) -> None: + # Arrange + mock_redis.hkeys.return_value = [b"sid-1", "sid-2"] + repository = WorkflowCollaborationRepository() + + # Act + result = repository.get_session_sids("wf-1") + + # Assert + assert result == ["sid-1", "sid-2"] diff --git a/api/tests/unit_tests/services/test_workflow_collaboration_service.py b/api/tests/unit_tests/services/test_workflow_collaboration_service.py new file mode 100644 index 0000000000..8a6addfece --- /dev/null +++ b/api/tests/unit_tests/services/test_workflow_collaboration_service.py @@ -0,0 +1,608 @@ +from unittest.mock import Mock, patch + +import pytest + +from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository +from services.workflow_collaboration_service import WorkflowCollaborationService + + +class TestWorkflowCollaborationService: + @pytest.fixture + def service(self) -> tuple[WorkflowCollaborationService, Mock, Mock]: + repository = Mock(spec=WorkflowCollaborationRepository) + socketio = Mock() + return WorkflowCollaborationService(repository, socketio), repository, socketio + + def test_authorize_and_join_workflow_room_returns_leader_status( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + # Arrange + collaboration_service, repository, socketio = service + socketio.get_session.return_value = { + "user_id": "u-1", + "username": "Jane", + "avatar": None, + "tenant_id": "t-1", + } + + with ( + patch.object(collaboration_service, "_can_access_workflow", return_value=True), + patch.object(collaboration_service, "get_or_set_leader", return_value="sid-1"), + patch.object(collaboration_service, "broadcast_online_users"), + ): + # Act + result = collaboration_service.authorize_and_join_workflow_room("wf-1", "sid-1") + + # Assert + assert result == ("u-1", True) + repository.set_session_info.assert_called_once() + socketio.enter_room.assert_called_once_with("sid-1", "wf-1") + socketio.emit.assert_called_once_with("status", {"isLeader": True}, room="sid-1") + + def test_authorize_and_join_workflow_room_returns_none_when_missing_user( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + # Arrange + collaboration_service, _repository, socketio = service + socketio.get_session.return_value = {} + + # Act + result = collaboration_service.authorize_and_join_workflow_room("wf-1", "sid-1") + + # Assert + assert result is None + + def test_authorize_and_join_workflow_room_returns_none_when_missing_tenant( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + collaboration_service, repository, socketio = service + socketio.get_session.return_value = {"user_id": "u-1", "username": "Jane", "avatar": None} + + result = collaboration_service.authorize_and_join_workflow_room("wf-1", "sid-1") + + assert result is None + repository.set_session_info.assert_not_called() + socketio.enter_room.assert_not_called() + socketio.emit.assert_not_called() + + def test_authorize_and_join_workflow_room_returns_none_when_workflow_is_not_accessible( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + collaboration_service, repository, socketio = service + socketio.get_session.return_value = { + "user_id": "u-1", + "username": "Jane", + "avatar": None, + "tenant_id": "t-1", + } + + with patch.object(collaboration_service, "_can_access_workflow", return_value=False): + result = collaboration_service.authorize_and_join_workflow_room("wf-1", "sid-1") + + assert result is None + repository.set_session_info.assert_not_called() + socketio.enter_room.assert_not_called() + socketio.emit.assert_not_called() + + def test_repr_and_save_socket_identity(self, service: tuple[WorkflowCollaborationService, Mock, Mock]) -> None: + collaboration_service, _repository, socketio = service + user = Mock() + user.id = "u-1" + user.name = "Jane" + user.avatar = "avatar.png" + user.current_tenant_id = "t-1" + + assert "WorkflowCollaborationService" in repr(collaboration_service) + + collaboration_service.save_socket_identity("sid-1", user) + + socketio.save_session.assert_called_once_with( + "sid-1", + {"user_id": "u-1", "username": "Jane", "avatar": "avatar.png", "tenant_id": "t-1"}, + ) + + def test_can_access_workflow_uses_session_factory( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + collaboration_service, _repository, _socketio = service + session = Mock() + session.scalar.return_value = "wf-1" + session_context = Mock() + session_context.__enter__ = Mock(return_value=session) + session_context.__exit__ = Mock(return_value=False) + + with patch( + "services.workflow_collaboration_service.session_factory.create_session", + return_value=session_context, + ): + result = collaboration_service._can_access_workflow("wf-1", "tenant-1") + + assert result is True + session.scalar.assert_called_once() + + def test_relay_collaboration_event_unauthorized( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + # Arrange + collaboration_service, repository, _socketio = service + repository.get_sid_mapping.return_value = None + + # Act + result = collaboration_service.relay_collaboration_event("sid-1", {}) + + # Assert + assert result == ({"msg": "unauthorized"}, 401) + + def test_relay_collaboration_event_emits_update( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + # Arrange + collaboration_service, repository, socketio = service + repository.get_sid_mapping.return_value = {"workflow_id": "wf-1", "user_id": "u-1"} + payload = {"type": "mouse_move", "data": {"x": 1}, "timestamp": 123} + + # Act + result = collaboration_service.relay_collaboration_event("sid-1", payload) + + # Assert + assert result == ({"msg": "event_broadcasted"}, 200) + socketio.emit.assert_called_once_with( + "collaboration_update", + {"type": "mouse_move", "userId": "u-1", "data": {"x": 1}, "timestamp": 123}, + room="wf-1", + skip_sid="sid-1", + ) + + def test_relay_collaboration_event_requires_event_type( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + collaboration_service, repository, _socketio = service + repository.get_sid_mapping.return_value = {"workflow_id": "wf-1", "user_id": "u-1"} + + result = collaboration_service.relay_collaboration_event("sid-1", {"data": {"x": 1}}) + + assert result == ({"msg": "invalid event type"}, 400) + + def test_relay_collaboration_event_sync_request_forwards_to_active_leader( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + collaboration_service, repository, socketio = service + repository.get_sid_mapping.return_value = {"workflow_id": "wf-1", "user_id": "u-1"} + repository.get_current_leader.return_value = "sid-leader" + payload = {"type": "sync_request", "data": {"reason": "join"}, "timestamp": 123} + + with ( + patch.object(collaboration_service, "refresh_session_state"), + patch.object(collaboration_service, "is_session_active", return_value=True), + ): + result = collaboration_service.relay_collaboration_event("sid-1", payload) + + assert result == ({"msg": "sync_request_forwarded"}, 200) + socketio.emit.assert_called_once_with( + "collaboration_update", + {"type": "sync_request", "userId": "u-1", "data": {"reason": "join"}, "timestamp": 123}, + room="sid-leader", + ) + repository.set_leader.assert_not_called() + + def test_relay_collaboration_event_sync_request_reelects_active_leader( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + collaboration_service, repository, socketio = service + repository.get_sid_mapping.return_value = {"workflow_id": "wf-1", "user_id": "u-1"} + repository.get_current_leader.return_value = "sid-old" + repository.list_sessions.return_value = [ + { + "user_id": "u-2", + "username": "B", + "avatar": None, + "sid": "sid-2", + "connected_at": 1, + "graph_active": True, + }, + { + "user_id": "u-3", + "username": "C", + "avatar": None, + "sid": "sid-3", + "connected_at": 2, + "graph_active": True, + }, + ] + payload = {"type": "sync_request", "data": {"reason": "join"}, "timestamp": 123} + + def _is_session_active(_workflow_id: str, session_sid: str) -> bool: + return session_sid != "sid-old" + + with ( + patch.object(collaboration_service, "refresh_session_state"), + patch.object(collaboration_service, "broadcast_leader_change") as broadcast_leader_change, + patch.object(collaboration_service, "is_session_active", side_effect=_is_session_active), + ): + result = collaboration_service.relay_collaboration_event("sid-2", payload) + + assert result == ({"msg": "sync_request_forwarded"}, 200) + repository.delete_leader.assert_called_once_with("wf-1") + repository.set_leader.assert_called_once_with("wf-1", "sid-2") + broadcast_leader_change.assert_called_once_with("wf-1", "sid-2") + socketio.emit.assert_called_once_with( + "collaboration_update", + {"type": "sync_request", "userId": "u-1", "data": {"reason": "join"}, "timestamp": 123}, + room="sid-2", + ) + + def test_relay_collaboration_event_sync_request_returns_when_no_active_leader( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + collaboration_service, repository, socketio = service + repository.get_sid_mapping.return_value = {"workflow_id": "wf-1", "user_id": "u-1"} + repository.get_current_leader.return_value = "sid-old" + repository.list_sessions.return_value = [] + payload = {"type": "sync_request", "data": {"reason": "join"}, "timestamp": 123} + + with ( + patch.object(collaboration_service, "refresh_session_state"), + patch.object(collaboration_service, "is_session_active", return_value=False), + ): + result = collaboration_service.relay_collaboration_event("sid-2", payload) + + assert result == ({"msg": "no_active_leader"}, 200) + repository.delete_leader.assert_called_once_with("wf-1") + socketio.emit.assert_not_called() + + def test_relay_graph_event_unauthorized(self, service: tuple[WorkflowCollaborationService, Mock, Mock]) -> None: + # Arrange + collaboration_service, repository, _socketio = service + repository.get_sid_mapping.return_value = None + + # Act + result = collaboration_service.relay_graph_event("sid-1", {"nodes": []}) + + # Assert + assert result == ({"msg": "unauthorized"}, 401) + + def test_disconnect_session_no_mapping(self, service: tuple[WorkflowCollaborationService, Mock, Mock]) -> None: + # Arrange + collaboration_service, repository, _socketio = service + repository.get_sid_mapping.return_value = None + + # Act + collaboration_service.disconnect_session("sid-1") + + # Assert + repository.delete_session.assert_not_called() + + def test_disconnect_session_cleans_up(self, service: tuple[WorkflowCollaborationService, Mock, Mock]) -> None: + # Arrange + collaboration_service, repository, _socketio = service + repository.get_sid_mapping.return_value = {"workflow_id": "wf-1", "user_id": "u-1"} + + with ( + patch.object(collaboration_service, "handle_leader_disconnect") as handle_leader_disconnect, + patch.object(collaboration_service, "broadcast_online_users") as broadcast_online_users, + ): + # Act + collaboration_service.disconnect_session("sid-1") + + # Assert + repository.delete_session.assert_called_once_with("wf-1", "sid-1") + handle_leader_disconnect.assert_called_once_with("wf-1", "sid-1") + broadcast_online_users.assert_called_once_with("wf-1") + + def test_get_or_set_leader_returns_active_leader( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + # Arrange + collaboration_service, repository, _socketio = service + repository.get_current_leader.return_value = "sid-1" + + with patch.object(collaboration_service, "is_session_active", return_value=True): + # Act + result = collaboration_service.get_or_set_leader("wf-1", "sid-2") + + # Assert + assert result == "sid-1" + repository.set_leader_if_absent.assert_not_called() + + def test_get_or_set_leader_replaces_dead_leader( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + # Arrange + collaboration_service, repository, _socketio = service + repository.get_current_leader.return_value = "sid-1" + repository.set_leader_if_absent.return_value = True + repository.list_sessions.return_value = [ + { + "user_id": "u-2", + "username": "B", + "avatar": None, + "sid": "sid-2", + "connected_at": 1, + "graph_active": True, + } + ] + + with ( + patch.object(collaboration_service, "is_session_active", side_effect=lambda _wf, sid: sid != "sid-1"), + patch.object(collaboration_service, "broadcast_leader_change") as broadcast_leader_change, + ): + # Act + result = collaboration_service.get_or_set_leader("wf-1", "sid-2") + + # Assert + assert result == "sid-2" + repository.delete_session.assert_called_once_with("wf-1", "sid-1") + repository.delete_leader.assert_called_once_with("wf-1") + broadcast_leader_change.assert_called_once_with("wf-1", "sid-2") + + def test_get_or_set_leader_falls_back_to_existing( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + # Arrange + collaboration_service, repository, _socketio = service + repository.get_current_leader.side_effect = [None, "sid-3"] + repository.set_leader_if_absent.return_value = False + repository.list_sessions.return_value = [ + { + "user_id": "u-2", + "username": "B", + "avatar": None, + "sid": "sid-2", + "connected_at": 1, + "graph_active": True, + } + ] + + # Act + result = collaboration_service.get_or_set_leader("wf-1", "sid-2") + + # Assert + assert result == "sid-3" + + def test_get_or_set_leader_returns_sid_when_leader_still_missing( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + collaboration_service, repository, _socketio = service + repository.get_current_leader.side_effect = [None, None] + repository.set_leader_if_absent.return_value = False + + result = collaboration_service.get_or_set_leader("wf-1", "sid-2") + + assert result == "sid-2" + + def test_handle_leader_disconnect_elects_new( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + # Arrange + collaboration_service, repository, _socketio = service + repository.get_current_leader.return_value = "sid-1" + repository.list_sessions.return_value = [ + { + "user_id": "u-2", + "username": "B", + "avatar": None, + "sid": "sid-2", + "connected_at": 1, + "graph_active": True, + } + ] + + with ( + patch.object(collaboration_service, "is_session_active", return_value=True), + patch.object(collaboration_service, "broadcast_leader_change") as broadcast_leader_change, + ): + # Act + collaboration_service.handle_leader_disconnect("wf-1", "sid-1") + + # Assert + repository.set_leader.assert_called_once_with("wf-1", "sid-2") + broadcast_leader_change.assert_called_once_with("wf-1", "sid-2") + + def test_handle_leader_disconnect_clears_when_empty( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + # Arrange + collaboration_service, repository, _socketio = service + repository.get_current_leader.return_value = "sid-1" + repository.list_sessions.return_value = [] + + # Act + collaboration_service.handle_leader_disconnect("wf-1", "sid-1") + + # Assert + repository.delete_leader.assert_called_once_with("wf-1") + + def test_handle_leader_disconnect_ignores_non_leader_or_missing_leader( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + collaboration_service, repository, _socketio = service + + repository.get_current_leader.return_value = None + collaboration_service.handle_leader_disconnect("wf-1", "sid-1") + + repository.get_current_leader.return_value = "sid-leader" + collaboration_service.handle_leader_disconnect("wf-1", "sid-other") + + repository.set_leader.assert_not_called() + repository.delete_leader.assert_not_called() + + def test_broadcast_leader_change_logs_emit_errors( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + collaboration_service, repository, socketio = service + repository.get_session_sids.return_value = ["sid-1", "sid-2"] + socketio.emit.side_effect = [RuntimeError("boom"), None] + + with patch("services.workflow_collaboration_service.logging.exception") as exception_mock: + collaboration_service.broadcast_leader_change("wf-1", "sid-2") + + assert exception_mock.call_count == 1 + + def test_broadcast_online_users_sorts_and_emits( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + # Arrange + collaboration_service, repository, socketio = service + repository.list_sessions.return_value = [ + {"user_id": "u-1", "username": "A", "avatar": None, "sid": "sid-1", "connected_at": 3}, + {"user_id": "u-2", "username": "B", "avatar": None, "sid": "sid-2", "connected_at": 1}, + ] + repository.get_current_leader.return_value = "sid-1" + + with patch.object(collaboration_service, "is_session_active", return_value=True): + # Act + collaboration_service.broadcast_online_users("wf-1") + + # Assert + socketio.emit.assert_called_once_with( + "online_users", + { + "workflow_id": "wf-1", + "users": [ + {"user_id": "u-2", "username": "B", "avatar": None, "sid": "sid-2", "connected_at": 1}, + {"user_id": "u-1", "username": "A", "avatar": None, "sid": "sid-1", "connected_at": 3}, + ], + "leader": "sid-1", + }, + room="wf-1", + ) + + def test_broadcast_online_users_reassigns_missing_leader( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + collaboration_service, repository, socketio = service + users = [{"user_id": "u-2", "username": "B", "avatar": None, "sid": "sid-2", "connected_at": 1}] + repository.get_current_leader.return_value = "sid-old" + + with ( + patch.object(collaboration_service, "_prune_inactive_sessions", return_value=users), + patch.object(collaboration_service, "_select_graph_leader", return_value="sid-2"), + patch.object(collaboration_service, "broadcast_leader_change") as broadcast_leader_change, + ): + collaboration_service.broadcast_online_users("wf-1") + + repository.delete_leader.assert_called_once_with("wf-1") + repository.set_leader.assert_called_once_with("wf-1", "sid-2") + broadcast_leader_change.assert_called_once_with("wf-1", "sid-2") + socketio.emit.assert_called_once_with( + "online_users", + {"workflow_id": "wf-1", "users": users, "leader": "sid-2"}, + room="wf-1", + ) + + def test_refresh_session_state_expires_active_leader( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + # Arrange + collaboration_service, repository, _socketio = service + repository.get_current_leader.return_value = "sid-1" + + with patch.object(collaboration_service, "is_session_active", return_value=True): + # Act + collaboration_service.refresh_session_state("wf-1", "sid-1") + + # Assert + repository.refresh_session_state.assert_called_once_with("wf-1", "sid-1") + repository.expire_leader.assert_called_once_with("wf-1") + repository.set_leader.assert_not_called() + + def test_refresh_session_state_sets_leader_when_missing( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + # Arrange + collaboration_service, repository, _socketio = service + repository.get_current_leader.return_value = None + repository.list_sessions.return_value = [ + { + "user_id": "u-2", + "username": "B", + "avatar": None, + "sid": "sid-2", + "connected_at": 1, + "graph_active": True, + } + ] + + with ( + patch.object(collaboration_service, "is_session_active", return_value=True), + patch.object(collaboration_service, "broadcast_leader_change") as broadcast_leader_change, + ): + # Act + collaboration_service.refresh_session_state("wf-1", "sid-2") + + # Assert + repository.set_leader.assert_called_once_with("wf-1", "sid-2") + broadcast_leader_change.assert_called_once_with("wf-1", "sid-2") + + def test_refresh_session_state_replaces_inactive_existing_leader( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + collaboration_service, repository, _socketio = service + repository.get_current_leader.return_value = "sid-old" + + with ( + patch.object(collaboration_service, "is_session_active", return_value=False), + patch.object(collaboration_service, "broadcast_leader_change") as broadcast_leader_change, + ): + collaboration_service.refresh_session_state("wf-1", "sid-new") + + repository.delete_leader.assert_called_once_with("wf-1") + repository.set_leader.assert_called_once_with("wf-1", "sid-new") + broadcast_leader_change.assert_called_once_with("wf-1", "sid-new") + + def test_relay_graph_event_emits_update(self, service: tuple[WorkflowCollaborationService, Mock, Mock]) -> None: + # Arrange + collaboration_service, repository, socketio = service + repository.get_sid_mapping.return_value = {"workflow_id": "wf-1", "user_id": "u-1"} + + # Act + result = collaboration_service.relay_graph_event("sid-1", {"nodes": []}) + + # Assert + assert result == ({"msg": "graph_update_broadcasted"}, 200) + repository.refresh_session_state.assert_called_once_with("wf-1", "sid-1") + socketio.emit.assert_called_once_with("graph_update", {"nodes": []}, room="wf-1", skip_sid="sid-1") + + def test_prune_inactive_sessions_handles_empty_and_removes_stale( + self, service: tuple[WorkflowCollaborationService, Mock, Mock] + ) -> None: + collaboration_service, repository, _socketio = service + repository.list_sessions.return_value = [] + assert collaboration_service._prune_inactive_sessions("wf-1") == [] + + active = {"sid": "sid-1", "user_id": "u-1", "connected_at": 1} + stale = {"sid": "sid-2", "user_id": "u-2", "connected_at": 2} + repository.list_sessions.return_value = [active, stale] + + with patch.object( + collaboration_service, + "is_session_active", + side_effect=lambda _workflow_id, sid: sid == "sid-1", + ): + users = collaboration_service._prune_inactive_sessions("wf-1") + + assert users == [active] + repository.delete_session.assert_called_with("wf-1", "sid-2") + + def test_is_session_active_guard_branches(self, service: tuple[WorkflowCollaborationService, Mock, Mock]) -> None: + collaboration_service, repository, socketio = service + socketio.manager.is_connected.return_value = True + repository.session_exists.return_value = True + repository.sid_mapping_exists.return_value = True + + assert collaboration_service.is_session_active("wf-1", "") is False + + socketio.manager.is_connected.return_value = False + assert collaboration_service.is_session_active("wf-1", "sid-1") is False + + socketio.manager.is_connected.side_effect = AttributeError("missing manager") + assert collaboration_service.is_session_active("wf-1", "sid-1") is False + socketio.manager.is_connected.side_effect = None + + socketio.manager.is_connected.return_value = True + repository.session_exists.return_value = False + assert collaboration_service.is_session_active("wf-1", "sid-1") is False + + repository.session_exists.return_value = True + repository.sid_mapping_exists.return_value = False + assert collaboration_service.is_session_active("wf-1", "sid-1") is False diff --git a/api/tests/unit_tests/services/test_workflow_comment_service.py b/api/tests/unit_tests/services/test_workflow_comment_service.py new file mode 100644 index 0000000000..e6db068e07 --- /dev/null +++ b/api/tests/unit_tests/services/test_workflow_comment_service.py @@ -0,0 +1,578 @@ +from unittest.mock import MagicMock, Mock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +from services import workflow_comment_service as service_module +from services.workflow_comment_service import WorkflowCommentService + + +@pytest.fixture +def mock_session(monkeypatch: pytest.MonkeyPatch) -> Mock: + session = Mock() + context_manager = MagicMock() + context_manager.__enter__.return_value = session + context_manager.__exit__.return_value = False + mock_db = MagicMock() + mock_db.engine = Mock() + empty_scalars = Mock() + empty_scalars.all.return_value = [] + session.scalars.return_value = empty_scalars + monkeypatch.setattr(service_module, "Session", Mock(return_value=context_manager)) + monkeypatch.setattr(service_module, "db", mock_db) + monkeypatch.setattr(service_module.send_workflow_comment_mention_email_task, "delay", Mock()) + return session + + +def _mock_scalars(result_list: list[object]) -> Mock: + scalars = Mock() + scalars.all.return_value = result_list + return scalars + + +class TestWorkflowCommentService: + def test_validate_content_rejects_empty(self) -> None: + with pytest.raises(ValueError): + WorkflowCommentService._validate_content(" ") + + def test_validate_content_rejects_too_long(self) -> None: + with pytest.raises(ValueError): + WorkflowCommentService._validate_content("a" * 1001) + + def test_filter_valid_mentioned_user_ids_filters_by_tenant_and_preserves_order(self, mock_session: Mock) -> None: + tenant_member_1 = "123e4567-e89b-12d3-a456-426614174000" + tenant_member_2 = "123e4567-e89b-12d3-a456-426614174002" + non_tenant_member = "123e4567-e89b-12d3-a456-426614174001" + mock_session.scalars.return_value = _mock_scalars([tenant_member_1, tenant_member_2]) + + result = WorkflowCommentService._filter_valid_mentioned_user_ids( + [ + tenant_member_1, + "", + 123, # type: ignore[list-item] + tenant_member_1, + non_tenant_member, + tenant_member_2, + ], + session=mock_session, + tenant_id="tenant-1", + ) + + assert result == [ + tenant_member_1, + tenant_member_2, + ] + + def test_format_comment_excerpt_handles_short_and_long_limits(self) -> None: + assert WorkflowCommentService._format_comment_excerpt(" hello ", max_length=10) == "hello" + assert WorkflowCommentService._format_comment_excerpt("abcdefghijk", max_length=3) == "abc" + assert WorkflowCommentService._format_comment_excerpt(" abcdefghijk ", max_length=8) == "abcde..." + + def test_build_mention_email_payloads_returns_empty_for_no_candidates(self, mock_session: Mock) -> None: + assert ( + WorkflowCommentService._build_mention_email_payloads( + session=mock_session, + tenant_id="tenant-1", + app_id="app-1", + mentioner_id="user-1", + mentioned_user_ids=[], + content="hello", + ) + == [] + ) + assert ( + WorkflowCommentService._build_mention_email_payloads( + session=mock_session, + tenant_id="tenant-1", + app_id="app-1", + mentioner_id="user-1", + mentioned_user_ids=["user-1"], + content="hello", + ) + == [] + ) + + def test_dispatch_mention_emails_enqueues_each_payload(self) -> None: + delay_mock = Mock() + with patch.object(service_module.send_workflow_comment_mention_email_task, "delay", delay_mock): + WorkflowCommentService._dispatch_mention_emails( + [ + {"to": "a@example.com"}, + {"to": "b@example.com"}, + ] + ) + + assert delay_mock.call_count == 2 + + def test_build_mention_email_payloads_skips_accounts_without_email(self, mock_session: Mock) -> None: + account_without_email = Mock() + account_without_email.email = None + account_without_email.name = "No Email" + account_without_email.interface_language = "en-US" + + account_with_email = Mock() + account_with_email.email = "user@example.com" + account_with_email.name = "" + account_with_email.interface_language = None + + mock_session.scalar.side_effect = ["My App", "Commenter"] + mock_session.scalars.return_value = _mock_scalars([account_without_email, account_with_email]) + + payloads = WorkflowCommentService._build_mention_email_payloads( + session=mock_session, + tenant_id="tenant-1", + app_id="app-1", + mentioner_id="user-1", + mentioned_user_ids=["user-2"], + content="hello", + ) + expected_app_url = f"{service_module.dify_config.CONSOLE_WEB_URL.rstrip('/')}/app/app-1/workflow" + + assert payloads == [ + { + "language": "en-US", + "to": "user@example.com", + "mentioned_name": "user@example.com", + "commenter_name": "Commenter", + "app_name": "My App", + "comment_content": "hello", + "app_url": expected_app_url, + } + ] + + def test_create_comment_creates_mentions(self, mock_session: Mock) -> None: + comment = Mock() + comment.id = "comment-1" + comment.created_at = "ts" + + with ( + patch.object(service_module, "WorkflowComment", return_value=comment), + patch.object(service_module, "WorkflowCommentMention", return_value=Mock()), + patch.object(WorkflowCommentService, "_filter_valid_mentioned_user_ids", return_value=["user-2"]), + ): + result = WorkflowCommentService.create_comment( + tenant_id="tenant-1", + app_id="app-1", + created_by="user-1", + content="hello", + position_x=1.0, + position_y=2.0, + mentioned_user_ids=["user-2", "bad-id"], + ) + + assert result == {"id": "comment-1", "created_at": "ts"} + assert mock_session.add.call_args_list[0].args[0] is comment + assert mock_session.add.call_count == 2 + mock_session.commit.assert_called_once() + + def test_update_comment_raises_not_found(self, mock_session: Mock) -> None: + mock_session.scalar.return_value = None + + with pytest.raises(NotFound): + WorkflowCommentService.update_comment( + tenant_id="tenant-1", + app_id="app-1", + comment_id="comment-1", + user_id="user-1", + content="hello", + ) + + def test_update_comment_raises_forbidden(self, mock_session: Mock) -> None: + comment = Mock() + comment.created_by = "owner" + mock_session.scalar.return_value = comment + + with pytest.raises(Forbidden): + WorkflowCommentService.update_comment( + tenant_id="tenant-1", + app_id="app-1", + comment_id="comment-1", + user_id="intruder", + content="hello", + ) + + def test_update_comment_replaces_mentions(self, mock_session: Mock) -> None: + comment = Mock() + comment.id = "comment-1" + comment.created_by = "owner" + mock_session.scalar.return_value = comment + + existing_mentions = [Mock(), Mock()] + mock_session.scalars.return_value = _mock_scalars(existing_mentions) + + with patch.object(WorkflowCommentService, "_filter_valid_mentioned_user_ids", return_value=["user-2"]): + result = WorkflowCommentService.update_comment( + tenant_id="tenant-1", + app_id="app-1", + comment_id="comment-1", + user_id="owner", + content="updated", + mentioned_user_ids=["user-2", "bad-id"], + ) + + assert result == {"id": "comment-1", "updated_at": comment.updated_at} + assert mock_session.delete.call_count == 2 + assert mock_session.add.call_count == 1 + mock_session.commit.assert_called_once() + + def test_update_comment_preserves_mentions_when_mentioned_user_ids_omitted(self, mock_session: Mock) -> None: + comment = Mock() + comment.id = "comment-1" + comment.created_by = "owner" + mock_session.scalar.return_value = comment + + with ( + patch.object(WorkflowCommentService, "_filter_valid_mentioned_user_ids") as filter_mentions_mock, + patch.object(WorkflowCommentService, "_build_mention_email_payloads") as build_payloads_mock, + patch.object(WorkflowCommentService, "_dispatch_mention_emails") as dispatch_mock, + ): + result = WorkflowCommentService.update_comment( + tenant_id="tenant-1", + app_id="app-1", + comment_id="comment-1", + user_id="owner", + content="updated", + ) + + assert result == {"id": "comment-1", "updated_at": comment.updated_at} + mock_session.delete.assert_not_called() + mock_session.add.assert_not_called() + filter_mentions_mock.assert_not_called() + build_payloads_mock.assert_not_called() + dispatch_mock.assert_called_once_with([]) + mock_session.commit.assert_called_once() + + def test_update_comment_clears_mentions_when_empty_list_provided(self, mock_session: Mock) -> None: + comment = Mock() + comment.id = "comment-1" + comment.created_by = "owner" + mock_session.scalar.return_value = comment + + existing_mentions = [Mock(), Mock()] + mock_session.scalars.return_value = _mock_scalars(existing_mentions) + + with patch.object(WorkflowCommentService, "_filter_valid_mentioned_user_ids", return_value=[]): + result = WorkflowCommentService.update_comment( + tenant_id="tenant-1", + app_id="app-1", + comment_id="comment-1", + user_id="owner", + content="updated", + mentioned_user_ids=[], + ) + + assert result == {"id": "comment-1", "updated_at": comment.updated_at} + assert mock_session.delete.call_count == 2 + mock_session.add.assert_not_called() + mock_session.commit.assert_called_once() + + def test_update_comment_notifies_only_new_mentions(self, mock_session: Mock) -> None: + comment = Mock() + comment.id = "comment-1" + comment.created_by = "owner" + mock_session.scalar.return_value = comment + + existing_mention = Mock() + existing_mention.mentioned_user_id = "user-2" + mock_session.scalars.return_value = _mock_scalars([existing_mention]) + + with ( + patch.object( + WorkflowCommentService, + "_filter_valid_mentioned_user_ids", + return_value=["user-2", "user-3"], + ), + patch.object( + WorkflowCommentService, + "_build_mention_email_payloads", + return_value=[], + ) as build_payloads_mock, + patch.object(WorkflowCommentService, "_dispatch_mention_emails") as dispatch_mock, + ): + WorkflowCommentService.update_comment( + tenant_id="tenant-1", + app_id="app-1", + comment_id="comment-1", + user_id="owner", + content="updated", + mentioned_user_ids=["user-2", "user-3"], + ) + + assert build_payloads_mock.call_args.kwargs["mentioned_user_ids"] == ["user-3"] + dispatch_mock.assert_called_once_with([]) + + def test_get_comments_preloads_related_accounts(self, mock_session: Mock) -> None: + comment = Mock() + comment.created_by = "user-1" + comment.resolved_by = "user-2" + reply = Mock() + reply.created_by = "user-3" + mention = Mock() + mention.mentioned_user_id = "user-4" + comment.replies = [reply] + comment.mentions = [mention] + comment.cache_created_by_account = Mock() + comment.cache_resolved_by_account = Mock() + reply.cache_created_by_account = Mock() + mention.cache_mentioned_user_account = Mock() + + account_1 = Mock() + account_1.id = "user-1" + account_2 = Mock() + account_2.id = "user-2" + account_3 = Mock() + account_3.id = "user-3" + account_4 = Mock() + account_4.id = "user-4" + + mock_session.scalars.side_effect = [ + _mock_scalars([comment]), + _mock_scalars([account_1, account_2, account_3, account_4]), + ] + + result = WorkflowCommentService.get_comments("tenant-1", "app-1") + + assert result == [comment] + comment.cache_created_by_account.assert_called_once_with(account_1) + comment.cache_resolved_by_account.assert_called_once_with(account_2) + reply.cache_created_by_account.assert_called_once_with(account_3) + mention.cache_mentioned_user_account.assert_called_once_with(account_4) + + def test_preload_accounts_returns_early_for_empty_comments(self, mock_session: Mock) -> None: + WorkflowCommentService._preload_accounts(mock_session, []) + + mock_session.scalars.assert_not_called() + + def test_get_comment_raises_not_found_with_provided_session(self) -> None: + session = Mock() + session.scalar.return_value = None + + with pytest.raises(NotFound): + WorkflowCommentService.get_comment("tenant-1", "app-1", "comment-1", session=session) + + def test_get_comment_uses_context_manager_when_session_not_provided(self, mock_session: Mock) -> None: + comment = Mock() + comment.created_by = "user-1" + comment.resolved_by = None + comment.replies = [] + comment.mentions = [] + comment.cache_created_by_account = Mock() + comment.cache_resolved_by_account = Mock() + mock_session.scalar.return_value = comment + mock_session.scalars.return_value = _mock_scalars([]) + + result = WorkflowCommentService.get_comment("tenant-1", "app-1", "comment-1") + + assert result is comment + comment.cache_created_by_account.assert_called_once() + comment.cache_resolved_by_account.assert_called_once_with(None) + + def test_delete_comment_raises_forbidden(self, mock_session: Mock) -> None: + comment = Mock() + comment.created_by = "owner" + + with patch.object(WorkflowCommentService, "get_comment", return_value=comment): + with pytest.raises(Forbidden): + WorkflowCommentService.delete_comment("tenant-1", "app-1", "comment-1", "intruder") + + def test_delete_comment_removes_related_entities(self, mock_session: Mock) -> None: + comment = Mock() + comment.created_by = "owner" + + mentions = [Mock(), Mock()] + replies = [Mock()] + mock_session.scalars.side_effect = [_mock_scalars(mentions), _mock_scalars(replies)] + + with patch.object(WorkflowCommentService, "get_comment", return_value=comment): + WorkflowCommentService.delete_comment("tenant-1", "app-1", "comment-1", "owner") + + assert mock_session.delete.call_count == 4 + mock_session.commit.assert_called_once() + + def test_resolve_comment_sets_fields(self, mock_session: Mock) -> None: + comment = Mock() + comment.resolved = False + comment.resolved_at = None + comment.resolved_by = None + + with ( + patch.object(WorkflowCommentService, "get_comment", return_value=comment), + patch.object(service_module, "naive_utc_now", return_value="now"), + ): + result = WorkflowCommentService.resolve_comment("tenant-1", "app-1", "comment-1", "user-1") + + assert result is comment + assert comment.resolved is True + assert comment.resolved_at == "now" + assert comment.resolved_by == "user-1" + mock_session.commit.assert_called_once() + + def test_resolve_comment_noop_when_already_resolved(self, mock_session: Mock) -> None: + comment = Mock() + comment.resolved = True + + with patch.object(WorkflowCommentService, "get_comment", return_value=comment): + result = WorkflowCommentService.resolve_comment("tenant-1", "app-1", "comment-1", "user-1") + + assert result is comment + mock_session.commit.assert_not_called() + + def test_create_reply_requires_comment(self, mock_session: Mock) -> None: + mock_session.get.return_value = None + + with pytest.raises(NotFound): + WorkflowCommentService.create_reply("comment-1", "hello", "user-1") + + def test_create_reply_creates_mentions(self, mock_session: Mock) -> None: + mock_session.get.return_value = Mock() + reply = Mock() + reply.id = "reply-1" + reply.created_at = "ts" + + with ( + patch.object(service_module, "WorkflowCommentReply", return_value=reply), + patch.object(service_module, "WorkflowCommentMention", return_value=Mock()), + patch.object(WorkflowCommentService, "_filter_valid_mentioned_user_ids", return_value=["user-2"]), + ): + result = WorkflowCommentService.create_reply( + comment_id="comment-1", + content="hello", + created_by="user-1", + mentioned_user_ids=["user-2", "bad-id"], + ) + + assert result == {"id": "reply-1", "created_at": "ts"} + assert mock_session.add.call_count == 2 + mock_session.commit.assert_called_once() + + def test_update_reply_raises_not_found(self, mock_session: Mock) -> None: + mock_session.scalar.return_value = None + + with pytest.raises(NotFound): + WorkflowCommentService.update_reply( + tenant_id="tenant-1", + app_id="app-1", + comment_id="comment-1", + reply_id="reply-1", + user_id="user-1", + content="hello", + ) + + def test_update_reply_raises_forbidden(self, mock_session: Mock) -> None: + reply = Mock() + reply.created_by = "owner" + mock_session.scalar.return_value = reply + + with pytest.raises(Forbidden): + WorkflowCommentService.update_reply( + tenant_id="tenant-1", + app_id="app-1", + comment_id="comment-1", + reply_id="reply-1", + user_id="intruder", + content="hello", + ) + + def test_update_reply_replaces_mentions(self, mock_session: Mock) -> None: + reply = Mock() + reply.id = "reply-1" + reply.comment_id = "comment-1" + reply.created_by = "owner" + reply.updated_at = "updated" + mock_session.scalar.return_value = reply + mock_session.scalars.return_value = _mock_scalars([Mock()]) + comment = Mock() + comment.tenant_id = "tenant-1" + comment.app_id = "app-1" + mock_session.get.return_value = comment + + with patch.object(WorkflowCommentService, "_filter_valid_mentioned_user_ids", return_value=["user-2"]): + result = WorkflowCommentService.update_reply( + tenant_id="tenant-1", + app_id="app-1", + comment_id="comment-1", + reply_id="reply-1", + user_id="owner", + content="new", + mentioned_user_ids=["user-2", "bad-id"], + ) + + assert result == {"id": "reply-1", "updated_at": "updated"} + assert mock_session.delete.call_count == 1 + assert mock_session.add.call_count == 1 + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once_with(reply) + + def test_update_comment_updates_position_coordinates_when_provided(self, mock_session: Mock) -> None: + comment = Mock() + comment.id = "comment-1" + comment.created_by = "owner" + comment.position_x = 1.0 + comment.position_y = 2.0 + mock_session.scalar.return_value = comment + mock_session.scalars.return_value = _mock_scalars([]) + + WorkflowCommentService.update_comment( + tenant_id="tenant-1", + app_id="app-1", + comment_id="comment-1", + user_id="owner", + content="updated", + position_x=10.5, + position_y=20.5, + mentioned_user_ids=[], + ) + + assert comment.position_x == 10.5 + assert comment.position_y == 20.5 + + def test_delete_reply_raises_forbidden(self, mock_session: Mock) -> None: + reply = Mock() + reply.created_by = "owner" + mock_session.scalar.return_value = reply + + with pytest.raises(Forbidden): + WorkflowCommentService.delete_reply( + tenant_id="tenant-1", + app_id="app-1", + comment_id="comment-1", + reply_id="reply-1", + user_id="intruder", + ) + + def test_delete_reply_raises_not_found(self, mock_session: Mock) -> None: + mock_session.scalar.return_value = None + + with pytest.raises(NotFound): + WorkflowCommentService.delete_reply( + tenant_id="tenant-1", + app_id="app-1", + comment_id="comment-1", + reply_id="reply-1", + user_id="owner", + ) + + def test_delete_reply_removes_mentions(self, mock_session: Mock) -> None: + reply = Mock() + reply.created_by = "owner" + mock_session.scalar.return_value = reply + mock_session.scalars.return_value = _mock_scalars([Mock(), Mock()]) + + WorkflowCommentService.delete_reply( + tenant_id="tenant-1", + app_id="app-1", + comment_id="comment-1", + reply_id="reply-1", + user_id="owner", + ) + + assert mock_session.delete.call_count == 3 + mock_session.commit.assert_called_once() + + def test_validate_comment_access_delegates_to_get_comment(self) -> None: + comment = Mock() + with patch.object(WorkflowCommentService, "get_comment", return_value=comment) as get_comment_mock: + result = WorkflowCommentService.validate_comment_access("comment-1", "tenant-1", "app-1") + + assert result is comment + get_comment_mock.assert_called_once_with("tenant-1", "app-1", "comment-1") diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 7906daacfc..351f6ffb5f 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -12,7 +12,7 @@ This test suite covers: import json import uuid from typing import Any, cast -from unittest.mock import ANY, MagicMock, patch +from unittest.mock import ANY, MagicMock, Mock, patch import pytest from graphon.entities import WorkflowNodeExecution @@ -713,6 +713,79 @@ class TestWorkflowService: with pytest.raises(ValueError, match="Invalid app mode"): workflow_service.validate_features_structure(app, features) + # ==================== Draft Workflow Variable Update Tests ==================== + # These tests verify updating draft workflow environment/conversation variables + + def test_update_draft_workflow_environment_variables_updates_workflow(self, workflow_service, mock_db_session): + """Test update_draft_workflow_environment_variables updates draft fields.""" + app = TestWorkflowAssociatedDataFactory.create_app_mock() + account = TestWorkflowAssociatedDataFactory.create_account_mock() + workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock() + variables = [Mock()] + + with ( + patch.object(workflow_service, "get_draft_workflow", return_value=workflow), + patch("services.workflow_service.naive_utc_now", return_value="now"), + ): + workflow_service.update_draft_workflow_environment_variables( + app_model=app, + environment_variables=variables, + account=account, + ) + + assert workflow.environment_variables == variables + assert workflow.updated_by == account.id + assert workflow.updated_at == "now" + mock_db_session.session.commit.assert_called_once() + + def test_update_draft_workflow_environment_variables_raises_when_missing(self, workflow_service): + """Test update_draft_workflow_environment_variables raises when draft missing.""" + app = TestWorkflowAssociatedDataFactory.create_app_mock() + account = TestWorkflowAssociatedDataFactory.create_account_mock() + + with patch.object(workflow_service, "get_draft_workflow", return_value=None): + with pytest.raises(ValueError, match="No draft workflow found."): + workflow_service.update_draft_workflow_environment_variables( + app_model=app, + environment_variables=[], + account=account, + ) + + def test_update_draft_workflow_conversation_variables_updates_workflow(self, workflow_service, mock_db_session): + """Test update_draft_workflow_conversation_variables updates draft fields.""" + app = TestWorkflowAssociatedDataFactory.create_app_mock() + account = TestWorkflowAssociatedDataFactory.create_account_mock() + workflow = TestWorkflowAssociatedDataFactory.create_workflow_mock() + variables = [Mock()] + + with ( + patch.object(workflow_service, "get_draft_workflow", return_value=workflow), + patch("services.workflow_service.naive_utc_now", return_value="now"), + ): + workflow_service.update_draft_workflow_conversation_variables( + app_model=app, + conversation_variables=variables, + account=account, + ) + + assert workflow.conversation_variables == variables + assert workflow.updated_by == account.id + assert workflow.updated_at == "now" + mock_db_session.session.commit.assert_called_once() + + def test_update_draft_workflow_conversation_variables_raises_when_missing(self, workflow_service): + """Test update_draft_workflow_conversation_variables raises when draft missing.""" + app = TestWorkflowAssociatedDataFactory.create_app_mock() + account = TestWorkflowAssociatedDataFactory.create_account_mock() + + with patch.object(workflow_service, "get_draft_workflow", return_value=None): + with pytest.raises(ValueError, match="No draft workflow found."): + workflow_service.update_draft_workflow_conversation_variables( + app_model=app, + conversation_variables=[], + account=account, + ) + # ==================== Publish Workflow Tests ==================== # These tests verify creating published versions from draft workflows diff --git a/api/uv.lock b/api/uv.lock index 5250a614b6..87588493ee 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -537,6 +537,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1a/39/47f9197bdd44df24d67ac8893641e16f386c984a0619ef2ee4c51fbbc019/beautifulsoup4-4.14.3-py3-none-any.whl", hash = "sha256:0918bfe44902e6ad8d57732ba310582e98da931428d231a5ecb9e7c703a735bb", size = 107721, upload-time = "2025-11-30T15:08:24.087Z" }, ] +[[package]] +name = "bidict" +version = "0.23.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/6e/026678aa5a830e07cd9498a05d3e7e650a4f56a42f267a53d22bcda1bdc9/bidict-0.23.1.tar.gz", hash = "sha256:03069d763bc387bbd20e7d49914e75fc4132a41937fa3405417e1a5a2d006d71", size = 29093, upload-time = "2024-02-18T19:09:05.748Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/37/e8730c3587a65eb5645d4aba2d27aae48e8003614d6aaf15dda67f702f1f/bidict-0.23.1-py3-none-any.whl", hash = "sha256:5dae8d4d79b552a71cbabc7deb25dfe8ce710b17ff41711e13010ead2abfc3e5", size = 32764, upload-time = "2024-02-18T19:09:04.156Z" }, +] + [[package]] name = "billiard" version = "4.2.3" @@ -1290,6 +1299,7 @@ dependencies = [ { name = "flask-orjson" }, { name = "flask-restx" }, { name = "gevent" }, + { name = "gevent-websocket" }, { name = "gmpy2" }, { name = "google-api-python-client" }, { name = "google-cloud-aiplatform" }, @@ -1311,10 +1321,12 @@ dependencies = [ { name = "opik" }, { name = "psycogreen" }, { name = "psycopg2-binary" }, + { name = "python-socketio" }, { name = "readabilipy" }, { name = "redis", extra = ["hiredis"] }, { name = "resend" }, { name = "sendgrid" }, + { name = "sseclient-py" }, { name = "weave" }, ] @@ -1341,7 +1353,6 @@ dev = [ { name = "pytest-xdist" }, { name = "ruff" }, { name = "scipy-stubs" }, - { name = "sseclient-py" }, { name = "testcontainers" }, { name = "types-aiofiles" }, { name = "types-beautifulsoup4" }, @@ -1542,6 +1553,7 @@ requires-dist = [ { name = "flask-orjson", specifier = ">=2.0.0,<3.0.0" }, { name = "flask-restx", specifier = ">=1.3.2,<2.0.0" }, { name = "gevent", specifier = ">=26.4.0" }, + { name = "gevent-websocket", specifier = ">=0.10.1" }, { name = "gmpy2", specifier = ">=2.3.0" }, { name = "google-api-python-client", specifier = ">=2.194.0" }, { name = "google-cloud-aiplatform", specifier = ">=1.147.0,<2.0.0" }, @@ -1563,10 +1575,12 @@ requires-dist = [ { name = "opik", specifier = "~=1.11.2" }, { name = "psycogreen", specifier = ">=1.0.2" }, { name = "psycopg2-binary", specifier = ">=2.9.11" }, + { name = "python-socketio", specifier = ">=5.13.0" }, { name = "readabilipy", specifier = ">=0.3.0,<1.0.0" }, { name = "redis", extras = ["hiredis"], specifier = ">=7.4.0" }, { name = "resend", specifier = ">=2.27.0,<3.0.0" }, { name = "sendgrid", specifier = ">=6.12.5" }, + { name = "sseclient-py", specifier = ">=1.8.0" }, { name = "weave", specifier = ">=0.52.36,<1.0.0" }, ] @@ -1593,7 +1607,6 @@ dev = [ { name = "pytest-xdist", specifier = ">=3.8.0" }, { name = "ruff", specifier = ">=0.15.10" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, - { name = "sseclient-py", specifier = ">=1.8.0" }, { name = "testcontainers", specifier = ">=4.14.2" }, { name = "types-aiofiles", specifier = ">=25.1.0" }, { name = "types-beautifulsoup4", specifier = ">=4.12.0" }, @@ -2464,6 +2477,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f6/df/7875e08b06a95f4577b71708ec470d029fadf873a66eb813a2861d79dfb5/gevent-26.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:1c737e6ac6ce1398df0e3f41c58d982e397c993cbe73ac05b7edbe39e128c9cb", size = 1680530, upload-time = "2026-04-08T23:15:38.714Z" }, ] +[[package]] +name = "gevent-websocket" +version = "0.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gevent" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/d2/6fa19239ff1ab072af40ebf339acd91fb97f34617c2ee625b8e34bf42393/gevent-websocket-0.10.1.tar.gz", hash = "sha256:7eaef32968290c9121f7c35b973e2cc302ffb076d018c9068d2f5ca8b2d85fb0", size = 18366, upload-time = "2017-03-12T22:46:05.68Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/84/2dc373eb6493e00c884cc11e6c059ec97abae2678d42f06bf780570b0193/gevent_websocket-0.10.1-py3-none-any.whl", hash = "sha256:17b67d91282f8f4c973eba0551183fc84f56f1c90c8f6b6b30256f31f66f5242", size = 22987, upload-time = "2017-03-12T22:46:03.611Z" }, +] + [[package]] name = "gitdb" version = "4.0.12" @@ -5313,6 +5338,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101, upload-time = "2026-03-01T16:00:25.09Z" }, ] +[[package]] +name = "python-engineio" +version = "4.13.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "simple-websocket" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/34/12/bdef9dbeedbe2cdeba2a2056ad27b1fb081557d34b69a97f574843462cae/python_engineio-4.13.1.tar.gz", hash = "sha256:0a853fcef52f5b345425d8c2b921ac85023a04dfcf75d7b74696c61e940fd066", size = 92348, upload-time = "2026-02-06T23:38:06.12Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/54/0cce26da03a981f949bb8449c9778537f75f5917c172e1d2992ff25cb57d/python_engineio-4.13.1-py3-none-any.whl", hash = "sha256:f32ad10589859c11053ad7d9bb3c9695cdf862113bfb0d20bc4d890198287399", size = 59847, upload-time = "2026-02-06T23:38:04.861Z" }, +] + [[package]] name = "python-http-client" version = "3.3.7" @@ -5369,6 +5406,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/4f/00be2196329ebbff56ce564aa94efb0fbc828d00de250b1980de1a34ab49/python_pptx-1.0.2-py3-none-any.whl", hash = "sha256:160838e0b8565a8b1f67947675886e9fea18aa5e795db7ae531606d68e785cba", size = 472788, upload-time = "2024-08-07T17:33:28.192Z" }, ] +[[package]] +name = "python-socketio" +version = "5.16.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "bidict" }, + { name = "python-engineio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/59/81/cf8284f45e32efa18d3848ed82cdd4dcc1b657b082458fbe01ad3e1f2f8d/python_socketio-5.16.1.tar.gz", hash = "sha256:f863f98eacce81ceea2e742f6388e10ca3cdd0764be21d30d5196470edf5ea89", size = 128508, upload-time = "2026-02-06T23:42:07Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/c7/deb8c5e604404dbf10a3808a858946ca3547692ff6316b698945bb72177e/python_socketio-5.16.1-py3-none-any.whl", hash = "sha256:a3eb1702e92aa2f2b5d3ba00261b61f062cce51f1cfb6900bf3ab4d1934d2d35", size = 82054, upload-time = "2026-02-06T23:42:05.772Z" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -5749,6 +5799,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, ] +[[package]] +name = "simple-websocket" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wsproto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b0/d4/bfa032f961103eba93de583b161f0e6a5b63cebb8f2c7d0c6e6efe1e3d2e/simple_websocket-1.1.0.tar.gz", hash = "sha256:7939234e7aa067c534abdab3a9ed933ec9ce4691b0713c78acb195560aa52ae4", size = 17300, upload-time = "2024-10-10T22:39:31.412Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/59/0782e51887ac6b07ffd1570e0364cf901ebc36345fea669969d2084baebb/simple_websocket-1.1.0-py3-none-any.whl", hash = "sha256:4af6069630a38ed6c561010f0e11a5bc0d4ca569b36306eb257cd9a192497c8c", size = 13842, upload-time = "2024-10-10T22:39:29.645Z" }, +] + [[package]] name = "six" version = "1.17.0" @@ -7185,6 +7247,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ff/21/abdedb4cdf6ff41ebf01a74087740a709e2edb146490e4d9beea054b0b7a/wrapt-1.16.0-py3-none-any.whl", hash = "sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1", size = 23362, upload-time = "2023-11-09T06:33:28.271Z" }, ] +[[package]] +name = "wsproto" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/79/12135bdf8b9c9367b8701c2c19a14c913c120b882d50b014ca0d38083c2c/wsproto-1.3.2.tar.gz", hash = "sha256:b86885dcf294e15204919950f666e06ffc6c7c114ca900b060d6e16293528294", size = 50116, upload-time = "2025-11-20T18:18:01.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/f5/10b68b7b1544245097b2a1b8238f66f2fc6dcaeb24ba5d917f52bd2eed4f/wsproto-1.3.2-py3-none-any.whl", hash = "sha256:61eea322cdf56e8cc904bd3ad7573359a242ba65688716b0710a5eb12beab584", size = 24405, upload-time = "2025-11-20T18:18:00.454Z" }, +] + [[package]] name = "xinference-client" version = "2.4.0" diff --git a/docker/.env.example b/docker/.env.example index 856b04a3df..8176155698 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -132,6 +132,10 @@ MIGRATION_ENABLED=true # The default value is 300 seconds. FILES_ACCESS_TIMEOUT=300 +# Collaboration mode toggle +# To open collaboration features, you also need to set SERVER_WORKER_CLASS=geventwebsocket.gunicorn.workers.GeventWebSocketWorker +ENABLE_COLLABORATION_MODE=false + # Access token expiration time in minutes ACCESS_TOKEN_EXPIRE_MINUTES=60 @@ -167,6 +171,7 @@ SERVER_WORKER_AMOUNT=1 # Modifying it may also decrease throughput. # # It is strongly discouraged to change this parameter. +# If enable collaboration mode, it must be set to geventwebsocket.gunicorn.workers.GeventWebSocketWorker SERVER_WORKER_CLASS=gevent # Default number of worker connections, the default is 10. @@ -428,6 +433,8 @@ CONSOLE_CORS_ALLOW_ORIGINS=* COOKIE_DOMAIN= # When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. NEXT_PUBLIC_COOKIE_DOMAIN= +# WebSocket server URL. +NEXT_PUBLIC_SOCKET_URL=ws://localhost NEXT_PUBLIC_BATCH_CONCURRENCY=5 # ------------------------------ diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 4f4b3851f6..888f96332c 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -159,6 +159,7 @@ services: APP_API_URL: ${APP_API_URL:-} AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-} NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-} + NEXT_PUBLIC_SOCKET_URL: ${NEXT_PUBLIC_SOCKET_URL:-ws://localhost} SENTRY_DSN: ${WEB_SENTRY_DSN:-} NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} EXPERIMENTAL_ENABLE_VINEXT: ${EXPERIMENTAL_ENABLE_VINEXT:-false} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index c1ddba4f80..a10fdf77c6 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -34,6 +34,7 @@ x-shared-env: &shared-api-worker-env OPENAI_API_BASE: ${OPENAI_API_BASE:-https://api.openai.com/v1} MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} + ENABLE_COLLABORATION_MODE: ${ENABLE_COLLABORATION_MODE:-false} ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} REFRESH_TOKEN_EXPIRE_DAYS: ${REFRESH_TOKEN_EXPIRE_DAYS:-30} APP_DEFAULT_ACTIVE_REQUESTS: ${APP_DEFAULT_ACTIVE_REQUESTS:-0} @@ -119,6 +120,7 @@ x-shared-env: &shared-api-worker-env CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*} COOKIE_DOMAIN: ${COOKIE_DOMAIN:-} NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-} + NEXT_PUBLIC_SOCKET_URL: ${NEXT_PUBLIC_SOCKET_URL:-ws://localhost} NEXT_PUBLIC_BATCH_CONCURRENCY: ${NEXT_PUBLIC_BATCH_CONCURRENCY:-5} STORAGE_TYPE: ${STORAGE_TYPE:-opendal} OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs} @@ -878,6 +880,7 @@ services: APP_API_URL: ${APP_API_URL:-} AMPLITUDE_API_KEY: ${AMPLITUDE_API_KEY:-} NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-} + NEXT_PUBLIC_SOCKET_URL: ${NEXT_PUBLIC_SOCKET_URL:-ws://localhost} SENTRY_DSN: ${WEB_SENTRY_DSN:-} NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} EXPERIMENTAL_ENABLE_VINEXT: ${EXPERIMENTAL_ENABLE_VINEXT:-false} diff --git a/docker/nginx/conf.d/default.conf.template b/docker/nginx/conf.d/default.conf.template index 1d63c1b97d..94a748290f 100644 --- a/docker/nginx/conf.d/default.conf.template +++ b/docker/nginx/conf.d/default.conf.template @@ -14,6 +14,14 @@ server { include proxy.conf; } + location /socket.io/ { + proxy_pass http://api:5001; + include proxy.conf; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_cache_bypass $http_upgrade; + } + location /v1 { proxy_pass http://api:5001; include proxy.conf; diff --git a/packages/iconify-collections/assets/public/common/enter-key.svg b/packages/iconify-collections/assets/public/common/enter-key.svg new file mode 100644 index 0000000000..edfddfc188 --- /dev/null +++ b/packages/iconify-collections/assets/public/common/enter-key.svg @@ -0,0 +1,4 @@ + + + + diff --git a/packages/iconify-collections/assets/public/other/comment.svg b/packages/iconify-collections/assets/public/other/comment.svg new file mode 100644 index 0000000000..0f0609f0b6 --- /dev/null +++ b/packages/iconify-collections/assets/public/other/comment.svg @@ -0,0 +1,3 @@ + + + diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 094faf78cb..8d6dfa8b2f 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -384,6 +384,9 @@ catalogs: lexical: specifier: 0.43.0 version: 0.43.0 + loro-crdt: + specifier: 1.10.8 + version: 1.10.8 mermaid: specifier: 11.14.0 version: 11.14.0 @@ -471,6 +474,9 @@ catalogs: shiki: specifier: 4.0.2 version: 4.0.2 + socket.io-client: + specifier: 4.8.3 + version: 4.8.3 sortablejs: specifier: 1.15.7 version: 1.15.7 @@ -839,6 +845,9 @@ importers: lexical: specifier: 'catalog:' version: 0.43.0 + loro-crdt: + specifier: 'catalog:' + version: 1.10.8 mermaid: specifier: 'catalog:' version: 11.14.0 @@ -920,6 +929,9 @@ importers: shiki: specifier: 'catalog:' version: 4.0.2 + socket.io-client: + specifier: 'catalog:' + version: 4.8.3 sortablejs: specifier: 'catalog:' version: 1.15.7 @@ -3544,6 +3556,9 @@ packages: resolution: {integrity: sha512-TeheYy0ILzBEI/CO55CP6zJCSdSWeRtGnHy8U8dWSUH4I68iqTsy7HkMktR4xakThc9jotkPQUXT4ITdbV7cHA==} engines: {node: '>=18'} + '@socket.io/component-emitter@3.1.2': + resolution: {integrity: sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==} + '@solid-primitives/event-listener@2.4.5': resolution: {integrity: sha512-nwRV558mIabl4yVAhZKY8cb6G+O1F0M6Z75ttTu5hk+SxdOnKSGj+eetDIu7Oax1P138ZdUU01qnBPR8rnxaEA==} peerDependencies: @@ -5500,6 +5515,13 @@ packages: end-of-stream@1.4.5: resolution: {integrity: sha512-ooEGc6HP26xXq/N+GCGOT0JKCLDGrq2bQUZrQ7gyrJiZANJ/8YDTxTpQBXGMn+WbIQXNVpyWymm7KYVICQnyOg==} + engine.io-client@6.6.4: + resolution: {integrity: sha512-+kjUJnZGwzewFDw951CDWcwj35vMNf2fcj7xQWOctq1F2i1jkDdVvdFG9kM/BEChymCH36KgjnW0NsL58JYRxw==} + + engine.io-parser@5.2.3: + resolution: {integrity: sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==} + engines: {node: '>=10.0.0'} + enhanced-resolve@5.20.1: resolution: {integrity: sha512-Qohcme7V1inbAfvjItgw0EaxVX5q2rdVEZHRBrEQdRZTssLDGsL8Lwrznl8oQ/6kuTJONLaDcGjkNP247XEhcA==} engines: {node: '>=10.13.0'} @@ -6624,6 +6646,9 @@ packages: resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} hasBin: true + loro-crdt@1.10.8: + resolution: {integrity: sha512-GvH8fSJST1VDHRGzlQml80pBYoFbIP4ULeV1S8fD4ffmA8m+icoPORyVUW2AkJBY3dxKIcMMn0WqaJmpCmnbkQ==} + loupe@3.2.1: resolution: {integrity: sha512-CdzqowRJCeLU72bHvWqwRBBlLcMEtIvGrlvef74kMnV2AolS9Y8xUv1I0U/MNAWMhBlKIoyuEgoJ0t/bbwHbLQ==} @@ -7763,6 +7788,14 @@ packages: resolution: {integrity: sha512-dWUG8F5sIIARXih1DTaQAX4SsiTXhInKf1buxdY9DIg4ZYPZK5nGM1VRIYmEbDbsHt7USo99xSLFu5Q1IqTmsg==} engines: {node: '>= 18'} + socket.io-client@4.8.3: + resolution: {integrity: sha512-uP0bpjWrjQmUt5DTHq9RuoCBdFJF10cdX9X+a368j/Ft0wmaVgxlrjvK3kjvgCODOMMOz9lcaRzxmso0bTWZ/g==} + engines: {node: '>=10.0.0'} + + socket.io-parser@4.2.6: + resolution: {integrity: sha512-asJqbVBDsBCJx0pTqw3WfesSY0iRX+2xzWEWzrpcH7L6fLzrhyF8WPI8UaeM4YCuDfpwA/cgsdugMsmtz8EJeg==} + engines: {node: '>=10.0.0'} + solid-js@1.9.11: resolution: {integrity: sha512-WEJtcc5mkh/BnHA6Yrg4whlF8g6QwpmXXRg4P2ztPmcKeHHlH4+djYecBLhSpecZY2RRECXYUwIc/C2r3yzQ4Q==} @@ -8490,6 +8523,18 @@ packages: wrappy@1.0.2: resolution: {integrity: sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==} + ws@8.18.3: + resolution: {integrity: sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==} + engines: {node: '>=10.0.0'} + peerDependencies: + bufferutil: ^4.0.1 + utf-8-validate: '>=5.0.2' + peerDependenciesMeta: + bufferutil: + optional: true + utf-8-validate: + optional: true + ws@8.20.0: resolution: {integrity: sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA==} engines: {node: '>=10.0.0'} @@ -8518,6 +8563,10 @@ packages: resolution: {integrity: sha512-yMqGBqtXyeN1e3TGYvgNgDVZ3j84W4cwkOXQswghol6APgZWaff9lnbvN7MHYJOiXsvGPXtjTYJEiC9J2wv9Eg==} engines: {node: '>=8.0'} + xmlhttprequest-ssl@2.1.2: + resolution: {integrity: sha512-TEU+nJVUUnA4CYJFLvK5X9AOeH4KvDvhIfm0vV1GaQRtchnG0hgK5p8hw/xjv8cunWYCsiPCSDzObPyhEwq3KQ==} + engines: {node: '>=0.4.0'} + yallist@3.1.1: resolution: {integrity: sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g==} @@ -10902,6 +10951,8 @@ snapshots: '@sindresorhus/base62@1.0.0': {} + '@socket.io/component-emitter@3.1.2': {} + '@solid-primitives/event-listener@2.4.5(solid-js@1.9.11)': dependencies: '@solid-primitives/utils': 6.4.0(solid-js@1.9.11) @@ -12966,6 +13017,20 @@ snapshots: dependencies: once: 1.4.0 + engine.io-client@6.6.4: + dependencies: + '@socket.io/component-emitter': 3.1.2 + debug: 4.4.3(supports-color@8.1.1) + engine.io-parser: 5.2.3 + ws: 8.18.3 + xmlhttprequest-ssl: 2.1.2 + transitivePeerDependencies: + - bufferutil + - supports-color + - utf-8-validate + + engine.io-parser@5.2.3: {} + enhanced-resolve@5.20.1: dependencies: graceful-fs: 4.2.11 @@ -14310,6 +14375,8 @@ snapshots: dependencies: js-tokens: 4.0.0 + loro-crdt@1.10.8: {} + loupe@3.2.1: {} lower-case@2.0.2: @@ -16002,6 +16069,24 @@ snapshots: smol-toml@1.6.1: {} + socket.io-client@4.8.3: + dependencies: + '@socket.io/component-emitter': 3.1.2 + debug: 4.4.3(supports-color@8.1.1) + engine.io-client: 6.6.4 + socket.io-parser: 4.2.6 + transitivePeerDependencies: + - bufferutil + - supports-color + - utf-8-validate + + socket.io-parser@4.2.6: + dependencies: + '@socket.io/component-emitter': 3.1.2 + debug: 4.4.3(supports-color@8.1.1) + transitivePeerDependencies: + - supports-color + solid-js@1.9.11: dependencies: csstype: 3.2.3 @@ -16776,6 +16861,8 @@ snapshots: wrappy@1.0.2: {} + ws@8.18.3: {} + ws@8.20.0: {} wsl-utils@0.1.0: @@ -16791,6 +16878,8 @@ snapshots: xmlbuilder@15.1.1: {} + xmlhttprequest-ssl@2.1.2: {} + yallist@3.1.1: {} yallist@5.0.0: {} diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index cd72c6bc0e..1e142fc3b5 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -142,6 +142,7 @@ catalog: ky: 2.0.0 lamejs: 1.2.1 lexical: 0.43.0 + loro-crdt: 1.10.8 mermaid: 11.14.0 mime: 4.1.0 mitt: 3.0.1 @@ -172,6 +173,7 @@ catalog: scheduler: 0.27.0 sharp: 0.34.5 shiki: 4.0.2 + socket.io-client: 4.8.3 sortablejs: 1.15.7 std-semver: 1.0.8 storybook: 10.3.5 diff --git a/web/.env.example b/web/.env.example index 93cbc22fc8..643aba482e 100644 --- a/web/.env.example +++ b/web/.env.example @@ -14,6 +14,8 @@ NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api # When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. NEXT_PUBLIC_COOKIE_DOMAIN= +# WebSocket server URL. +NEXT_PUBLIC_SOCKET_URL=ws://localhost:5001 # Dev-only Hono proxy targets. # The frontend keeps requesting http://localhost:5001 directly, diff --git a/web/__tests__/app/app-access-control-flow.test.tsx b/web/__tests__/app/app-access-control-flow.test.tsx index 49443eb4ec..63f7fd0378 100644 --- a/web/__tests__/app/app-access-control-flow.test.tsx +++ b/web/__tests__/app/app-access-control-flow.test.tsx @@ -1,3 +1,4 @@ +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import AppPublisher from '@/app/components/app/app-publisher' @@ -23,6 +24,27 @@ let mockAppDetail: { } } | null = null +const createTestQueryClient = () => + new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + mutations: { + retry: false, + }, + }, + }) + +const renderWithQueryClient = (ui: React.ReactElement) => { + const queryClient = createTestQueryClient() + return render( + + {ui} + , + ) +} + vi.mock('react-i18next', () => ({ useTranslation: () => ({ t: (key: string, options?: { ns?: string }) => options?.ns ? `${options.ns}.${key}` : key, @@ -76,6 +98,18 @@ vi.mock('@/app/components/app/overview/embedded', () => ({ default: () => null, })) +vi.mock('@/app/components/workflow/collaboration/core/websocket-manager', () => ({ + webSocketClient: { + getSocket: vi.fn(() => null), + }, +})) + +vi.mock('@/app/components/workflow/collaboration/core/collaboration-manager', () => ({ + collaborationManager: { + onAppPublishUpdate: vi.fn(() => vi.fn()), + }, +})) + vi.mock('@/app/components/app/app-access-control', () => ({ default: ({ onConfirm, @@ -115,7 +149,7 @@ describe('App Access Control Flow', () => { }) it('refreshes app detail after confirming access control updates', async () => { - render() + renderWithQueryClient() fireEvent.click(screen.getByRole('button', { name: 'workflow.common.publish' })) fireEvent.click(screen.getByText('app.accessControlDialog.accessItems.specific')) diff --git a/web/__tests__/app/app-publisher-flow.test.tsx b/web/__tests__/app/app-publisher-flow.test.tsx index 5c330cf71e..4f24fc310c 100644 --- a/web/__tests__/app/app-publisher-flow.test.tsx +++ b/web/__tests__/app/app-publisher-flow.test.tsx @@ -1,3 +1,4 @@ +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import AppPublisher from '@/app/components/app/app-publisher' @@ -27,6 +28,27 @@ let mockAppDetail: { } } | null = null +const createTestQueryClient = () => + new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + mutations: { + retry: false, + }, + }, + }) + +const renderWithQueryClient = (ui: React.ReactElement) => { + const queryClient = createTestQueryClient() + return render( + + {ui} + , + ) +} + vi.mock('react-i18next', () => ({ useTranslation: () => ({ t: (key: string) => key, @@ -106,6 +128,18 @@ vi.mock('@/app/components/app/overview/embedded', () => ({ ), })) +vi.mock('@/app/components/workflow/collaboration/core/websocket-manager', () => ({ + webSocketClient: { + getSocket: vi.fn(() => null), + }, +})) + +vi.mock('@/app/components/workflow/collaboration/core/collaboration-manager', () => ({ + collaborationManager: { + onAppPublishUpdate: vi.fn(() => vi.fn()), + }, +})) + vi.mock('@/app/components/app/app-access-control', () => ({ default: () =>
, })) @@ -183,7 +217,7 @@ describe('App Publisher Flow', () => { it('publishes from the summary panel and tracks the publish event', async () => { const onPublish = vi.fn().mockResolvedValue(undefined) - render( + renderWithQueryClient( { }) it('opens embedded modal and resolves the installed explore target', async () => { - render() + renderWithQueryClient() fireEvent.click(screen.getByText('common.publish')) fireEvent.click(screen.getByText('common.embedIntoSite')) @@ -231,7 +265,7 @@ describe('App Publisher Flow', () => { installed_apps: [], }) - render() + renderWithQueryClient() fireEvent.click(screen.getByText('common.publish')) fireEvent.click(screen.getByText('common.openInExplore')) diff --git a/web/__tests__/apps/app-list-browsing-flow.test.tsx b/web/__tests__/apps/app-list-browsing-flow.test.tsx index 1be7e56086..a5ed79a7bd 100644 --- a/web/__tests__/apps/app-list-browsing-flow.test.tsx +++ b/web/__tests__/apps/app-list-browsing-flow.test.tsx @@ -93,6 +93,10 @@ vi.mock('@/service/tag', () => ({ fetchTagList: vi.fn().mockResolvedValue([]), })) +vi.mock('@/service/apps', () => ({ + fetchWorkflowOnlineUsers: vi.fn().mockResolvedValue({}), +})) + vi.mock('@/service/use-apps', () => ({ useInfiniteAppList: () => ({ data: { pages: mockPages }, diff --git a/web/__tests__/apps/create-app-flow.test.tsx b/web/__tests__/apps/create-app-flow.test.tsx index bc1f7a3a06..9abc870ecf 100644 --- a/web/__tests__/apps/create-app-flow.test.tsx +++ b/web/__tests__/apps/create-app-flow.test.tsx @@ -80,6 +80,10 @@ vi.mock('@/service/tag', () => ({ fetchTagList: vi.fn().mockResolvedValue([]), })) +vi.mock('@/service/apps', () => ({ + fetchWorkflowOnlineUsers: vi.fn().mockResolvedValue({}), +})) + vi.mock('@/service/use-apps', () => ({ useInfiniteAppList: () => ({ data: { pages: mockPages }, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx index fb2edf0102..289e50c257 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx @@ -5,7 +5,8 @@ import type { BlockEnum } from '@/app/components/workflow/types' import type { UpdateAppSiteCodeResponse } from '@/models/app' import type { App } from '@/types/app' import type { I18nKeysByPrefix } from '@/types/i18n' -import { useCallback, useMemo } from 'react' +import * as React from 'react' +import { useCallback, useEffect, useMemo } from 'react' import { useTranslation } from 'react-i18next' import AppCard from '@/app/components/app/overview/app-card' import TriggerCard from '@/app/components/app/overview/trigger-card' @@ -13,6 +14,8 @@ import { useStore as useAppStore } from '@/app/components/app/store' import Loading from '@/app/components/base/loading' import { toast } from '@/app/components/base/ui/toast' import MCPServiceCard from '@/app/components/tools/mcp/mcp-service-card' +import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' +import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager' import { isTriggerNode } from '@/app/components/workflow/types' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { @@ -72,25 +75,56 @@ const CardView: FC = ({ appId, isInPanel, className }) => { ? buildTriggerModeMessage(t('mcp.server.title', { ns: 'tools' })) : null - const updateAppDetail = async () => { + const updateAppDetail = useCallback(async () => { try { const res = await fetchAppDetail({ url: '/apps', id: appId }) setAppDetail({ ...res }) } - catch (error) { console.error(error) } - } + catch (error) { + console.error(error) + } + }, [appId, setAppDetail]) const handleCallbackResult = (err: Error | null, message?: I18nKeysByPrefix<'common', 'actionMsg.'>) => { const type = err ? 'error' : 'success' message ||= (type === 'success' ? 'modifiedSuccessfully' : 'modifiedUnsuccessfully') - if (type === 'success') + if (type === 'success') { updateAppDetail() + // Emit collaboration event to notify other clients of app state changes + const socket = webSocketClient.getSocket(appId) + if (socket) { + socket.emit('collaboration_event', { + type: 'app_state_update', + data: { timestamp: Date.now() }, + timestamp: Date.now(), + }) + } + } + toast(t(`actionMsg.${message}`, { ns: 'common' }) as string, { type }) } + // Listen for collaborative app state updates from other clients + useEffect(() => { + if (!appId) + return + + const unsubscribe = collaborationManager.onAppStateUpdate(async () => { + try { + // Update app detail when other clients modify app state + await updateAppDetail() + } + catch (error) { + console.error('app state update failed:', error) + } + }) + + return unsubscribe + }, [appId, updateAppDetail]) + const onChangeSiteStatus = async (value: boolean) => { const [err] = await asyncRunSafe( updateAppSiteStatus({ diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx index c63d482c8f..6ff8725f63 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx @@ -12,7 +12,6 @@ import { PortalToFollowElem, PortalToFollowElemContent, } from '@/app/components/base/portal-to-follow-elem' -import { Button } from '@/app/components/base/ui/button' import { AlertDialog, AlertDialogActions, @@ -22,6 +21,7 @@ import { AlertDialogDescription, AlertDialogTitle, } from '@/app/components/base/ui/alert-dialog' +import { Button } from '@/app/components/base/ui/button' import { toast } from '@/app/components/base/ui/toast' import { addTracingConfig, removeTracingConfig, updateTracingConfig } from '@/service/apps' import { docURL } from './config' diff --git a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts index d5eaa4bfe4..5b10b4c32b 100644 --- a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts +++ b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts @@ -20,8 +20,11 @@ const mockUpdateAppInfo = vi.fn() const mockCopyApp = vi.fn() const mockExportAppConfig = vi.fn() const mockDeleteApp = vi.fn() +const mockFetchAppDetail = vi.fn() const mockFetchWorkflowDraft = vi.fn() const mockDownloadBlob = vi.fn() +const mockGetSocket = vi.fn() +const mockOnAppMetaUpdate = vi.fn() let mockAppDetail: Record | undefined = { id: 'app-1', @@ -68,6 +71,7 @@ vi.mock('@/service/apps', () => ({ copyApp: (...args: unknown[]) => mockCopyApp(...args), exportAppConfig: (...args: unknown[]) => mockExportAppConfig(...args), deleteApp: (...args: unknown[]) => mockDeleteApp(...args), + fetchAppDetail: (...args: unknown[]) => mockFetchAppDetail(...args), })) vi.mock('@/service/workflow', () => ({ @@ -82,6 +86,18 @@ vi.mock('@/utils/app-redirection', () => ({ getRedirection: vi.fn(), })) +vi.mock('@/app/components/workflow/collaboration/core/websocket-manager', () => ({ + webSocketClient: { + getSocket: (...args: unknown[]) => mockGetSocket(...args), + }, +})) + +vi.mock('@/app/components/workflow/collaboration/core/collaboration-manager', () => ({ + collaborationManager: { + onAppMetaUpdate: (...args: unknown[]) => mockOnAppMetaUpdate(...args), + }, +})) + vi.mock('@/config', () => ({ NEED_REFRESH_APP_LIST_KEY: 'test-refresh-key', })) @@ -89,6 +105,8 @@ vi.mock('@/config', () => ({ describe('useAppInfoActions', () => { beforeEach(() => { vi.clearAllMocks() + mockOnAppMetaUpdate.mockReturnValue(() => {}) + mockGetSocket.mockReturnValue(null) mockAppDetail = { id: 'app-1', name: 'Test App', @@ -191,6 +209,35 @@ describe('useAppInfoActions', () => { expect(toastMocks.call).toHaveBeenCalledWith({ type: 'success', message: 'app.editDone' }) }) + it('should emit app_meta_update after successful edit when collaboration socket exists', async () => { + const updatedApp = { ...mockAppDetail, name: 'Updated' } + const socket = { emit: vi.fn() } + mockUpdateAppInfo.mockResolvedValue(updatedApp) + mockGetSocket.mockReturnValue(socket) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onEdit({ + name: 'Updated', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + description: '', + use_icon_as_answer_icon: false, + }) + }) + await new Promise(resolve => setTimeout(resolve, 0)) + + expect(mockGetSocket).toHaveBeenCalledWith('app-1') + expect(socket.emit).toHaveBeenCalledWith( + 'collaboration_event', + expect.objectContaining({ + type: 'app_meta_update', + }), + ) + }) + it('should notify error on edit failure', async () => { mockUpdateAppInfo.mockRejectedValue(new Error('fail')) @@ -502,4 +549,31 @@ describe('useAppInfoActions', () => { }) }) }) + + describe('collaboration app meta updates', () => { + it('should refresh app detail when receiving app_meta_update', async () => { + const updated = { ...mockAppDetail, name: 'Remote Updated' } + const unsubscribe = vi.fn() + let onUpdate: (() => Promise) | undefined + + mockOnAppMetaUpdate.mockImplementation((callback: () => Promise) => { + onUpdate = callback + return unsubscribe + }) + mockFetchAppDetail.mockResolvedValue(updated) + + const { unmount } = renderHook(() => useAppInfoActions({})) + await new Promise(resolve => setTimeout(resolve, 0)) + + await act(async () => { + await onUpdate?.() + }) + + expect(mockFetchAppDetail).toHaveBeenCalledWith({ url: '/apps', id: 'app-1' }) + expect(mockSetAppDetail).toHaveBeenCalledWith(updated) + + unmount() + expect(unsubscribe).toHaveBeenCalled() + }) + }) }) diff --git a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts index 8b559f7bba..3192d48f81 100644 --- a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts +++ b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts @@ -1,14 +1,14 @@ import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import type { EnvironmentVariable } from '@/app/components/workflow/types' -import { useCallback, useState } from 'react' +import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { useStore as useAppStore } from '@/app/components/app/store' import { toast } from '@/app/components/base/ui/toast' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useProviderContext } from '@/context/provider-context' import { useRouter } from '@/next/navigation' -import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' +import { copyApp, deleteApp, exportAppConfig, fetchAppDetail, updateAppInfo } from '@/service/apps' import { useInvalidateAppList } from '@/service/use-apps' import { fetchWorkflowDraft } from '@/service/workflow' import { AppModeEnum } from '@/types/app' @@ -47,6 +47,56 @@ export function useAppInfoActions({ onDetailExpand }: UseAppInfoActionsParams) { setActiveModal(null) }, []) + const emitAppMetaUpdate = useCallback(() => { + if (!appDetail?.id) + return + + void import('@/app/components/workflow/collaboration/core/websocket-manager') + .then(({ webSocketClient }) => { + const socket = webSocketClient.getSocket(appDetail.id) + if (!socket) + return + socket.emit('collaboration_event', { + type: 'app_meta_update', + data: { timestamp: Date.now() }, + timestamp: Date.now(), + }) + }) + .catch(() => {}) + }, [appDetail?.id]) + + useEffect(() => { + if (!appDetail?.id) + return + + let unsubscribe: (() => void) | null = null + let disposed = false + + void import('@/app/components/workflow/collaboration/core/collaboration-manager') + .then(({ collaborationManager }) => { + if (disposed) + return + + unsubscribe = collaborationManager.onAppMetaUpdate(async () => { + try { + const res = await fetchAppDetail({ url: '/apps', id: appDetail.id }) + if (disposed) + return + setAppDetail({ ...res }) + } + catch (error) { + console.error('failed to refresh app detail from collaboration update:', error) + } + }) + }) + .catch(() => {}) + + return () => { + disposed = true + unsubscribe?.() + } + }, [appDetail?.id, setAppDetail]) + const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, icon_type, @@ -72,11 +122,12 @@ export function useAppInfoActions({ onDetailExpand }: UseAppInfoActionsParams) { closeModal() toast(t('editDone', { ns: 'app' }), { type: 'success' }) setAppDetail(app) + emitAppMetaUpdate() } catch { toast(t('editFailed', { ns: 'app' }), { type: 'error' }) } - }, [appDetail, closeModal, setAppDetail, t]) + }, [appDetail, closeModal, setAppDetail, t, emitAppMetaUpdate]) const onCopy: DuplicateAppModalProps['onConfirm'] = useCallback(async ({ name, diff --git a/web/app/components/app/app-access-control/__tests__/access-control-dialog.spec.tsx b/web/app/components/app/app-access-control/__tests__/access-control-dialog.spec.tsx index 5c7d2f2dc0..9b3dd8ee05 100644 --- a/web/app/components/app/app-access-control/__tests__/access-control-dialog.spec.tsx +++ b/web/app/components/app/app-access-control/__tests__/access-control-dialog.spec.tsx @@ -1,4 +1,3 @@ -/* eslint-disable ts/no-explicit-any */ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import AccessControlDialog from '../access-control-dialog' diff --git a/web/app/components/app/app-access-control/__tests__/index.spec.tsx b/web/app/components/app/app-access-control/__tests__/index.spec.tsx index f2fa09f98a..612d17d88a 100644 --- a/web/app/components/app/app-access-control/__tests__/index.spec.tsx +++ b/web/app/components/app/app-access-control/__tests__/index.spec.tsx @@ -1,4 +1,3 @@ -/* eslint-disable ts/no-explicit-any */ import type { App } from '@/types/app' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { toast } from '@/app/components/base/ui/toast' diff --git a/web/app/components/app/app-publisher/__tests__/index.spec.tsx b/web/app/components/app/app-publisher/__tests__/index.spec.tsx index e97efaa525..1e29e44c82 100644 --- a/web/app/components/app/app-publisher/__tests__/index.spec.tsx +++ b/web/app/components/app/app-publisher/__tests__/index.spec.tsx @@ -15,6 +15,7 @@ const mockOpenAsyncWindow = vi.fn() const mockFetchInstalledAppList = vi.fn() const mockFetchAppDetailDirect = vi.fn() const mockToastError = vi.fn() +const mockInvalidateAppWorkflow = vi.fn() const sectionProps = vi.hoisted(() => ({ summary: null as null | Record, @@ -88,6 +89,10 @@ vi.mock('@/service/apps', () => ({ fetchAppDetailDirect: (...args: unknown[]) => mockFetchAppDetailDirect(...args), })) +vi.mock('@/service/use-workflow', () => ({ + useInvalidateAppWorkflow: () => mockInvalidateAppWorkflow, +})) + vi.mock('@/app/components/base/ui/toast', () => ({ toast: { error: (...args: unknown[]) => mockToastError(...args), diff --git a/web/app/components/app/app-publisher/features-wrapper.tsx b/web/app/components/app/app-publisher/features-wrapper.tsx index 6a5c9582a7..47ce63645c 100644 --- a/web/app/components/app/app-publisher/features-wrapper.tsx +++ b/web/app/components/app/app-publisher/features-wrapper.tsx @@ -1,6 +1,7 @@ import type { AppPublisherProps } from '@/app/components/app/app-publisher' import type { ModelAndParameter } from '@/app/components/app/configuration/debug/types' import type { FileUpload } from '@/app/components/base/features/types' +import type { PublishWorkflowParams } from '@/types/workflow' import { produce } from 'immer' import * as React from 'react' import { useCallback, useState } from 'react' @@ -21,7 +22,7 @@ import { SupportUploadFileTypes } from '@/app/components/workflow/types' import { Resolution } from '@/types/app' type Props = Omit & { - onPublish?: (modelAndParameter?: ModelAndParameter, features?: any) => Promise | any + onPublish?: (params?: ModelAndParameter | PublishWorkflowParams, features?: any) => Promise | any publishedConfig?: any resetAppConfig?: () => void } @@ -70,8 +71,8 @@ const FeaturesWrappedAppPublisher = (props: Props) => { setRestoreConfirmOpen(false) }, [featuresStore, props]) - const handlePublish = useCallback((modelAndParameter?: ModelAndParameter) => { - return props.onPublish?.(modelAndParameter, features) + const handlePublish = useCallback((params?: ModelAndParameter | PublishWorkflowParams) => { + return props.onPublish?.(params, features) }, [features, props]) return ( diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 9c50a98124..07bd04b954 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -1,10 +1,12 @@ import type { ModelAndParameter } from '../configuration/debug/types' +import type { CollaborationUpdate } from '@/app/components/workflow/collaboration/types/collaboration' import type { InputVar, Variable } from '@/app/components/workflow/types' import type { PublishWorkflowParams } from '@/types/workflow' import { useKeyPress } from 'ahooks' import { memo, useCallback, + useContext, useEffect, useMemo, useState, @@ -19,6 +21,9 @@ import { PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' import { Button } from '@/app/components/base/ui/button' +import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' +import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager' +import { WorkflowContext } from '@/app/components/workflow/context' import { useGlobalPublicStore } from '@/context/global-public-context' import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' @@ -26,6 +31,8 @@ import { AccessMode } from '@/models/access-control' import { useAppWhiteListSubjects, useGetUserCanAccessApp } from '@/service/access-control' import { fetchAppDetailDirect } from '@/service/apps' import { fetchInstalledAppList } from '@/service/explore' +import { useInvalidateAppWorkflow } from '@/service/use-workflow' +import { fetchPublishedWorkflow } from '@/service/workflow' import { AppModeEnum } from '@/types/app' import { basePath } from '@/utils/var' import { toast } from '../../base/ui/toast' @@ -97,6 +104,7 @@ const AppPublisher = ({ const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false) + const workflowStore = useContext(WorkflowContext) const appDetail = useAppStore(state => state.appDetail) const setAppDetail = useAppStore(s => s.setAppDetail) const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) @@ -108,6 +116,7 @@ const AppPublisher = ({ const { data: userCanAccessApp, isLoading: isGettingUserCanAccessApp, refetch } = useGetUserCanAccessApp({ appId: appDetail?.id, enabled: false }) const { data: appAccessSubjects, isLoading: isGettingAppWhiteListSubjects } = useAppWhiteListSubjects(appDetail?.id, open && systemFeatures.webapp_auth.enabled && appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS) + const invalidateAppWorkflow = useInvalidateAppWorkflow() const openAsyncWindow = useAsyncWindowOpen() const isAppAccessSet = useMemo(() => isPublisherAccessConfigured(appDetail, appAccessSubjects), [appAccessSubjects, appDetail]) @@ -135,12 +144,35 @@ const AppPublisher = ({ try { await onPublish?.(params) setPublished(true) + + const appId = appDetail?.id + const socket = appId ? webSocketClient.getSocket(appId) : null + if (appId) + invalidateAppWorkflow(appId) + else + console.warn('[app-publisher] missing appId, skip workflow invalidate and socket emit') + if (socket) { + const timestamp = Date.now() + socket.emit('collaboration_event', { + type: 'app_publish_update', + data: { + action: 'published', + timestamp, + }, + timestamp, + }) + } + else if (appId) { + console.warn('[app-publisher] socket not ready, skip collaboration_event emit', { appId }) + } + trackEvent('app_published_time', { action_mode: 'app', app_id: appDetail?.id, app_name: appDetail?.name }) } - catch { + catch (error) { + console.warn('[app-publisher] publish failed', error) setPublished(false) } - }, [appDetail, onPublish]) + }, [appDetail, onPublish, invalidateAppWorkflow]) const handleRestore = useCallback(async () => { try { @@ -199,6 +231,29 @@ const AppPublisher = ({ handlePublish() }, { exactMatch: true, useCapture: true }) + useEffect(() => { + const appId = appDetail?.id + if (!appId) + return + + const unsubscribe = collaborationManager.onAppPublishUpdate((update: CollaborationUpdate) => { + const action = typeof update.data.action === 'string' ? update.data.action : undefined + if (action === 'published') { + invalidateAppWorkflow(appId) + fetchPublishedWorkflow(`/apps/${appId}/workflows/publish`) + .then((publishedWorkflow) => { + if (publishedWorkflow?.created_at) + workflowStore?.getState().setPublishedAt(publishedWorkflow.created_at) + }) + .catch((error) => { + console.warn('[app-publisher] refresh published workflow failed', error) + }) + } + }) + + return unsubscribe + }, [appDetail?.id, invalidateAppWorkflow, workflowStore]) + const hasPublishedVersion = !!publishedAt const workflowToolMessage = !hasPublishedVersion || !workflowToolAvailable ? t('common.workflowAsToolDisabledHint', { ns: 'workflow' }) diff --git a/web/app/components/app/configuration/hooks/use-configuration.ts b/web/app/components/app/configuration/hooks/use-configuration.ts index f2f708ad52..943642b545 100644 --- a/web/app/components/app/configuration/hooks/use-configuration.ts +++ b/web/app/components/app/configuration/hooks/use-configuration.ts @@ -21,6 +21,7 @@ import type { TextToSpeechConfig, } from '@/models/debug' import type { VisionSettings } from '@/types/app' +import type { PublishWorkflowParams } from '@/types/workflow' import { useBoolean, useGetState } from 'ahooks' import { clone } from 'es-toolkit/object' import { produce } from 'immer' @@ -480,34 +481,40 @@ export const useConfiguration = (): ConfigurationViewModel => { resolvedModelModeType, ]) - const onPublish = useCallback(async (modelAndParameter?: ModelAndParameter, features?: FeaturesData) => createPublishHandler({ - appId, - chatPromptConfig, - citationConfig, - completionParamsState, - completionPromptConfig, - contextVar, - contextVarEmpty, - dataSets, - datasetConfigs, - externalDataToolsConfig, - hasSetBlockStatus, - introduction, - isAdvancedMode, - isFunctionCall, - mode, - modelConfig, - moreLikeThisConfig, - promptEmpty, - promptMode, - resolvedModelModeType, - setCanReturnToSimpleMode, - setPublishedConfig, - speechToTextConfig, - suggestedQuestionsAfterAnswerConfig, - t, - textToSpeechConfig, - })(updateAppModelConfig, modelAndParameter, features), [ + const onPublish = useCallback(async (params?: ModelAndParameter | PublishWorkflowParams, features?: FeaturesData) => { + const modelAndParameter = params && 'model' in params && 'provider' in params && 'parameters' in params + ? params + : undefined + + return createPublishHandler({ + appId, + chatPromptConfig, + citationConfig, + completionParamsState, + completionPromptConfig, + contextVar, + contextVarEmpty, + dataSets, + datasetConfigs, + externalDataToolsConfig, + hasSetBlockStatus, + introduction, + isAdvancedMode, + isFunctionCall, + mode, + modelConfig, + moreLikeThisConfig, + promptEmpty, + promptMode, + resolvedModelModeType, + setCanReturnToSimpleMode, + setPublishedConfig, + speechToTextConfig, + suggestedQuestionsAfterAnswerConfig, + t, + textToSpeechConfig, + })(updateAppModelConfig, modelAndParameter, features) + }, [ appId, chatPromptConfig, citationConfig, diff --git a/web/app/components/app/create-app-modal/__tests__/index.spec.tsx b/web/app/components/app/create-app-modal/__tests__/index.spec.tsx index 3e06b89f0e..305b90981b 100644 --- a/web/app/components/app/create-app-modal/__tests__/index.spec.tsx +++ b/web/app/components/app/create-app-modal/__tests__/index.spec.tsx @@ -1,6 +1,6 @@ import type { App } from '@/types/app' import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest' +import { beforeEach, describe, expect, it, vi } from 'vitest' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' @@ -121,7 +121,6 @@ const renderModal = () => { describe('CreateAppModal', () => { const mockSetItem = vi.fn() - const originalLocalStorage = window.localStorage beforeEach(() => { vi.clearAllMocks() @@ -153,13 +152,6 @@ describe('CreateAppModal', () => { }) }) - afterAll(() => { - Object.defineProperty(window, 'localStorage', { - value: originalLocalStorage, - writable: true, - }) - }) - it('creates an app, notifies success, and fires callbacks', async () => { const mockApp: Partial = { id: 'app-1', mode: AppModeEnum.ADVANCED_CHAT } mockCreateApp.mockResolvedValue(mockApp as App) diff --git a/web/app/components/apps/__tests__/list.spec.tsx b/web/app/components/apps/__tests__/list.spec.tsx index 877c392e6d..9b82648986 100644 --- a/web/app/components/apps/__tests__/list.spec.tsx +++ b/web/app/components/apps/__tests__/list.spec.tsx @@ -116,9 +116,13 @@ vi.mock('@/service/tag', () => ({ fetchTagList: vi.fn().mockResolvedValue([{ id: 'tag-1', name: 'Test Tag', type: 'app' }]), })) -vi.mock('@/config', () => ({ - NEED_REFRESH_APP_LIST_KEY: 'needRefreshAppList', -})) +vi.mock('@/config', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + NEED_REFRESH_APP_LIST_KEY: 'needRefreshAppList', + } +}) vi.mock('@/hooks/use-pay', () => ({ CheckModal: () => null, @@ -386,10 +390,11 @@ describe('List', () => { describe('Edge Cases', () => { it('should handle multiple renders without issues', () => { - const { rerender } = renderWithNuqs() + const { unmount } = renderWithNuqs() expect(screen.getByText('app.types.all')).toBeInTheDocument() - rerender() + unmount() + renderList() expect(screen.getByText('app.types.all')).toBeInTheDocument() }) diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index f4136b5270..d48372bdf0 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -5,6 +5,7 @@ import type { HtmlContentProps } from '@/app/components/base/popover' import type { Tag } from '@/app/components/base/tag-management/constant' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import type { EnvironmentVariable } from '@/app/components/workflow/types' +import type { WorkflowOnlineUser } from '@/models/app' import type { App } from '@/types/app' import { cn } from '@langgenius/dify-ui/cn' import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill, RiVerifiedBadgeLine } from '@remixicon/react' @@ -28,6 +29,7 @@ import { AlertDialogTitle, } from '@/app/components/base/ui/alert-dialog' import { toast } from '@/app/components/base/ui/toast' +import { UserAvatarList } from '@/app/components/base/user-avatar-list' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' @@ -65,10 +67,11 @@ const AccessControl = dynamic(() => import('@/app/components/app/app-access-cont type AppCardProps = { app: App + onlineUsers?: WorkflowOnlineUser[] onRefresh?: () => void } -const AppCard = ({ app, onRefresh }: AppCardProps) => { +const AppCard = ({ app, onlineUsers = [], onRefresh }: AppCardProps) => { const { t } = useTranslation() const deleteAppNameInputId = useId() const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) @@ -360,6 +363,20 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { return `${t('segment.editedAt', { ns: 'datasetDocuments' })} ${timeText}` }, [app.updated_at, app.created_at, t]) + const onlinePresenceUsers = useMemo(() => { + return onlineUsers + .map((user, index) => { + const id = user.user_id || user.sid || `${app.id}-online-${index}` + const name = user.username || user.user_id || user.sid || `${index + 1}` + return { + id, + name, + avatar_url: user.avatar || null, + } + }) + .filter(user => Boolean(user.id)) + }, [app.id, onlineUsers]) + return ( <>
{
{EditTimeText}
-
- {app.access_mode === AccessMode.PUBLIC && ( - - - - )} - {app.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && ( - - - - )} - {app.access_mode === AccessMode.ORGANIZATION && ( - - - - )} - {app.access_mode === AccessMode.EXTERNAL_MEMBERS && ( - - - +
+ {onlinePresenceUsers.length > 0 && ( + )} +
+ {app.access_mode === AccessMode.PUBLIC && ( + + + + )} + {app.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && ( + + + + )} + {app.access_mode === AccessMode.ORGANIZATION && ( + + + + )} + {app.access_mode === AccessMode.EXTERNAL_MEMBERS && ( + + + + )} +
diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 5e15102a49..a898e682f5 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -1,10 +1,11 @@ 'use client' import type { FC } from 'react' +import type { WorkflowOnlineUser } from '@/models/app' import { cn } from '@langgenius/dify-ui/cn' import { useDebounceFn } from 'ahooks' import { parseAsStringLiteral, useQueryState } from 'nuqs' -import { useCallback, useEffect, useRef, useState } from 'react' +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Checkbox from '@/app/components/base/checkbox' import Input from '@/app/components/base/input' @@ -16,6 +17,7 @@ import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' import { CheckModal } from '@/hooks/use-pay' import dynamic from '@/next/dynamic' +import { fetchWorkflowOnlineUsers } from '@/service/apps' import { useInfiniteAppList } from '@/service/use-apps' import { AppModeEnum, AppModes } from '@/types/app' import AppCard from './app-card' @@ -68,6 +70,7 @@ const List: FC = ({ const containerRef = useRef(null) const [showCreateFromDSLModal, setShowCreateFromDSLModal] = useState(false) const [droppedDSLFile, setDroppedDSLFile] = useState() + const [workflowOnlineUsersMap, setWorkflowOnlineUsersMap] = useState>({}) const setKeywords = useCallback((keywords: string) => { setQuery(prev => ({ ...prev, keywords })) }, [setQuery]) @@ -183,6 +186,53 @@ const List: FC = ({ }, [isCreatedByMe, setQuery]) const pages = data?.pages ?? [] + const appIds = useMemo(() => { + const ids = new Set() + pages.forEach((page) => { + page.data?.forEach((app) => { + if (app.id) + ids.add(app.id) + }) + }) + return Array.from(ids) + }, [pages]) + + const refreshWorkflowOnlineUsers = useCallback(async () => { + if (!systemFeatures.enable_collaboration_mode) { + setWorkflowOnlineUsersMap({}) + return + } + + if (!appIds.length) { + setWorkflowOnlineUsersMap({}) + return + } + + try { + const onlineUsersMap = await fetchWorkflowOnlineUsers({ appIds }) + setWorkflowOnlineUsersMap(onlineUsersMap) + } + catch { + setWorkflowOnlineUsersMap({}) + } + }, [appIds, systemFeatures.enable_collaboration_mode]) + + useEffect(() => { + void refreshWorkflowOnlineUsers() + }, [refreshWorkflowOnlineUsers]) + + useEffect(() => { + if (!systemFeatures.enable_collaboration_mode) + return + + const timer = window.setInterval(() => { + void refetch() + void refreshWorkflowOnlineUsers() + }, 10000) + + return () => window.clearInterval(timer) + }, [refetch, refreshWorkflowOnlineUsers, systemFeatures.enable_collaboration_mode]) + const hasAnyApp = (pages[0]?.total ?? 0) > 0 // Show skeleton during initial load or when refetching with no previous data const showSkeleton = isLoading || (isFetching && pages.length === 0) @@ -191,7 +241,7 @@ const List: FC = ({ <>
{dragging && ( -
+
)} @@ -242,7 +292,12 @@ const List: FC = ({ if (hasAnyApp) { return pages.flatMap(({ data: apps }) => apps).map(app => ( - + )) } diff --git a/web/app/components/base/chat/embedded-chatbot/header/__tests__/index.spec.tsx b/web/app/components/base/chat/embedded-chatbot/header/__tests__/index.spec.tsx index 5236beda60..51aa8a9b0c 100644 --- a/web/app/components/base/chat/embedded-chatbot/header/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/embedded-chatbot/header/__tests__/index.spec.tsx @@ -83,6 +83,7 @@ describe('EmbeddedChatbot Header', () => { allow_email_code_login: false, allow_email_password_login: false, }, + enable_collaboration_mode: false, enable_trial_app: false, enable_explore_banner: false, } diff --git a/web/app/components/base/content-dialog/index.tsx b/web/app/components/base/content-dialog/index.tsx index db48a9cb8f..7348879111 100644 --- a/web/app/components/base/content-dialog/index.tsx +++ b/web/app/components/base/content-dialog/index.tsx @@ -19,7 +19,7 @@ const ContentDialog = ({
& { + ref?: React.RefObject> + }, +) => + +Icon.displayName = 'EnterKey' + +export default Icon diff --git a/web/app/components/base/icons/src/public/other/Comment.json b/web/app/components/base/icons/src/public/other/Comment.json new file mode 100644 index 0000000000..780ddc4cf5 --- /dev/null +++ b/web/app/components/base/icons/src/public/other/Comment.json @@ -0,0 +1,26 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "xmlns": "http://www.w3.org/2000/svg", + "width": "14", + "height": "12", + "viewBox": "0 0 14 12", + "fill": "none" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M12.3334 4C12.3334 2.52725 11.1395 1.33333 9.66671 1.33333H4.33337C2.86062 1.33333 1.66671 2.52724 1.66671 4V10.6667H9.66671C11.1395 10.6667 12.3334 9.47274 12.3334 8V4ZM7.66671 6.66667V8H4.33337V6.66667H7.66671ZM9.66671 4V5.33333H4.33337V4H9.66671ZM13.6667 8C13.6667 10.2091 11.8758 12 9.66671 12H0.333374V4C0.333374 1.79086 2.12424 0 4.33337 0H9.66671C11.8758 0 13.6667 1.79086 13.6667 4V8Z", + "fill": "currentColor" + }, + "children": [] + } + ] + }, + "name": "Comment" +} diff --git a/web/app/components/base/icons/src/public/other/Comment.tsx b/web/app/components/base/icons/src/public/other/Comment.tsx new file mode 100644 index 0000000000..887754e48e --- /dev/null +++ b/web/app/components/base/icons/src/public/other/Comment.tsx @@ -0,0 +1,20 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import type { IconData } from '@/app/components/base/icons/IconBase' +import * as React from 'react' +import IconBase from '@/app/components/base/icons/IconBase' +import data from './Comment.json' + +const Icon = ( + { + ref, + ...props + }: React.SVGProps & { + ref?: React.RefObject> + }, +) => + +Icon.displayName = 'Comment' + +export default Icon diff --git a/web/app/components/base/icons/src/public/other/index.ts b/web/app/components/base/icons/src/public/other/index.ts index a8f91dd98b..167218e524 100644 --- a/web/app/components/base/icons/src/public/other/index.ts +++ b/web/app/components/base/icons/src/public/other/index.ts @@ -1,5 +1,5 @@ +export { default as Comment } from './Comment' export { default as DefaultToolIcon } from './DefaultToolIcon' - export { default as Message3Fill } from './Message3Fill' export { default as RowStruct } from './RowStruct' export { default as Slack } from './Slack' diff --git a/web/app/components/base/markdown-blocks/__tests__/plugin-paragraph.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/plugin-paragraph.spec.tsx index b18ac1cdcc..1d3a1c34f0 100644 --- a/web/app/components/base/markdown-blocks/__tests__/plugin-paragraph.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/plugin-paragraph.spec.tsx @@ -1,4 +1,3 @@ -/* eslint-disable next/no-img-element */ import type { ExtraProps } from 'streamdown' import { render, screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' diff --git a/web/app/components/base/prompt-editor/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/__tests__/index.spec.tsx index 93812bcd2a..9d75b9e061 100644 --- a/web/app/components/base/prompt-editor/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/__tests__/index.spec.tsx @@ -31,6 +31,9 @@ const mocks = vi.hoisted(() => { registerNodeTransform: vi.fn(() => vi.fn()), dispatchCommand: vi.fn(), getRootElement: vi.fn(() => rootElement), + getEditorState: vi.fn(() => ({ + read: (fn: () => boolean) => fn(), + })), parseEditorState: vi.fn(() => ({ state: 'parsed' })), setEditorState: vi.fn(), focus: vi.fn(), @@ -66,6 +69,7 @@ vi.mock('lexical', async (importOriginal) => { getChildren: () => mocks.rootLines.map(line => ({ getTextContent: () => line, })), + getAllTextNodes: () => [], }), TextNode: class TextNode { __text: string diff --git a/web/app/components/base/prompt-editor/index.tsx b/web/app/components/base/prompt-editor/index.tsx index 3cc0e25016..fce70d2781 100644 --- a/web/app/components/base/prompt-editor/index.tsx +++ b/web/app/components/base/prompt-editor/index.tsx @@ -23,6 +23,7 @@ import type { import { cn } from '@langgenius/dify-ui/cn' import { CodeNode } from '@lexical/code' import { LexicalComposer } from '@lexical/react/LexicalComposer' +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import { $getRoot, TextNode, @@ -67,6 +68,35 @@ import { import PromptEditorContent from './prompt-editor-content' import { textToEditorState } from './utils' +const ValueSyncPlugin: FC<{ value?: string }> = ({ value }) => { + const [editor] = useLexicalComposerContext() + + useEffect(() => { + if (value === undefined) + return + + const incomingValue = value ?? '' + const shouldUpdate = editor.getEditorState().read(() => { + const currentText = $getRoot().getChildren().map(node => node.getTextContent()).join('\n') + return currentText !== incomingValue + }) + + if (!shouldUpdate) + return + + const editorState = editor.parseEditorState(textToEditorState(incomingValue)) + editor.setEditorState(editorState) + editor.update(() => { + $getRoot().getAllTextNodes().forEach((node) => { + if (node instanceof CustomTextNode) + node.markDirty() + }) + }) + }, [editor, value]) + + return null +} + export type PromptEditorProps = { instanceId?: string compact?: boolean @@ -208,6 +238,7 @@ const PromptEditor: FC = ({ floatingAnchorElem={floatingAnchorElem} onEditorChange={handleEditorChange} /> +
) diff --git a/web/app/components/base/prompt-editor/plugins/update-block.tsx b/web/app/components/base/prompt-editor/plugins/update-block.tsx index 2d83573b1f..e67b027be7 100644 --- a/web/app/components/base/prompt-editor/plugins/update-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/update-block.tsx @@ -1,5 +1,5 @@ import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' -import { $insertNodes } from 'lexical' +import { $getRoot, $insertNodes } from 'lexical' import { useEventEmitterContextContext } from '@/context/event-emitter' import { textToEditorState } from '../utils' import { CustomTextNode } from './custom-text/node' @@ -20,6 +20,12 @@ const UpdateBlock = ({ if (v.type === PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER && v.instanceId === instanceId) { const editorState = editor.parseEditorState(textToEditorState(v.payload)) editor.setEditorState(editorState) + editor.update(() => { + $getRoot().getAllTextNodes().forEach((node) => { + if (node instanceof CustomTextNode) + node.markDirty() + }) + }) } }) diff --git a/web/app/components/base/ui/avatar/__tests__/index.spec.tsx b/web/app/components/base/ui/avatar/__tests__/index.spec.tsx index 8be3f8bf0f..8a384139c2 100644 --- a/web/app/components/base/ui/avatar/__tests__/index.spec.tsx +++ b/web/app/components/base/ui/avatar/__tests__/index.spec.tsx @@ -1,5 +1,5 @@ import { render, screen } from '@testing-library/react' -import { Avatar } from '..' +import { Avatar, AvatarFallback, AvatarImage, AvatarRoot } from '..' describe('Avatar', () => { describe('Rendering', () => { @@ -60,6 +60,23 @@ describe('Avatar', () => { }) }) + describe('Primitives', () => { + it('should support composed avatar usage through exported primitives', () => { + render( + + + + J + + , + ) + + expect(screen.getByTestId('avatar-root')).toHaveClass('size-6') + expect(screen.getByText('J')).toBeInTheDocument() + expect(screen.getByText('J')).toHaveStyle({ backgroundColor: 'rgb(1, 2, 3)' }) + }) + }) + describe('Edge Cases', () => { it('should handle empty string name gracefully', () => { const { container } = render() diff --git a/web/app/components/base/ui/avatar/index.stories.tsx b/web/app/components/base/ui/avatar/index.stories.tsx index bf4da697db..abb3c99771 100644 --- a/web/app/components/base/ui/avatar/index.stories.tsx +++ b/web/app/components/base/ui/avatar/index.stories.tsx @@ -1,5 +1,5 @@ import type { Meta, StoryObj } from '@storybook/nextjs-vite' -import { Avatar } from '.' +import { Avatar, AvatarFallback, AvatarRoot } from '.' const meta = { title: 'Base/Data Display/Avatar', @@ -84,3 +84,27 @@ export const AllFallbackSizes: Story = {
), } + +export const ComposedFallback: Story = { + render: () => ( + + + C + + + ), + parameters: { + docs: { + source: { + language: 'tsx', + code: ` + + + C + + + `.trim(), + }, + }, + }, +} diff --git a/web/app/components/base/ui/avatar/index.tsx b/web/app/components/base/ui/avatar/index.tsx index 18f040ff00..1587dd2d33 100644 --- a/web/app/components/base/ui/avatar/index.tsx +++ b/web/app/components/base/ui/avatar/index.tsx @@ -28,7 +28,7 @@ type AvatarRootProps = React.ComponentPropsWithRef & { size?: AvatarSize } -function AvatarRoot({ +export function AvatarRoot({ size = 'md', className, ...props @@ -45,25 +45,11 @@ function AvatarRoot({ ) } -type AvatarImageProps = React.ComponentPropsWithRef - -function AvatarImage({ - className, - ...props -}: AvatarImageProps) { - return ( - - ) -} - type AvatarFallbackProps = React.ComponentPropsWithRef & { size?: AvatarSize } -function AvatarFallback({ +export function AvatarFallback({ size = 'md', className, ...props @@ -80,6 +66,20 @@ function AvatarFallback({ ) } +type AvatarImageProps = React.ComponentPropsWithRef + +export function AvatarImage({ + className, + ...props +}: AvatarImageProps) { + return ( + + ) +} + export const Avatar = ({ name, avatar, diff --git a/web/app/components/base/user-avatar-list/index.tsx b/web/app/components/base/user-avatar-list/index.tsx new file mode 100644 index 0000000000..fc2ce60572 --- /dev/null +++ b/web/app/components/base/user-avatar-list/index.tsx @@ -0,0 +1,99 @@ +import type { FC } from 'react' +import type { AvatarSize } from '@/app/components/base/ui/avatar' +import { memo } from 'react' +import { AvatarFallback, AvatarImage, AvatarRoot } from '@/app/components/base/ui/avatar' +import { getUserColor } from '@/app/components/workflow/collaboration/utils/user-color' +import { useAppContext } from '@/context/app-context' + +type User = { + id: string + name: string + avatar_url?: string | null +} + +type UserAvatarListProps = { + users: User[] + maxVisible?: number + size?: AvatarSize + className?: string + showCount?: boolean +} + +const avatarSizeToPx: Record = { + 'xxs': 16, + 'xs': 20, + 'sm': 24, + 'md': 32, + 'lg': 36, + 'xl': 40, + '2xl': 48, + '3xl': 64, +} + +export const UserAvatarList: FC = memo(({ + users, + maxVisible = 3, + size = 'sm', + className = '', + showCount = true, +}) => { + const { userProfile } = useAppContext() + if (!users.length) + return null + + const shouldShowCount = showCount && users.length > maxVisible + const actualMaxVisible = shouldShowCount ? Math.max(1, maxVisible - 1) : maxVisible + const visibleUsers = users.slice(0, actualMaxVisible) + const remainingCount = users.length - actualMaxVisible + + const currentUserId = userProfile?.id + + return ( +
+ {visibleUsers.map((user, index) => { + const isCurrentUser = user.id === currentUserId + const userColor = isCurrentUser ? undefined : getUserColor(user.id) + return ( +
+ + {user.avatar_url && ( + + )} + + {user.name?.[0]?.toLocaleUpperCase()} + + +
+ ) + }, + + )} + {shouldShowCount && remainingCount > 0 && ( +
+
+ + + {remainingCount} +
+
+ )} +
+ ) +}) + +UserAvatarList.displayName = 'UserAvatarList' diff --git a/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx b/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx index 48e6b58766..478302d983 100644 --- a/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx +++ b/web/app/components/datasets/documents/components/document-list/__tests__/index.spec.tsx @@ -21,6 +21,16 @@ vi.mock('@/context/dataset-detail', () => ({ selector({ dataset: { doc_form: ChunkingMode.text } }), })) +vi.mock('@/app/components/datasets/metadata/hooks/use-batch-edit-document-metadata', () => ({ + default: () => ({ + isShowEditModal: false, + showEditModal: vi.fn(), + hideEditModal: vi.fn(), + originalList: [], + handleSave: vi.fn(), + }), +})) + const createTestQueryClient = () => new QueryClient({ defaultOptions: { queries: { retry: false, gcTime: 0 }, diff --git a/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts b/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts index f29d85b460..935f1329d1 100644 --- a/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts +++ b/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts @@ -157,7 +157,7 @@ describe('useDatasetCardState', () => { expect(result.current.modalState.showRenameModal).toBe(false) }) - it('should close confirm delete modal when closeConfirmDelete is called', () => { + it('should close confirm delete modal when closeConfirmDelete is called', async () => { const dataset = createMockDataset() const { result } = renderHook(() => useDatasetCardState({ dataset, onSuccess: vi.fn() }), @@ -168,7 +168,7 @@ describe('useDatasetCardState', () => { result.current.detectIsUsedByApp() }) - waitFor(() => { + await waitFor(() => { expect(result.current.modalState.showConfirmDelete).toBe(true) }) diff --git a/web/app/components/develop/secret-key/secret-key-modal.tsx b/web/app/components/develop/secret-key/secret-key-modal.tsx index dbd27606be..58895b26d6 100644 --- a/web/app/components/develop/secret-key/secret-key-modal.tsx +++ b/web/app/components/develop/secret-key/secret-key-modal.tsx @@ -10,7 +10,6 @@ import ActionButton from '@/app/components/base/action-button' import CopyFeedback from '@/app/components/base/copy-feedback' import Loading from '@/app/components/base/loading' import Modal from '@/app/components/base/modal' -import { Button } from '@/app/components/base/ui/button' import { AlertDialog, AlertDialogActions, @@ -20,6 +19,7 @@ import { AlertDialogDescription, AlertDialogTitle, } from '@/app/components/base/ui/alert-dialog' +import { Button } from '@/app/components/base/ui/button' import { useAppContext } from '@/context/app-context' import useTimestamp from '@/hooks/use-timestamp' import { diff --git a/web/app/components/header/account-setting/__tests__/index.spec.tsx b/web/app/components/header/account-setting/__tests__/index.spec.tsx index 279af0b114..d4e093bd30 100644 --- a/web/app/components/header/account-setting/__tests__/index.spec.tsx +++ b/web/app/components/header/account-setting/__tests__/index.spec.tsx @@ -47,6 +47,36 @@ vi.mock('@/hooks/use-breakpoints', () => ({ default: vi.fn(), })) +vi.mock('@/context/global-public-context', async (importOriginal) => { + const actual = await importOriginal() + const systemFeatures = { + ...actual.useGlobalPublicStore.getState().systemFeatures, + webapp_auth: { + ...actual.useGlobalPublicStore.getState().systemFeatures.webapp_auth, + enabled: true, + }, + branding: { + ...actual.useGlobalPublicStore.getState().systemFeatures.branding, + enabled: false, + }, + enable_marketplace: true, + enable_collaboration_mode: false, + } + + return { + ...actual, + useGlobalPublicStore: (selector: (state: Record) => unknown) => selector({ + systemFeatures, + }), + useSystemFeaturesQuery: () => ({ + data: systemFeatures, + isPending: false, + isLoading: false, + isFetching: false, + }), + } +}) + vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({ useDefaultModel: vi.fn(() => ({ data: null, isLoading: false })), useUpdateDefaultModel: vi.fn(() => ({ trigger: vi.fn() })), @@ -54,6 +84,7 @@ vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () useInvalidateDefaultModel: vi.fn(() => vi.fn()), useModelList: vi.fn(() => ({ data: [], isLoading: false })), useSystemDefaultModelAndModelList: vi.fn(() => [null, vi.fn()]), + useMarketplaceAllPlugins: vi.fn(() => ({ plugins: [], isLoading: false })), })) vi.mock('@/app/components/header/account-setting/model-provider-page/atoms', () => ({ @@ -70,6 +101,11 @@ vi.mock('@/service/use-common', () => ({ useProviderContext: vi.fn(), })) +vi.mock('@/app/components/billing/billing-page', () => ({ + __esModule: true, + default: () =>
, +})) + const baseAppContextValue: AppContextValue = { userProfile: { id: '1', diff --git a/web/app/components/tools/mcp/__tests__/mcp-service-card.spec.tsx b/web/app/components/tools/mcp/__tests__/mcp-service-card.spec.tsx index d408b3092e..62c6d5f3f6 100644 --- a/web/app/components/tools/mcp/__tests__/mcp-service-card.spec.tsx +++ b/web/app/components/tools/mcp/__tests__/mcp-service-card.spec.tsx @@ -26,6 +26,23 @@ const mockHandleGenCode = vi.fn() const mockOpenConfirmDelete = vi.fn() const mockCloseConfirmDelete = vi.fn() const mockOpenServerModal = vi.fn() +const mockOnMcpServerUpdate = vi.hoisted(() => vi.fn()) +const mockUnsubscribeMcpServerUpdate = vi.hoisted(() => vi.fn()) +const invalidateMCPServerDetailFns = vi.hoisted(() => [] as Array>) + +vi.mock('@/app/components/workflow/collaboration/core/collaboration-manager', () => ({ + collaborationManager: { + onMcpServerUpdate: mockOnMcpServerUpdate, + }, +})) + +vi.mock('@/service/use-tools', () => ({ + useInvalidateMCPServerDetail: () => { + const invalidateFn = vi.fn() + invalidateMCPServerDetailFns.push(invalidateFn) + return invalidateFn + }, +})) type MockHookState = { genLoading: boolean @@ -106,12 +123,15 @@ describe('MCPServiceCard', () => { beforeEach(() => { mockHookState = createDefaultHookState() + invalidateMCPServerDetailFns.length = 0 mockHandleStatusChange.mockClear().mockResolvedValue({ activated: true }) mockHandleServerModalHide.mockClear().mockReturnValue({ shouldDeactivate: false }) mockHandleGenCode.mockClear() mockOpenConfirmDelete.mockClear() mockCloseConfirmDelete.mockClear() mockOpenServerModal.mockClear() + mockUnsubscribeMcpServerUpdate.mockClear() + mockOnMcpServerUpdate.mockReset().mockReturnValue(mockUnsubscribeMcpServerUpdate) }) describe('Rendering', () => { @@ -431,4 +451,27 @@ describe('MCPServiceCard', () => { expect(screen.getByRole('switch')).toBeInTheDocument() }) }) + + describe('Collaboration Sync', () => { + it('should keep a stable MCP update subscription across rerenders and invalidate with the latest callback', async () => { + let mcpUpdateHandler: ((payload: unknown) => void) | undefined + mockOnMcpServerUpdate.mockImplementation((handler: (payload: unknown) => void) => { + mcpUpdateHandler = handler + return mockUnsubscribeMcpServerUpdate + }) + + const wrapper = createWrapper() + const { rerender } = render(, { wrapper }) + + rerender() + + expect(mockOnMcpServerUpdate).toHaveBeenCalledTimes(1) + expect(invalidateMCPServerDetailFns).toHaveLength(2) + + mcpUpdateHandler?.({ type: 'mcp_server_update' }) + + expect(invalidateMCPServerDetailFns[0]).not.toHaveBeenCalled() + expect(invalidateMCPServerDetailFns[1]).toHaveBeenCalledWith('app-123') + }) + }) }) diff --git a/web/app/components/tools/mcp/mcp-server-modal.tsx b/web/app/components/tools/mcp/mcp-server-modal.tsx index 383f6ec2b8..e9eb1d8043 100644 --- a/web/app/components/tools/mcp/mcp-server-modal.tsx +++ b/web/app/components/tools/mcp/mcp-server-modal.tsx @@ -11,6 +11,7 @@ import Modal from '@/app/components/base/modal' import Textarea from '@/app/components/base/textarea' import { Button } from '@/app/components/base/ui/button' import MCPServerParamItem from '@/app/components/tools/mcp/mcp-server-param-item' +import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager' import { useCreateMCPServer, useInvalidateMCPServerDetail, @@ -59,6 +60,22 @@ const MCPServerModal = ({ return res } + const emitMcpServerUpdate = (action: 'created' | 'updated') => { + const socket = webSocketClient.getSocket(appID) + if (!socket) + return + + const timestamp = Date.now() + socket.emit('collaboration_event', { + type: 'mcp_server_update', + data: { + action, + timestamp, + }, + timestamp, + }) + } + const submit = async () => { if (!data) { const payload: any = { @@ -71,6 +88,7 @@ const MCPServerModal = ({ await createMCPServer(payload) invalidateMCPServerDetail(appID) + emitMcpServerUpdate('created') onHide() } else { @@ -83,6 +101,7 @@ const MCPServerModal = ({ payload.description = description await updateMCPServer(payload) invalidateMCPServerDetail(appID) + emitMcpServerUpdate('updated') onHide() } } diff --git a/web/app/components/tools/mcp/mcp-service-card.tsx b/web/app/components/tools/mcp/mcp-service-card.tsx index 877fcd1ce7..33036969d8 100644 --- a/web/app/components/tools/mcp/mcp-service-card.tsx +++ b/web/app/components/tools/mcp/mcp-service-card.tsx @@ -1,11 +1,12 @@ 'use client' import type { TFunction } from 'i18next' import type { FC, ReactNode } from 'react' +import type { CollaborationUpdate } from '@/app/components/workflow/collaboration/types/collaboration' import type { AppDetailResponse } from '@/models/app' import type { AppSSO } from '@/types/app' import { cn } from '@langgenius/dify-ui/cn' import { RiEditLine, RiLoopLeftLine } from '@remixicon/react' -import { useState } from 'react' +import { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import CopyFeedback from '@/app/components/base/copy-feedback' import Divider from '@/app/components/base/divider' @@ -24,7 +25,9 @@ import { import { Button } from '@/app/components/base/ui/button' import Indicator from '@/app/components/header/indicator' import MCPServerModal from '@/app/components/tools/mcp/mcp-server-modal' +import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' import { useDocLink } from '@/context/i18n' +import { useInvalidateMCPServerDetail } from '@/service/use-tools' import { useMCPServiceCardState } from './hooks/use-mcp-service-card' // Sub-components @@ -171,6 +174,12 @@ const MCPServiceCard: FC = ({ const { t } = useTranslation() const docLink = useDocLink() const appId = appInfo.id + const invalidateMCPServerDetail = useInvalidateMCPServerDetail() + const invalidateMCPServerDetailRef = useRef(invalidateMCPServerDetail) + + useEffect(() => { + invalidateMCPServerDetailRef.current = invalidateMCPServerDetail + }, [invalidateMCPServerDetail]) const { genLoading, @@ -199,6 +208,28 @@ const MCPServiceCard: FC = ({ const [pendingStatus, setPendingStatus] = useState(null) const activated = pendingStatus ?? serverActivated + const emitMcpServerUpdate = async (data: Record) => { + try { + const { webSocketClient } = await import('@/app/components/workflow/collaboration/core/websocket-manager') + const socket = webSocketClient.getSocket(appId) + if (!socket) + return + + const timestamp = Date.now() + socket.emit('collaboration_event', { + type: 'mcp_server_update', + data: { + ...data, + timestamp, + }, + timestamp, + }) + } + catch (error) { + console.error('MCP collaboration event emit failed:', error) + } + } + const onChangeStatus = async (state: boolean) => { setPendingStatus(state) const result = await handleStatusChange(state) @@ -206,6 +237,15 @@ const MCPServiceCard: FC = ({ // Server modal was opened instead, clear pending status setPendingStatus(null) } + + if (result.activated !== state) + return + + // Emit collaboration event to notify other clients of MCP server status change + void emitMcpServerUpdate({ + action: 'statusChanged', + status: state ? 'active' : 'inactive', + }) } const onServerModalHide = () => { @@ -215,10 +255,35 @@ const MCPServiceCard: FC = ({ } const onConfirmRegenerate = () => { - handleGenCode() closeConfirmDelete() + + void (async () => { + await handleGenCode() + + // Emit collaboration event to notify other clients of MCP server code changes + await emitMcpServerUpdate({ + action: 'codeRegenerated', + }) + })() } + // Listen for collaborative MCP server updates from other clients + useEffect(() => { + if (!appId) + return + + const unsubscribe = collaborationManager.onMcpServerUpdate((_update: CollaborationUpdate) => { + try { + invalidateMCPServerDetailRef.current(appId) + } + catch (error) { + console.error('MCP server update failed:', error) + } + }) + + return unsubscribe + }, [appId]) + if (isLoading) return null diff --git a/web/app/components/workflow-app/components/__tests__/workflow-main.spec.tsx b/web/app/components/workflow-app/components/__tests__/workflow-main.spec.tsx index 250b87069f..d16da63b93 100644 --- a/web/app/components/workflow-app/components/__tests__/workflow-main.spec.tsx +++ b/web/app/components/workflow-app/components/__tests__/workflow-main.spec.tsx @@ -1,11 +1,17 @@ import type { ReactNode } from 'react' import type { WorkflowProps } from '@/app/components/workflow' -import { fireEvent, render, screen } from '@testing-library/react' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { ChatVarType } from '@/app/components/workflow/panel/chat-variable-panel/type' import WorkflowMain from '../workflow-main' const mockSetFeatures = vi.fn() const mockSetConversationVariables = vi.fn() const mockSetEnvironmentVariables = vi.fn() +const mockHandleUpdateWorkflowCanvas = vi.hoisted(() => vi.fn()) +const mockFetchWorkflowDraft = vi.hoisted(() => vi.fn()) +const mockOnVarsAndFeaturesUpdate = vi.hoisted(() => vi.fn()) +const mockOnWorkflowUpdate = vi.hoisted(() => vi.fn()) +const mockOnSyncRequest = vi.hoisted(() => vi.fn()) const hookFns = { doSyncWorkflowDraft: vi.fn(), @@ -43,9 +49,24 @@ const hookFns = { invalidateConversationVarValues: vi.fn(), } +const collaborationRuntime = vi.hoisted(() => ({ + startCursorTracking: vi.fn(), + stopCursorTracking: vi.fn(), + onlineUsers: [] as Array<{ user_id: string, username: string, avatar: string, sid: string }>, + cursors: {} as Record, + isConnected: false, + isEnabled: false, +})) + +const collaborationListeners = vi.hoisted(() => ({ + varsAndFeaturesUpdate: null as null | ((update: unknown) => void | Promise), + workflowUpdate: null as null | (() => void | Promise), + syncRequest: null as null | (() => void), +})) + let capturedContextProps: Record | null = null -type MockWorkflowWithInnerContextProps = Pick & { +type MockWorkflowWithInnerContextProps = Pick & { hooksStore?: Record children?: ReactNode } @@ -59,6 +80,9 @@ vi.mock('@/app/components/base/features/hooks', () => ({ })) vi.mock('@/app/components/workflow/store', () => ({ + useStore: (selector: (state: { appId: string }) => T) => selector({ + appId: 'app-1', + }), useWorkflowStore: () => ({ getState: () => ({ setConversationVariables: mockSetConversationVariables, @@ -67,6 +91,53 @@ vi.mock('@/app/components/workflow/store', () => ({ }), })) +vi.mock('reactflow', () => ({ + useReactFlow: () => ({ + getNodes: () => [], + setNodes: vi.fn(), + getEdges: () => [], + setEdges: vi.fn(), + }), +})) + +vi.mock('@/app/components/workflow/collaboration/hooks/use-collaboration', () => ({ + useCollaboration: () => ({ + startCursorTracking: collaborationRuntime.startCursorTracking, + stopCursorTracking: collaborationRuntime.stopCursorTracking, + onlineUsers: collaborationRuntime.onlineUsers, + cursors: collaborationRuntime.cursors, + isConnected: collaborationRuntime.isConnected, + isEnabled: collaborationRuntime.isEnabled, + }), +})) + +vi.mock('@/app/components/workflow/hooks/use-workflow-interactions', () => ({ + useWorkflowUpdate: () => ({ + handleUpdateWorkflowCanvas: mockHandleUpdateWorkflowCanvas, + }), +})) + +vi.mock('@/app/components/workflow/collaboration/core/collaboration-manager', () => ({ + collaborationManager: { + onVarsAndFeaturesUpdate: mockOnVarsAndFeaturesUpdate.mockImplementation((handler: (update: unknown) => void | Promise) => { + collaborationListeners.varsAndFeaturesUpdate = handler + return vi.fn() + }), + onWorkflowUpdate: mockOnWorkflowUpdate.mockImplementation((handler: () => void | Promise) => { + collaborationListeners.workflowUpdate = handler + return vi.fn() + }), + onSyncRequest: mockOnSyncRequest.mockImplementation((handler: () => void) => { + collaborationListeners.syncRequest = handler + return vi.fn() + }), + }, +})) + +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (...args: unknown[]) => mockFetchWorkflowDraft(...args), +})) + vi.mock('@/app/components/workflow', () => ({ WorkflowWithInnerContext: ({ nodes, @@ -74,6 +145,9 @@ vi.mock('@/app/components/workflow', () => ({ viewport, onWorkflowDataUpdate, hooksStore, + cursors, + myUserId, + onlineUsers, children, }: MockWorkflowWithInnerContextProps) => { capturedContextProps = { @@ -81,15 +155,32 @@ vi.mock('@/app/components/workflow', () => ({ edges, viewport, hooksStore, + cursors, + myUserId, + onlineUsers, } return (
@@ -169,6 +268,16 @@ describe('WorkflowMain', () => { beforeEach(() => { vi.clearAllMocks() capturedContextProps = null + collaborationRuntime.startCursorTracking.mockReset() + collaborationRuntime.stopCursorTracking.mockReset() + collaborationRuntime.onlineUsers = [] + collaborationRuntime.cursors = {} + collaborationRuntime.isConnected = false + collaborationRuntime.isEnabled = false + collaborationListeners.varsAndFeaturesUpdate = null + collaborationListeners.workflowUpdate = null + collaborationListeners.syncRequest = null + mockFetchWorkflowDraft.mockReset() }) it('should render the inner workflow context with children and forwarded graph props', () => { @@ -204,9 +313,11 @@ describe('WorkflowMain', () => { fireEvent.click(screen.getByRole('button', { name: /update-workflow-data/i })) - expect(mockSetFeatures).toHaveBeenCalledWith({ file: { enabled: true } }) - expect(mockSetConversationVariables).toHaveBeenCalledWith([{ id: 'conversation-1' }]) - expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([{ id: 'env-1' }]) + expect(mockSetFeatures).toHaveBeenCalledWith(expect.objectContaining({ + file: expect.objectContaining({ enabled: true }), + })) + expect(mockSetConversationVariables).toHaveBeenCalledWith([expect.objectContaining({ id: 'conversation-1' })]) + expect(mockSetEnvironmentVariables).toHaveBeenCalledWith([expect.objectContaining({ id: 'env-1' })]) }) it('should only update the workflow store slices present in the payload', () => { @@ -220,7 +331,7 @@ describe('WorkflowMain', () => { fireEvent.click(screen.getByRole('button', { name: /update-conversation-only/i })) - expect(mockSetConversationVariables).toHaveBeenCalledWith([{ id: 'conversation-only' }]) + expect(mockSetConversationVariables).toHaveBeenCalledWith([expect.objectContaining({ id: 'conversation-only' })]) expect(mockSetFeatures).not.toHaveBeenCalled() expect(mockSetEnvironmentVariables).not.toHaveBeenCalled() }) @@ -274,4 +385,79 @@ describe('WorkflowMain', () => { configsMap: { flowId: 'app-1', flowType: 'app-flow', fileSettings: { enabled: true } }, }) }) + + it('passes collaboration props and tracks cursors when collaboration is enabled', () => { + collaborationRuntime.isEnabled = true + collaborationRuntime.isConnected = true + collaborationRuntime.onlineUsers = [{ user_id: 'u-1', username: 'Alice', avatar: '', sid: 'sid-1' }] + collaborationRuntime.cursors = { + 'current-user': { x: 1, y: 2, userId: 'current-user', timestamp: 1 }, + 'user-other': { x: 20, y: 30, userId: 'user-other', timestamp: 2 }, + } + + const { unmount } = render( + , + ) + + expect(collaborationRuntime.startCursorTracking).toHaveBeenCalled() + expect(capturedContextProps).toMatchObject({ + myUserId: 'current-user', + onlineUsers: [{ user_id: 'u-1' }], + cursors: { + 'user-other': expect.objectContaining({ userId: 'user-other' }), + }, + }) + + unmount() + expect(collaborationRuntime.stopCursorTracking).toHaveBeenCalled() + }) + + it('subscribes collaboration listeners and handles sync/workflow update callbacks', async () => { + collaborationRuntime.isEnabled = true + mockFetchWorkflowDraft.mockResolvedValue({ + features: { + file_upload: { enabled: true }, + opening_statement: 'hello', + }, + conversation_variables: [], + environment_variables: [], + graph: { + nodes: [{ id: 'n-1' }], + edges: [{ id: 'e-1' }], + viewport: { x: 3, y: 4, zoom: 1.2 }, + }, + }) + + render( + , + ) + + expect(mockOnVarsAndFeaturesUpdate).toHaveBeenCalled() + expect(mockOnWorkflowUpdate).toHaveBeenCalled() + expect(mockOnSyncRequest).toHaveBeenCalled() + + collaborationListeners.syncRequest?.() + expect(hookFns.doSyncWorkflowDraft).toHaveBeenCalled() + + await collaborationListeners.varsAndFeaturesUpdate?.({}) + await collaborationListeners.workflowUpdate?.() + + await waitFor(() => { + expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/apps/app-1/workflows/draft') + expect(mockSetFeatures).toHaveBeenCalled() + expect(mockHandleUpdateWorkflowCanvas).toHaveBeenCalledWith({ + nodes: [{ id: 'n-1' }], + edges: [{ id: 'e-1' }], + viewport: { x: 3, y: 4, zoom: 1.2 }, + }) + }) + }) }) diff --git a/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx b/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx index 5ce30d0701..94e487ba57 100644 --- a/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx +++ b/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx @@ -1,3 +1,4 @@ +import type { ModelAndParameter } from '@/app/components/app/configuration/debug/types' import type { EndNodeType } from '@/app/components/workflow/nodes/end/types' import type { StartNodeType } from '@/app/components/workflow/nodes/start/types' import type { @@ -143,7 +144,8 @@ const FeaturesTrigger = () => { const needWarningNodes = useChecklist(nodes, edges) const updatePublishedWorkflow = useInvalidateAppWorkflow() - const onPublish = useCallback(async (params?: PublishWorkflowParams) => { + const onPublish = useCallback(async (params?: ModelAndParameter | PublishWorkflowParams) => { + const publishParams = params && 'title' in params ? params : undefined // First check if there are any items in the checklist // if (!validateBeforeRun()) // throw new Error('Checklist has unresolved items') @@ -157,10 +159,9 @@ const FeaturesTrigger = () => { if (await handleCheckBeforePublish()) { const res = await publishWorkflow({ url: `/apps/${appID}/workflows/publish`, - title: params?.title || '', - releaseNotes: params?.releaseNotes || '', + title: publishParams?.title || '', + releaseNotes: publishParams?.releaseNotes || '', }) - if (res) { toast.success(t('api.actionSuccess', { ns: 'common' })) updatePublishedWorkflow(appID!) diff --git a/web/app/components/workflow-app/components/workflow-main.tsx b/web/app/components/workflow-app/components/workflow-main.tsx index 38a044f088..957344a1da 100644 --- a/web/app/components/workflow-app/components/workflow-main.tsx +++ b/web/app/components/workflow-app/components/workflow-main.tsx @@ -1,11 +1,25 @@ +import type { Features as FeaturesData } from '@/app/components/base/features/types' import type { WorkflowProps } from '@/app/components/workflow' +import type { CollaborationUpdate } from '@/app/components/workflow/collaboration/types/collaboration' +import type { Shape as HooksStoreShape } from '@/app/components/workflow/hooks-store/store' +import type { Edge, Node } from '@/app/components/workflow/types' +import type { FetchWorkflowDraftResponse } from '@/types/workflow' import { useCallback, + useEffect, useMemo, + useRef, } from 'react' +import { useReactFlow } from 'reactflow' import { useFeaturesStore } from '@/app/components/base/features/hooks' +import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' import { WorkflowWithInnerContext } from '@/app/components/workflow' -import { useWorkflowStore } from '@/app/components/workflow/store' +import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' +import { useCollaboration } from '@/app/components/workflow/collaboration/hooks/use-collaboration' +import { useWorkflowUpdate } from '@/app/components/workflow/hooks/use-workflow-interactions' +import { useStore, useWorkflowStore } from '@/app/components/workflow/store' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' +import { fetchWorkflowDraft } from '@/service/workflow' import { useAvailableNodesMetaData, useConfigsMap, @@ -21,6 +35,7 @@ import { import WorkflowChildren from './workflow-children' type WorkflowMainProps = Pick +type WorkflowDataUpdatePayload = Pick const WorkflowMain = ({ nodes, edges, @@ -28,8 +43,48 @@ const WorkflowMain = ({ }: WorkflowMainProps) => { const featuresStore = useFeaturesStore() const workflowStore = useWorkflowStore() + const appId = useStore(s => s.appId) + const containerRef = useRef(null) + const reactFlow = useReactFlow() - const handleWorkflowDataUpdate = useCallback((payload: any) => { + const reactFlowStore = useMemo(() => ({ + getState: () => ({ + getNodes: () => reactFlow.getNodes(), + setNodes: (nodesToSet: Node[]) => reactFlow.setNodes(nodesToSet), + getEdges: () => reactFlow.getEdges(), + setEdges: (edgesToSet: Edge[]) => reactFlow.setEdges(edgesToSet), + }), + }), [reactFlow]) + const { + startCursorTracking, + stopCursorTracking, + onlineUsers, + cursors, + isConnected, + isEnabled: isCollaborationEnabled, + } = useCollaboration(appId || '', reactFlowStore) + const myUserId = useMemo( + () => (isCollaborationEnabled && isConnected ? 'current-user' : null), + [isCollaborationEnabled, isConnected], + ) + + const filteredCursors = Object.fromEntries( + Object.entries(cursors).filter(([userId]) => userId !== myUserId), + ) + + useEffect(() => { + if (!isCollaborationEnabled) + return + + if (containerRef.current) + startCursorTracking(containerRef as React.RefObject, reactFlow) + + return () => { + stopCursorTracking() + } + }, [startCursorTracking, stopCursorTracking, reactFlow, isCollaborationEnabled]) + + const handleWorkflowDataUpdate = useCallback((payload: WorkflowDataUpdatePayload) => { const { features, conversation_variables, @@ -38,7 +93,33 @@ const WorkflowMain = ({ if (features && featuresStore) { const { setFeatures } = featuresStore.getState() - setFeatures(features) + const transformedFeatures: FeaturesData = { + file: { + image: { + enabled: !!features.file_upload?.image?.enabled, + number_limits: features.file_upload?.image?.number_limits || 3, + transfer_methods: features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + }, + enabled: !!(features.file_upload?.enabled || features.file_upload?.image?.enabled), + allowed_file_types: features.file_upload?.allowed_file_types || [SupportUploadFileTypes.image], + allowed_file_extensions: features.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`), + allowed_file_upload_methods: features.file_upload?.allowed_file_upload_methods || features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + number_limits: features.file_upload?.number_limits || features.file_upload?.image?.number_limits || 3, + }, + opening: { + enabled: !!features.opening_statement, + opening_statement: features.opening_statement, + suggested_questions: features.suggested_questions, + }, + suggested: features.suggested_questions_after_answer || { enabled: false }, + speech2text: features.speech_to_text || { enabled: false }, + text2speech: features.text_to_speech || { enabled: false }, + citation: features.retriever_resource || { enabled: false }, + moderation: features.sensitive_word_avoidance || { enabled: false }, + annotationReply: features.annotation_reply || { enabled: false }, + } + + setFeatures(transformedFeatures) } if (conversation_variables) { const { setConversationVariables } = workflowStore.getState() @@ -55,6 +136,7 @@ const WorkflowMain = ({ syncWorkflowDraftWhenPageClose, } = useNodesSyncDraft() const { handleRefreshWorkflowDraft } = useWorkflowRefreshDraft() + const { handleUpdateWorkflowCanvas } = useWorkflowUpdate() const { handleBackupDraft, handleLoadBackupDraft, @@ -62,6 +144,64 @@ const WorkflowMain = ({ handleRun, handleStopRun, } = useWorkflowRun() + + useEffect(() => { + if (!appId || !isCollaborationEnabled) + return + + const unsubscribe = collaborationManager.onVarsAndFeaturesUpdate(async (_update: CollaborationUpdate) => { + try { + const response = await fetchWorkflowDraft(`/apps/${appId}/workflows/draft`) + handleWorkflowDataUpdate(response) + } + catch (error) { + console.error('workflow vars and features update failed:', error) + } + }) + + return unsubscribe + }, [appId, handleWorkflowDataUpdate, isCollaborationEnabled]) + + // Listen for workflow updates from other users + useEffect(() => { + if (!appId || !isCollaborationEnabled) + return + + const unsubscribe = collaborationManager.onWorkflowUpdate(async () => { + try { + const response = await fetchWorkflowDraft(`/apps/${appId}/workflows/draft`) + + // Handle features, variables etc. + handleWorkflowDataUpdate(response) + + // Update workflow canvas (nodes, edges, viewport) + if (response.graph) { + handleUpdateWorkflowCanvas({ + nodes: response.graph.nodes || [], + edges: response.graph.edges || [], + viewport: response.graph.viewport || { x: 0, y: 0, zoom: 1 }, + }) + } + } + catch (error) { + console.error('Failed to fetch updated workflow:', error) + } + }) + + return unsubscribe + }, [appId, handleWorkflowDataUpdate, handleUpdateWorkflowCanvas, isCollaborationEnabled]) + + // Listen for sync requests from other users (only processed by leader) + useEffect(() => { + if (!appId || !isCollaborationEnabled) + return + + const unsubscribe = collaborationManager.onSyncRequest(() => { + doSyncWorkflowDraft() + }) + + return unsubscribe + }, [appId, doSyncWorkflowDraft, isCollaborationEnabled]) const { handleStartWorkflowRun, handleWorkflowStartRunInChatflow, @@ -79,6 +219,7 @@ const WorkflowMain = ({ } = useDSL() const configsMap = useConfigsMap() + const { fetchInspectVars } = useSetWorkflowVarsWithValue({ ...configsMap, }) @@ -176,15 +317,23 @@ const WorkflowMain = ({ ]) return ( - - - + } + cursors={filteredCursors} + myUserId={myUserId} + onlineUsers={onlineUsers} + > + + +
) } diff --git a/web/app/components/workflow-app/components/workflow-panel.tsx b/web/app/components/workflow-app/components/workflow-panel.tsx index 4b145339d7..1a9a76d946 100644 --- a/web/app/components/workflow-app/components/workflow-panel.tsx +++ b/web/app/components/workflow-app/components/workflow-panel.tsx @@ -6,6 +6,7 @@ import { import { useShallow } from 'zustand/react/shallow' import { useStore as useAppStore } from '@/app/components/app/store' import Panel from '@/app/components/workflow/panel' +import CommentsPanel from '@/app/components/workflow/panel/comments-panel' import { useStore } from '@/app/components/workflow/store' import dynamic from '@/next/dynamic' import { @@ -67,6 +68,7 @@ const WorkflowPanelOnRight = () => { const showDebugAndPreviewPanel = useStore(s => s.showDebugAndPreviewPanel) const showChatVariablePanel = useStore(s => s.showChatVariablePanel) const showGlobalVariablePanel = useStore(s => s.showGlobalVariablePanel) + const controlMode = useStore(s => s.controlMode) return ( <> @@ -100,6 +102,7 @@ const WorkflowPanelOnRight = () => { ) } + {controlMode === 'comment' && } ) } diff --git a/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts b/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts index c9fa535d51..1a8a7d3a59 100644 --- a/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts +++ b/web/app/components/workflow-app/hooks/__tests__/use-nodes-sync-draft.spec.ts @@ -8,6 +8,10 @@ const mockPostWithKeepalive = vi.fn() const mockSetSyncWorkflowDraftHash = vi.fn() const mockSetDraftUpdatedAt = vi.fn() const mockGetNodesReadOnly = vi.fn() +const mockCollaborationIsConnected = vi.fn() +const mockCollaborationGetIsLeader = vi.fn() +const mockCollaborationEmitSyncRequest = vi.fn() +let isCollaborationEnabled = false let reactFlowState: { getNodes: typeof mockGetNodes @@ -57,6 +61,23 @@ vi.mock('@/app/components/workflow/hooks/use-workflow', () => ({ useNodesReadOnly: () => ({ getNodesReadOnly: mockGetNodesReadOnly }), })) +vi.mock('@/app/components/workflow/collaboration/core/collaboration-manager', () => ({ + collaborationManager: { + isConnected: (...args: unknown[]) => mockCollaborationIsConnected(...args), + getIsLeader: (...args: unknown[]) => mockCollaborationGetIsLeader(...args), + emitSyncRequest: (...args: unknown[]) => mockCollaborationEmitSyncRequest(...args), + }, +})) + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_collaboration_mode: boolean } }) => unknown) => + selector({ + systemFeatures: { + enable_collaboration_mode: isCollaborationEnabled, + }, + }), +})) + vi.mock('@/app/components/workflow/hooks/use-serial-async-callback', () => ({ useSerialAsyncCallback: (fn: (...args: unknown[]) => Promise, checkFn: () => boolean) => (...args: unknown[]) => { @@ -109,6 +130,9 @@ describe('useNodesSyncDraft — handleRefreshWorkflowDraft(true) on 409', () => mockGetNodesReadOnly.mockReturnValue(false) mockGetNodes.mockReturnValue([{ id: 'n1', position: { x: 0, y: 0 }, data: { type: 'start' } }]) mockSyncWorkflowDraft.mockResolvedValue({ hash: 'new', updated_at: 1 }) + mockCollaborationIsConnected.mockReturnValue(false) + mockCollaborationGetIsLeader.mockReturnValue(true) + isCollaborationEnabled = false }) it('should call handleRefreshWorkflowDraft(true) — not updating canvas — on draft_workflow_not_sync', async () => { @@ -261,4 +285,41 @@ describe('useNodesSyncDraft — handleRefreshWorkflowDraft(true) on 409', () => hash: 'hash-123', })) }) + + it('should emit sync request instead of syncing when current user is collaboration follower', async () => { + isCollaborationEnabled = true + mockCollaborationIsConnected.mockReturnValue(true) + mockCollaborationGetIsLeader.mockReturnValue(false) + const callbacks = { + onSuccess: vi.fn(), + onError: vi.fn(), + onSettled: vi.fn(), + } + + const { result } = renderHook(() => useNodesSyncDraft()) + + await act(async () => { + await result.current.doSyncWorkflowDraft(false, callbacks) + }) + + expect(mockCollaborationEmitSyncRequest).toHaveBeenCalled() + expect(mockSyncWorkflowDraft).not.toHaveBeenCalled() + expect(callbacks.onSuccess).not.toHaveBeenCalled() + expect(callbacks.onError).not.toHaveBeenCalled() + expect(callbacks.onSettled).toHaveBeenCalled() + }) + + it('should skip keepalive sync on page close when current user is collaboration follower', () => { + isCollaborationEnabled = true + mockCollaborationIsConnected.mockReturnValue(true) + mockCollaborationGetIsLeader.mockReturnValue(false) + + const { result } = renderHook(() => useNodesSyncDraft()) + + act(() => { + result.current.syncWorkflowDraftWhenPageClose() + }) + + expect(mockPostWithKeepalive).not.toHaveBeenCalled() + }) }) diff --git a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts index 5f61997d9f..42946e14a8 100644 --- a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts @@ -1,12 +1,15 @@ import type { SyncDraftCallback } from '@/app/components/workflow/hooks-store' +import type { WorkflowDraftFeaturesPayload } from '@/service/workflow' import { produce } from 'immer' import { useCallback } from 'react' import { useStoreApi } from 'reactflow' import { useFeaturesStore } from '@/app/components/base/features/hooks' +import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' import { useSerialAsyncCallback } from '@/app/components/workflow/hooks/use-serial-async-callback' import { useNodesReadOnly } from '@/app/components/workflow/hooks/use-workflow' import { useWorkflowStore } from '@/app/components/workflow/store' import { API_PREFIX } from '@/config' +import { useGlobalPublicStore } from '@/context/global-public-context' import { postWithKeepalive } from '@/service/fetch' import { syncWorkflowDraft } from '@/service/workflow' import { useWorkflowRefreshDraft } from '.' @@ -17,6 +20,7 @@ export const useNodesSyncDraft = () => { const featuresStore = useFeaturesStore() const { getNodesReadOnly } = useNodesReadOnly() const { handleRefreshWorkflowDraft } = useWorkflowRefreshDraft() + const isCollaborationEnabled = useGlobalPublicStore(s => s.systemFeatures.enable_collaboration_mode) const getPostParams = useCallback(() => { const { @@ -54,7 +58,16 @@ export const useNodesSyncDraft = () => { }) }) }) - const viewport = { x, y, zoom } + const featuresPayload: WorkflowDraftFeaturesPayload = { + opening_statement: features.opening?.enabled ? (features.opening?.opening_statement || '') : '', + suggested_questions: features.opening?.enabled ? (features.opening?.suggested_questions || []) : [], + suggested_questions_after_answer: features.suggested, + text_to_speech: features.text2speech, + speech_to_text: features.speech2text, + retriever_resource: features.citation, + sensitive_word_avoidance: features.moderation, + file_upload: features.file, + } return { url: `/apps/${appId}/workflows/draft`, @@ -62,33 +75,37 @@ export const useNodesSyncDraft = () => { graph: { nodes: producedNodes, edges: producedEdges, - viewport, - }, - features: { - opening_statement: features.opening?.enabled ? (features.opening?.opening_statement || '') : '', - suggested_questions: features.opening?.enabled ? (features.opening?.suggested_questions || []) : [], - suggested_questions_after_answer: features.suggested, - text_to_speech: features.text2speech, - speech_to_text: features.speech2text, - retriever_resource: features.citation, - sensitive_word_avoidance: features.moderation, - file_upload: features.file, + viewport: { + x, + y, + zoom, + }, }, + features: featuresPayload, environment_variables: environmentVariables, conversation_variables: conversationVariables, hash: syncWorkflowDraftHash, + ...(isCollaborationEnabled ? { _is_collaborative: true } : {}), }, } - }, [store, featuresStore, workflowStore]) + }, [store, featuresStore, workflowStore, isCollaborationEnabled]) const syncWorkflowDraftWhenPageClose = useCallback(() => { if (getNodesReadOnly()) return + + const isFollower = isCollaborationEnabled + && collaborationManager.isConnected() + && !collaborationManager.getIsLeader() + + if (isFollower) + return + const postParams = getPostParams() if (postParams) postWithKeepalive(`${API_PREFIX}${postParams.url}`, postParams.params) - }, [getPostParams, getNodesReadOnly]) + }, [getPostParams, getNodesReadOnly, isCollaborationEnabled]) const performSync = useCallback(async ( notRefreshWhenSyncError?: boolean, @@ -97,7 +114,16 @@ export const useNodesSyncDraft = () => { if (getNodesReadOnly()) return - // Get base params without hash + const isFollower = isCollaborationEnabled + && collaborationManager.isConnected() + && !collaborationManager.getIsLeader() + + if (isFollower) { + collaborationManager.emitSyncRequest() + callback?.onSettled?.() + return + } + const baseParams = getPostParams() if (!baseParams) return @@ -108,15 +134,13 @@ export const useNodesSyncDraft = () => { } = workflowStore.getState() try { - // IMPORTANT: Get the LATEST hash right before sending the request - // This ensures that even if queued, each request uses the most recent hash const latestHash = workflowStore.getState().syncWorkflowDraftHash const postParams = { ...baseParams, params: { ...baseParams.params, - hash: latestHash || null, // null for first-time, otherwise use latest hash + hash: latestHash || null, }, } @@ -137,7 +161,7 @@ export const useNodesSyncDraft = () => { finally { callback?.onSettled?.() } - }, [workflowStore, getPostParams, getNodesReadOnly, handleRefreshWorkflowDraft]) + }, [workflowStore, getPostParams, getNodesReadOnly, handleRefreshWorkflowDraft, isCollaborationEnabled]) const doSyncWorkflowDraft = useSerialAsyncCallback(performSync, getNodesReadOnly) diff --git a/web/app/components/workflow-app/index.tsx b/web/app/components/workflow-app/index.tsx index 57bff3fa6e..e6fdf0ddc6 100644 --- a/web/app/components/workflow-app/index.tsx +++ b/web/app/components/workflow-app/index.tsx @@ -78,15 +78,18 @@ const WorkflowAppWithAdditionalContext = () => { }, [workflowStore]) const nodesData = useMemo(() => { - if (data) - return initialNodes(data.graph.nodes, data.graph.edges) - + if (data) { + const processedNodes = initialNodes(data.graph.nodes, data.graph.edges) + return processedNodes + } return [] }, [data]) - const edgesData = useMemo(() => { - if (data) - return initialEdges(data.graph.edges, data.graph.nodes) + const edgesData = useMemo(() => { + if (data) { + const processedEdges = initialEdges(data.graph.edges, data.graph.nodes) + return processedEdges + } return [] }, [data]) diff --git a/web/app/components/workflow/__tests__/features.spec.tsx b/web/app/components/workflow/__tests__/features.spec.tsx index 8be40faea9..632a4c4e0b 100644 --- a/web/app/components/workflow/__tests__/features.spec.tsx +++ b/web/app/components/workflow/__tests__/features.spec.tsx @@ -8,8 +8,25 @@ import { InputVarType } from '../types' import { createStartNode } from './fixtures' import { renderWorkflowFlowComponent } from './workflow-test-env' -const mockHandleSyncWorkflowDraft = vi.fn() const mockHandleAddVariable = vi.fn() +const mockUpdateFeatures = vi.fn() +const mockFeaturesStore = { + getState: () => ({ + features: { + opening: { + enabled: false, + opening_statement: '', + suggested_questions: [], + }, + suggested: false, + text2speech: false, + speech2text: false, + citation: false, + moderation: false, + file: false, + }, + }), +} let mockIsChatMode = true let mockNodesReadOnly = false @@ -22,9 +39,6 @@ vi.mock('../hooks', async () => { useNodesReadOnly: () => ({ nodesReadOnly: mockNodesReadOnly, }), - useNodesSyncDraft: () => ({ - handleSyncWorkflowDraft: mockHandleSyncWorkflowDraft, - }), } }) @@ -34,6 +48,14 @@ vi.mock('../nodes/start/use-config', () => ({ }), })) +vi.mock('@/service/workflow', () => ({ + updateFeatures: (...args: unknown[]) => mockUpdateFeatures(...args), +})) + +vi.mock('@/app/components/base/features/hooks', () => ({ + useFeaturesStore: () => mockFeaturesStore, +})) + vi.mock('@/app/components/base/features/new-feature-panel', () => ({ default: ({ show, @@ -112,21 +134,29 @@ const DelayedFeatures = () => { return } -const renderFeatures = (options?: Omit[1], 'nodes' | 'edges'>) => - renderWorkflowFlowComponent( +const renderFeatures = (options?: Omit[1]>, 'nodes' | 'edges'>) => { + const mergedInitialStoreState = { + appId: 'app-1', + ...(options?.initialStoreState || {}), + } + + return renderWorkflowFlowComponent( , { nodes: [startNode], edges: [], ...options, + initialStoreState: mergedInitialStoreState, }, ) +} describe('Features', () => { beforeEach(() => { vi.clearAllMocks() mockIsChatMode = true mockNodesReadOnly = false + mockUpdateFeatures.mockResolvedValue(undefined) }) describe('Rendering', () => { @@ -146,8 +176,10 @@ describe('Features', () => { await user.click(screen.getByRole('button', { name: 'open features' })) - expect(mockHandleSyncWorkflowDraft).toHaveBeenCalledTimes(1) - expect(store.getState().showFeaturesPanel).toBe(true) + await vi.waitFor(() => { + expect(mockUpdateFeatures).toHaveBeenCalledTimes(1) + expect(store.getState().showFeaturesPanel).toBe(true) + }) }) it('should close the workflow feature panel and transform required prompt variables', async () => { diff --git a/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx b/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx index 914c1be617..7a02781c17 100644 --- a/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx +++ b/web/app/components/workflow/__tests__/panel-contextmenu.spec.tsx @@ -8,6 +8,7 @@ const mockUseStore = vi.hoisted(() => vi.fn()) const mockUseNodesInteractions = vi.hoisted(() => vi.fn()) const mockUsePanelInteractions = vi.hoisted(() => vi.fn()) const mockUseWorkflowStartRun = vi.hoisted(() => vi.fn()) +const mockUseWorkflowMoveMode = vi.hoisted(() => vi.fn()) const mockUseOperator = vi.hoisted(() => vi.fn()) const mockUseDSL = vi.hoisted(() => vi.fn()) @@ -23,6 +24,9 @@ vi.mock('@/app/components/workflow/store', () => ({ useStore: (selector: (state: { panelMenu?: { left: number, top: number } clipboardElements: unknown[] + pendingComment: null | { pageX: number, pageY: number, elementX: number, elementY: number } + setCommentPlacing: (placing: boolean) => void + setCommentQuickAdd: (quickAdd: boolean) => void setShowImportDSLModal: (visible: boolean) => void }) => unknown) => mockUseStore(selector), })) @@ -31,6 +35,7 @@ vi.mock('@/app/components/workflow/hooks', () => ({ useNodesInteractions: () => mockUseNodesInteractions(), usePanelInteractions: () => mockUsePanelInteractions(), useWorkflowStartRun: () => mockUseWorkflowStartRun(), + useWorkflowMoveMode: () => mockUseWorkflowMoveMode(), useDSL: () => mockUseDSL(), })) @@ -62,14 +67,18 @@ describe('PanelContextmenu', () => { const mockHandleAddNote = vi.fn() const mockExportCheck = vi.fn() const mockSetShowImportDSLModal = vi.fn() + const mockSetCommentPlacing = vi.fn() + const mockSetCommentQuickAdd = vi.fn() let panelMenu: { left: number, top: number } | undefined let clipboardElements: unknown[] + let pendingComment: null | { pageX: number, pageY: number, elementX: number, elementY: number } let clickAwayHandler: (() => void) | undefined beforeEach(() => { vi.clearAllMocks() panelMenu = undefined clipboardElements = [] + pendingComment = null clickAwayHandler = undefined mockUseClickAway.mockImplementation((handler: () => void) => { @@ -81,10 +90,16 @@ describe('PanelContextmenu', () => { mockUseStore.mockImplementation((selector: (state: { panelMenu?: { left: number, top: number } clipboardElements: unknown[] + pendingComment: null | { pageX: number, pageY: number, elementX: number, elementY: number } + setCommentPlacing: (placing: boolean) => void + setCommentQuickAdd: (quickAdd: boolean) => void setShowImportDSLModal: (visible: boolean) => void }) => unknown) => selector({ panelMenu, clipboardElements, + pendingComment, + setCommentPlacing: mockSetCommentPlacing, + setCommentQuickAdd: mockSetCommentQuickAdd, setShowImportDSLModal: mockSetShowImportDSLModal, })) mockUseNodesInteractions.mockReturnValue({ @@ -96,6 +111,9 @@ describe('PanelContextmenu', () => { mockUseWorkflowStartRun.mockReturnValue({ handleStartWorkflowRun: mockHandleStartWorkflowRun, }) + mockUseWorkflowMoveMode.mockReturnValue({ + isCommentModeAvailable: false, + }) mockUseOperator.mockReturnValue({ handleAddNote: mockHandleAddNote, }) diff --git a/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx b/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx index 961ab6ddb4..d56cf80779 100644 --- a/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx +++ b/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx @@ -17,6 +17,7 @@ class MockFileReader { vi.stubGlobal('FileReader', MockFileReader as unknown as typeof FileReader) const mockEmit = vi.fn() +const mockEmitWorkflowUpdate = vi.hoisted(() => vi.fn()) vi.mock('@/app/components/base/ui/toast', () => ({ toast: { @@ -39,6 +40,12 @@ vi.mock('@/service/workflow', () => ({ fetchWorkflowDraft: (path: string) => mockFetchWorkflowDraft(path), })) +vi.mock('../collaboration/core/collaboration-manager', () => ({ + collaborationManager: { + emitWorkflowUpdate: mockEmitWorkflowUpdate, + }, +})) + const mockHandleCheckPluginDependencies = vi.fn() vi.mock('@/app/components/workflow/plugin-dependency/hooks', () => ({ usePluginDependencies: () => ({ @@ -138,6 +145,7 @@ describe('UpdateDSLModal', () => { expect(mockEmit).toHaveBeenCalledWith(expect.objectContaining({ type: 'WORKFLOW_DATA_UPDATE', })) + expect(mockEmitWorkflowUpdate).toHaveBeenCalledWith('app-1') expect(defaultProps.onImport).toHaveBeenCalledTimes(1) expect(defaultProps.onCancel).toHaveBeenCalledTimes(1) }) @@ -187,6 +195,7 @@ describe('UpdateDSLModal', () => { await waitFor(() => { expect(mockImportDSLConfirm).toHaveBeenCalledWith({ import_id: 'import-2' }) }) + expect(mockEmitWorkflowUpdate).toHaveBeenCalledWith('app-1') }) it('should open the pending modal after the timeout and allow dismissing it', async () => { diff --git a/web/app/components/workflow/__tests__/workflow-edge-events.spec.tsx b/web/app/components/workflow/__tests__/workflow-edge-events.spec.tsx index 7ad6d0c13d..586a3c930e 100644 --- a/web/app/components/workflow/__tests__/workflow-edge-events.spec.tsx +++ b/web/app/components/workflow/__tests__/workflow-edge-events.spec.tsx @@ -5,6 +5,7 @@ import { BaseEdge, internalsSymbol, Position, ReactFlowProvider, useStoreApi } f import { FlowType } from '@/types/common' import { WORKFLOW_DATA_UPDATE } from '../constants' import { Workflow } from '../index' +import { ControlMode } from '../types' import { renderWorkflowComponent } from './workflow-test-env' type WorkflowUpdateEvent = { @@ -23,6 +24,32 @@ const reactFlowBridge = vi.hoisted(() => ({ store: null as null | ReturnType, })) +const collaborationBridge = vi.hoisted(() => ({ + graphImportHandler: null as null | ((payload: { nodes: Node[], edges: Edge[] }) => void), + historyActionHandler: null as null | ((payload: unknown) => void), +})) + +const toastInfoMock = vi.hoisted(() => vi.fn()) + +const workflowCommentState = vi.hoisted(() => ({ + comments: [] as Array>, + pendingComment: null as null | { elementX: number, elementY: number }, + activeComment: null as null | Record, + activeCommentLoading: false, + replySubmitting: false, + replyUpdating: false, + handleCommentSubmit: vi.fn(), + handleCommentCancel: vi.fn(), + handleCommentIconClick: vi.fn(), + handleActiveCommentClose: vi.fn(), + handleCommentResolve: vi.fn(), + handleCommentDelete: vi.fn(async () => {}), + handleCommentReply: vi.fn(), + handleCommentReplyUpdate: vi.fn(), + handleCommentReplyDelete: vi.fn(async () => {}), + handleCommentPositionUpdate: vi.fn(), +})) + const workflowHookMocks = vi.hoisted(() => ({ handleNodeDragStart: vi.fn(), handleNodeDrag: vi.fn(), @@ -110,6 +137,12 @@ vi.mock('@/next/dynamic', () => ({ default: () => () => null, })) +vi.mock('@/next/navigation', () => ({ + useParams: () => ({ + appId: 'app-1', + }), +})) + vi.mock('@/context/event-emitter', () => ({ useEventEmitterContextContext: () => ({ eventEmitter: { @@ -131,6 +164,109 @@ vi.mock('@/service/workflow', () => ({ fetchAllInspectVars: vi.fn().mockResolvedValue([]), })) +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + info: toastInfoMock, + }, +})) + +vi.mock('../collaboration/core/collaboration-manager', () => ({ + collaborationManager: { + onGraphImport: (handler: (payload: { nodes: Node[], edges: Edge[] }) => void) => { + collaborationBridge.graphImportHandler = handler + return vi.fn() + }, + onHistoryAction: (handler: (payload: unknown) => void) => { + collaborationBridge.historyActionHandler = handler + return vi.fn() + }, + }, +})) + +vi.mock('../comment-manager', () => ({ + default: () =>
, +})) + +vi.mock('../comment/cursor', () => ({ + CommentCursor: () =>
, +})) + +vi.mock('../comment/comment-input', () => ({ + CommentInput: ({ disabled, onCancel }: { disabled?: boolean, onCancel?: () => void }) => ( + + ), +})) + +vi.mock('../comment/comment-icon', () => ({ + CommentIcon: ({ + comment, + onClick, + onPositionUpdate, + }: { + comment: { id: string } + onClick?: () => void + onPositionUpdate?: (position: { elementX: number, elementY: number }) => void + }) => ( + + ), +})) + +vi.mock('../comment/thread', () => ({ + CommentThread: ({ + onDelete, + onReplyDelete, + onNext, + }: { + onDelete?: () => void + onReplyDelete?: (replyId: string) => void + onNext?: () => void + }) => ( +
+ + + +
+ ), +})) + +vi.mock('../hooks/use-workflow-comment', () => ({ + useWorkflowComment: () => workflowCommentState, +})) + +vi.mock('../base/confirm', () => ({ + default: ({ + isShow, + onConfirm, + onCancel, + }: { + isShow: boolean + onConfirm: () => void + onCancel: () => void + }) => isShow + ? ( +
+ + +
+ ) + : null, +})) + vi.mock('../candidate-node', () => ({ default: () => null, })) @@ -254,6 +390,7 @@ vi.mock('../hooks', () => ({ useWorkflowRefreshDraft: () => ({ handleRefreshWorkflowDraft: vi.fn(), }), + useLeaderRestoreListener: vi.fn(), })) vi.mock('../hooks/use-workflow-search', () => ({ @@ -329,6 +466,14 @@ describe('Workflow edge event wiring', () => { vi.clearAllMocks() eventEmitterState.subscription = null reactFlowBridge.store = null + collaborationBridge.graphImportHandler = null + collaborationBridge.historyActionHandler = null + workflowCommentState.comments = [] + workflowCommentState.pendingComment = null + workflowCommentState.activeComment = null + workflowCommentState.activeCommentLoading = false + workflowCommentState.replySubmitting = false + workflowCommentState.replyUpdating = false }) it('should forward pane, node and edge-change events to workflow handlers when emitted by the canvas', async () => { @@ -450,4 +595,98 @@ describe('Workflow edge event wiring', () => { expect(onConfirm).toHaveBeenCalledTimes(1) }) + + it('should sync graph import events and show history action toast', async () => { + renderSubject() + + const importedNodes = [createInitializedNode('node-3', 480, 'Workflow node node-3')] as unknown as Node[] + + act(() => { + collaborationBridge.graphImportHandler?.({ + nodes: importedNodes, + edges: [], + }) + collaborationBridge.historyActionHandler?.({ action: 'undo' }) + }) + + await waitFor(() => { + expect(screen.getByText('Workflow node node-3')).toBeInTheDocument() + expect(toastInfoMock).toHaveBeenCalledTimes(1) + }) + }) + + it('should render comment overlays and execute comment actions in comment mode', async () => { + workflowCommentState.comments = [ + { id: 'comment-1', resolved: false }, + { id: 'comment-2', resolved: false }, + ] + workflowCommentState.activeComment = { id: 'comment-1', resolved: false } + workflowCommentState.pendingComment = { elementX: 20, elementY: 30 } + + const { container, store } = renderSubject({ + initialStoreState: { + controlMode: ControlMode.Comment, + showUserComments: true, + showResolvedComments: false, + isCommentPlacing: true, + pendingComment: null, + isCommentPreviewHovering: true, + mousePosition: { + pageX: 100, + pageY: 120, + elementX: 40, + elementY: 60, + }, + }, + }) + + const pane = getPane(container) + act(() => { + fireEvent.mouseMove(pane, { clientX: 150, clientY: 180 }) + }) + + expect(screen.getByTestId('comment-cursor')).toBeInTheDocument() + expect(screen.getByTestId('comment-input-preview')).toBeInTheDocument() + expect(screen.getByTestId('comment-input-active')).toBeInTheDocument() + expect(screen.getByTestId('comment-icon-comment-1')).toBeInTheDocument() + expect(screen.getByTestId('comment-icon-comment-2')).toBeInTheDocument() + expect(screen.getByTestId('comment-thread')).toBeInTheDocument() + + act(() => { + fireEvent.click(screen.getByRole('button', { name: 'next-comment' })) + }) + expect(workflowCommentState.handleCommentIconClick).toHaveBeenCalledWith({ id: 'comment-2', resolved: false }) + + act(() => { + fireEvent.click(screen.getByRole('button', { name: 'delete-thread' })) + }) + expect(store.getState().showConfirm).toBeDefined() + + await act(async () => { + await store.getState().showConfirm?.onConfirm() + }) + expect(workflowCommentState.handleCommentDelete).toHaveBeenCalledWith('comment-1') + + act(() => { + fireEvent.click(screen.getByRole('button', { name: 'delete-reply' })) + }) + expect(store.getState().showConfirm).toBeDefined() + await act(async () => { + await store.getState().showConfirm?.onConfirm() + }) + expect(workflowCommentState.handleCommentReplyDelete).toHaveBeenCalledWith('comment-1', 'reply-1') + + const wheelEvent = new WheelEvent('wheel', { + cancelable: true, + ctrlKey: true, + }) + act(() => { + window.dispatchEvent(wheelEvent) + }) + + const gestureEvent = new Event('gesturestart', { cancelable: true }) + act(() => { + window.dispatchEvent(gestureEvent) + }) + }) }) diff --git a/web/app/components/workflow/candidate-node-main.tsx b/web/app/components/workflow/candidate-node-main.tsx index 9df5510627..a3a62345ff 100644 --- a/web/app/components/workflow/candidate-node-main.tsx +++ b/web/app/components/workflow/candidate-node-main.tsx @@ -11,9 +11,9 @@ import { } from 'react' import { useReactFlow, - useStoreApi, useViewport, } from 'reactflow' +import { useCollaborativeWorkflow } from '@/app/components/workflow/hooks/use-collaborative-workflow' import { CUSTOM_NODE } from './constants' import { useAutoGenerateWebhookUrl, useNodesInteractions, useNodesSyncDraft, useWorkflowHistory, WorkflowHistoryEvent } from './hooks' import CustomNode from './nodes' @@ -32,7 +32,6 @@ type Props = { const CandidateNodeMain: FC = ({ candidateNode, }) => { - const store = useStoreApi() const reactflow = useReactFlow() const workflowStore = useWorkflowStore() const mousePosition = useStore(s => s.mousePosition) @@ -41,15 +40,12 @@ const CandidateNodeMain: FC = ({ const { saveStateToHistory } = useWorkflowHistory() const { handleSyncWorkflowDraft } = useNodesSyncDraft() const autoGenerateWebhookUrl = useAutoGenerateWebhookUrl() + const collaborativeWorkflow = useCollaborativeWorkflow() useEventListener('click', (e) => { e.preventDefault() - const { - getNodes, - setNodes, - } = store.getState() const { screenToFlowPosition } = reactflow - const nodes = getNodes() + const { nodes, setNodes } = collaborativeWorkflow.getState() const { x, y } = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY }) const newNodes = produce(nodes, (draft) => { draft.push({ diff --git a/web/app/components/workflow/collaboration/components/user-cursors.tsx b/web/app/components/workflow/collaboration/components/user-cursors.tsx new file mode 100644 index 0000000000..4fbf491d0e --- /dev/null +++ b/web/app/components/workflow/collaboration/components/user-cursors.tsx @@ -0,0 +1,78 @@ +import type { FC } from 'react' +import type { CursorPosition, OnlineUser } from '@/app/components/workflow/collaboration/types/collaboration' +import { useViewport } from 'reactflow' +import { getUserColor } from '../utils/user-color' + +type UserCursorsProps = { + cursors: Record + myUserId: string | null + onlineUsers: OnlineUser[] +} + +const UserCursors: FC = ({ + cursors, + myUserId, + onlineUsers, +}) => { + const viewport = useViewport() + + const convertToScreenCoordinates = (cursor: CursorPosition) => { + // Convert world coordinates to screen coordinates using current viewport + const screenX = cursor.x * viewport.zoom + viewport.x + const screenY = cursor.y * viewport.zoom + viewport.y + + return { x: screenX, y: screenY } + } + return ( + <> + {Object.entries(cursors || {}).map(([userId, cursor]) => { + if (userId === myUserId) + return null + + const userInfo = onlineUsers.find(user => user.user_id === userId) + const userName = userInfo?.username || `User ${userId.slice(-4)}` + const userColor = getUserColor(userId) + const screenPos = convertToScreenCoordinates(cursor) + + return ( +
+ + + + +
+ {userName} +
+
+ ) + })} + + ) +} + +export default UserCursors diff --git a/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.logs-and-events.spec.ts b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.logs-and-events.spec.ts new file mode 100644 index 0000000000..89888bf1cb --- /dev/null +++ b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.logs-and-events.spec.ts @@ -0,0 +1,398 @@ +import type { LoroMap } from 'loro-crdt' +import type { OnlineUser, RestoreRequestData } from '../../types/collaboration' +import type { NoteNodeType } from '@/app/components/workflow/note-node/types' +import type { Edge, Node } from '@/app/components/workflow/types' +import { LoroDoc } from 'loro-crdt' +import { BlockEnum } from '@/app/components/workflow/types' +import { CollaborationManager } from '../collaboration-manager' +import { webSocketClient } from '../websocket-manager' + +type ReactFlowStore = { + getState: () => { + getNodes: () => Node[] + setNodes: (nodes: Node[]) => void + getEdges: () => Edge[] + setEdges: (edges: Edge[]) => void + } +} + +type UndoManagerLike = { + canUndo: () => boolean + canRedo: () => boolean + undo: () => boolean + redo: () => boolean + clear: () => void +} + +type CollaborationManagerInternals = { + doc: LoroDoc | null + nodesMap: LoroMap | null + edgesMap: LoroMap | null + undoManager: UndoManagerLike | null + currentAppId: string | null + reactFlowStore: ReactFlowStore | null + leaderId: string | null + isLeader: boolean + graphViewActive: boolean | null + pendingInitialSync: boolean + onlineUsers: OnlineUser[] + graphImportLogs: unknown[] + setNodesAnomalyLogs: unknown[] + graphSyncDiagnostics: unknown[] + pendingImportLog: unknown | null +} + +const getManagerInternals = (manager: CollaborationManager): CollaborationManagerInternals => + manager as unknown as CollaborationManagerInternals + +const createNode = (id: string): Node => ({ + id, + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.Start, + title: `Node-${id}`, + desc: '', + }, +}) + +const createEdge = (id: string, source: string, target: string): Edge => ({ + id, + source, + target, + type: 'default', + data: { + sourceType: BlockEnum.Start, + targetType: BlockEnum.End, + }, +}) + +const createRestoreRequestData = (): RestoreRequestData => ({ + versionId: 'version-1', + versionName: 'Version One', + initiatorUserId: 'user-1', + initiatorName: 'Alice', + graphData: { + nodes: [createNode('n-restore')], + edges: [], + viewport: { x: 1, y: 2, zoom: 0.5 }, + }, +}) + +const setupManagerWithDoc = () => { + const manager = new CollaborationManager() + const doc = new LoroDoc() + const internals = getManagerInternals(manager) + internals.doc = doc + internals.nodesMap = doc.getMap('nodes') + internals.edgesMap = doc.getMap('edges') + return { manager, internals } +} + +describe('CollaborationManager logs and event helpers', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('refreshGraphSynchronously emits merged graph with local selected node state', () => { + const { manager, internals } = setupManagerWithDoc() + const node = createNode('n-1') + const edge = createEdge('e-1', 'n-1', 'n-2') + + manager.setNodes([], [node]) + manager.setEdges([], [edge]) + + internals.reactFlowStore = { + getState: () => ({ + getNodes: () => [{ + ...node, + data: { + ...node.data, + selected: true, + }, + }], + setNodes: vi.fn(), + getEdges: () => [edge], + setEdges: vi.fn(), + }), + } + + const graphPayloads: Array<{ nodes: Node[], edges: Edge[] }> = [] + manager.onGraphImport((graph) => { + graphPayloads.push(graph) + }) + + manager.refreshGraphSynchronously() + + expect(graphPayloads).toHaveLength(1) + const payload = graphPayloads[0] + if (!payload) + throw new Error('graph import payload should exist') + expect(payload.nodes).toHaveLength(1) + expect(payload.edges).toHaveLength(1) + expect(payload.nodes[0]?.data.selected).toBe(true) + }) + + it('seeds the full reactflow graph before applying a partial local node update to an empty CRDT', () => { + const { manager, internals } = setupManagerWithDoc() + const startNode = createNode('n-start') + const noteNode: Node = { + id: 'n-note', + type: 'custom-note', + position: { x: 100, y: 100 }, + data: { + type: BlockEnum.Start, + title: '', + desc: '', + text: 'note', + theme: 'yellow', + author: 'Dify', + showAuthor: true, + }, + } + const edge = createEdge('e-start-note', 'n-start', 'n-note') + + const oldNodes = [startNode, noteNode] + const nextNodes: Node[] = [ + startNode, + { + ...noteNode, + data: { + ...noteNode.data, + selected: true, + }, + }, + ] + + internals.reactFlowStore = { + getState: () => ({ + getNodes: () => oldNodes, + setNodes: vi.fn(), + getEdges: () => [edge], + setEdges: vi.fn(), + }), + } + + manager.setNodes(oldNodes, nextNodes, 'test:partial-note-update') + + expect(manager.getNodes().map(node => node.id).sort()).toEqual(['n-note', 'n-start']) + expect(manager.getEdges().map(currentEdge => currentEdge.id)).toEqual(['e-start-note']) + }) + + it('clearGraphImportLog clears logs and pending import snapshot', () => { + const { manager, internals } = setupManagerWithDoc() + internals.graphImportLogs = [{ id: 1 }] + internals.setNodesAnomalyLogs = [{ id: 2 }] + internals.graphSyncDiagnostics = [{ id: 3 }] + internals.pendingImportLog = { id: 4 } + + manager.clearGraphImportLog() + + expect(manager.getGraphImportLog()).toEqual([]) + expect(internals.setNodesAnomalyLogs).toEqual([]) + expect(internals.graphSyncDiagnostics).toEqual([]) + expect(internals.pendingImportLog).toBeNull() + }) + + it('downloadGraphImportLog exports a JSON snapshot and triggers browser download', async () => { + const { manager, internals } = setupManagerWithDoc() + const node = createNode('n-export') + const edge = createEdge('e-export', 'n-export', 'n-target') + + manager.setNodes([], [node]) + manager.setEdges([], [edge]) + + internals.currentAppId = 'app-export' + internals.leaderId = 'leader-1' + internals.isLeader = true + internals.graphViewActive = true + internals.pendingInitialSync = false + internals.onlineUsers = [{ user_id: 'u-1', username: 'Alice', avatar: '', sid: 'sid-1' }] + internals.graphImportLogs = [{ timestamp: 1 }] + internals.setNodesAnomalyLogs = [{ timestamp: 2 }] + internals.graphSyncDiagnostics = [{ timestamp: 3 }] + internals.reactFlowStore = { + getState: () => ({ + getNodes: () => [createNode('rf-1'), createNode('rf-2')], + setNodes: vi.fn(), + getEdges: () => [createEdge('rf-e', 'rf-1', 'rf-2')], + setEdges: vi.fn(), + }), + } + + const createObjectURLSpy = vi.spyOn(URL, 'createObjectURL').mockReturnValue('blob:workflow-log') + const revokeObjectURLSpy = vi.spyOn(URL, 'revokeObjectURL').mockImplementation(() => {}) + const anchor = document.createElement('a') + const clickSpy = vi.spyOn(anchor, 'click').mockImplementation(() => {}) + const originalCreateElement = document.createElement.bind(document) + const createElementSpy = vi.spyOn(document, 'createElement').mockImplementation((tagName: string): HTMLElement => { + if (tagName === 'a') + return anchor + return originalCreateElement(tagName) + }) + + manager.downloadGraphImportLog() + + expect(createObjectURLSpy).toHaveBeenCalledTimes(1) + expect(clickSpy).toHaveBeenCalledTimes(1) + expect(anchor.download).toContain('workflow-graph-import-log-app-export-') + expect(anchor.download).toMatch(/\.json$/) + expect(revokeObjectURLSpy).toHaveBeenCalledWith('blob:workflow-log') + + const blobArg = createObjectURLSpy.mock.calls[0]?.[0] + expect(blobArg).toBeInstanceOf(Blob) + const payload = JSON.parse(await (blobArg as Blob).text()) as { + appId: string | null + summary: { + logCount: number + setNodesAnomalyCount: number + syncDiagnosticCount: number + onlineUsersCount: number + crdtCounts: { nodes: number, edges: number } + reactFlowCounts: { nodes: number, edges: number } + } + } + + expect(payload.appId).toBe('app-export') + expect(payload.summary.logCount).toBe(1) + expect(payload.summary.setNodesAnomalyCount).toBe(1) + expect(payload.summary.syncDiagnosticCount).toBe(1) + expect(payload.summary.onlineUsersCount).toBe(1) + expect(payload.summary.crdtCounts).toEqual({ nodes: 1, edges: 1 }) + expect(payload.summary.reactFlowCounts).toEqual({ nodes: 2, edges: 1 }) + + createElementSpy.mockRestore() + clickSpy.mockRestore() + }) + + it('emits collaboration events only when current app is connected', () => { + const { manager, internals } = setupManagerWithDoc() + const sendSpy = vi.spyOn( + manager as unknown as { sendCollaborationEvent: (payload: unknown) => void }, + 'sendCollaborationEvent', + ).mockImplementation(() => {}) + const isConnectedSpy = vi.spyOn(webSocketClient, 'isConnected').mockReturnValue(false) + + manager.emitCommentsUpdate('app-1') + manager.emitHistoryAction('undo') + manager.emitRestoreRequest(createRestoreRequestData()) + manager.emitRestoreIntent({ + versionId: 'version-1', + versionName: 'Version One', + initiatorUserId: 'u-1', + initiatorName: 'Alice', + }) + manager.emitRestoreComplete({ versionId: 'version-1', success: true }) + expect(sendSpy).not.toHaveBeenCalled() + + internals.currentAppId = 'app-1' + + manager.emitCommentsUpdate('app-1') + manager.emitHistoryAction('undo') + manager.emitRestoreRequest(createRestoreRequestData()) + expect(sendSpy).not.toHaveBeenCalled() + + isConnectedSpy.mockReturnValue(true) + manager.emitCommentsUpdate('app-1') + manager.emitHistoryAction('redo') + manager.emitRestoreRequest(createRestoreRequestData()) + manager.emitRestoreIntent({ + versionId: 'version-2', + initiatorUserId: 'u-2', + initiatorName: 'Bob', + }) + manager.emitRestoreComplete({ versionId: 'version-2', success: false, error: 'failed' }) + + const eventTypes = sendSpy.mock.calls.map(call => ( + (call[0] as { type: string }).type + )) + expect(eventTypes).toEqual([ + 'comments_update', + 'workflow_history_action', + 'workflow_restore_request', + 'workflow_restore_intent', + 'workflow_restore_complete', + ]) + }) + + it('returns leader state through public getters', () => { + const { manager, internals } = setupManagerWithDoc() + internals.leaderId = 'leader-123' + internals.isLeader = true + + expect(manager.getLeaderId()).toBe('leader-123') + expect(manager.getIsLeader()).toBe(true) + }) + + it('undo and redo apply CRDT graph to ReactFlow store and emit undo/redo state', () => { + const { manager, internals } = setupManagerWithDoc() + const updatedNode = createNode('n-after-undo-redo') + const updatedEdge = createEdge('e-after-undo-redo', 'n-after-undo-redo', 'n-target') + internals.nodesMap?.set(updatedNode.id, updatedNode as unknown as Record) + internals.edgesMap?.set(updatedEdge.id, updatedEdge as unknown as Record) + + const setNodesSpy = vi.fn() + const setEdgesSpy = vi.fn() + internals.reactFlowStore = { + getState: () => ({ + getNodes: () => [createNode('old-node')], + setNodes: setNodesSpy, + getEdges: () => [createEdge('old-edge', 'old-node', 'old-target')], + setEdges: setEdgesSpy, + }), + } + + const undoManager: UndoManagerLike = { + canUndo: vi.fn(() => true), + canRedo: vi.fn(() => true), + undo: vi.fn(() => true), + redo: vi.fn(() => true), + clear: vi.fn(), + } + internals.undoManager = undoManager + + const rafSpy = vi.spyOn(globalThis, 'requestAnimationFrame').mockImplementation((callback: FrameRequestCallback) => { + callback(0) + return 1 + }) + + const historyStates: Array<{ canUndo: boolean, canRedo: boolean }> = [] + manager.onUndoRedoStateChange((state) => { + historyStates.push(state) + }) + + expect(manager.undo()).toBe(true) + expect(manager.redo()).toBe(true) + expect(setNodesSpy).toHaveBeenCalledTimes(2) + expect(setEdgesSpy).toHaveBeenCalledTimes(2) + expect(historyStates).toEqual([ + { canUndo: true, canRedo: true }, + { canUndo: true, canRedo: true }, + ]) + + rafSpy.mockRestore() + }) + + it('exposes undo stack state helpers and supports clearing the stack', () => { + const { manager, internals } = setupManagerWithDoc() + + expect(manager.canUndo()).toBe(false) + expect(manager.canRedo()).toBe(false) + expect(manager.undo()).toBe(false) + expect(manager.redo()).toBe(false) + + const undoManager: UndoManagerLike = { + canUndo: vi.fn(() => false), + canRedo: vi.fn(() => true), + undo: vi.fn(() => false), + redo: vi.fn(() => false), + clear: vi.fn(), + } + internals.undoManager = undoManager + + expect(manager.canUndo()).toBe(false) + expect(manager.canRedo()).toBe(true) + manager.clearUndoStack() + expect(undoManager.clear).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.merge-behavior.test.ts b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.merge-behavior.test.ts new file mode 100644 index 0000000000..1368851783 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.merge-behavior.test.ts @@ -0,0 +1,331 @@ +import type { LoroMap } from 'loro-crdt/base64' +import type { Node } from '@/app/components/workflow/types' +import { LoroDoc } from 'loro-crdt/base64' +import { BlockEnum } from '@/app/components/workflow/types' +import { CollaborationManager } from '../collaboration-manager' + +const NODE_ID = 'node-1' +const LLM_NODE_ID = 'llm-node' +const PARAM_NODE_ID = 'parameter-node' + +type WorkflowVariable = { + variable: string + label: string + type: string + required: boolean + default: string + max_length: number + placeholder: string + options: string[] + hint: string +} + +type PromptTemplateItem = { + id: string + role: string + text: string +} + +type ParameterItem = { + description: string + name: string + required: boolean + type: string +} + +type StartNodeData = { + variables: WorkflowVariable[] +} + +type LLMNodeData = { + model: { + mode: string + name: string + provider: string + completion_params: { + temperature: number + } + } + context: { + enabled: boolean + variable_selector: string[] + } + vision: { + enabled: boolean + } + prompt_template: PromptTemplateItem[] +} + +type ParameterExtractorNodeData = { + model: { + mode: string + name: string + provider: string + completion_params: { + temperature: number + } + } + parameters: ParameterItem[] + query: unknown[] + reasoning_mode: string + vision: { + enabled: boolean + } +} + +type CollaborationManagerInternals = { + doc: LoroDoc + nodesMap: LoroMap + edgesMap: LoroMap + syncNodes: (oldNodes: Node[], newNodes: Node[]) => void +} + +const createNode = (variables: string[]): Node => ({ + id: NODE_ID, + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.Start, + title: 'Start', + desc: '', + variables: variables.map(name => ({ + variable: name, + label: name, + type: 'text-input', + required: true, + default: '', + max_length: 48, + placeholder: '', + options: [], + hint: '', + })), + }, +}) + +const createLLMNode = (templates: PromptTemplateItem[]): Node => ({ + id: LLM_NODE_ID, + type: 'custom', + position: { x: 200, y: 200 }, + data: { + type: BlockEnum.LLM, + title: 'LLM', + desc: '', + selected: false, + model: { + mode: 'chat', + name: 'gemini-2.5-pro', + provider: 'langgenius/gemini/google', + completion_params: { + temperature: 0.7, + }, + }, + context: { + enabled: false, + variable_selector: [], + }, + vision: { + enabled: false, + }, + prompt_template: templates, + }, +}) + +const createParameterExtractorNode = (parameters: ParameterItem[]): Node => ({ + id: PARAM_NODE_ID, + type: 'custom', + position: { x: 400, y: 120 }, + data: { + type: BlockEnum.ParameterExtractor, + title: 'ParameterExtractor', + desc: '', + selected: true, + model: { + mode: 'chat', + name: '', + provider: '', + completion_params: { + temperature: 0.7, + }, + }, + query: [], + reasoning_mode: 'prompt', + parameters, + vision: { + enabled: false, + }, + }, +}) + +const getManagerInternals = (manager: CollaborationManager): CollaborationManagerInternals => + manager as unknown as CollaborationManagerInternals + +const getManager = (doc: LoroDoc) => { + const manager = new CollaborationManager() + const internals = getManagerInternals(manager) + internals.doc = doc + internals.nodesMap = doc.getMap('nodes') + internals.edgesMap = doc.getMap('edges') + return manager +} + +const deepClone = (value: T): T => JSON.parse(JSON.stringify(value)) + +const syncNodes = (manager: CollaborationManager, previous: Node[], next: Node[]) => { + const internals = getManagerInternals(manager) + internals.syncNodes(previous, next) +} + +const exportNodes = (manager: CollaborationManager) => manager.getNodes() + +describe('Loro merge behavior smoke test', () => { + it('inspects concurrent edits after merge', () => { + const docA = new LoroDoc() + const managerA = getManager(docA) + syncNodes(managerA, [], [createNode(['a'])]) + + const snapshot = docA.export({ mode: 'snapshot' }) + + const docB = LoroDoc.fromSnapshot(snapshot) + const managerB = getManager(docB) + + syncNodes(managerA, [createNode(['a'])], [createNode(['a', 'b'])]) + syncNodes(managerB, [createNode(['a'])], [createNode(['a', 'c'])]) + + const updateForA = docB.export({ mode: 'update', from: docA.version() }) + docA.import(updateForA) + + const updateForB = docA.export({ mode: 'update', from: docB.version() }) + docB.import(updateForB) + + const finalA = exportNodes(managerA) + const finalB = exportNodes(managerB) + expect(finalA.length).toBe(1) + expect(finalB.length).toBe(1) + }) + + it('merges prompt template insertions and edits across replicas', () => { + const baseTemplate = [ + { + id: 'system-1', + role: 'system', + text: 'base instruction', + }, + ] + + const docA = new LoroDoc() + const managerA = getManager(docA) + syncNodes(managerA, [], [createLLMNode(deepClone(baseTemplate))]) + + const snapshot = docA.export({ mode: 'snapshot' }) + const docB = LoroDoc.fromSnapshot(snapshot) + const managerB = getManager(docB) + + const additionTemplate = [ + ...baseTemplate, + { + id: 'user-1', + role: 'user', + text: 'hello from docA', + }, + ] + syncNodes(managerA, [createLLMNode(deepClone(baseTemplate))], [createLLMNode(deepClone(additionTemplate))]) + + const editedTemplate = [ + { + id: 'system-1', + role: 'system', + text: 'updated by docB', + }, + ] + syncNodes(managerB, [createLLMNode(deepClone(baseTemplate))], [createLLMNode(deepClone(editedTemplate))]) + + const updateForA = docB.export({ mode: 'update', from: docA.version() }) + docA.import(updateForA) + + const updateForB = docA.export({ mode: 'update', from: docB.version() }) + docB.import(updateForB) + + const finalA = exportNodes(managerA).find(node => node.id === LLM_NODE_ID) as Node | undefined + const finalB = exportNodes(managerB).find(node => node.id === LLM_NODE_ID) as Node | undefined + + expect(finalA).toBeDefined() + expect(finalB).toBeDefined() + + const expectedTemplates = [ + { + id: 'system-1', + role: 'system', + text: 'updated by docB', + }, + { + id: 'user-1', + role: 'user', + text: 'hello from docA', + }, + ] + + expect(finalA!.data.prompt_template).toEqual(expectedTemplates) + expect(finalB!.data.prompt_template).toEqual(expectedTemplates) + }) + + it('converges when parameter lists are edited concurrently', () => { + const baseParameters = [ + { description: 'bb', name: 'aa', required: false, type: 'string' }, + { description: 'dd', name: 'cc', required: false, type: 'string' }, + ] + + const docA = new LoroDoc() + const managerA = getManager(docA) + syncNodes(managerA, [], [createParameterExtractorNode(deepClone(baseParameters))]) + + const snapshot = docA.export({ mode: 'snapshot' }) + const docB = LoroDoc.fromSnapshot(snapshot) + const managerB = getManager(docB) + + const docAUpdate = [ + { description: 'bb updated by A', name: 'aa', required: true, type: 'string' }, + { description: 'dd', name: 'cc', required: false, type: 'string' }, + { description: 'new from A', name: 'ee', required: false, type: 'number' }, + ] + syncNodes( + managerA, + [createParameterExtractorNode(deepClone(baseParameters))], + [createParameterExtractorNode(deepClone(docAUpdate))], + ) + + const docBUpdate = [ + { description: 'bb', name: 'aa', required: false, type: 'string' }, + { description: 'dd updated by B', name: 'cc', required: true, type: 'string' }, + ] + syncNodes( + managerB, + [createParameterExtractorNode(deepClone(baseParameters))], + [createParameterExtractorNode(deepClone(docBUpdate))], + ) + + const updateForA = docB.export({ mode: 'update', from: docA.version() }) + docA.import(updateForA) + + const updateForB = docA.export({ mode: 'update', from: docB.version() }) + docB.import(updateForB) + + const finalA = exportNodes(managerA).find(node => node.id === PARAM_NODE_ID) as + | Node + | undefined + const finalB = exportNodes(managerB).find(node => node.id === PARAM_NODE_ID) as + | Node + | undefined + + expect(finalA).toBeDefined() + expect(finalB).toBeDefined() + + const expectedParameters = [ + { description: 'bb updated by A', name: 'aa', required: true, type: 'string' }, + { description: 'dd updated by B', name: 'cc', required: true, type: 'string' }, + { description: 'new from A', name: 'ee', required: false, type: 'number' }, + ] + + expect(finalA!.data.parameters).toEqual(expectedParameters) + expect(finalB!.data.parameters).toEqual(expectedParameters) + }) +}) diff --git a/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.socket-and-subscriptions.spec.ts b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.socket-and-subscriptions.spec.ts new file mode 100644 index 0000000000..6768e5b4e2 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.socket-and-subscriptions.spec.ts @@ -0,0 +1,1136 @@ +import type { Socket } from 'socket.io-client' +import type { + CollaborationUpdate, + NodePanelPresenceMap, + OnlineUser, + RestoreCompleteData, + RestoreIntentData, +} from '../../types/collaboration' +import type { Edge, Node } from '@/app/components/workflow/types' +import { LoroDoc, LoroMap } from 'loro-crdt' +import { BlockEnum } from '@/app/components/workflow/types' +import { CollaborationManager } from '../collaboration-manager' +import { webSocketClient } from '../websocket-manager' + +type ReactFlowStore = { + getState: () => { + getNodes: () => Node[] + setNodes: (nodes: Node[]) => void + getEdges: () => Edge[] + setEdges: (edges: Edge[]) => void + } +} + +type LoroSubscribeEvent = { + by?: string +} + +type UndoManagerLike = { + canUndo: () => boolean + canRedo: () => boolean + undo: () => boolean + redo: () => boolean + clear: () => void +} + +type MockSocket = { + id: string + connected: boolean + emit: ReturnType + on: ReturnType + off: ReturnType + trigger: (event: string, ...args: unknown[]) => void +} + +type CollaborationManagerInternals = { + doc: LoroDoc | null + nodesMap: LoroMap | null + edgesMap: LoroMap | null + undoManager: UndoManagerLike | null + activeConnections: Set + currentAppId: string | null + reactFlowStore: ReactFlowStore | null + eventEmitter: { + emit: (event: string, ...args: unknown[]) => void + } + isUndoRedoInProgress: boolean + isLeader: boolean + leaderId: string | null + pendingInitialSync: boolean + pendingGraphImportEmit: boolean + rejoinInProgress: boolean + onlineUsers: OnlineUser[] + nodePanelPresence: NodePanelPresenceMap + cursors: Record + graphSyncDiagnostics: unknown[] + setNodesAnomalyLogs: unknown[] + handleSessionUnauthorized: () => void + forceDisconnect: () => void + setupSocketEventListeners: (socket: Socket) => void + setupSubscriptions: () => void + scheduleGraphImportEmit: () => void + emitGraphResyncRequest: () => void + broadcastCurrentGraph: () => void + requestInitialSyncIfNeeded: () => void + cleanupNodePanelPresence: (activeClientIds: Set) => void + recordGraphSyncDiagnostic: ( + stage: 'nodes_subscribe' | 'edges_subscribe' | 'nodes_import_apply' | 'edges_import_apply' | 'schedule_graph_import_emit' | 'graph_import_emit' | 'start_import_log' | 'finalize_import_log', + status: 'triggered' | 'skipped' | 'applied' | 'queued' | 'emitted' | 'snapshot', + reason?: string, + details?: Record, + ) => void + captureSetNodesAnomaly: (oldNodes: Node[], newNodes: Node[], source: string) => void +} + +const getManagerInternals = (manager: CollaborationManager): CollaborationManagerInternals => + manager as unknown as CollaborationManagerInternals + +const createNode = (id: string, title = `Node-${id}`): Node => ({ + id, + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.Start, + title, + desc: '', + }, +}) + +const createEdge = (id: string, source: string, target: string): Edge => ({ + id, + source, + target, + type: 'custom', + data: { + sourceType: BlockEnum.Start, + targetType: BlockEnum.End, + }, +}) + +const createMockSocket = (id = 'socket-1'): MockSocket => { + const handlers = new Map void>() + + return { + id, + connected: true, + emit: vi.fn(), + on: vi.fn((event: string, handler: (...args: unknown[]) => void) => { + handlers.set(event, handler) + }), + off: vi.fn(), + trigger: (event: string, ...args: unknown[]) => { + const handler = handlers.get(event) + if (handler) + handler(...args) + }, + } +} + +const setupManagerWithDoc = () => { + const manager = new CollaborationManager() + const doc = new LoroDoc() + const internals = getManagerInternals(manager) + internals.doc = doc + internals.nodesMap = doc.getMap('nodes') + internals.edgesMap = doc.getMap('edges') + return { manager, internals } +} + +describe('CollaborationManager socket and subscription behavior', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('emits cursor/sync/workflow events via collaboration_event when connected', () => { + const { manager, internals } = setupManagerWithDoc() + const socket = createMockSocket('socket-connected') + + internals.currentAppId = 'app-1' + vi.spyOn(webSocketClient, 'isConnected').mockReturnValue(true) + vi.spyOn(webSocketClient, 'getSocket').mockReturnValue(socket as unknown as Socket) + + manager.emitCursorMove({ x: 11, y: 22, userId: 'u-1', timestamp: Date.now() }) + manager.emitSyncRequest() + manager.emitWorkflowUpdate('wf-1') + + expect(socket.emit).toHaveBeenCalledTimes(3) + const payloads = socket.emit.mock.calls.map(call => call[1] as { type: string, data: Record }) + expect(payloads.map(item => item.type)).toEqual(['mouse_move', 'sync_request', 'workflow_update']) + expect(payloads[0]?.data).toMatchObject({ x: 11, y: 22 }) + expect(payloads[2]?.data).toMatchObject({ appId: 'wf-1' }) + }) + + it('tries to rejoin on unauthorized and forces disconnect on unauthorized ack', () => { + const { internals } = setupManagerWithDoc() + const socket = createMockSocket('socket-rejoin') + const getSocketSpy = vi.spyOn(webSocketClient, 'getSocket').mockReturnValue(socket as unknown as Socket) + const forceDisconnectSpy = vi.spyOn(internals, 'forceDisconnect').mockImplementation(() => undefined) + + internals.currentAppId = 'app-rejoin' + internals.rejoinInProgress = true + internals.handleSessionUnauthorized() + expect(socket.emit).not.toHaveBeenCalled() + + internals.rejoinInProgress = false + internals.handleSessionUnauthorized() + expect(socket.emit).toHaveBeenCalledWith( + 'user_connect', + { workflow_id: 'app-rejoin' }, + expect.any(Function), + ) + + const ack = socket.emit.mock.calls[0]?.[2] as ((...ackArgs: unknown[]) => void) | undefined + expect(ack).toBeDefined() + ack?.({ msg: 'unauthorized' }) + + expect(forceDisconnectSpy).toHaveBeenCalledTimes(1) + expect(internals.rejoinInProgress).toBe(false) + expect(getSocketSpy).toHaveBeenCalled() + }) + + it('routes collaboration_update payloads to corresponding event channels', () => { + const { manager, internals } = setupManagerWithDoc() + const socket = createMockSocket('socket-events') + + const broadcastSpy = vi.spyOn(internals, 'broadcastCurrentGraph').mockImplementation(() => undefined) + internals.isLeader = true + internals.setupSocketEventListeners(socket as unknown as Socket) + + const varsFeatureHandler = vi.fn() + const appStateHandler = vi.fn() + const appMetaHandler = vi.fn() + const appPublishHandler = vi.fn() + const mcpHandler = vi.fn() + const workflowUpdateHandler = vi.fn() + const commentsHandler = vi.fn() + const restoreRequestHandler = vi.fn() + const restoreIntentHandler = vi.fn() + const restoreCompleteHandler = vi.fn() + const historyHandler = vi.fn() + const syncRequestHandler = vi.fn() + let latestPresence: NodePanelPresenceMap | null = null + let latestCursors: Record | null = null + + manager.onVarsAndFeaturesUpdate(varsFeatureHandler) + manager.onAppStateUpdate(appStateHandler) + manager.onAppMetaUpdate(appMetaHandler) + manager.onAppPublishUpdate(appPublishHandler) + manager.onMcpServerUpdate(mcpHandler) + manager.onWorkflowUpdate(workflowUpdateHandler) + manager.onCommentsUpdate(commentsHandler) + manager.onRestoreRequest(restoreRequestHandler) + manager.onRestoreIntent(restoreIntentHandler) + manager.onRestoreComplete(restoreCompleteHandler) + manager.onHistoryAction(historyHandler) + manager.onSyncRequest(syncRequestHandler) + manager.onNodePanelPresenceUpdate((presence) => { + latestPresence = presence + }) + manager.onCursorUpdate((cursors) => { + latestCursors = cursors as Record + }) + + const baseUpdate: Pick = { + userId: 'u-1', + timestamp: 1000, + } + + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'mouse_move', + data: { x: 1, y: 2 }, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'vars_and_features_update', + data: { value: 1 }, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'app_state_update', + data: { value: 2 }, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'app_meta_update', + data: { value: 3 }, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'app_publish_update', + data: { value: 4 }, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'mcp_server_update', + data: { value: 5 }, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'workflow_update', + data: { appId: 'wf', timestamp: 9 }, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'comments_update', + data: { appId: 'wf', timestamp: 10 }, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'node_panel_presence', + data: { + nodeId: 'n-1', + action: 'open', + user: { userId: 'u-1', username: 'Alice' }, + clientId: 'socket-events', + timestamp: 11, + }, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'sync_request', + data: {}, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'graph_resync_request', + data: {}, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'workflow_restore_request', + data: { versionId: 'v1', initiatorUserId: 'u-1', initiatorName: 'Alice', graphData: { nodes: [], edges: [] } } as unknown as Record, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'workflow_restore_intent', + data: { versionId: 'v1', initiatorUserId: 'u-1', initiatorName: 'Alice' } as unknown as Record, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'workflow_restore_complete', + data: { versionId: 'v1', success: true } as unknown as Record, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'workflow_history_action', + data: { action: 'undo' }, + } satisfies CollaborationUpdate) + socket.trigger('collaboration_update', { + ...baseUpdate, + type: 'workflow_history_action', + data: {}, + } satisfies CollaborationUpdate) + + expect(latestCursors).toMatchObject({ + 'u-1': { x: 1, y: 2, userId: 'u-1' }, + }) + expect(varsFeatureHandler).toHaveBeenCalledTimes(1) + expect(appStateHandler).toHaveBeenCalledTimes(1) + expect(appMetaHandler).toHaveBeenCalledTimes(1) + expect(appPublishHandler).toHaveBeenCalledTimes(1) + expect(mcpHandler).toHaveBeenCalledTimes(1) + expect(workflowUpdateHandler).toHaveBeenCalledWith({ appId: 'wf', timestamp: 9 }) + expect(commentsHandler).toHaveBeenCalledWith({ appId: 'wf', timestamp: 10 }) + expect(latestPresence).toMatchObject({ 'n-1': { 'socket-events': { userId: 'u-1' } } }) + expect(syncRequestHandler).toHaveBeenCalledTimes(1) + expect(broadcastSpy).toHaveBeenCalledTimes(1) + expect(restoreRequestHandler).toHaveBeenCalledTimes(1) + expect(restoreIntentHandler).toHaveBeenCalledWith({ versionId: 'v1', initiatorUserId: 'u-1', initiatorName: 'Alice' } satisfies RestoreIntentData) + expect(restoreCompleteHandler).toHaveBeenCalledWith({ versionId: 'v1', success: true } satisfies RestoreCompleteData) + expect(historyHandler).toHaveBeenCalledWith({ action: 'undo', userId: 'u-1' }) + }) + + it('processes online_users/status/connect/disconnect/error socket events', () => { + const { manager, internals } = setupManagerWithDoc() + const socket = createMockSocket('socket-state') + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => undefined) + const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => undefined) + const emitGraphResyncRequestSpy = vi.spyOn(internals, 'emitGraphResyncRequest').mockImplementation(() => undefined) + + internals.cursors = { + stale: { + x: 1, + y: 1, + userId: 'offline-user', + timestamp: 1, + }, + } + internals.nodePanelPresence = { + 'n-1': { + 'offline-client': { + userId: 'offline-user', + username: 'Offline', + clientId: 'offline-client', + timestamp: 1, + }, + }, + } + + internals.setupSocketEventListeners(socket as unknown as Socket) + + const onlineUsersHandler = vi.fn() + const leaderChangeHandler = vi.fn() + const stateChanges: Array<{ isConnected: boolean, disconnectReason?: string, error?: string }> = [] + manager.onOnlineUsersUpdate(onlineUsersHandler) + manager.onLeaderChange(leaderChangeHandler) + manager.onStateChange((state) => { + stateChanges.push(state as { isConnected: boolean, disconnectReason?: string, error?: string }) + }) + + socket.trigger('online_users', { users: 'invalid-structure' }) + expect(warnSpy).toHaveBeenCalled() + + socket.trigger('online_users', { + users: [{ + user_id: 'online-user', + username: 'Alice', + avatar: '', + sid: 'socket-state', + }], + leader: 'leader-1', + }) + + expect(onlineUsersHandler).toHaveBeenCalledWith([{ + user_id: 'online-user', + username: 'Alice', + avatar: '', + sid: 'socket-state', + } satisfies OnlineUser]) + expect(internals.cursors).toEqual({}) + expect(internals.nodePanelPresence).toEqual({}) + expect(internals.leaderId).toBe('leader-1') + + socket.trigger('status', { isLeader: 'invalid' }) + expect(warnSpy).toHaveBeenCalled() + + internals.pendingInitialSync = true + internals.isLeader = false + socket.trigger('status', { isLeader: false }) + expect(emitGraphResyncRequestSpy).toHaveBeenCalledTimes(1) + expect(internals.pendingInitialSync).toBe(false) + + socket.trigger('status', { isLeader: true }) + expect(leaderChangeHandler).toHaveBeenCalledWith(true) + + socket.trigger('connect') + socket.trigger('disconnect', 'transport close') + socket.trigger('connect_error', new Error('connect failed')) + socket.trigger('error', new Error('generic socket error')) + + expect(stateChanges).toEqual([ + { isConnected: true }, + { isConnected: false, disconnectReason: 'transport close' }, + { isConnected: false, error: 'connect failed' }, + ]) + expect(errorSpy).toHaveBeenCalled() + }) + + it('removes stale node panel viewers by inactive client even when the same user is still online in another tab', () => { + const { manager, internals } = setupManagerWithDoc() + const socket = createMockSocket('socket-tab-b') + + internals.nodePanelPresence = { + 'n-1': { + 'socket-tab-a': { + userId: 'u-1', + username: 'Alice', + clientId: 'socket-tab-a', + timestamp: 1, + }, + 'socket-tab-b': { + userId: 'u-1', + username: 'Alice', + clientId: 'socket-tab-b', + timestamp: 2, + }, + }, + } + + const presenceUpdates: NodePanelPresenceMap[] = [] + manager.onNodePanelPresenceUpdate((presence) => { + presenceUpdates.push(presence) + }) + + internals.setupSocketEventListeners(socket as unknown as Socket) + + socket.trigger('online_users', { + users: [{ + user_id: 'u-1', + username: 'Alice', + avatar: '', + sid: 'socket-tab-b', + }], + }) + + expect(internals.nodePanelPresence).toEqual({ + 'n-1': { + 'socket-tab-b': { + userId: 'u-1', + username: 'Alice', + clientId: 'socket-tab-b', + timestamp: 2, + }, + }, + }) + expect(presenceUpdates.at(-1)).toEqual({ + 'n-1': { + 'socket-tab-b': { + userId: 'u-1', + username: 'Alice', + clientId: 'socket-tab-b', + timestamp: 2, + }, + }, + }) + }) + + it('setupSubscriptions applies import updates and emits merged graph payload', () => { + const { manager, internals } = setupManagerWithDoc() + const rafSpy = vi.spyOn(globalThis, 'requestAnimationFrame').mockImplementation((callback: FrameRequestCallback) => { + callback(0) + return 1 + }) + const initialNode = { + ...createNode('n-1', 'Initial'), + data: { + ...createNode('n-1', 'Initial').data, + selected: false, + }, + } + const remoteNode = { + ...initialNode, + data: { + ...initialNode.data, + title: 'RemoteTitle', + }, + } + const edge = createEdge('e-1', 'n-1', 'n-2') + + manager.setNodes([], [initialNode]) + manager.setEdges([], [edge]) + manager.setNodes([initialNode], [remoteNode]) + + let reactFlowNodes: Node[] = [{ + ...initialNode, + data: ({ + ...initialNode.data, + selected: true, + _localMeta: 'keep-me', + } as Node['data'] & Record), + }] + let reactFlowEdges: Edge[] = [edge] + const setNodesSpy = vi.fn((nodes: Node[]) => { + reactFlowNodes = nodes + }) + const setEdgesSpy = vi.fn((edges: Edge[]) => { + reactFlowEdges = edges + }) + internals.reactFlowStore = { + getState: () => ({ + getNodes: () => reactFlowNodes, + setNodes: setNodesSpy, + getEdges: () => reactFlowEdges, + setEdges: setEdgesSpy, + }), + } + + let nodesSubscribeHandler: (event: LoroSubscribeEvent) => void = () => {} + let edgesSubscribeHandler: (event: LoroSubscribeEvent) => void = () => {} + vi.spyOn(internals.nodesMap as object as { subscribe: (handler: (event: LoroSubscribeEvent) => void) => void }, 'subscribe') + .mockImplementation((handler: (event: LoroSubscribeEvent) => void) => { + nodesSubscribeHandler = handler + }) + vi.spyOn(internals.edgesMap as object as { subscribe: (handler: (event: LoroSubscribeEvent) => void) => void }, 'subscribe') + .mockImplementation((handler: (event: LoroSubscribeEvent) => void) => { + edgesSubscribeHandler = handler + }) + + const importedGraphs: Array<{ nodes: Node[], edges: Edge[] }> = [] + manager.onGraphImport((payload) => { + importedGraphs.push(payload) + }) + + internals.setupSubscriptions() + nodesSubscribeHandler({ by: 'local' }) + nodesSubscribeHandler({ by: 'import' }) + edgesSubscribeHandler({ by: 'import' }) + + expect(setNodesSpy).toHaveBeenCalled() + expect(setEdgesSpy).toHaveBeenCalled() + expect(importedGraphs.length).toBeGreaterThan(0) + const importedGraph = importedGraphs.at(-1) + if (!importedGraph) + throw new Error('imported graph should exist') + expect(importedGraph.nodes[0]?.data).toMatchObject({ + title: 'RemoteTitle', + selected: true, + _localMeta: 'keep-me', + }) + + internals.pendingGraphImportEmit = true + internals.scheduleGraphImportEmit() + expect(internals.pendingGraphImportEmit).toBe(true) + + internals.reactFlowStore = null + nodesSubscribeHandler({ by: 'import' }) + + rafSpy.mockRestore() + }) + + it('respects diagnostic and anomaly log limits', () => { + const { internals } = setupManagerWithDoc() + const oldNode = createNode('old') + + for (let i = 0; i < 401; i += 1) { + internals.recordGraphSyncDiagnostic('nodes_subscribe', 'triggered', undefined, { index: i }) + } + for (let i = 0; i < 101; i += 1) { + internals.captureSetNodesAnomaly([oldNode], [], `source-${i}`) + } + + expect(internals.graphSyncDiagnostics).toHaveLength(400) + expect(internals.setNodesAnomalyLogs).toHaveLength(100) + + // no anomaly should be recorded when node count and start node invariants are unchanged + const beforeLength = internals.setNodesAnomalyLogs.length + internals.captureSetNodesAnomaly([oldNode], [createNode('old')], 'no-op') + expect(internals.setNodesAnomalyLogs).toHaveLength(beforeLength) + }) + + it('guards graph resync emission and graph snapshot broadcast', () => { + const { manager, internals } = setupManagerWithDoc() + const socket = createMockSocket('socket-resync') + const sendGraphEventSpy = vi.spyOn( + manager as unknown as { sendGraphEvent: (payload: Uint8Array) => void }, + 'sendGraphEvent', + ).mockImplementation(() => undefined) + + internals.currentAppId = null + vi.spyOn(webSocketClient, 'isConnected').mockReturnValue(false) + vi.spyOn(webSocketClient, 'getSocket').mockReturnValue(socket as unknown as Socket) + internals.emitGraphResyncRequest() + expect(socket.emit).not.toHaveBeenCalled() + + internals.currentAppId = 'app-graph' + vi.spyOn(webSocketClient, 'isConnected').mockReturnValue(true) + internals.emitGraphResyncRequest() + expect(socket.emit).toHaveBeenCalledWith( + 'collaboration_event', + expect.objectContaining({ type: 'graph_resync_request' }), + expect.any(Function), + ) + + internals.doc = null + internals.broadcastCurrentGraph() + expect(sendGraphEventSpy).not.toHaveBeenCalled() + + const doc = new LoroDoc() + internals.doc = doc + internals.nodesMap = doc.getMap('nodes') + internals.edgesMap = doc.getMap('edges') + internals.broadcastCurrentGraph() + expect(sendGraphEventSpy).not.toHaveBeenCalled() + + manager.setNodes([], [createNode('n-broadcast')]) + internals.broadcastCurrentGraph() + expect(sendGraphEventSpy).toHaveBeenCalledTimes(1) + }) + + it('covers connect lifecycle branches including reconnect and force disconnect cleanup', async () => { + const { manager, internals } = setupManagerWithDoc() + const socket = createMockSocket('socket-connect') + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => undefined) + const disconnectSpy = vi.spyOn(webSocketClient, 'disconnect').mockImplementation(() => undefined) + vi.spyOn(webSocketClient, 'getSocket').mockReturnValue(socket as unknown as Socket) + vi.spyOn(webSocketClient, 'connect').mockReturnValue(socket as unknown as Socket) + + manager.init('app-1', undefined as unknown as ReactFlowStore) + expect(warnSpy).toHaveBeenCalledWith( + 'CollaborationManager.init called without reactFlowStore, deferring to connect()', + ) + + const reactFlowStore = { + getState: () => ({ + getNodes: () => [], + setNodes: vi.fn(), + getEdges: () => [], + setEdges: vi.fn(), + }), + } + + const eventEmitSpy = vi.spyOn(internals.eventEmitter, 'emit') + + const firstConnectionId = await manager.connect('app-1', reactFlowStore) + expect(firstConnectionId).toBeTruthy() + expect(internals.currentAppId).toBe('app-1') + expect(internals.activeConnections.size).toBe(1) + + const secondConnectionId = await manager.connect('app-1') + expect(secondConnectionId).toBeTruthy() + expect(disconnectSpy).not.toHaveBeenCalled() + + await manager.connect('app-2', reactFlowStore) + expect(disconnectSpy).toHaveBeenCalledWith('app-1') + expect(internals.currentAppId).toBe('app-2') + + internals.isLeader = true + manager.disconnect(secondConnectionId) + manager.disconnect(firstConnectionId) + expect(disconnectSpy).toHaveBeenCalledWith('app-2') + expect(eventEmitSpy).toHaveBeenCalledWith('leaderChange', false) + expect(internals.currentAppId).toBeNull() + expect(internals.activeConnections.size).toBe(0) + }) + + it('covers setNodes/setEdges guards and destroy delegation', () => { + const { manager, internals } = setupManagerWithDoc() + const destroyDisconnectSpy = vi.spyOn( + manager as unknown as { disconnect: () => void }, + 'disconnect', + ).mockImplementation(() => undefined) + + manager.setNodes([], [createNode('n-guard')]) + manager.setEdges([], [createEdge('e-guard', 'n-a', 'n-b')]) + + const commitSpy = vi.fn() + internals.doc = { commit: commitSpy } as unknown as LoroDoc + const syncNodesSpy = vi.spyOn( + internals as unknown as { syncNodes: (oldNodes: Node[], newNodes: Node[]) => void }, + 'syncNodes', + ).mockImplementation(() => undefined) + const syncEdgesSpy = vi.spyOn( + internals as unknown as { syncEdges: (oldEdges: Edge[], newEdges: Edge[]) => void }, + 'syncEdges', + ).mockImplementation(() => undefined) + + internals.isUndoRedoInProgress = true + manager.setNodes([], [createNode('n-skip')]) + manager.setEdges([], [createEdge('e-skip', 'n-a', 'n-b')]) + + internals.isUndoRedoInProgress = false + manager.setNodes([], [createNode('n-apply')]) + manager.setEdges([], [createEdge('e-apply', 'n-a', 'n-b')]) + + expect(syncNodesSpy).toHaveBeenCalledTimes(1) + expect(syncEdgesSpy).toHaveBeenCalledTimes(1) + expect(commitSpy).toHaveBeenCalledTimes(2) + + manager.destroy() + expect(destroyDisconnectSpy).toHaveBeenCalledTimes(1) + }) + + it('covers emit guards and node panel presence local updates', () => { + const { manager, internals } = setupManagerWithDoc() + const socket = createMockSocket('socket-presence') + const sendSpy = vi.spyOn( + manager as unknown as { sendCollaborationEvent: (payload: unknown) => void }, + 'sendCollaborationEvent', + ).mockImplementation(() => undefined) + const isConnectedSpy = vi.spyOn(webSocketClient, 'isConnected').mockReturnValue(false) + const getSocketSpy = vi.spyOn(webSocketClient, 'getSocket').mockReturnValue(null) + + manager.emitCursorMove({ x: 1, y: 1, userId: 'u-1', timestamp: 1 }) + manager.emitSyncRequest() + manager.emitWorkflowUpdate('app-1') + manager.emitNodePanelPresence('node-1', true, { userId: 'u-1', username: 'Alice' }) + expect(sendSpy).not.toHaveBeenCalled() + + internals.currentAppId = 'app-1' + isConnectedSpy.mockReturnValue(true) + manager.emitCursorMove({ x: 2, y: 2, userId: 'u-2', timestamp: 2 }) + expect(sendSpy).not.toHaveBeenCalled() + + getSocketSpy.mockReturnValue(socket as unknown as Socket) + manager.emitNodePanelPresence('', true, { userId: 'u-3', username: 'Bob' }) + manager.emitNodePanelPresence('node-2', true, { userId: '', username: 'Bob' }) + expect(sendSpy).not.toHaveBeenCalled() + + let latestPresence: NodePanelPresenceMap | null = null + manager.onNodePanelPresenceUpdate((presence) => { + latestPresence = presence + }) + manager.emitNodePanelPresence('node-3', true, { userId: 'u-4', username: 'Carol' }) + + expect(sendSpy).toHaveBeenCalledTimes(1) + expect(latestPresence).toMatchObject({ + 'node-3': { + 'socket-presence': { + userId: 'u-4', + }, + }, + }) + }) + + it('covers merge/import log helper branches and log cap', () => { + const { manager, internals } = setupManagerWithDoc() + const reactFlowStore = { + getState: () => ({ + getNodes: () => [{ ...createNode('local-node'), selected: true }], + setNodes: vi.fn(), + getEdges: () => [], + setEdges: vi.fn(), + }), + } + internals.reactFlowStore = reactFlowStore + + const helperInternals = internals as unknown as { + mergeLocalNodeState: (nodes: Node[]) => Node[] + snapshotReactFlowGraph: () => { nodes: Node[], edges: Edge[] } + startImportLog: (source: 'nodes' | 'edges') => void + finalizeImportLog: () => void + } + + const merged = helperInternals.mergeLocalNodeState([createNode('remote-node')]) + expect(merged[0]?.id).toBe('remote-node') + + const mergedWithLocalSelection = helperInternals.mergeLocalNodeState([createNode('local-node')]) + expect(mergedWithLocalSelection[0]?.data.selected).toBe(true) + + internals.reactFlowStore = null + const snapshot = helperInternals.snapshotReactFlowGraph() + expect(snapshot).toEqual({ nodes: manager.getNodes(), edges: manager.getEdges() }) + + helperInternals.startImportLog('nodes') + helperInternals.startImportLog('edges') + helperInternals.finalizeImportLog() + helperInternals.finalizeImportLog() + + for (let i = 0; i < 25; i += 1) { + helperInternals.startImportLog('nodes') + helperInternals.finalizeImportLog() + } + + expect(manager.getGraphImportLog()).toHaveLength(20) + }) + + it('covers socket handler catch branches and initial sync leader short-circuit', () => { + const { internals } = setupManagerWithDoc() + const socket = createMockSocket('socket-catch') + const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => undefined) + const cleanupSpy = vi.spyOn(internals, 'cleanupNodePanelPresence').mockImplementation(() => { + throw new Error('cleanup-failed') + }) + + internals.setupSocketEventListeners(socket as unknown as Socket) + socket.trigger('online_users', { + users: [{ + user_id: 'u-1', + username: 'Alice', + avatar: '', + sid: 'socket-catch', + }], + }) + expect(cleanupSpy).toHaveBeenCalled() + + const requestSyncSpy = vi.spyOn(internals, 'requestInitialSyncIfNeeded').mockImplementationOnce(() => { + throw new Error('status-failed') + }) + socket.trigger('status', { isLeader: false }) + expect(requestSyncSpy).toHaveBeenCalled() + expect(errorSpy).toHaveBeenCalled() + + const resyncSpy = vi.spyOn(internals, 'emitGraphResyncRequest').mockImplementation(() => undefined) + internals.pendingInitialSync = true + internals.isLeader = true + internals.requestInitialSyncIfNeeded() + expect(internals.pendingInitialSync).toBe(false) + expect(resyncSpy).not.toHaveBeenCalled() + }) + + it('covers graph broadcast guard and error path', () => { + const { manager, internals } = setupManagerWithDoc() + const socket = createMockSocket('socket-broadcast') + const sendGraphEventSpy = vi.spyOn( + manager as unknown as { sendGraphEvent: (payload: Uint8Array) => void }, + 'sendGraphEvent', + ).mockImplementation(() => undefined) + const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => undefined) + + internals.currentAppId = 'app-broadcast' + const isConnectedSpy = vi.spyOn(webSocketClient, 'isConnected').mockReturnValue(false) + const getSocketSpy = vi.spyOn(webSocketClient, 'getSocket').mockReturnValue(null) + internals.broadcastCurrentGraph() + expect(sendGraphEventSpy).not.toHaveBeenCalled() + + isConnectedSpy.mockReturnValue(true) + internals.broadcastCurrentGraph() + expect(sendGraphEventSpy).not.toHaveBeenCalled() + + const doc = new LoroDoc() + internals.doc = doc + internals.nodesMap = doc.getMap('nodes') + internals.edgesMap = doc.getMap('edges') + manager.setNodes([], [createNode('node-error')]) + getSocketSpy.mockReturnValue(socket as unknown as Socket) + vi.spyOn(internals.doc, 'export').mockImplementation(() => { + throw new Error('export-failed') + }) + + internals.broadcastCurrentGraph() + expect(sendGraphEventSpy).not.toHaveBeenCalled() + expect(errorSpy).toHaveBeenCalledWith('Failed to broadcast graph snapshot:', expect.any(Error)) + }) + + it('covers private guard branches for socket helpers and container migration', async () => { + const manager = new CollaborationManager() + const internals = getManagerInternals(manager) + const socket = createMockSocket('socket-private') + const getSocketSpy = vi.spyOn(webSocketClient, 'getSocket').mockReturnValue(null) + vi.spyOn(webSocketClient, 'connect').mockReturnValue(socket as unknown as Socket) + + type PrivateInternals = { + getActiveSocket: () => Socket | null + sendCollaborationEvent: (payload: CollaborationUpdate) => void + sendGraphEvent: (payload: Uint8Array) => void + getNodeContainer: (nodeId: string) => LoroMap + populateNodeContainer: (container: LoroMap, node: Node) => void + mergeLocalNodeState: (nodes: Node[]) => Node[] + requestInitialSyncIfNeeded: () => void + } + const privateInternals = internals as unknown as PrivateInternals + + expect(privateInternals.getActiveSocket()).toBeNull() + privateInternals.sendCollaborationEvent({ + type: 'sync_request', + data: {}, + timestamp: Date.now(), + userId: 'u-1', + } satisfies CollaborationUpdate) + privateInternals.sendGraphEvent(new Uint8Array([1, 2])) + + internals.currentAppId = 'app-private' + expect(privateInternals.getActiveSocket()).toBeNull() + + getSocketSpy.mockReturnValue(socket as unknown as Socket) + privateInternals.sendCollaborationEvent({ + type: 'sync_request', + data: {}, + timestamp: Date.now(), + userId: 'u-1', + } satisfies CollaborationUpdate) + privateInternals.sendGraphEvent(new Uint8Array([3, 4])) + expect(socket.emit).toHaveBeenCalled() + + expect(() => privateInternals.getNodeContainer('no-map')).toThrow('Nodes map not initialized') + + const doc = new LoroDoc() + internals.doc = doc + internals.nodesMap = doc.getMap('nodes') + internals.edgesMap = doc.getMap('edges') + internals.nodesMap.set('legacy-node', { + id: 'legacy-node', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.Start, + title: 'Legacy', + desc: '', + }, + } as unknown as Record) + privateInternals.getNodeContainer('legacy-node') + + const modernContainer = privateInternals.getNodeContainer('modern-node') + const dataContainer = modernContainer.setContainer('data', new LoroMap()) as LoroMap + dataContainer.set('_internal_only', 'do-not-sync') + privateInternals.populateNodeContainer(modernContainer, createNode('modern-node')) + + const noLocalState = privateInternals.mergeLocalNodeState([createNode('no-local')]) + expect(noLocalState[0]?.id).toBe('no-local') + + internals.pendingInitialSync = false + const resyncSpy = vi.spyOn(internals, 'emitGraphResyncRequest').mockImplementation(() => undefined) + privateInternals.requestInitialSyncIfNeeded() + expect(resyncSpy).not.toHaveBeenCalled() + + const reactFlowStore = { + getState: () => ({ + getNodes: () => [], + setNodes: vi.fn(), + getEdges: () => [], + setEdges: vi.fn(), + }), + } + manager.init('app-init-with-store', reactFlowStore) + await manager.connect('app-no-store') + await manager.connect('app-no-store', reactFlowStore) + expect(internals.reactFlowStore).toBe(reactFlowStore) + }) + + it('covers undo/redo and sync negative branches', () => { + const { manager, internals } = setupManagerWithDoc() + const undoManager = { + canUndo: vi.fn(() => true), + canRedo: vi.fn(() => true), + undo: vi.fn(() => false), + redo: vi.fn(() => false), + clear: vi.fn(), + } + internals.undoManager = undoManager + internals.reactFlowStore = null + + expect(manager.undo()).toBe(false) + expect(manager.redo()).toBe(false) + + undoManager.canUndo.mockReturnValue(false) + undoManager.canRedo.mockReturnValue(false) + expect(manager.undo()).toBe(false) + expect(manager.redo()).toBe(false) + + internals.undoManager = null + manager.clearUndoStack() + + const privateInternals = internals as unknown as { + syncNodes: (oldNodes: Node[], newNodes: Node[]) => void + syncEdges: (oldEdges: Edge[], newEdges: Edge[]) => void + } + + const oldNode = createNode('old-node') + internals.doc = null + internals.nodesMap = null + privateInternals.syncNodes([oldNode], []) + + const doc = new LoroDoc() + internals.doc = doc + internals.nodesMap = doc.getMap('nodes') + internals.edgesMap = doc.getMap('edges') + + privateInternals.syncNodes([], [oldNode]) + privateInternals.syncNodes([oldNode], []) + privateInternals.syncNodes([oldNode], [oldNode]) + privateInternals.syncNodes([createNode('old-node')], [createNode('old-node')]) + + internals.edgesMap = null + privateInternals.syncEdges([createEdge('e-old', 'a', 'b')], []) + }) + + it('covers import subscription skip branches', () => { + const { internals } = setupManagerWithDoc() + const reactFlowStore = { + getState: () => ({ + getNodes: () => [], + setNodes: vi.fn(), + getEdges: () => [], + setEdges: vi.fn(), + }), + } + internals.reactFlowStore = reactFlowStore + + let nodesHandler: (event: LoroSubscribeEvent) => void = () => {} + let edgesHandler: (event: LoroSubscribeEvent) => void = () => {} + vi.spyOn(internals.nodesMap as object as { subscribe: (handler: (event: LoroSubscribeEvent) => void) => void }, 'subscribe') + .mockImplementation((handler: (event: LoroSubscribeEvent) => void) => { + nodesHandler = handler + }) + vi.spyOn(internals.edgesMap as object as { subscribe: (handler: (event: LoroSubscribeEvent) => void) => void }, 'subscribe') + .mockImplementation((handler: (event: LoroSubscribeEvent) => void) => { + edgesHandler = handler + }) + + internals.setupSubscriptions() + internals.isUndoRedoInProgress = true + nodesHandler({ by: 'import' }) + edgesHandler({ by: 'import' }) + + internals.isUndoRedoInProgress = false + edgesHandler({ by: 'local' }) + internals.reactFlowStore = null + edgesHandler({ by: 'import' }) + }) + + it('covers missing-doc guards and unauthorized rejoin early returns', () => { + const manager = new CollaborationManager() + const internals = getManagerInternals(manager) + const getSocketSpy = vi.spyOn(webSocketClient, 'getSocket').mockReturnValue(null) + + manager.setNodes([], [createNode('doc-missing-node')]) + manager.setEdges([], [createEdge('doc-missing-edge', 'a', 'b')]) + expect(manager.getNodes()).toEqual([]) + + internals.handleSessionUnauthorized() + internals.currentAppId = 'app-unauthorized' + internals.handleSessionUnauthorized() + expect(getSocketSpy).toHaveBeenCalled() + }) + + it('covers undo manager push/pop metadata path with real connect flow', async () => { + const manager = new CollaborationManager() + const internals = getManagerInternals(manager) + const socket = createMockSocket('socket-undo-pop') + const rafSpy = vi.spyOn(globalThis, 'requestAnimationFrame').mockImplementation((callback: FrameRequestCallback) => { + callback(0) + return 1 + }) + vi.useFakeTimers() + vi.spyOn(webSocketClient, 'connect').mockReturnValue(socket as unknown as Socket) + vi.spyOn(webSocketClient, 'disconnect').mockImplementation(() => undefined) + vi.spyOn(webSocketClient, 'getSocket').mockReturnValue(socket as unknown as Socket) + + let nodes: Node[] = [ + { + ...createNode('undo-node-1'), + data: { + ...createNode('undo-node-1').data, + selected: true, + }, + }, + createNode('undo-node-2'), + ] + let edges: Edge[] = [] + const setNodesSpy = vi.fn((nextNodes: Node[]) => { + nodes = nextNodes + }) + const setEdgesSpy = vi.fn((nextEdges: Edge[]) => { + edges = nextEdges + }) + const reactFlowStore = { + getState: () => ({ + getNodes: () => nodes, + setNodes: setNodesSpy, + getEdges: () => edges, + setEdges: setEdgesSpy, + }), + } + + const undoStateSpy = vi.fn() + manager.onUndoRedoStateChange(undoStateSpy) + + const connectionId = await manager.connect('app-undo-pop', reactFlowStore) + manager.setNodes([], nodes) + const nextNodes = nodes.map((node) => { + if (node.id === 'undo-node-1') { + return { + ...node, + data: { + ...node.data, + selected: false, + }, + } + } + if (node.id === 'undo-node-2') { + return { + ...node, + data: { + ...node.data, + selected: true, + }, + } + } + return node + }) + manager.setNodes(nodes, nextNodes) + nodes = nextNodes + + expect(manager.canUndo()).toBe(true) + expect(manager.undo()).toBe(true) + + vi.runAllTimers() + expect(setNodesSpy).toHaveBeenCalled() + expect(undoStateSpy).toHaveBeenCalled() + + manager.disconnect(connectionId) + expect(internals.isUndoRedoInProgress).toBe(false) + vi.useRealTimers() + rafSpy.mockRestore() + }) +}) diff --git a/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.test.ts b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.test.ts new file mode 100644 index 0000000000..957ef76e86 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.test.ts @@ -0,0 +1,763 @@ +import type { LoroMap } from 'loro-crdt/base64' +import type { + NodePanelPresenceMap, + NodePanelPresenceUser, +} from '@/app/components/workflow/collaboration/types/collaboration' +import type { CommonNodeType, Edge, Node } from '@/app/components/workflow/types' +import { LoroDoc } from 'loro-crdt/base64' +import { Position } from 'reactflow' +import { CollaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' +import { BlockEnum } from '@/app/components/workflow/types' + +const NODE_ID = '1760342909316' + +type WorkflowVariable = { + default: string + hint: string + label: string + max_length: number + options: string[] + placeholder: string + required: boolean + type: string + variable: string +} + +type PromptTemplateItem = { + id: string + role: string + text: string +} + +type ParameterItem = { + description: string + name: string + required: boolean + type: string +} + +type NodePanelPresenceEventData = { + nodeId: string + action: 'open' | 'close' + user: NodePanelPresenceUser + clientId: string + timestamp?: number +} + +type StartNodeData = { + variables: WorkflowVariable[] +} + +type LLMNodeData = { + context: { + enabled: boolean + variable_selector: string[] + } + model: { + mode: string + name: string + provider: string + completion_params: { + temperature: number + } + } + prompt_template: PromptTemplateItem[] + vision: { + enabled: boolean + } +} + +type ParameterExtractorNodeData = { + model: { + mode: string + name: string + provider: string + completion_params: { + temperature: number + } + } + parameters: ParameterItem[] + query: unknown[] + reasoning_mode: string + vision: { + enabled: boolean + } +} + +type LLMNodeDataWithUnknownTemplate = Omit & { + prompt_template: unknown +} + +type ManagerDoc = LoroDoc | { commit: () => void } + +type CollaborationManagerInternals = { + doc: ManagerDoc + nodesMap: LoroMap + edgesMap: LoroMap + syncNodes: (oldNodes: Node[], newNodes: Node[]) => void + syncEdges: (oldEdges: Edge[], newEdges: Edge[]) => void + applyNodePanelPresenceUpdate: (update: NodePanelPresenceEventData) => void + forceDisconnect: () => void + activeConnections: Set + isUndoRedoInProgress: boolean +} + +const createVariable = (name: string, overrides: Partial = {}): WorkflowVariable => ({ + default: '', + hint: '', + label: name, + max_length: 48, + options: [], + placeholder: '', + required: true, + type: 'text-input', + variable: name, + ...overrides, +}) + +const deepClone = (value: T): T => JSON.parse(JSON.stringify(value)) + +const createNodeSnapshot = (variableNames: string[]): Node => ({ + id: NODE_ID, + type: 'custom', + position: { x: 0, y: 24 }, + positionAbsolute: { x: 0, y: 24 }, + height: 88, + width: 242, + selected: true, + selectable: true, + draggable: true, + sourcePosition: Position.Right, + targetPosition: Position.Left, + data: { + selected: true, + title: '开始', + desc: '', + type: BlockEnum.Start, + variables: variableNames.map(name => createVariable(name)), + }, +}) + +const LLM_NODE_ID = 'llm-node' +const PARAM_NODE_ID = 'param-extractor-node' + +const createLLMNodeSnapshot = (promptTemplates: PromptTemplateItem[]): Node => ({ + id: LLM_NODE_ID, + type: 'custom', + position: { x: 200, y: 120 }, + positionAbsolute: { x: 200, y: 120 }, + height: 320, + width: 460, + selected: false, + selectable: true, + draggable: true, + sourcePosition: Position.Right, + targetPosition: Position.Left, + data: { + type: BlockEnum.LLM, + title: 'LLM', + desc: '', + selected: false, + context: { + enabled: false, + variable_selector: [], + }, + model: { + mode: 'chat', + name: 'gemini-2.5-pro', + provider: 'langgenius/gemini/google', + completion_params: { + temperature: 0.7, + }, + }, + vision: { + enabled: false, + }, + prompt_template: promptTemplates, + }, +}) + +const createParameterExtractorNodeSnapshot = (parameters: ParameterItem[]): Node => ({ + id: PARAM_NODE_ID, + type: 'custom', + position: { x: 420, y: 220 }, + positionAbsolute: { x: 420, y: 220 }, + height: 260, + width: 420, + selected: true, + selectable: true, + draggable: true, + sourcePosition: Position.Right, + targetPosition: Position.Left, + data: { + type: BlockEnum.ParameterExtractor, + title: '参数提取器', + desc: '', + selected: true, + model: { + mode: 'chat', + name: '', + provider: '', + completion_params: { + temperature: 0.7, + }, + }, + reasoning_mode: 'prompt', + parameters, + query: [], + vision: { + enabled: false, + }, + }, +}) + +const getVariables = (node: Node): string[] => { + const data = node.data as CommonNodeType<{ variables?: WorkflowVariable[] }> + const variables = data.variables ?? [] + return variables.map(item => item.variable) +} + +const getVariableObject = (node: Node, name: string): WorkflowVariable | undefined => { + const data = node.data as CommonNodeType<{ variables?: WorkflowVariable[] }> + const variables = data.variables ?? [] + return variables.find(item => item.variable === name) +} + +const getPromptTemplates = (node: Node): PromptTemplateItem[] => { + const data = node.data as CommonNodeType<{ prompt_template?: PromptTemplateItem[] }> + return data.prompt_template ?? [] +} + +const getParameters = (node: Node): ParameterItem[] => { + const data = node.data as CommonNodeType<{ parameters?: ParameterItem[] }> + return data.parameters ?? [] +} + +const getManagerInternals = (manager: CollaborationManager): CollaborationManagerInternals => + manager as unknown as CollaborationManagerInternals + +const setupManager = (): { manager: CollaborationManager, internals: CollaborationManagerInternals } => { + const manager = new CollaborationManager() + const doc = new LoroDoc() + const internals = getManagerInternals(manager) + internals.doc = doc + internals.nodesMap = doc.getMap('nodes') + internals.edgesMap = doc.getMap('edges') + return { manager, internals } +} + +describe('CollaborationManager syncNodes', () => { + let manager: CollaborationManager + let internals: CollaborationManagerInternals + + beforeEach(() => { + const setup = setupManager() + manager = setup.manager + internals = setup.internals + + const initialNode = createNodeSnapshot(['a']) + internals.syncNodes([], [deepClone(initialNode)]) + }) + + it('updates collaborators map when a single client adds a variable', () => { + const base = [createNodeSnapshot(['a'])] + const next = [createNodeSnapshot(['a', 'b'])] + + internals.syncNodes(base, next) + + const stored = (manager.getNodes() as Node[]).find(node => node.id === NODE_ID) + expect(stored).toBeDefined() + expect(getVariables(stored!)).toEqual(['a', 'b']) + }) + + it('applies the latest parallel additions derived from the same base snapshot', () => { + const base = [createNodeSnapshot(['a'])] + const userA = [createNodeSnapshot(['a', 'b'])] + const userB = [createNodeSnapshot(['a', 'c'])] + + internals.syncNodes(base, userA) + + const afterUserA = (manager.getNodes() as Node[]).find(node => node.id === NODE_ID) + expect(getVariables(afterUserA!)).toEqual(['a', 'b']) + + internals.syncNodes(base, userB) + + const finalNode = (manager.getNodes() as Node[]).find(node => node.id === NODE_ID) + const finalVariables = getVariables(finalNode!) + + expect(finalVariables).toEqual(['a', 'c']) + }) + + it('prefers the incoming mutation when the same variable is edited concurrently', () => { + const base = [createNodeSnapshot(['a'])] + const userA = [ + { + ...createNodeSnapshot(['a']), + data: { + ...createNodeSnapshot(['a']).data, + variables: [ + createVariable('a', { label: 'A from userA', hint: 'hintA' }), + ], + }, + }, + ] + const userB = [ + { + ...createNodeSnapshot(['a']), + data: { + ...createNodeSnapshot(['a']).data, + variables: [ + createVariable('a', { label: 'A from userB', hint: 'hintB' }), + ], + }, + }, + ] + + internals.syncNodes(base, userA) + internals.syncNodes(base, userB) + + const finalNode = (manager.getNodes() as Node[]).find(node => node.id === NODE_ID) + const finalVariable = getVariableObject(finalNode!, 'a') + + expect(finalVariable?.label).toBe('A from userB') + expect(finalVariable?.hint).toBe('hintB') + }) + + it('reflects the last writer when concurrent removal and edits happen', () => { + const base = [createNodeSnapshot(['a', 'b'])] + internals.syncNodes([], [deepClone(base[0])]) + const userA = [ + { + ...createNodeSnapshot(['a']), + data: { + ...createNodeSnapshot(['a']).data, + variables: [ + createVariable('a', { label: 'A after deletion' }), + ], + }, + }, + ] + const userB = [ + { + ...createNodeSnapshot(['a', 'b']), + data: { + ...createNodeSnapshot(['a']).data, + variables: [ + createVariable('a'), + createVariable('b', { label: 'B edited but should vanish' }), + ], + }, + }, + ] + + internals.syncNodes(base, userA) + internals.syncNodes(base, userB) + + const finalNode = (manager.getNodes() as Node[]).find(node => node.id === NODE_ID) + const finalVariables = getVariables(finalNode!) + expect(finalVariables).toEqual(['a', 'b']) + expect(getVariableObject(finalNode!, 'b')).toBeDefined() + }) + + it('synchronizes prompt_template list updates across collaborators', () => { + const { manager: promptManager, internals: promptInternals } = setupManager() + + const baseTemplate = [ + { + id: 'abcfa5f9-3c44-4252-aeba-4b6eaf0acfc4', + role: 'system', + text: 'avc', + }, + ] + + const baseNode = createLLMNodeSnapshot(baseTemplate) + promptInternals.syncNodes([], [deepClone(baseNode)]) + + const updatedTemplates = [ + ...baseTemplate, + { + id: 'user-1', + role: 'user', + text: 'hello world', + }, + ] + + const updatedNode = createLLMNodeSnapshot(updatedTemplates) + promptInternals.syncNodes([deepClone(baseNode)], [deepClone(updatedNode)]) + + const stored = (promptManager.getNodes() as Node[]).find(node => node.id === LLM_NODE_ID) + expect(stored).toBeDefined() + + const storedTemplates = getPromptTemplates(stored!) + expect(storedTemplates).toHaveLength(2) + expect(storedTemplates[0]).toEqual(baseTemplate[0]) + expect(storedTemplates[1]).toEqual(updatedTemplates[1]) + + const editedTemplates = [ + { + id: 'abcfa5f9-3c44-4252-aeba-4b6eaf0acfc4', + role: 'system', + text: 'updated system prompt', + }, + ] + const editedNode = createLLMNodeSnapshot(editedTemplates) + + promptInternals.syncNodes([deepClone(updatedNode)], [deepClone(editedNode)]) + + const final = (promptManager.getNodes() as Node[]).find(node => node.id === LLM_NODE_ID) + const finalTemplates = getPromptTemplates(final!) + expect(finalTemplates).toHaveLength(1) + expect(finalTemplates[0].text).toBe('updated system prompt') + }) + + it('keeps parameter list in sync when nodes add, edit, or remove parameters', () => { + const { manager: parameterManager, internals: parameterInternals } = setupManager() + + const baseParameters: ParameterItem[] = [ + { description: 'bb', name: 'aa', required: false, type: 'string' }, + { description: 'dd', name: 'cc', required: false, type: 'string' }, + ] + + const baseNode = createParameterExtractorNodeSnapshot(baseParameters) + parameterInternals.syncNodes([], [deepClone(baseNode)]) + + const updatedParameters: ParameterItem[] = [ + ...baseParameters, + { description: 'ff', name: 'ee', required: true, type: 'number' }, + ] + + const updatedNode = createParameterExtractorNodeSnapshot(updatedParameters) + parameterInternals.syncNodes([deepClone(baseNode)], [deepClone(updatedNode)]) + + const stored = (parameterManager.getNodes() as Node[]).find(node => node.id === PARAM_NODE_ID) + expect(stored).toBeDefined() + expect(getParameters(stored!)).toEqual(updatedParameters) + + const editedParameters: ParameterItem[] = [ + { description: 'bb edited', name: 'aa', required: true, type: 'string' }, + ] + const editedNode = createParameterExtractorNodeSnapshot(editedParameters) + + parameterInternals.syncNodes([deepClone(updatedNode)], [deepClone(editedNode)]) + + const final = (parameterManager.getNodes() as Node[]).find(node => node.id === PARAM_NODE_ID) + expect(getParameters(final!)).toEqual(editedParameters) + }) + + it('handles nodes without data gracefully', () => { + const emptyNode: Node = { + id: 'empty-node', + type: 'custom', + position: { x: 0, y: 0 }, + data: undefined as unknown as CommonNodeType>, + } + + internals.syncNodes([], [deepClone(emptyNode)]) + + const stored = (manager.getNodes() as Node[]).find(node => node.id === 'empty-node') + expect(stored).toBeDefined() + expect(stored?.data).toEqual({}) + }) + + it('preserves CRDT list instances when synchronizing parsed state back into the manager', () => { + const { manager: promptManager, internals: promptInternals } = setupManager() + + const base = createLLMNodeSnapshot([ + { id: 'system', role: 'system', text: 'base' }, + ]) + promptInternals.syncNodes([], [deepClone(base)]) + + const storedBefore = promptManager.getNodes().find(node => node.id === LLM_NODE_ID) as Node | undefined + expect(storedBefore).toBeDefined() + const firstTemplate = storedBefore?.data.prompt_template?.[0] + expect(firstTemplate?.text).toBe('base') + + // simulate consumer mutating the plain JSON array and syncing back + const baseNode = storedBefore! + const mutatedNode = deepClone(baseNode) + mutatedNode.data.prompt_template.push({ + id: 'user', + role: 'user', + text: 'mutated', + }) + + promptInternals.syncNodes([baseNode], [mutatedNode]) + + const storedAfter = promptManager.getNodes().find(node => node.id === LLM_NODE_ID) as Node | undefined + const templatesAfter = storedAfter?.data.prompt_template + expect(Array.isArray(templatesAfter)).toBe(true) + expect(templatesAfter).toHaveLength(2) + }) + + it('reuses CRDT list when syncing parameters repeatedly', () => { + const { manager: parameterManager, internals: parameterInternals } = setupManager() + + const initialParameters: ParameterItem[] = [ + { description: 'desc', name: 'param', required: false, type: 'string' }, + ] + const node = createParameterExtractorNodeSnapshot(initialParameters) + parameterInternals.syncNodes([], [deepClone(node)]) + + const stored = parameterManager.getNodes().find(n => n.id === PARAM_NODE_ID) as Node + const mutatedNode = deepClone(stored) + mutatedNode.data.parameters[0].description = 'updated' + + parameterInternals.syncNodes([stored], [mutatedNode]) + + const storedAfter = parameterManager.getNodes().find(n => n.id === PARAM_NODE_ID) as + | Node + | undefined + const params = storedAfter?.data.parameters ?? [] + expect(params).toHaveLength(1) + expect(params[0].description).toBe('updated') + }) + + it('filters out transient/private data keys while keeping allowlisted ones', () => { + const nodeWithPrivate: Node<{ _foo: string, variables: WorkflowVariable[] }> = { + id: 'private-node', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.Start, + title: 'private', + desc: '', + _foo: 'should disappear', + _children: [{ nodeId: 'child-a', nodeType: BlockEnum.Start }], + selected: true, + variables: [], + }, + } + + internals.syncNodes([], [deepClone(nodeWithPrivate)]) + + const stored = (manager.getNodes() as Node[]).find(node => node.id === 'private-node')! + const storedData = stored.data as CommonNodeType<{ _foo?: string }> + expect(storedData._foo).toBeUndefined() + expect(storedData._children).toEqual([{ nodeId: 'child-a', nodeType: BlockEnum.Start }]) + expect(storedData.selected).toBeUndefined() + }) + + it('removes list fields when they are omitted in the update snapshot', () => { + const baseNode = createNodeSnapshot(['alpha']) + internals.syncNodes([], [deepClone(baseNode)]) + + const withoutVariables: Node = { + ...deepClone(baseNode), + data: { + ...deepClone(baseNode).data, + }, + } + delete (withoutVariables.data as CommonNodeType<{ variables?: WorkflowVariable[] }>).variables + + internals.syncNodes([deepClone(baseNode)], [withoutVariables]) + + const stored = (manager.getNodes() as Node[]).find(node => node.id === NODE_ID)! + const storedData = stored.data as CommonNodeType<{ variables?: WorkflowVariable[] }> + expect(storedData.variables).toBeUndefined() + }) + + it('treats non-array list inputs as empty lists during synchronization', () => { + const { manager: promptManager, internals: promptInternals } = setupManager() + + const nodeWithInvalidTemplate = createLLMNodeSnapshot([]) + promptInternals.syncNodes([], [deepClone(nodeWithInvalidTemplate)]) + + const mutated = deepClone(nodeWithInvalidTemplate) as Node + mutated.data.prompt_template = 'not-an-array' + + promptInternals.syncNodes([deepClone(nodeWithInvalidTemplate)], [mutated]) + + const stored = promptManager.getNodes().find(node => node.id === LLM_NODE_ID) as Node + expect(Array.isArray(stored.data.prompt_template)).toBe(true) + expect(stored.data.prompt_template).toHaveLength(0) + }) + + it('updates edges map when edges are added, modified, and removed', () => { + const { manager: edgeManager } = setupManager() + + const edge: Edge = { + id: 'edge-1', + source: 'node-a', + target: 'node-b', + type: 'default', + data: { + sourceType: BlockEnum.Start, + targetType: BlockEnum.LLM, + _waitingRun: false, + }, + } + + edgeManager.setEdges([], [edge]) + expect(edgeManager.getEdges()).toHaveLength(1) + const storedEdge = edgeManager.getEdges()[0]! + expect(storedEdge.data).toBeDefined() + expect(storedEdge.data!._waitingRun).toBe(false) + + const updatedEdge: Edge = { + ...edge, + data: { + sourceType: BlockEnum.Start, + targetType: BlockEnum.LLM, + _waitingRun: true, + }, + } + edgeManager.setEdges([edge], [updatedEdge]) + expect(edgeManager.getEdges()).toHaveLength(1) + const updatedStoredEdge = edgeManager.getEdges()[0]! + expect(updatedStoredEdge.data).toBeDefined() + expect(updatedStoredEdge.data!._waitingRun).toBe(true) + + edgeManager.setEdges([updatedEdge], []) + expect(edgeManager.getEdges()).toHaveLength(0) + }) +}) + +describe('CollaborationManager public API wrappers', () => { + let manager: CollaborationManager + let internals: CollaborationManagerInternals + const baseNodes: Node[] = [] + const updatedNodes: Node[] = [ + { + id: 'new-node', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.Start, + title: 'New node', + desc: '', + }, + }, + ] + const baseEdges: Edge[] = [] + const updatedEdges: Edge[] = [ + { + id: 'edge-1', + source: 'source', + target: 'target', + type: 'default', + data: { + sourceType: BlockEnum.Start, + targetType: BlockEnum.End, + }, + }, + ] + + beforeEach(() => { + manager = new CollaborationManager() + internals = getManagerInternals(manager) + }) + + it('setNodes delegates to syncNodes and commits the CRDT document', () => { + const commit = vi.fn() + internals.doc = { commit } + const syncSpy = vi.spyOn(internals, 'syncNodes').mockImplementation(() => undefined) + + manager.setNodes(baseNodes, updatedNodes) + + expect(syncSpy).toHaveBeenCalledWith(baseNodes, updatedNodes) + expect(commit).toHaveBeenCalled() + syncSpy.mockRestore() + }) + + it('setNodes skips syncing when undo/redo replay is running', () => { + const commit = vi.fn() + internals.doc = { commit } + internals.isUndoRedoInProgress = true + const syncSpy = vi.spyOn(internals, 'syncNodes').mockImplementation(() => undefined) + + manager.setNodes(baseNodes, updatedNodes) + + expect(syncSpy).not.toHaveBeenCalled() + expect(commit).not.toHaveBeenCalled() + syncSpy.mockRestore() + }) + + it('setEdges delegates to syncEdges and commits the CRDT document', () => { + const commit = vi.fn() + internals.doc = { commit } + const syncSpy = vi.spyOn(internals, 'syncEdges').mockImplementation(() => undefined) + + manager.setEdges(baseEdges, updatedEdges) + + expect(syncSpy).toHaveBeenCalledWith(baseEdges, updatedEdges) + expect(commit).toHaveBeenCalled() + syncSpy.mockRestore() + }) + + it('disconnect tears down the collaboration state only when last connection closes', () => { + const forceSpy = vi.spyOn(internals, 'forceDisconnect').mockImplementation(() => undefined) + internals.activeConnections.add('conn-a') + internals.activeConnections.add('conn-b') + + manager.disconnect('conn-a') + expect(forceSpy).not.toHaveBeenCalled() + + manager.disconnect('conn-b') + expect(forceSpy).toHaveBeenCalledTimes(1) + forceSpy.mockRestore() + }) + + it('applyNodePanelPresenceUpdate keeps a client visible on a single node at a time', () => { + const updates: NodePanelPresenceMap[] = [] + manager.onNodePanelPresenceUpdate((presence) => { + updates.push(presence) + }) + + const user: NodePanelPresenceUser = { userId: 'user-1', username: 'Dana' } + + internals.applyNodePanelPresenceUpdate({ + nodeId: 'node-a', + action: 'open', + user, + clientId: 'client-1', + timestamp: 100, + }) + + internals.applyNodePanelPresenceUpdate({ + nodeId: 'node-b', + action: 'open', + user, + clientId: 'client-1', + timestamp: 200, + }) + + const finalSnapshot = updates[updates.length - 1]! + expect(finalSnapshot).toEqual({ + 'node-b': { + 'client-1': { + userId: 'user-1', + username: 'Dana', + clientId: 'client-1', + timestamp: 200, + }, + }, + }) + }) + + it('applyNodePanelPresenceUpdate clears node entries when last viewer closes the panel', () => { + const updates: NodePanelPresenceMap[] = [] + manager.onNodePanelPresenceUpdate((presence) => { + updates.push(presence) + }) + + const user: NodePanelPresenceUser = { userId: 'user-2', username: 'Kai' } + + internals.applyNodePanelPresenceUpdate({ + nodeId: 'node-a', + action: 'open', + user, + clientId: 'client-9', + timestamp: 300, + }) + + internals.applyNodePanelPresenceUpdate({ + nodeId: 'node-a', + action: 'close', + user, + clientId: 'client-9', + timestamp: 301, + }) + + expect(updates[updates.length - 1]).toEqual({}) + }) +}) diff --git a/web/app/components/workflow/collaboration/core/__tests__/crdt-provider.test.ts b/web/app/components/workflow/collaboration/core/__tests__/crdt-provider.test.ts new file mode 100644 index 0000000000..613c2d1b75 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/__tests__/crdt-provider.test.ts @@ -0,0 +1,138 @@ +import type { LoroDoc } from 'loro-crdt/base64' +import type { Socket } from 'socket.io-client' +import { CRDTProvider } from '../crdt-provider' + +type FakeDocEvent = { + by: string +} + +type FakeDoc = { + export: ReturnType + import: ReturnType + subscribe: ReturnType + trigger: (event: FakeDocEvent) => void +} + +const createFakeDoc = (): FakeDoc => { + let handler: ((payload: FakeDocEvent) => void) | null = null + + const exportFn = vi.fn(() => new Uint8Array([1, 2, 3])) + const importFn = vi.fn() + const subscribeFn = vi.fn((cb: (payload: FakeDocEvent) => void) => { + handler = cb + }) + + return { + export: exportFn, + import: importFn, + subscribe: subscribeFn, + trigger: (event: FakeDocEvent) => { + handler?.(event) + }, + } +} + +type MockSocket = { + trigger: (event: string, ...args: unknown[]) => void + emit: ReturnType + on: ReturnType + off: ReturnType +} + +const createMockSocket = (): MockSocket => { + const handlers = new Map void>() + + const socket: MockSocket = { + emit: vi.fn(), + on: vi.fn((event: string, handler: (...args: unknown[]) => void) => { + handlers.set(event, handler) + }), + off: vi.fn((event: string) => { + handlers.delete(event) + }), + trigger: (event: string, ...args: unknown[]) => { + const handler = handlers.get(event) + if (handler) + handler(...args) + }, + } + + return socket +} + +describe('CRDTProvider', () => { + it('emits graph_event when local changes happen', () => { + const doc = createFakeDoc() + const socket = createMockSocket() + + const provider = new CRDTProvider(socket as unknown as Socket, doc as unknown as LoroDoc) + expect(provider).toBeInstanceOf(CRDTProvider) + + doc.trigger({ by: 'local' }) + + expect(socket.emit).toHaveBeenCalledWith( + 'graph_event', + expect.any(Uint8Array), + expect.any(Function), + ) + expect(doc.export).toHaveBeenCalledWith({ mode: 'update' }) + }) + + it('ignores non-local events', () => { + const doc = createFakeDoc() + const socket = createMockSocket() + + const provider = new CRDTProvider(socket as unknown as Socket, doc as unknown as LoroDoc) + + doc.trigger({ by: 'remote' }) + + expect(socket.emit).not.toHaveBeenCalled() + provider.destroy() + }) + + it('imports remote updates on graph_update', () => { + const doc = createFakeDoc() + const socket = createMockSocket() + + const provider = new CRDTProvider(socket as unknown as Socket, doc as unknown as LoroDoc) + + const payload = new Uint8Array([9, 9, 9]) + socket.trigger('graph_update', payload) + + expect(doc.import).toHaveBeenCalledWith(expect.any(Uint8Array)) + expect(Array.from(doc.import.mock.calls[0][0])).toEqual([9, 9, 9]) + provider.destroy() + }) + + it('removes graph_update listener on destroy', () => { + const doc = createFakeDoc() + const socket = createMockSocket() + + const provider = new CRDTProvider(socket as unknown as Socket, doc as unknown as LoroDoc) + provider.destroy() + + expect(socket.off).toHaveBeenCalledWith('graph_update') + }) + + it('logs an error when graph_update import fails but continues operating', () => { + const doc = createFakeDoc() + const socket = createMockSocket() + doc.import.mockImplementation(() => { + throw new Error('boom') + }) + + const provider = new CRDTProvider(socket as unknown as Socket, doc as unknown as LoroDoc) + + const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => undefined) + + socket.trigger('graph_update', new Uint8Array([1])) + expect(errorSpy).toHaveBeenCalledWith('Error importing graph update:', expect.any(Error)) + + doc.import.mockReset() + socket.trigger('graph_update', new Uint8Array([2, 3])) + expect(doc.import).toHaveBeenCalled() + + provider.destroy() + errorSpy.mockRestore() + }) +}) diff --git a/web/app/components/workflow/collaboration/core/__tests__/event-emitter.test.ts b/web/app/components/workflow/collaboration/core/__tests__/event-emitter.test.ts new file mode 100644 index 0000000000..19c4990856 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/__tests__/event-emitter.test.ts @@ -0,0 +1,93 @@ +import { EventEmitter } from '../event-emitter' + +describe('EventEmitter', () => { + it('registers and invokes handlers via on/emit', () => { + const emitter = new EventEmitter() + const handler = vi.fn() + + emitter.on('test', handler) + emitter.emit('test', { value: 42 }) + + expect(handler).toHaveBeenCalledWith({ value: 42 }) + }) + + it('removes specific handler with off', () => { + const emitter = new EventEmitter() + const handlerA = vi.fn() + const handlerB = vi.fn() + + emitter.on('test', handlerA) + emitter.on('test', handlerB) + + emitter.off('test', handlerA) + emitter.emit('test', 'payload') + + expect(handlerA).not.toHaveBeenCalled() + expect(handlerB).toHaveBeenCalledWith('payload') + }) + + it('clears all listeners when off is called without handler', () => { + const emitter = new EventEmitter() + const handlerA = vi.fn() + const handlerB = vi.fn() + + emitter.on('trigger', handlerA) + emitter.on('trigger', handlerB) + + emitter.off('trigger') + emitter.emit('trigger', 'payload') + + expect(handlerA).not.toHaveBeenCalled() + expect(handlerB).not.toHaveBeenCalled() + expect(emitter.getListenerCount('trigger')).toBe(0) + }) + + it('removeAllListeners clears every registered event', () => { + const emitter = new EventEmitter() + emitter.on('one', vi.fn()) + emitter.on('two', vi.fn()) + + emitter.removeAllListeners() + + expect(emitter.getListenerCount('one')).toBe(0) + expect(emitter.getListenerCount('two')).toBe(0) + }) + + it('returns an unsubscribe function from on', () => { + const emitter = new EventEmitter() + const handler = vi.fn() + + const unsubscribe = emitter.on('detach', handler) + unsubscribe() + + emitter.emit('detach', 'value') + + expect(handler).not.toHaveBeenCalled() + }) + + it('continues emitting when a handler throws', () => { + const emitter = new EventEmitter() + const errorHandler = vi + .spyOn(console, 'error') + .mockImplementation(() => undefined) + + const failingHandler = vi.fn(() => { + throw new Error('boom') + }) + const succeedingHandler = vi.fn() + + emitter.on('safe', failingHandler) + emitter.on('safe', succeedingHandler) + + emitter.emit('safe', 7) + + expect(failingHandler).toHaveBeenCalledWith(7) + expect(succeedingHandler).toHaveBeenCalledWith(7) + expect(errorHandler).toHaveBeenCalledWith( + expect.stringContaining('Error in event handler for safe:'), + expect.any(Error), + ) + + errorHandler.mockRestore() + }) +}) diff --git a/web/app/components/workflow/collaboration/core/__tests__/websocket-manager.test.ts b/web/app/components/workflow/collaboration/core/__tests__/websocket-manager.test.ts new file mode 100644 index 0000000000..b9982164c5 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/__tests__/websocket-manager.test.ts @@ -0,0 +1,161 @@ +type MockSocket = { + trigger: (event: string, ...args: unknown[]) => void + emit: ReturnType + on: ReturnType + disconnect: ReturnType + connected: boolean +} + +type IoOptions = { + auth?: unknown + path?: string + transports?: string[] + withCredentials?: boolean +} + +const ioMock = vi.hoisted(() => vi.fn()) + +vi.mock('socket.io-client', () => ({ + io: (...args: Parameters) => ioMock(...args), +})) + +const createMockSocket = (id: string): MockSocket => { + const handlers = new Map void>() + + const socket: MockSocket & { id: string } = { + id, + connected: true, + emit: vi.fn(), + disconnect: vi.fn(() => { + socket.connected = false + }), + on: vi.fn((event: string, handler: (...args: unknown[]) => void) => { + handlers.set(event, handler) + }), + trigger: (event: string, ...args: unknown[]) => { + const handler = handlers.get(event) + if (handler) + handler(...args) + }, + } + + return socket +} + +describe('WebSocketClient', () => { + beforeEach(() => { + vi.resetModules() + ioMock.mockReset() + }) + + it('connects with default url and registers base listeners', async () => { + const mockSocket = createMockSocket('socket-fallback') + ioMock.mockImplementation(() => mockSocket) + + const { WebSocketClient } = await import('../websocket-manager') + const client = new WebSocketClient() + const socket = client.connect('app-1') + + expect(ioMock).toHaveBeenCalledWith( + 'ws://localhost:5001', + expect.objectContaining({ + path: '/socket.io', + transports: ['websocket'], + withCredentials: true, + }), + ) + expect(socket).toBe(mockSocket) + expect(mockSocket.on).toHaveBeenCalledWith('connect', expect.any(Function)) + expect(mockSocket.on).toHaveBeenCalledWith('disconnect', expect.any(Function)) + expect(mockSocket.on).toHaveBeenCalledWith('connect_error', expect.any(Function)) + }) + + it('reuses existing connected socket and avoids duplicate connections', async () => { + const mockSocket = createMockSocket('socket-reuse') + ioMock.mockImplementation(() => mockSocket) + + const { WebSocketClient } = await import('../websocket-manager') + const client = new WebSocketClient() + + const first = client.connect('app-reuse') + const second = client.connect('app-reuse') + + expect(ioMock).toHaveBeenCalledTimes(1) + expect(second).toBe(first) + }) + + it('emits user_connect on connect without auth payload', async () => { + const mockSocket = createMockSocket('socket-auth') + ioMock.mockImplementation((url: string, options: IoOptions) => { + expect(options.auth).toBeUndefined() + return mockSocket + }) + + const { WebSocketClient } = await import('../websocket-manager') + const client = new WebSocketClient() + client.connect('app-auth') + + const connectHandler = mockSocket.on.mock.calls.find(call => call[0] === 'connect')?.[1] as () => void + expect(connectHandler).toBeDefined() + connectHandler() + + expect(mockSocket.emit).toHaveBeenCalledWith( + 'user_connect', + { workflow_id: 'app-auth' }, + expect.any(Function), + ) + }) + + it('disconnects a specific app and clears internal maps', async () => { + const mockSocket = createMockSocket('socket-disconnect-one') + ioMock.mockImplementation(() => mockSocket) + + const { WebSocketClient } = await import('../websocket-manager') + const client = new WebSocketClient() + client.connect('app-disconnect') + + expect(client.isConnected('app-disconnect')).toBe(true) + client.disconnect('app-disconnect') + + expect(mockSocket.disconnect).toHaveBeenCalled() + expect(client.getSocket('app-disconnect')).toBeNull() + expect(client.isConnected('app-disconnect')).toBe(false) + }) + + it('disconnects all apps when no id is provided', async () => { + const socketA = createMockSocket('socket-a') + const socketB = createMockSocket('socket-b') + ioMock.mockImplementationOnce(() => socketA).mockImplementationOnce(() => socketB) + + const { WebSocketClient } = await import('../websocket-manager') + const client = new WebSocketClient() + client.connect('app-a') + client.connect('app-b') + + client.disconnect() + + expect(socketA.disconnect).toHaveBeenCalled() + expect(socketB.disconnect).toHaveBeenCalled() + expect(client.getConnectedApps()).toEqual([]) + }) + + it('reports connected apps, sockets, and debug info correctly', async () => { + const socketA = createMockSocket('socket-debug-a') + const socketB = createMockSocket('socket-debug-b') + socketB.connected = false + ioMock.mockImplementationOnce(() => socketA).mockImplementationOnce(() => socketB) + + const { WebSocketClient } = await import('../websocket-manager') + const client = new WebSocketClient() + client.connect('app-a') + client.connect('app-b') + + expect(client.getConnectedApps()).toEqual(['app-a']) + + const debugInfo = client.getDebugInfo() + expect(debugInfo).toMatchObject({ + 'app-a': { connected: true, socketId: 'socket-debug-a' }, + 'app-b': { connected: false, socketId: 'socket-debug-b' }, + }) + }) +}) diff --git a/web/app/components/workflow/collaboration/core/collaboration-manager.ts b/web/app/components/workflow/collaboration/core/collaboration-manager.ts new file mode 100644 index 0000000000..5000886222 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/collaboration-manager.ts @@ -0,0 +1,1658 @@ +'use client' + +import type { Value } from 'loro-crdt' +import type { Socket } from 'socket.io-client' +import type { + CommonNodeType, + Edge, + Node, +} from '../../types' +import type { + CollaborationState, + CollaborationUpdate, + CursorPosition, + NodePanelPresenceMap, + NodePanelPresenceUser, + OnlineUser, + RestoreCompleteData, + RestoreIntentData, + RestoreRequestData, +} from '../types/collaboration' +import { cloneDeep } from 'es-toolkit/object' +import { isEqual } from 'es-toolkit/predicate' +import { LoroDoc, LoroList, LoroMap, UndoManager } from 'loro-crdt' +import { CRDTProvider } from './crdt-provider' +import { EventEmitter } from './event-emitter' +import { emitWithAuthGuard, webSocketClient } from './websocket-manager' + +type NodePanelPresenceEventData = { + nodeId: string + action: 'open' | 'close' + user: NodePanelPresenceUser + clientId: string + timestamp: number +} + +type ReactFlowStore = { + getState: () => { + getNodes: () => Node[] + setNodes: (nodes: Node[]) => void + getEdges: () => Edge[] + setEdges: (edges: Edge[]) => void + } +} + +type CollaborationEventPayload = { + type: CollaborationUpdate['type'] + data: Record + timestamp: number + userId?: string +} + +type LoroSubscribeEvent = { + by?: string +} + +type LoroContainer = { + kind?: () => string + getAttached?: () => unknown +} + +type GraphImportLogEntry = { + timestamp: number + appId: string | null + sources: Array<'nodes' | 'edges'> + before: { + nodes: Node[] + edges: Edge[] + } + after: { + nodes: Node[] + edges: Edge[] + } + meta: { + leaderId: string | null + isLeader: boolean + graphViewActive: boolean | null + pendingInitialSync: boolean + } +} + +type SetNodesAnomalyReason = 'node_count_decrease' | 'start_removed' + +type SetNodesAnomalyLogEntry = { + timestamp: number + appId: string | null + source: string + reasons: SetNodesAnomalyReason[] + oldCount: number + newCount: number + removedNodeIds: string[] + oldStartNodeIds: string[] + newStartNodeIds: string[] + oldNodeIds: string[] + newNodeIds: string[] + visibilityState: DocumentVisibilityState | 'unknown' + meta: { + leaderId: string | null + isLeader: boolean + graphViewActive: boolean | null + pendingInitialSync: boolean + isConnected: boolean + } +} + +type GraphSyncDiagnosticStage + = | 'nodes_subscribe' + | 'edges_subscribe' + | 'nodes_import_apply' + | 'edges_import_apply' + | 'schedule_graph_import_emit' + | 'graph_import_emit' + | 'start_import_log' + | 'finalize_import_log' + +type GraphSyncDiagnosticEvent = { + timestamp: number + appId: string | null + stage: GraphSyncDiagnosticStage + status: 'triggered' | 'skipped' | 'applied' | 'queued' | 'emitted' | 'snapshot' + reason?: string + details?: Record + meta: { + leaderId: string | null + isLeader: boolean + isUndoRedoInProgress: boolean + pendingInitialSync: boolean + pendingGraphImportEmit: boolean + isConnected: boolean + } +} + +const GRAPH_IMPORT_LOG_LIMIT = 20 +const SET_NODES_ANOMALY_LOG_LIMIT = 100 +const GRAPH_SYNC_DIAGNOSTIC_LOG_LIMIT = 400 + +const toLoroValue = (value: unknown): Value => cloneDeep(value) as Value +const toLoroRecord = (value: unknown): Record => cloneDeep(value) as Record +export class CollaborationManager { + private doc: LoroDoc | null = null + private undoManager: UndoManager | null = null + private provider: CRDTProvider | null = null + private nodesMap: LoroMap> | null = null + private edgesMap: LoroMap> | null = null + private eventEmitter = new EventEmitter() + private currentAppId: string | null = null + private reactFlowStore: ReactFlowStore | null = null + private isLeader = false + private leaderId: string | null = null + private onlineUsers: OnlineUser[] = [] + private cursors: Record = {} + private nodePanelPresence: NodePanelPresenceMap = {} + private activeConnections = new Set() + private isUndoRedoInProgress = false + private pendingInitialSync = false + private rejoinInProgress = false + private pendingGraphImportEmit = false + private graphViewActive: boolean | null = null + private graphImportLogs: GraphImportLogEntry[] = [] + private setNodesAnomalyLogs: SetNodesAnomalyLogEntry[] = [] + private graphSyncDiagnostics: GraphSyncDiagnosticEvent[] = [] + private pendingImportLog: { + timestamp: number + sources: Set<'nodes' | 'edges'> + before: { + nodes: Node[] + edges: Edge[] + } + } | null = null + + private getActiveSocket(): Socket | null { + if (!this.currentAppId) + return null + return webSocketClient.getSocket(this.currentAppId) + } + + private handleSessionUnauthorized = (): void => { + if (this.rejoinInProgress) + return + if (!this.currentAppId) + return + + const socket = this.getActiveSocket() + if (!socket) + return + + this.rejoinInProgress = true + console.warn('Collaboration session expired, attempting to rejoin workflow.') + emitWithAuthGuard( + socket, + 'user_connect', + { workflow_id: this.currentAppId }, + { + onAck: () => { + this.rejoinInProgress = false + }, + onUnauthorized: () => { + this.rejoinInProgress = false + console.error('Rejoin failed due to authorization error, forcing disconnect.') + this.forceDisconnect() + }, + }, + ) + } + + private sendCollaborationEvent(payload: CollaborationEventPayload): void { + const socket = this.getActiveSocket() + if (!socket) + return + + emitWithAuthGuard(socket, 'collaboration_event', payload, { onUnauthorized: this.handleSessionUnauthorized }) + } + + private sendGraphEvent(payload: Uint8Array): void { + const socket = this.getActiveSocket() + if (!socket) + return + + emitWithAuthGuard(socket, 'graph_event', payload, { onUnauthorized: this.handleSessionUnauthorized }) + } + + private getNodeContainer(nodeId: string): LoroMap> { + if (!this.nodesMap) + throw new Error('Nodes map not initialized') + + let container = this.nodesMap.get(nodeId) as unknown + + const isMapContainer = (value: unknown): value is LoroMap> & LoroContainer => { + return !!value && typeof (value as LoroContainer).kind === 'function' && (value as LoroContainer).kind?.() === 'Map' + } + + if (!container || !isMapContainer(container)) { + const previousValue = container + const newContainer = this.nodesMap.setContainer(nodeId, new LoroMap()) + const attached = (newContainer as LoroContainer).getAttached?.() ?? newContainer + container = attached + if (previousValue && typeof previousValue === 'object') + this.populateNodeContainer(container as LoroMap>, previousValue as Node) + } + else { + const attached = (container as LoroContainer).getAttached?.() ?? container + container = attached + } + + return container as LoroMap> + } + + private ensureDataContainer(nodeContainer: LoroMap>): LoroMap> { + let dataContainer = nodeContainer.get('data') as unknown + + if (!dataContainer || typeof (dataContainer as LoroContainer).kind !== 'function' || (dataContainer as LoroContainer).kind?.() !== 'Map') + dataContainer = nodeContainer.setContainer('data', new LoroMap()) + + const attached = (dataContainer as LoroContainer).getAttached?.() ?? dataContainer + return attached as LoroMap> + } + + private ensureList(nodeContainer: LoroMap>, key: string): LoroList { + const dataContainer = this.ensureDataContainer(nodeContainer) + let list = dataContainer.get(key) as unknown + + if (!list || typeof (list as LoroContainer).kind !== 'function' || (list as LoroContainer).kind?.() !== 'List') + list = dataContainer.setContainer(key, new LoroList()) + + const attached = (list as LoroContainer).getAttached?.() ?? list + return attached as LoroList + } + + private exportNode(nodeId: string): Node { + const container = this.getNodeContainer(nodeId) + const json = container.toJSON() as Node + return { + ...json, + data: json.data || {}, + } + } + + private populateNodeContainer(container: LoroMap>, node: Node): void { + const listFields = new Set(['variables', 'prompt_template', 'parameters']) + container.set('id', node.id) + container.set('type', node.type) + container.set('position', toLoroValue(node.position)) + container.set('sourcePosition', node.sourcePosition) + container.set('targetPosition', node.targetPosition) + + if (node.width === undefined) + container.delete('width') + else container.set('width', node.width) + + if (node.height === undefined) + container.delete('height') + else container.set('height', node.height) + + if (node.selected === undefined) + container.delete('selected') + else container.set('selected', node.selected) + + const optionalProps: Array = [ + 'parentId', + 'positionAbsolute', + 'extent', + 'zIndex', + 'draggable', + 'selectable', + 'dragHandle', + 'dragging', + 'connectable', + 'expandParent', + 'focusable', + 'hidden', + 'style', + 'className', + 'ariaLabel', + 'resizing', + 'deletable', + ] + + optionalProps.forEach((prop) => { + const value = node[prop] + if (value === undefined) + container.delete(prop as string) + else + container.set(prop as string, toLoroValue(value)) + }) + + const dataContainer = this.ensureDataContainer(container) + const handledKeys = new Set() + + Object.entries(node.data || {}).forEach(([key, value]) => { + if (!this.shouldSyncDataKey(key)) + return + handledKeys.add(key) + + if (listFields.has(key)) + this.syncList(container, key, Array.isArray(value) ? value : []) + else + dataContainer.set(key, toLoroValue(value)) + }) + + const existingData = dataContainer.toJSON() || {} + Object.keys(existingData).forEach((key) => { + if (!this.shouldSyncDataKey(key)) + return + if (handledKeys.has(key)) + return + + dataContainer.delete(key) + }) + } + + private shouldSyncDataKey(key: string): boolean { + const syncDataAllowList = new Set(['_children', '_connectedSourceHandleIds', '_connectedTargetHandleIds', '_targetBranches']) + return (syncDataAllowList.has(key) || !key.startsWith('_')) && key !== 'selected' + } + + private syncList(nodeContainer: LoroMap>, key: string, desired: Array): void { + const list = this.ensureList(nodeContainer, key) + const current = list.toJSON() as Array + const target = Array.isArray(desired) ? desired : [] + const minLength = Math.min(current.length, target.length) + + for (let i = 0; i < minLength; i += 1) { + if (!isEqual(current[i], target[i])) { + list.delete(i, 1) + list.insert(i, cloneDeep(target[i])) + } + } + + if (current.length > target.length) { + list.delete(target.length, current.length - target.length) + } + else if (target.length > current.length) { + for (let i = current.length; i < target.length; i += 1) + list.insert(i, cloneDeep(target[i])) + } + } + + private getNodePanelPresenceSnapshot(): NodePanelPresenceMap { + const snapshot: NodePanelPresenceMap = {} + Object.entries(this.nodePanelPresence).forEach(([nodeId, viewers]) => { + snapshot[nodeId] = { ...viewers } + }) + return snapshot + } + + private applyNodePanelPresenceUpdate(update: NodePanelPresenceEventData): void { + const { nodeId, action, clientId, user, timestamp } = update + + if (action === 'open') { + // ensure a client only appears on a single node at a time + Object.entries(this.nodePanelPresence).forEach(([id, viewers]) => { + if (viewers[clientId]) { + delete viewers[clientId] + if (Object.keys(viewers).length === 0) + delete this.nodePanelPresence[id] + } + }) + + if (!this.nodePanelPresence[nodeId]) + this.nodePanelPresence[nodeId] = {} + + this.nodePanelPresence[nodeId][clientId] = { + ...user, + clientId, + timestamp: timestamp || Date.now(), + } + } + else { + const viewers = this.nodePanelPresence[nodeId] + if (viewers) { + delete viewers[clientId] + if (Object.keys(viewers).length === 0) + delete this.nodePanelPresence[nodeId] + } + } + + this.eventEmitter.emit('nodePanelPresence', this.getNodePanelPresenceSnapshot()) + } + + private cleanupNodePanelPresence(activeClientIds: Set): void { + let hasChanges = false + + Object.entries(this.nodePanelPresence).forEach(([nodeId, viewers]) => { + Object.keys(viewers).forEach((clientId) => { + const clientActive = activeClientIds.has(clientId) + + if (!clientActive) { + delete viewers[clientId] + hasChanges = true + } + }) + + if (Object.keys(viewers).length === 0) + delete this.nodePanelPresence[nodeId] + }) + + if (hasChanges) + this.eventEmitter.emit('nodePanelPresence', this.getNodePanelPresenceSnapshot()) + } + + init = (appId: string, reactFlowStore: ReactFlowStore): void => { + if (!reactFlowStore) { + console.warn('CollaborationManager.init called without reactFlowStore, deferring to connect()') + return + } + this.connect(appId, reactFlowStore) + } + + setNodes = (oldNodes: Node[], newNodes: Node[], source = 'collaboration-manager:setNodes'): void => { + if (!this.doc) + return + + // Don't track operations during undo/redo to prevent loops + if (this.isUndoRedoInProgress) + return + + this.seedCrdtGraphFromReactFlowIfNeeded() + this.captureSetNodesAnomaly(oldNodes, newNodes, source) + this.syncNodes(oldNodes, newNodes) + this.doc.commit() + } + + setEdges = (oldEdges: Edge[], newEdges: Edge[]): void => { + if (!this.doc) + return + + // Don't track operations during undo/redo to prevent loops + if (this.isUndoRedoInProgress) + return + + this.seedCrdtGraphFromReactFlowIfNeeded() + this.syncEdges(oldEdges, newEdges) + this.doc.commit() + } + + destroy = (): void => { + this.disconnect() + } + + async connect(appId: string, reactFlowStore?: ReactFlowStore): Promise { + const connectionId = Math.random().toString(36).substring(2, 11) + + this.activeConnections.add(connectionId) + + if (this.currentAppId === appId && this.doc) { + // Already connected to the same app, only update store if provided and we don't have one + if (reactFlowStore && !this.reactFlowStore) + this.reactFlowStore = reactFlowStore + + return connectionId + } + + // Only disconnect if switching to a different app + if (this.currentAppId && this.currentAppId !== appId) + this.forceDisconnect() + + this.currentAppId = appId + // Only set store if provided + if (reactFlowStore) + this.reactFlowStore = reactFlowStore + + const socket = webSocketClient.connect(appId) + + // Setup event listeners BEFORE any other operations + this.setupSocketEventListeners(socket) + + this.doc = new LoroDoc() + this.nodesMap = this.doc.getMap('nodes') as LoroMap> + this.edgesMap = this.doc.getMap('edges') as LoroMap> + + // Initialize UndoManager for collaborative undo/redo + this.undoManager = new UndoManager(this.doc, { + maxUndoSteps: 100, + mergeInterval: 500, // Merge operations within 500ms + excludeOriginPrefixes: [], // Don't exclude anything - let UndoManager track all local operations + onPush: (_isUndo, _range, _event) => { + // Store current selection state when an operation is pushed + const selectedNode = this.reactFlowStore?.getState().getNodes().find((n: Node) => n.data?.selected) + + // Emit event to update UI button states when new operation is pushed + setTimeout(() => { + this.eventEmitter.emit('undoRedoStateChange', { + canUndo: this.undoManager?.canUndo() || false, + canRedo: this.undoManager?.canRedo() || false, + }) + }, 0) + + return { + value: { + selectedNodeId: selectedNode?.id || null, + timestamp: Date.now(), + }, + cursors: [], + } + }, + onPop: (_isUndo, value, _counterRange) => { + // Restore selection state when undoing/redoing + if (value?.value && typeof value.value === 'object' && 'selectedNodeId' in value.value && this.reactFlowStore) { + const selectedNodeId = (value.value as { selectedNodeId?: string | null }).selectedNodeId + if (selectedNodeId) { + const state = this.reactFlowStore.getState() + const { setNodes } = state + const nodes = state.getNodes() + const newNodes = nodes.map((n: Node) => ({ + ...n, + data: { + ...n.data, + selected: n.id === selectedNodeId, + }, + })) + this.captureSetNodesAnomaly(nodes, newNodes, 'reactflow-native:undo-redo-selection-restore') + setNodes(newNodes) + } + } + }, + }) + + this.provider = new CRDTProvider(socket, this.doc, this.handleSessionUnauthorized) + + this.setupSubscriptions() + + // Force user_connect if already connected + if (socket.connected) + emitWithAuthGuard(socket, 'user_connect', { workflow_id: appId }, { onUnauthorized: this.handleSessionUnauthorized }) + + return connectionId + } + + disconnect = (connectionId?: string): void => { + if (connectionId) + this.activeConnections.delete(connectionId) + + // Only disconnect when no more connections + if (this.activeConnections.size === 0) + this.forceDisconnect() + } + + private forceDisconnect = (): void => { + if (this.currentAppId) + webSocketClient.disconnect(this.currentAppId) + + this.provider?.destroy() + this.undoManager = null + this.doc = null + this.provider = null + this.nodesMap = null + this.edgesMap = null + this.currentAppId = null + this.reactFlowStore = null + this.cursors = {} + this.onlineUsers = [] + this.nodePanelPresence = {} + this.isUndoRedoInProgress = false + this.rejoinInProgress = false + this.clearGraphImportLog() + + // Only reset leader status when actually disconnecting + const wasLeader = this.isLeader + this.isLeader = false + this.leaderId = null + + if (wasLeader) + this.eventEmitter.emit('leaderChange', false) + + this.activeConnections.clear() + this.eventEmitter.removeAllListeners() + } + + isConnected(): boolean { + return this.currentAppId ? webSocketClient.isConnected(this.currentAppId) : false + } + + getNodes(): Node[] { + if (!this.nodesMap) + return [] + return Array.from(this.nodesMap.keys()).map(id => this.exportNode(id as string)) + } + + getEdges(): Edge[] { + return this.edgesMap ? Array.from(this.edgesMap.values()) as Edge[] : [] + } + + emitCursorMove(position: CursorPosition): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) + return + + const socket = this.getActiveSocket() + if (!socket) + return + + this.sendCollaborationEvent({ + type: 'mouse_move', + userId: socket.id, + data: { x: position.x, y: position.y }, + timestamp: Date.now(), + }) + } + + emitSyncRequest(): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) + return + + this.sendCollaborationEvent({ + type: 'sync_request', + data: { timestamp: Date.now() }, + timestamp: Date.now(), + }) + } + + emitWorkflowUpdate(appId: string): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) + return + + this.sendCollaborationEvent({ + type: 'workflow_update', + data: { appId, timestamp: Date.now() }, + timestamp: Date.now(), + }) + } + + emitNodePanelPresence(nodeId: string, isOpen: boolean, user: NodePanelPresenceUser): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) + return + + const socket = this.getActiveSocket() + if (!socket || !nodeId || !user?.userId) + return + + const payload: NodePanelPresenceEventData = { + nodeId, + action: isOpen ? 'open' : 'close', + user, + clientId: socket.id as string, + timestamp: Date.now(), + } + + this.sendCollaborationEvent({ + type: 'node_panel_presence', + data: payload, + timestamp: payload.timestamp, + }) + + this.applyNodePanelPresenceUpdate(payload) + } + + onSyncRequest(callback: () => void): () => void { + return this.eventEmitter.on('syncRequest', callback) + } + + onGraphImport(callback: (payload: { nodes: Node[], edges: Edge[] }) => void): () => void { + return this.eventEmitter.on('graphImport', callback) + } + + onStateChange(callback: (state: Partial) => void): () => void { + return this.eventEmitter.on('stateChange', callback) + } + + onCursorUpdate(callback: (cursors: Record) => void): () => void { + return this.eventEmitter.on('cursors', callback) + } + + onOnlineUsersUpdate(callback: (users: OnlineUser[]) => void): () => void { + return this.eventEmitter.on('onlineUsers', callback) + } + + onWorkflowUpdate(callback: (update: { appId: string, timestamp: number }) => void): () => void { + return this.eventEmitter.on('workflowUpdate', callback) + } + + onVarsAndFeaturesUpdate(callback: (update: CollaborationUpdate) => void): () => void { + return this.eventEmitter.on('varsAndFeaturesUpdate', callback) + } + + onAppStateUpdate(callback: (update: CollaborationUpdate) => void): () => void { + return this.eventEmitter.on('appStateUpdate', callback) + } + + onAppPublishUpdate(callback: (update: CollaborationUpdate) => void): () => void { + return this.eventEmitter.on('appPublishUpdate', callback) + } + + onAppMetaUpdate(callback: (update: CollaborationUpdate) => void): () => void { + return this.eventEmitter.on('appMetaUpdate', callback) + } + + onMcpServerUpdate(callback: (update: CollaborationUpdate) => void): () => void { + return this.eventEmitter.on('mcpServerUpdate', callback) + } + + onNodePanelPresenceUpdate(callback: (presence: NodePanelPresenceMap) => void): () => void { + const off = this.eventEmitter.on('nodePanelPresence', callback) + callback(this.getNodePanelPresenceSnapshot()) + return off + } + + onLeaderChange(callback: (isLeader: boolean) => void): () => void { + return this.eventEmitter.on('leaderChange', callback) + } + + onCommentsUpdate(callback: (update: { appId: string, timestamp: number }) => void): () => void { + return this.eventEmitter.on('commentsUpdate', callback) + } + + emitCommentsUpdate(appId: string): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) + return + + this.sendCollaborationEvent({ + type: 'comments_update', + data: { appId, timestamp: Date.now() }, + timestamp: Date.now(), + }) + } + + emitHistoryAction(action: 'undo' | 'redo' | 'jump'): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) + return + + this.sendCollaborationEvent({ + type: 'workflow_history_action', + data: { action }, + timestamp: Date.now(), + }) + } + + onUndoRedoStateChange(callback: (state: { canUndo: boolean, canRedo: boolean }) => void): () => void { + return this.eventEmitter.on('undoRedoStateChange', callback) + } + + onHistoryAction(callback: (payload: { action: 'undo' | 'redo' | 'jump', userId?: string }) => void): () => void { + return this.eventEmitter.on('historyAction', callback) + } + + emitRestoreRequest(data: RestoreRequestData): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) + return + + this.sendCollaborationEvent({ + type: 'workflow_restore_request', + data: data as unknown as Record, + timestamp: Date.now(), + }) + } + + emitRestoreIntent(data: RestoreIntentData): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) + return + + this.sendCollaborationEvent({ + type: 'workflow_restore_intent', + data: data as unknown as Record, + timestamp: Date.now(), + }) + } + + emitRestoreComplete(data: RestoreCompleteData): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) + return + + this.sendCollaborationEvent({ + type: 'workflow_restore_complete', + data: data as unknown as Record, + timestamp: Date.now(), + }) + } + + onRestoreRequest(callback: (data: RestoreRequestData) => void): () => void { + return this.eventEmitter.on('restoreRequest', callback) + } + + onRestoreIntent(callback: (data: RestoreIntentData) => void): () => void { + return this.eventEmitter.on('restoreIntent', callback) + } + + onRestoreComplete(callback: (data: RestoreCompleteData) => void): () => void { + return this.eventEmitter.on('restoreComplete', callback) + } + + getLeaderId(): string | null { + return this.leaderId + } + + getIsLeader(): boolean { + return this.isLeader + } + + // Collaborative undo/redo methods + undo(): boolean { + if (!this.undoManager) + return false + + const canUndo = this.undoManager.canUndo() + if (canUndo) { + this.isUndoRedoInProgress = true + const result = this.undoManager.undo() + + // After undo, manually update React state from CRDT without triggering collaboration + const reactFlowStore = this.reactFlowStore + if (result && reactFlowStore) { + requestAnimationFrame(() => { + // Get ReactFlow's native setters, not the collaborative ones + const state = reactFlowStore.getState() + const previousNodes = state.getNodes() + const updatedNodes = Array.from(this.nodesMap?.values() || []) as Node[] + const updatedEdges = Array.from(this.edgesMap?.values() || []) as Edge[] + // Call ReactFlow's native setters directly to avoid triggering collaboration + this.captureSetNodesAnomaly(previousNodes, updatedNodes, 'reactflow-native:undo-apply') + state.setNodes(updatedNodes) + state.setEdges(updatedEdges) + + this.isUndoRedoInProgress = false + + // Emit event to update UI button states + this.eventEmitter.emit('undoRedoStateChange', { + canUndo: this.undoManager?.canUndo() || false, + canRedo: this.undoManager?.canRedo() || false, + }) + }) + } + else { + this.isUndoRedoInProgress = false + } + + return result + } + + return false + } + + redo(): boolean { + if (!this.undoManager) + return false + + const canRedo = this.undoManager.canRedo() + if (canRedo) { + this.isUndoRedoInProgress = true + const result = this.undoManager.redo() + + // After redo, manually update React state from CRDT without triggering collaboration + const reactFlowStore = this.reactFlowStore + if (result && reactFlowStore) { + requestAnimationFrame(() => { + // Get ReactFlow's native setters, not the collaborative ones + const state = reactFlowStore.getState() + const previousNodes = state.getNodes() + const updatedNodes = Array.from(this.nodesMap?.values() || []) as Node[] + const updatedEdges = Array.from(this.edgesMap?.values() || []) as Edge[] + // Call ReactFlow's native setters directly to avoid triggering collaboration + this.captureSetNodesAnomaly(previousNodes, updatedNodes, 'reactflow-native:redo-apply') + state.setNodes(updatedNodes) + state.setEdges(updatedEdges) + + this.isUndoRedoInProgress = false + + // Emit event to update UI button states + this.eventEmitter.emit('undoRedoStateChange', { + canUndo: this.undoManager?.canUndo() || false, + canRedo: this.undoManager?.canRedo() || false, + }) + }) + } + else { + this.isUndoRedoInProgress = false + } + + return result + } + + return false + } + + canUndo(): boolean { + if (!this.undoManager) + return false + return this.undoManager.canUndo() + } + + canRedo(): boolean { + if (!this.undoManager) + return false + return this.undoManager.canRedo() + } + + clearUndoStack(): void { + if (!this.undoManager) + return + this.undoManager.clear() + } + + private syncNodes(oldNodes: Node[], newNodes: Node[]): void { + if (!this.nodesMap || !this.doc) + return + + const oldNodesMap = new Map(oldNodes.map(node => [node.id, node])) + const newNodesMap = new Map(newNodes.map(node => [node.id, node])) + + oldNodes.forEach((oldNode) => { + if (!newNodesMap.has(oldNode.id)) { + this.nodesMap?.delete(oldNode.id) + } + }) + + newNodes.forEach((newNode) => { + const oldNode = oldNodesMap.get(newNode.id) + if (oldNode && oldNode === newNode) + return + if (oldNode && isEqual(oldNode, newNode)) + return + + const nodeContainer = this.getNodeContainer(newNode.id) + this.populateNodeContainer(nodeContainer, newNode) + }) + } + + private syncEdges(oldEdges: Edge[], newEdges: Edge[]): void { + if (!this.edgesMap) + return + + const oldEdgesMap = new Map(oldEdges.map(edge => [edge.id, edge])) + const newEdgesMap = new Map(newEdges.map(edge => [edge.id, edge])) + + oldEdges.forEach((oldEdge) => { + if (!newEdgesMap.has(oldEdge.id)) { + this.edgesMap?.delete(oldEdge.id) + } + }) + + newEdges.forEach((newEdge) => { + const oldEdge = oldEdgesMap.get(newEdge.id) + if (!oldEdge || !isEqual(oldEdge, newEdge)) { + const clonedEdge = toLoroRecord(newEdge) + this.edgesMap?.set(newEdge.id, clonedEdge) + } + }) + } + + private setupSubscriptions(): void { + this.nodesMap?.subscribe((event: LoroSubscribeEvent) => { + const reactFlowStore = this.reactFlowStore + const eventBy = event.by ?? 'unknown' + this.recordGraphSyncDiagnostic( + 'nodes_subscribe', + 'triggered', + undefined, + { + eventBy, + hasReactFlowStore: Boolean(reactFlowStore), + }, + ) + + if (eventBy !== 'import') { + this.recordGraphSyncDiagnostic('nodes_subscribe', 'skipped', 'event_by_not_import', { eventBy }) + return + } + + if (!reactFlowStore) { + this.recordGraphSyncDiagnostic('nodes_subscribe', 'skipped', 'reactflow_store_missing') + return + } + + // Don't update React nodes during undo/redo to prevent loops + if (this.isUndoRedoInProgress) { + this.recordGraphSyncDiagnostic('nodes_subscribe', 'skipped', 'undo_redo_in_progress') + return + } + + this.recordGraphSyncDiagnostic('nodes_subscribe', 'queued', 'raf_scheduled') + requestAnimationFrame(() => { + const state = reactFlowStore.getState() + const previousNodes: Node[] = state.getNodes() + const previousEdges: Edge[] = state.getEdges() + this.startImportLog('nodes', { nodes: previousNodes, edges: previousEdges }) + const previousNodeMap = new Map(previousNodes.map(node => [node.id, node])) + const selectedIds = new Set( + previousNodes + .filter(node => node.data?.selected) + .map(node => node.id), + ) + + this.pendingInitialSync = false + + const updatedNodes = Array + .from(this.nodesMap?.keys() || []) + .map((nodeId) => { + const node = this.exportNode(nodeId as string) + const clonedNode: Node = { + ...node, + data: { + ...(node.data || {}), + }, + } + const clonedNodeData = clonedNode.data as (CommonNodeType & Record) + // Keep the previous node's private data properties (starting with _) + const previousNode = previousNodeMap.get(clonedNode.id) + if (previousNode?.data) { + const previousData = previousNode.data as Record + Object.entries(previousData) + .filter(([key]) => key.startsWith('_')) + .forEach(([key, value]) => { + if (!(key in clonedNodeData)) + clonedNodeData[key] = value + }) + } + + if (selectedIds.has(clonedNode.id)) + clonedNode.data.selected = true + + return clonedNode + }) + + // Call ReactFlow's native setter directly to avoid triggering collaboration + this.captureSetNodesAnomaly(previousNodes, updatedNodes, 'reactflow-native:import-nodes-map-subscribe') + state.setNodes(updatedNodes) + this.recordGraphSyncDiagnostic( + 'nodes_import_apply', + 'applied', + undefined, + { + eventBy, + previousNodeCount: previousNodes.length, + updatedNodeCount: updatedNodes.length, + previousEdgeCount: previousEdges.length, + selectedCount: selectedIds.size, + }, + ) + + this.scheduleGraphImportEmit() + }) + }) + + this.edgesMap?.subscribe((event: LoroSubscribeEvent) => { + const reactFlowStore = this.reactFlowStore + const eventBy = event.by ?? 'unknown' + this.recordGraphSyncDiagnostic( + 'edges_subscribe', + 'triggered', + undefined, + { + eventBy, + hasReactFlowStore: Boolean(reactFlowStore), + }, + ) + + if (eventBy !== 'import') { + this.recordGraphSyncDiagnostic('edges_subscribe', 'skipped', 'event_by_not_import', { eventBy }) + return + } + + if (!reactFlowStore) { + this.recordGraphSyncDiagnostic('edges_subscribe', 'skipped', 'reactflow_store_missing') + return + } + + // Don't update React edges during undo/redo to prevent loops + if (this.isUndoRedoInProgress) { + this.recordGraphSyncDiagnostic('edges_subscribe', 'skipped', 'undo_redo_in_progress') + return + } + + this.recordGraphSyncDiagnostic('edges_subscribe', 'queued', 'raf_scheduled') + requestAnimationFrame(() => { + // Get ReactFlow's native setters, not the collaborative ones + const state = reactFlowStore.getState() + const previousNodes = state.getNodes() + const previousEdges = state.getEdges() + this.startImportLog('edges', { nodes: previousNodes, edges: previousEdges }) + const updatedEdges = Array.from(this.edgesMap?.values() || []) as Edge[] + + this.pendingInitialSync = false + + // Call ReactFlow's native setter directly to avoid triggering collaboration + state.setEdges(updatedEdges) + this.recordGraphSyncDiagnostic( + 'edges_import_apply', + 'applied', + undefined, + { + eventBy, + previousNodeCount: previousNodes.length, + previousEdgeCount: previousEdges.length, + updatedEdgeCount: updatedEdges.length, + }, + ) + + this.scheduleGraphImportEmit() + }) + }) + } + + private scheduleGraphImportEmit(): void { + if (this.pendingGraphImportEmit) { + this.recordGraphSyncDiagnostic( + 'schedule_graph_import_emit', + 'skipped', + 'already_pending', + ) + return + } + + this.recordGraphSyncDiagnostic('schedule_graph_import_emit', 'queued') + this.pendingGraphImportEmit = true + requestAnimationFrame(() => { + const beforeFinalizeNodes = this.getNodes().length + const beforeFinalizeEdges = this.getEdges().length + this.pendingGraphImportEmit = false + this.finalizeImportLog() + const mergedNodes = this.mergeLocalNodeState(this.getNodes()) + const mergedEdges = this.getEdges() + this.recordGraphSyncDiagnostic( + 'graph_import_emit', + 'emitted', + undefined, + { + mergedNodeCount: mergedNodes.length, + mergedEdgeCount: mergedEdges.length, + crdtNodeCountBeforeFinalize: beforeFinalizeNodes, + crdtEdgeCountBeforeFinalize: beforeFinalizeEdges, + }, + ) + this.eventEmitter.emit('graphImport', { + nodes: mergedNodes, + edges: mergedEdges, + }) + }) + } + + refreshGraphSynchronously(): void { + const mergedNodes = this.mergeLocalNodeState(this.getNodes()) + this.eventEmitter.emit('graphImport', { + nodes: mergedNodes, + edges: this.getEdges(), + }) + } + + private mergeLocalNodeState(nodes: Node[]): Node[] { + const reactFlowStore = this.reactFlowStore + const state = reactFlowStore?.getState() + const localNodes = state?.getNodes() || [] + + if (localNodes.length === 0) + return nodes + + const localNodesMap = new Map(localNodes.map(node => [node.id, node])) + return nodes.map((node) => { + const localNode = localNodesMap.get(node.id) + if (!localNode) + return node + + const nextNode = cloneDeep(node) + const nextData = { ...(nextNode.data || {}) } as Node['data'] + const nextDataRecord = nextData as Record + const localData = localNode.data as Record | undefined + + if (localData) { + Object.entries(localData).forEach(([key, value]) => { + if (key === 'selected' || key.startsWith('_')) + nextDataRecord[key] = value + }) + } + + if (!Object.prototype.hasOwnProperty.call(nextDataRecord, 'selected') && localNode.selected !== undefined) + nextDataRecord.selected = localNode.selected + + nextNode.data = nextData + return nextNode + }) + } + + getGraphImportLog(): GraphImportLogEntry[] { + return cloneDeep(this.graphImportLogs) + } + + clearGraphImportLog(): void { + this.graphImportLogs = [] + this.setNodesAnomalyLogs = [] + this.graphSyncDiagnostics = [] + this.pendingImportLog = null + } + + downloadGraphImportLog(): void { + const reactFlowState = this.reactFlowStore?.getState() + const payload = { + appId: this.currentAppId, + generatedAt: new Date().toISOString(), + entries: this.graphImportLogs, + setNodesAnomalies: this.setNodesAnomalyLogs, + syncDiagnostics: this.graphSyncDiagnostics, + summary: { + logCount: this.graphImportLogs.length, + setNodesAnomalyCount: this.setNodesAnomalyLogs.length, + syncDiagnosticCount: this.graphSyncDiagnostics.length, + leaderId: this.leaderId, + isLeader: this.isLeader, + graphViewActive: this.graphViewActive, + pendingInitialSync: this.pendingInitialSync, + isConnected: this.isConnected(), + hasDoc: Boolean(this.doc), + hasReactFlowStore: Boolean(this.reactFlowStore), + onlineUsersCount: this.onlineUsers.length, + crdtCounts: { + nodes: this.getNodes().length, + edges: this.getEdges().length, + }, + reactFlowCounts: { + nodes: reactFlowState?.getNodes().length ?? 0, + edges: reactFlowState?.getEdges().length ?? 0, + }, + }, + } + const stamp = new Date().toISOString().replace(/[:.]/g, '-') + const appSuffix = this.currentAppId ?? 'unknown' + const fileName = `workflow-graph-import-log-${appSuffix}-${stamp}.json` + const blob = new Blob([JSON.stringify(payload, null, 2)], { type: 'application/json' }) + const url = URL.createObjectURL(blob) + const link = document.createElement('a') + link.href = url + link.download = fileName + link.click() + URL.revokeObjectURL(url) + } + + private recordGraphSyncDiagnostic( + stage: GraphSyncDiagnosticStage, + status: GraphSyncDiagnosticEvent['status'], + reason?: string, + details?: Record, + ): void { + const entry: GraphSyncDiagnosticEvent = { + timestamp: Date.now(), + appId: this.currentAppId, + stage, + status, + reason, + details, + meta: { + leaderId: this.leaderId, + isLeader: this.isLeader, + isUndoRedoInProgress: this.isUndoRedoInProgress, + pendingInitialSync: this.pendingInitialSync, + pendingGraphImportEmit: this.pendingGraphImportEmit, + isConnected: this.isConnected(), + }, + } + + this.graphSyncDiagnostics.push(entry) + if (this.graphSyncDiagnostics.length > GRAPH_SYNC_DIAGNOSTIC_LOG_LIMIT) + this.graphSyncDiagnostics.splice(0, this.graphSyncDiagnostics.length - GRAPH_SYNC_DIAGNOSTIC_LOG_LIMIT) + } + + private captureSetNodesAnomaly(oldNodes: Node[], newNodes: Node[], source: string): void { + const oldNodeIds = oldNodes.map(node => node.id) + const newNodeIds = newNodes.map(node => node.id) + const newNodeIdSet = new Set(newNodeIds) + const removedNodeIds = oldNodeIds.filter(nodeId => !newNodeIdSet.has(nodeId)) + + const oldStartNodeIds = oldNodes + .filter(node => (node.data as CommonNodeType | undefined)?.type === 'start') + .map(node => node.id) + const newStartNodeIds = newNodes + .filter(node => (node.data as CommonNodeType | undefined)?.type === 'start') + .map(node => node.id) + + const reasons: SetNodesAnomalyReason[] = [] + if (newNodes.length < oldNodes.length) + reasons.push('node_count_decrease') + if (oldStartNodeIds.length > 0 && newStartNodeIds.length === 0) + reasons.push('start_removed') + + if (!reasons.length) + return + + const entry: SetNodesAnomalyLogEntry = { + timestamp: Date.now(), + appId: this.currentAppId, + source, + reasons, + oldCount: oldNodes.length, + newCount: newNodes.length, + removedNodeIds, + oldStartNodeIds, + newStartNodeIds, + oldNodeIds, + newNodeIds, + visibilityState: typeof document === 'undefined' ? 'unknown' : document.visibilityState, + meta: { + leaderId: this.leaderId, + isLeader: this.isLeader, + graphViewActive: this.graphViewActive, + pendingInitialSync: this.pendingInitialSync, + isConnected: this.isConnected(), + }, + } + this.setNodesAnomalyLogs.push(entry) + if (this.setNodesAnomalyLogs.length > SET_NODES_ANOMALY_LOG_LIMIT) + this.setNodesAnomalyLogs.splice(0, this.setNodesAnomalyLogs.length - SET_NODES_ANOMALY_LOG_LIMIT) + } + + private snapshotReactFlowGraph(): { nodes: Node[], edges: Edge[] } { + if (!this.reactFlowStore) { + return { + nodes: this.getNodes(), + edges: this.getEdges(), + } + } + + const state = this.reactFlowStore.getState() + return { + nodes: cloneDeep(state.getNodes()), + edges: cloneDeep(state.getEdges()), + } + } + + private startImportLog(source: 'nodes' | 'edges', before?: { nodes: Node[], edges: Edge[] }): void { + if (!this.pendingImportLog) { + const snapshot = before ?? this.snapshotReactFlowGraph() + this.pendingImportLog = { + timestamp: Date.now(), + sources: new Set([source]), + before: { + nodes: cloneDeep(snapshot.nodes), + edges: cloneDeep(snapshot.edges), + }, + } + this.recordGraphSyncDiagnostic( + 'start_import_log', + 'snapshot', + 'created', + { + source, + beforeNodes: snapshot.nodes.length, + beforeEdges: snapshot.edges.length, + }, + ) + return + } + this.pendingImportLog.sources.add(source) + this.recordGraphSyncDiagnostic( + 'start_import_log', + 'snapshot', + 'merged_source', + { + source, + sourceCount: this.pendingImportLog.sources.size, + }, + ) + } + + private finalizeImportLog(): void { + if (!this.pendingImportLog) { + this.recordGraphSyncDiagnostic('finalize_import_log', 'skipped', 'no_pending_import') + return + } + + const afterSnapshot = this.snapshotReactFlowGraph() + const entry: GraphImportLogEntry = { + timestamp: this.pendingImportLog.timestamp, + appId: this.currentAppId, + sources: Array.from(this.pendingImportLog.sources), + before: { + nodes: this.pendingImportLog.before.nodes, + edges: this.pendingImportLog.before.edges, + }, + after: { + nodes: cloneDeep(afterSnapshot.nodes), + edges: cloneDeep(afterSnapshot.edges), + }, + meta: { + leaderId: this.leaderId, + isLeader: this.isLeader, + graphViewActive: this.graphViewActive, + pendingInitialSync: this.pendingInitialSync, + }, + } + + this.graphImportLogs.push(entry) + this.recordGraphSyncDiagnostic( + 'finalize_import_log', + 'snapshot', + undefined, + { + sources: entry.sources, + beforeNodes: entry.before.nodes.length, + beforeEdges: entry.before.edges.length, + afterNodes: entry.after.nodes.length, + afterEdges: entry.after.edges.length, + }, + ) + if (this.graphImportLogs.length > GRAPH_IMPORT_LOG_LIMIT) + this.graphImportLogs.splice(0, this.graphImportLogs.length - GRAPH_IMPORT_LOG_LIMIT) + this.pendingImportLog = null + } + + private setupSocketEventListeners(socket: Socket): void { + socket.on('collaboration_update', (update: CollaborationUpdate) => { + if (update.type === 'mouse_move') { + // Update cursor state for this user + const data = update.data as { x: number, y: number } + this.cursors[update.userId] = { + x: data.x, + y: data.y, + userId: update.userId, + timestamp: update.timestamp, + } + + this.eventEmitter.emit('cursors', { ...this.cursors }) + } + else if (update.type === 'vars_and_features_update') { + this.eventEmitter.emit('varsAndFeaturesUpdate', update) + } + else if (update.type === 'app_state_update') { + this.eventEmitter.emit('appStateUpdate', update) + } + else if (update.type === 'app_meta_update') { + this.eventEmitter.emit('appMetaUpdate', update) + } + else if (update.type === 'app_publish_update') { + this.eventEmitter.emit('appPublishUpdate', update) + } + else if (update.type === 'mcp_server_update') { + this.eventEmitter.emit('mcpServerUpdate', update) + } + else if (update.type === 'workflow_update') { + this.eventEmitter.emit('workflowUpdate', update.data) + } + else if (update.type === 'comments_update') { + this.eventEmitter.emit('commentsUpdate', update.data) + } + else if (update.type === 'node_panel_presence') { + this.applyNodePanelPresenceUpdate(update.data as NodePanelPresenceEventData) + } + else if (update.type === 'sync_request') { + // Only process if we are the leader + if (this.isLeader) + this.eventEmitter.emit('syncRequest', {}) + } + else if (update.type === 'graph_resync_request') { + if (this.isLeader) + this.broadcastCurrentGraph() + } + else if (update.type === 'workflow_restore_request') { + if (this.isLeader) + this.eventEmitter.emit('restoreRequest', update.data as RestoreRequestData) + } + else if (update.type === 'workflow_restore_intent') { + this.eventEmitter.emit('restoreIntent', update.data as RestoreIntentData) + } + else if (update.type === 'workflow_restore_complete') { + this.eventEmitter.emit('restoreComplete', update.data as RestoreCompleteData) + } + else if (update.type === 'workflow_history_action') { + const data = update.data as { action?: 'undo' | 'redo' | 'jump' } | undefined + if (data?.action) + this.eventEmitter.emit('historyAction', { action: data.action, userId: update.userId }) + } + }) + + socket.on('online_users', (data: { users: OnlineUser[], leader?: string }) => { + try { + if (!data || !Array.isArray(data.users)) { + console.warn('Invalid online_users data structure:', data) + return + } + + const onlineUserIds = new Set(data.users.map((user: OnlineUser) => user.user_id)) + const onlineClientIds = new Set( + data.users + .map((user: OnlineUser) => user.sid) + .filter((sid): sid is string => typeof sid === 'string' && sid.length > 0), + ) + + // Remove cursors for offline users + Object.keys(this.cursors).forEach((userId) => { + if (!onlineUserIds.has(userId)) + delete this.cursors[userId] + }) + + this.cleanupNodePanelPresence(onlineClientIds) + + // Update leader information + if (data.leader && typeof data.leader === 'string') + this.leaderId = data.leader + + this.onlineUsers = data.users + this.eventEmitter.emit('onlineUsers', data.users) + this.eventEmitter.emit('cursors', { ...this.cursors }) + } + catch (error) { + console.error('Error processing online_users update:', error) + } + }) + + socket.on('status', (data: { isLeader: boolean }) => { + try { + if (!data || typeof data.isLeader !== 'boolean') { + console.warn('Invalid status data:', data) + return + } + + const wasLeader = this.isLeader + this.isLeader = data.isLeader + + if (this.isLeader) { + this.seedCrdtGraphFromReactFlowIfNeeded() + this.pendingInitialSync = false + } + else { + this.requestInitialSyncIfNeeded() + } + + if (wasLeader !== this.isLeader) + this.eventEmitter.emit('leaderChange', this.isLeader) + } + catch (error) { + console.error('Error processing status update:', error) + } + }) + + socket.on('connect', () => { + this.eventEmitter.emit('stateChange', { isConnected: true }) + this.pendingInitialSync = true + }) + + socket.on('disconnect', (reason) => { + this.cursors = {} + this.onlineUsers = [] + this.isLeader = false + this.leaderId = null + this.pendingInitialSync = false + this.eventEmitter.emit('stateChange', { isConnected: false, disconnectReason: reason }) + this.eventEmitter.emit('onlineUsers', []) + this.eventEmitter.emit('cursors', {}) + }) + + socket.on('connect_error', (error: Error) => { + console.error('WebSocket connection error:', error) + this.eventEmitter.emit('stateChange', { isConnected: false, error: error.message }) + }) + + socket.on('error', (error: Error) => { + console.error('WebSocket error:', error) + }) + } + + // We currently only relay CRDT updates; the server doesn't persist them. + // When a follower joins mid-session, it might miss earlier broadcasts and render stale data. + // This lightweight checkpoint asks the leader to rebroadcast the latest graph snapshot once. + private requestInitialSyncIfNeeded(): void { + if (!this.pendingInitialSync) + return + if (this.isLeader) { + this.pendingInitialSync = false + return + } + + this.emitGraphResyncRequest() + this.pendingInitialSync = false + } + + private emitGraphResyncRequest(): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) + return + + this.sendCollaborationEvent({ + type: 'graph_resync_request', + data: { timestamp: Date.now() }, + timestamp: Date.now(), + }) + } + + private seedCrdtGraphFromReactFlowIfNeeded(): void { + if (!this.doc) + return + if (!this.reactFlowStore) + return + + // CRDT may still be empty when the canvas was initially loaded from HTTP draft data + // before collaboration finished connecting, and no local mutation has been written yet. + // Seed once from the current ReactFlow graph so leader resync can broadcast a full snapshot. + if (this.getNodes().length > 0 || this.getEdges().length > 0) + return + + const state = this.reactFlowStore.getState() + const nodes = state.getNodes() + const edges = state.getEdges() + + if (!nodes.length && !edges.length) + return + + this.syncNodes([], nodes) + this.syncEdges([], edges) + this.doc.commit() + } + + private broadcastCurrentGraph(): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) + return + if (!this.doc) + return + + const socket = webSocketClient.getSocket(this.currentAppId) + if (!socket) + return + + try { + this.seedCrdtGraphFromReactFlowIfNeeded() + + if (this.getNodes().length === 0 && this.getEdges().length === 0) + return + + const snapshot = this.doc.export({ mode: 'snapshot' }) + this.sendGraphEvent(snapshot) + } + catch (error) { + console.error('Failed to broadcast graph snapshot:', error) + } + } +} + +export const collaborationManager = new CollaborationManager() diff --git a/web/app/components/workflow/collaboration/core/crdt-provider.ts b/web/app/components/workflow/collaboration/core/crdt-provider.ts new file mode 100644 index 0000000000..53528c9170 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/crdt-provider.ts @@ -0,0 +1,41 @@ +'use client' + +import type { LoroDoc } from 'loro-crdt' +import type { Socket } from 'socket.io-client' +import { emitWithAuthGuard } from './websocket-manager' + +export class CRDTProvider { + private doc: LoroDoc + private socket: Socket + private onUnauthorized?: () => void + + constructor(socket: Socket, doc: LoroDoc, onUnauthorized?: () => void) { + this.socket = socket + this.doc = doc + this.onUnauthorized = onUnauthorized + this.setupEventListeners() + } + + private setupEventListeners(): void { + this.doc.subscribe((event: { by?: string }) => { + if (event.by === 'local') { + const update = this.doc.export({ mode: 'update' }) + emitWithAuthGuard(this.socket, 'graph_event', update, { onUnauthorized: this.onUnauthorized }) + } + }) + + this.socket.on('graph_update', (updateData: Uint8Array) => { + try { + const data = new Uint8Array(updateData) + this.doc.import(data) + } + catch (error) { + console.error('Error importing graph update:', error) + } + }) + } + + destroy(): void { + this.socket.off('graph_update') + } +} diff --git a/web/app/components/workflow/collaboration/core/event-emitter.ts b/web/app/components/workflow/collaboration/core/event-emitter.ts new file mode 100644 index 0000000000..c562ae3083 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/event-emitter.ts @@ -0,0 +1,51 @@ +type EventHandler = (data: T) => void + +export class EventEmitter { + private events: Map>> = new Map() + + on(event: string, handler: EventHandler): () => void { + if (!this.events.has(event)) + this.events.set(event, new Set()) + + this.events.get(event)!.add(handler as EventHandler) + + return () => this.off(event, handler) + } + + off(event: string, handler?: EventHandler): void { + if (!this.events.has(event)) + return + + const handlers = this.events.get(event)! + if (handler) + handlers.delete(handler as EventHandler) + else + handlers.clear() + + if (handlers.size === 0) + this.events.delete(event) + } + + emit(event: string, data: T): void { + if (!this.events.has(event)) + return + + const handlers = this.events.get(event)! + handlers.forEach((handler) => { + try { + handler(data) + } + catch (error) { + console.error(`Error in event handler for ${event}:`, error) + } + }) + } + + removeAllListeners(): void { + this.events.clear() + } + + getListenerCount(event: string): number { + return this.events.get(event)?.size || 0 + } +} diff --git a/web/app/components/workflow/collaboration/core/websocket-manager.ts b/web/app/components/workflow/collaboration/core/websocket-manager.ts new file mode 100644 index 0000000000..62ffe2cf0c --- /dev/null +++ b/web/app/components/workflow/collaboration/core/websocket-manager.ts @@ -0,0 +1,157 @@ +import type { Socket } from 'socket.io-client' +import type { DebugInfo, WebSocketConfig } from '../types/websocket' +import { io } from 'socket.io-client' +import { SOCKET_URL } from '@/config' + +type AckArgs = unknown[] + +const isUnauthorizedAck = (...ackArgs: AckArgs): boolean => { + const [first, second] = ackArgs + + if (second === 401 || first === 401) + return true + + if (first && typeof first === 'object' && 'msg' in first) { + const message = (first as { msg?: unknown }).msg + return message === 'unauthorized' + } + + return false +} + +export type EmitAckOptions = { + onAck?: (...ackArgs: AckArgs) => void + onUnauthorized?: (...ackArgs: AckArgs) => void +} + +export const emitWithAuthGuard = ( + socket: Socket | null | undefined, + event: string, + payload: unknown, + options?: EmitAckOptions, +): void => { + if (!socket) + return + + socket.emit( + event, + payload, + (...ackArgs: AckArgs) => { + options?.onAck?.(...ackArgs) + if (isUnauthorizedAck(...ackArgs)) + options?.onUnauthorized?.(...ackArgs) + }, + ) +} + +export class WebSocketClient { + private connections: Map = new Map() + private connecting: Set = new Set() + private readonly url: string + private readonly transports: WebSocketConfig['transports'] + private readonly withCredentials?: boolean + + constructor(config: WebSocketConfig = {}) { + this.url = SOCKET_URL + this.transports = config.transports || ['websocket'] + this.withCredentials = config.withCredentials !== false + } + + connect(appId: string): Socket { + const existingSocket = this.connections.get(appId) + if (existingSocket?.connected) + return existingSocket + + if (this.connecting.has(appId)) { + const pendingSocket = this.connections.get(appId) + if (pendingSocket) + return pendingSocket + } + + if (existingSocket && !existingSocket.connected) { + existingSocket.disconnect() + this.connections.delete(appId) + } + + this.connecting.add(appId) + + const socketOptions: { + path: string + transports: WebSocketConfig['transports'] + withCredentials?: boolean + } = { + path: '/socket.io', + transports: this.transports, + withCredentials: this.withCredentials, + } + + const socket = io(this.url, socketOptions) + + this.connections.set(appId, socket) + this.setupBaseEventListeners(socket, appId) + + return socket + } + + disconnect(appId?: string): void { + if (appId) { + const socket = this.connections.get(appId) + if (socket) { + socket.disconnect() + this.connections.delete(appId) + this.connecting.delete(appId) + } + } + else { + this.connections.forEach(socket => socket.disconnect()) + this.connections.clear() + this.connecting.clear() + } + } + + getSocket(appId: string): Socket | null { + return this.connections.get(appId) || null + } + + isConnected(appId: string): boolean { + return this.connections.get(appId)?.connected || false + } + + getConnectedApps(): string[] { + const connectedApps: string[] = [] + this.connections.forEach((socket, appId) => { + if (socket.connected) + connectedApps.push(appId) + }) + return connectedApps + } + + getDebugInfo(): DebugInfo { + const info: DebugInfo = {} + this.connections.forEach((socket, appId) => { + info[appId] = { + connected: socket.connected, + connecting: this.connecting.has(appId), + socketId: socket.id, + } + }) + return info + } + + private setupBaseEventListeners(socket: Socket, appId: string): void { + socket.on('connect', () => { + this.connecting.delete(appId) + emitWithAuthGuard(socket, 'user_connect', { workflow_id: appId }) + }) + + socket.on('disconnect', () => { + this.connecting.delete(appId) + }) + + socket.on('connect_error', () => { + this.connecting.delete(appId) + }) + } +} + +export const webSocketClient = new WebSocketClient() diff --git a/web/app/components/workflow/collaboration/hooks/__tests__/use-collaboration.spec.ts b/web/app/components/workflow/collaboration/hooks/__tests__/use-collaboration.spec.ts new file mode 100644 index 0000000000..0f8a9e2c9a --- /dev/null +++ b/web/app/components/workflow/collaboration/hooks/__tests__/use-collaboration.spec.ts @@ -0,0 +1,151 @@ +import type { CursorPosition, NodePanelPresenceMap, OnlineUser } from '../../types/collaboration' +import { renderHook, waitFor } from '@testing-library/react' +import { useCollaboration } from '../use-collaboration' + +type HookReactFlowStore = NonNullable[1]> +type HookReactFlowInstance = Parameters['startCursorTracking']>[1] + +const mockConnect = vi.hoisted(() => vi.fn()) +const mockDisconnect = vi.hoisted(() => vi.fn()) +const mockIsConnected = vi.hoisted(() => vi.fn(() => true)) +const mockEmitCursorMove = vi.hoisted(() => vi.fn()) +const mockGetLeaderId = vi.hoisted(() => vi.fn(() => 'leader-1')) + +let onStateChangeCallback: ((state: { isConnected?: boolean, disconnectReason?: string, error?: string }) => void) | null = null +let onCursorCallback: ((cursors: Record) => void) | null = null +let onUsersCallback: ((users: OnlineUser[]) => void) | null = null +let onPresenceCallback: ((presence: NodePanelPresenceMap) => void) | null = null +let onLeaderCallback: ((isLeader: boolean) => void) | null = null + +const unsubscribeState = vi.hoisted(() => vi.fn()) +const unsubscribeCursor = vi.hoisted(() => vi.fn()) +const unsubscribeUsers = vi.hoisted(() => vi.fn()) +const unsubscribePresence = vi.hoisted(() => vi.fn()) +const unsubscribeLeader = vi.hoisted(() => vi.fn()) + +let isCollaborationEnabled = true + +const mockStartTracking = vi.hoisted(() => vi.fn()) +const mockStopTracking = vi.hoisted(() => vi.fn()) +const cursorServiceInstances: Array<{ startTracking: typeof mockStartTracking, stopTracking: typeof mockStopTracking }> = [] + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (state: { systemFeatures: { enable_collaboration_mode: boolean } }) => boolean) => + selector({ systemFeatures: { enable_collaboration_mode: isCollaborationEnabled } }), +})) + +vi.mock('../../core/collaboration-manager', () => ({ + collaborationManager: { + connect: (...args: unknown[]) => mockConnect(...args), + disconnect: (...args: unknown[]) => mockDisconnect(...args), + isConnected: () => mockIsConnected(), + emitCursorMove: (...args: unknown[]) => mockEmitCursorMove(...args), + getLeaderId: () => mockGetLeaderId(), + onStateChange: (callback: (state: { isConnected?: boolean, disconnectReason?: string, error?: string }) => void) => { + onStateChangeCallback = callback + return unsubscribeState + }, + onCursorUpdate: (callback: (cursors: Record) => void) => { + onCursorCallback = callback + return unsubscribeCursor + }, + onOnlineUsersUpdate: (callback: (users: OnlineUser[]) => void) => { + onUsersCallback = callback + return unsubscribeUsers + }, + onNodePanelPresenceUpdate: (callback: (presence: NodePanelPresenceMap) => void) => { + onPresenceCallback = callback + return unsubscribePresence + }, + onLeaderChange: (callback: (isLeader: boolean) => void) => { + onLeaderCallback = callback + return unsubscribeLeader + }, + }, +})) + +vi.mock('../../services/cursor-service', () => ({ + CursorService: class { + startTracking = mockStartTracking + stopTracking = mockStopTracking + constructor() { + cursorServiceInstances.push({ startTracking: this.startTracking, stopTracking: this.stopTracking }) + } + }, +})) + +describe('useCollaboration', () => { + beforeEach(() => { + vi.clearAllMocks() + onStateChangeCallback = null + onCursorCallback = null + onUsersCallback = null + onPresenceCallback = null + onLeaderCallback = null + isCollaborationEnabled = true + cursorServiceInstances.length = 0 + mockConnect.mockResolvedValue('conn-1') + mockIsConnected.mockReturnValue(true) + }) + + it('connects, reacts to manager updates, and disconnects on unmount', async () => { + const reactFlowStore: HookReactFlowStore = { + getState: vi.fn(), + } + const { result, unmount } = renderHook(() => useCollaboration('app-1', reactFlowStore)) + + await waitFor(() => { + expect(mockConnect).toHaveBeenCalledWith('app-1', reactFlowStore) + }) + + onStateChangeCallback?.({ isConnected: true }) + onUsersCallback?.([{ user_id: 'u1', username: 'U1', avatar: '', sid: 'sid-1' } as OnlineUser]) + onCursorCallback?.({ u1: { x: 10, y: 20, userId: 'u1', timestamp: 1 } }) + onPresenceCallback?.({ nodeA: { sid1: { userId: 'u1', username: 'U1', clientId: 'sid1', timestamp: 1 } } }) + onLeaderCallback?.(true) + + await waitFor(() => { + expect(result.current.isConnected).toBe(true) + expect(result.current.onlineUsers).toHaveLength(1) + expect(result.current.cursors.u1?.x).toBe(10) + expect(result.current.nodePanelPresence.nodeA).toBeDefined() + expect(result.current.isLeader).toBe(true) + expect(result.current.leaderId).toBe('leader-1') + }) + + const ref = { current: document.createElement('div') } + const reactFlowInstance: HookReactFlowInstance = { + getZoom: () => 1, + getViewport: () => ({ x: 0, y: 0, zoom: 1 }), + } as HookReactFlowInstance + result.current.startCursorTracking(ref, reactFlowInstance) + expect(mockStartTracking).toHaveBeenCalledTimes(1) + const emitPosition = mockStartTracking.mock.calls[0]?.[1] as ((position: CursorPosition) => void) + emitPosition({ x: 1, y: 2, userId: 'u1', timestamp: 2 }) + expect(mockEmitCursorMove).toHaveBeenCalledWith({ x: 1, y: 2, userId: 'u1', timestamp: 2 }) + + result.current.stopCursorTracking() + expect(mockStopTracking).toHaveBeenCalled() + + unmount() + expect(unsubscribeState).toHaveBeenCalled() + expect(unsubscribeCursor).toHaveBeenCalled() + expect(unsubscribeUsers).toHaveBeenCalled() + expect(unsubscribePresence).toHaveBeenCalled() + expect(unsubscribeLeader).toHaveBeenCalled() + expect(mockDisconnect).toHaveBeenCalledWith('conn-1') + }) + + it('does not connect or start cursor tracking when collaboration is disabled', async () => { + isCollaborationEnabled = false + const { result } = renderHook(() => useCollaboration('app-1')) + + await waitFor(() => { + expect(mockConnect).not.toHaveBeenCalled() + expect(result.current.isEnabled).toBe(false) + }) + + result.current.startCursorTracking({ current: document.createElement('div') }) + expect(mockStartTracking).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/collaboration/hooks/use-collaboration.ts b/web/app/components/workflow/collaboration/hooks/use-collaboration.ts new file mode 100644 index 0000000000..b24d9faea8 --- /dev/null +++ b/web/app/components/workflow/collaboration/hooks/use-collaboration.ts @@ -0,0 +1,150 @@ +import type { ReactFlowInstance } from 'reactflow' +import type { + CollaborationState, + CursorPosition, + NodePanelPresenceMap, + OnlineUser, +} from '../types/collaboration' +import { useEffect, useRef, useState } from 'react' +import { useGlobalPublicStore } from '@/context/global-public-context' +import { collaborationManager } from '../core/collaboration-manager' +import { CursorService } from '../services/cursor-service' + +type CollaborationViewState = { + isConnected: boolean + onlineUsers: OnlineUser[] + cursors: Record + nodePanelPresence: NodePanelPresenceMap + isLeader: boolean +} + +type ReactFlowStore = NonNullable[1]> + +const initialState: CollaborationViewState = { + isConnected: false, + onlineUsers: [], + cursors: {}, + nodePanelPresence: {}, + isLeader: false, +} + +export function useCollaboration(appId: string, reactFlowStore?: ReactFlowStore) { + const [state, setState] = useState(initialState) + + const cursorServiceRef = useRef(null) + const lastDisconnectReasonRef = useRef(null) + const isCollaborationEnabled = useGlobalPublicStore(s => s.systemFeatures.enable_collaboration_mode) + + useEffect(() => { + if (!appId || !isCollaborationEnabled) { + Promise.resolve().then(() => { + setState(initialState) + }) + return + } + + let connectionId: string | null = null + let isUnmounted = false + + if (!cursorServiceRef.current) + cursorServiceRef.current = new CursorService() + + const initCollaboration = async () => { + try { + const id = await collaborationManager.connect(appId, reactFlowStore) + if (isUnmounted) { + collaborationManager.disconnect(id) + return + } + connectionId = id + setState(prev => ({ ...prev, isConnected: collaborationManager.isConnected() })) + } + catch (error) { + console.error('Failed to initialize collaboration:', error) + } + } + + initCollaboration() + + const unsubscribeStateChange = collaborationManager.onStateChange((newState: Partial) => { + if (newState.isConnected === false) + lastDisconnectReasonRef.current = newState.disconnectReason || newState.error || null + if (newState.isConnected === true) + lastDisconnectReasonRef.current = null + + if (newState.isConnected === undefined) + return + + setState(prev => ({ ...prev, isConnected: newState.isConnected ?? prev.isConnected })) + }) + + const unsubscribeCursors = collaborationManager.onCursorUpdate((cursors: Record) => { + setState(prev => ({ ...prev, cursors })) + }) + + const unsubscribeUsers = collaborationManager.onOnlineUsersUpdate((users: OnlineUser[]) => { + setState(prev => ({ ...prev, onlineUsers: users })) + }) + + const unsubscribeNodePanelPresence = collaborationManager.onNodePanelPresenceUpdate((presence: NodePanelPresenceMap) => { + setState(prev => ({ ...prev, nodePanelPresence: presence })) + }) + + const unsubscribeLeaderChange = collaborationManager.onLeaderChange((isLeader: boolean) => { + setState(prev => ({ ...prev, isLeader })) + }) + + return () => { + isUnmounted = true + unsubscribeStateChange() + unsubscribeCursors() + unsubscribeUsers() + unsubscribeNodePanelPresence() + unsubscribeLeaderChange() + cursorServiceRef.current?.stopTracking() + if (connectionId) + collaborationManager.disconnect(connectionId) + } + }, [appId, reactFlowStore, isCollaborationEnabled]) + + const prevIsConnected = useRef(false) + useEffect(() => { + if (prevIsConnected.current && !state.isConnected) { + const reason = lastDisconnectReasonRef.current + if (reason) + console.warn('WebSocket disconnected:', reason) + else + console.warn('WebSocket disconnected.') + } + prevIsConnected.current = state.isConnected || false + }, [state.isConnected]) + + const startCursorTracking = (containerRef: React.RefObject, reactFlowInstance?: ReactFlowInstance) => { + if (!isCollaborationEnabled || !cursorServiceRef.current) + return + + if (cursorServiceRef.current) { + cursorServiceRef.current.startTracking(containerRef, (position) => { + collaborationManager.emitCursorMove(position) + }, reactFlowInstance) + } + } + + const stopCursorTracking = () => { + cursorServiceRef.current?.stopTracking() + } + + const result = { + isConnected: state.isConnected || false, + onlineUsers: state.onlineUsers || [], + cursors: state.cursors || {}, + nodePanelPresence: state.nodePanelPresence || {}, + isLeader: state.isLeader || false, + leaderId: collaborationManager.getLeaderId(), + isEnabled: isCollaborationEnabled, + startCursorTracking, + stopCursorTracking, + } + + return result +} diff --git a/web/app/components/workflow/collaboration/services/__tests__/cursor-service.spec.ts b/web/app/components/workflow/collaboration/services/__tests__/cursor-service.spec.ts new file mode 100644 index 0000000000..239435bec0 --- /dev/null +++ b/web/app/components/workflow/collaboration/services/__tests__/cursor-service.spec.ts @@ -0,0 +1,86 @@ +import type { ReactFlowInstance } from 'reactflow' +import { CursorService } from '../cursor-service' + +describe('CursorService', () => { + let service: CursorService + let container: HTMLDivElement + let now = 0 + + beforeEach(() => { + service = new CursorService() + container = document.createElement('div') + document.body.appendChild(container) + vi.spyOn(container, 'getBoundingClientRect').mockReturnValue({ + x: 10, + y: 20, + top: 20, + left: 10, + right: 410, + bottom: 220, + width: 400, + height: 200, + toJSON: () => ({}), + } as DOMRect) + now = 1000 + vi.spyOn(Date, 'now').mockImplementation(() => now) + }) + + afterEach(() => { + vi.restoreAllMocks() + container.remove() + }) + + it('emits transformed positions with throttle and distance guard', () => { + const onEmit = vi.fn() + const reactFlow = { + getViewport: () => ({ x: 5, y: 10, zoom: 2 }), + getZoom: () => 2, + } as unknown as ReactFlowInstance + + service.startTracking({ current: container }, onEmit, reactFlow) + + container.dispatchEvent(new MouseEvent('mousemove', { clientX: 30, clientY: 50 })) + expect(onEmit).toHaveBeenCalledTimes(1) + expect(onEmit).toHaveBeenLastCalledWith(expect.objectContaining({ + x: 7.5, + y: 10, + timestamp: 1000, + })) + + now = 1100 + container.dispatchEvent(new MouseEvent('mousemove', { clientX: 40, clientY: 60 })) + expect(onEmit).toHaveBeenCalledTimes(1) + + now = 1401 + container.dispatchEvent(new MouseEvent('mousemove', { clientX: 33, clientY: 53 })) + expect(onEmit).toHaveBeenCalledTimes(1) + + now = 1800 + container.dispatchEvent(new MouseEvent('mousemove', { clientX: 60, clientY: 90 })) + expect(onEmit).toHaveBeenCalledTimes(2) + expect(onEmit).toHaveBeenLastCalledWith(expect.objectContaining({ + x: 22.5, + y: 30, + timestamp: 1800, + })) + }) + + it('stops tracking and forwards cursor updates to registered handler', () => { + const onEmit = vi.fn() + const onCursorUpdate = vi.fn() + service.startTracking({ current: container }, onEmit) + service.setCursorUpdateHandler(onCursorUpdate) + + service.updateCursors({ + u1: { x: 1, y: 2, userId: 'u1', timestamp: 1 }, + }) + expect(onCursorUpdate).toHaveBeenCalledWith({ + u1: { x: 1, y: 2, userId: 'u1', timestamp: 1 }, + }) + + service.stopTracking() + now = 2000 + container.dispatchEvent(new MouseEvent('mousemove', { clientX: 40, clientY: 60 })) + expect(onEmit).not.toHaveBeenCalled() + }) +}) diff --git a/web/app/components/workflow/collaboration/services/cursor-service.ts b/web/app/components/workflow/collaboration/services/cursor-service.ts new file mode 100644 index 0000000000..7af6f2f27f --- /dev/null +++ b/web/app/components/workflow/collaboration/services/cursor-service.ts @@ -0,0 +1,90 @@ +import type { RefObject } from 'react' +import type { ReactFlowInstance } from 'reactflow' +import type { CursorPosition } from '../types/collaboration' + +const CURSOR_MIN_MOVE_DISTANCE = 10 +const CURSOR_THROTTLE_MS = 300 + +export class CursorService { + private containerRef: RefObject | null = null + private reactFlowInstance: ReactFlowInstance | null = null + private isTracking = false + private onCursorUpdate: ((cursors: Record) => void) | null = null + private onEmitPosition: ((position: CursorPosition) => void) | null = null + private lastEmitTime = 0 + private lastPosition: { x: number, y: number } | null = null + + startTracking( + containerRef: RefObject, + onEmitPosition: (position: CursorPosition) => void, + reactFlowInstance?: ReactFlowInstance, + ): void { + if (this.isTracking) + this.stopTracking() + + this.containerRef = containerRef + this.onEmitPosition = onEmitPosition + this.reactFlowInstance = reactFlowInstance || null + this.isTracking = true + + if (containerRef.current) + containerRef.current.addEventListener('mousemove', this.handleMouseMove) + } + + stopTracking(): void { + if (this.containerRef?.current) + this.containerRef.current.removeEventListener('mousemove', this.handleMouseMove) + + this.containerRef = null + this.reactFlowInstance = null + this.onEmitPosition = null + this.isTracking = false + this.lastPosition = null + } + + setCursorUpdateHandler(handler: (cursors: Record) => void): void { + this.onCursorUpdate = handler + } + + updateCursors(cursors: Record): void { + if (this.onCursorUpdate) + this.onCursorUpdate(cursors) + } + + private handleMouseMove = (event: MouseEvent): void => { + if (!this.containerRef?.current || !this.onEmitPosition) + return + + const rect = this.containerRef.current.getBoundingClientRect() + let x = event.clientX - rect.left + let y = event.clientY - rect.top + + // Transform coordinates to ReactFlow world coordinates if ReactFlow instance is available + if (this.reactFlowInstance) { + const viewport = this.reactFlowInstance.getViewport() + // Convert screen coordinates to world coordinates + // World coordinates = (screen coordinates - viewport translation) / zoom + x = (x - viewport.x) / viewport.zoom + y = (y - viewport.y) / viewport.zoom + } + + // Always emit cursor position (remove boundary check since world coordinates can be negative) + const now = Date.now() + const timeThrottled = now - this.lastEmitTime > CURSOR_THROTTLE_MS + const minDistance = CURSOR_MIN_MOVE_DISTANCE / (this.reactFlowInstance?.getZoom() || 1) + const distanceThrottled = !this.lastPosition + || (Math.abs(x - this.lastPosition.x) > minDistance) + || (Math.abs(y - this.lastPosition.y) > minDistance) + + if (timeThrottled && distanceThrottled) { + this.lastPosition = { x, y } + this.lastEmitTime = now + this.onEmitPosition({ + x, + y, + userId: '', + timestamp: now, + }) + } + } +} diff --git a/web/app/components/workflow/collaboration/types/collaboration.ts b/web/app/components/workflow/collaboration/types/collaboration.ts new file mode 100644 index 0000000000..ae355a7b51 --- /dev/null +++ b/web/app/components/workflow/collaboration/types/collaboration.ts @@ -0,0 +1,92 @@ +import type { Viewport } from 'reactflow' +import type { ConversationVariable, Edge, EnvironmentVariable, Node } from '../../types' +import type { Features } from '@/app/components/base/features/types' + +export type OnlineUser = { + user_id: string + username: string + avatar: string + sid: string +} + +export type CursorPosition = { + x: number + y: number + userId: string + timestamp: number +} + +export type NodePanelPresenceUser = { + userId: string + username: string + avatar?: string | null +} + +export type NodePanelPresenceInfo = NodePanelPresenceUser & { + clientId: string + timestamp: number +} + +export type NodePanelPresenceMap = Record> + +export type CollaborationState = { + appId: string + isConnected: boolean + onlineUsers: OnlineUser[] + cursors: Record + nodePanelPresence: NodePanelPresenceMap + disconnectReason?: string + error?: string +} + +export type CollaborationEventType + = | 'mouse_move' + | 'vars_and_features_update' + | 'sync_request' + | 'app_state_update' + | 'app_meta_update' + | 'mcp_server_update' + | 'workflow_update' + | 'comments_update' + | 'node_panel_presence' + | 'app_publish_update' + | 'graph_resync_request' + | 'workflow_restore_request' + | 'workflow_restore_intent' + | 'workflow_restore_complete' + | 'workflow_history_action' + +export type CollaborationUpdate = { + type: CollaborationEventType + userId: string + data: Record + timestamp: number +} + +export type RestoreRequestData = { + versionId: string + versionName?: string + initiatorUserId: string + initiatorName: string + graphData: { + nodes: Node[] + edges: Edge[] + viewport?: Viewport + } + features?: Features + environmentVariables?: EnvironmentVariable[] + conversationVariables?: ConversationVariable[] +} + +export type RestoreIntentData = { + versionId: string + versionName?: string + initiatorUserId: string + initiatorName: string +} + +export type RestoreCompleteData = { + versionId: string + success: boolean + error?: string +} diff --git a/web/app/components/workflow/collaboration/types/websocket.ts b/web/app/components/workflow/collaboration/types/websocket.ts new file mode 100644 index 0000000000..dd89df323f --- /dev/null +++ b/web/app/components/workflow/collaboration/types/websocket.ts @@ -0,0 +1,15 @@ +export type WebSocketConfig = { + token?: string + transports?: string[] + withCredentials?: boolean +} + +export type ConnectionInfo = { + connected: boolean + connecting: boolean + socketId?: string +} + +export type DebugInfo = { + [appId: string]: ConnectionInfo +} diff --git a/web/app/components/workflow/collaboration/utils/user-color.ts b/web/app/components/workflow/collaboration/utils/user-color.ts new file mode 100644 index 0000000000..51aee6a038 --- /dev/null +++ b/web/app/components/workflow/collaboration/utils/user-color.ts @@ -0,0 +1,12 @@ +/** + * Generate a consistent color for a user based on their ID + * Used for cursor colors and avatar backgrounds + */ +export const getUserColor = (id: string): string => { + const colors = ['#155AEF', '#0BA5EC', '#444CE7', '#7839EE', '#4CA30D', '#0E9384', '#DD2590', '#FF4405', '#D92D20', '#F79009', '#828DAD'] + const hash = id.split('').reduce((a, b) => { + a = ((a << 5) - a) + b.charCodeAt(0) + return a & a + }, 0) + return colors[Math.abs(hash) % colors.length] +} diff --git a/web/app/components/workflow/comment-manager.tsx b/web/app/components/workflow/comment-manager.tsx new file mode 100644 index 0000000000..41175ae04d --- /dev/null +++ b/web/app/components/workflow/comment-manager.tsx @@ -0,0 +1,55 @@ +import { useEventListener } from 'ahooks' +import { useWorkflowComment } from './hooks/use-workflow-comment' +import { useWorkflowStore } from './store' + +const CommentManager = () => { + const workflowStore = useWorkflowStore() + const { handleCreateComment, handleCommentCancel } = useWorkflowComment() + + useEventListener('click', (e) => { + const { controlMode, mousePosition, pendingComment, isCommentPlacing } = workflowStore.getState() + const target = e.target as HTMLElement + const isInDropdown = target.closest('[data-mention-dropdown]') + const isInCommentInput = target.closest('[data-comment-input]') + const isOnCanvasPane = target.closest('.react-flow__pane') + + if (isCommentPlacing) { + if (!isInDropdown && !isInCommentInput && isOnCanvasPane) { + e.preventDefault() + e.stopPropagation() + workflowStore.setState({ + pendingComment: mousePosition, + isCommentPlacing: false, + }) + } + return + } + + if (controlMode === 'comment') { + // Only when clicking on the React Flow canvas pane (background), + // and not inside comment input or its dropdown + if (!isInDropdown && !isInCommentInput && isOnCanvasPane) { + e.preventDefault() + e.stopPropagation() + if (pendingComment) + handleCommentCancel() + else + handleCreateComment(mousePosition) + } + } + }) + + useEventListener('contextmenu', () => { + const { isCommentPlacing } = workflowStore.getState() + if (!isCommentPlacing) + return + workflowStore.setState({ + isCommentPlacing: false, + isCommentQuickAdd: false, + }) + }) + + return null +} + +export default CommentManager diff --git a/web/app/components/workflow/comment/comment-icon.spec.tsx b/web/app/components/workflow/comment/comment-icon.spec.tsx new file mode 100644 index 0000000000..aee8c64fa3 --- /dev/null +++ b/web/app/components/workflow/comment/comment-icon.spec.tsx @@ -0,0 +1,148 @@ +import type { WorkflowCommentList } from '@/service/workflow-comment' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { CommentIcon } from './comment-icon' + +type Position = { x: number, y: number } + +let mockUserId = 'user-1' + +const mockFlowToScreenPosition = vi.fn((position: Position) => position) +const mockScreenToFlowPosition = vi.fn((position: Position) => position) + +vi.mock('reactflow', () => ({ + useReactFlow: () => ({ + flowToScreenPosition: mockFlowToScreenPosition, + screenToFlowPosition: mockScreenToFlowPosition, + }), + useViewport: () => ({ + x: 0, + y: 0, + zoom: 1, + }), +})) + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + userProfile: { + id: mockUserId, + name: 'User', + avatar_url: 'avatar', + }, + }), +})) + +vi.mock('@/app/components/base/user-avatar-list', () => ({ + UserAvatarList: ({ users }: { users: Array<{ id: string }> }) => ( +
{users.map(user => user.id).join(',')}
+ ), +})) + +vi.mock('./comment-preview', () => ({ + default: ({ onClick }: { onClick?: () => void }) => ( + + ), +})) + +const createComment = (overrides: Partial = {}): WorkflowCommentList => ({ + id: 'comment-1', + position_x: 0, + position_y: 0, + content: 'Hello', + created_by: 'user-1', + created_by_account: { + id: 'user-1', + name: 'Alice', + email: 'alice@example.com', + }, + created_at: 1, + updated_at: 2, + resolved: false, + mention_count: 0, + reply_count: 0, + participants: [], + ...overrides, +}) + +describe('CommentIcon', () => { + beforeEach(() => { + vi.clearAllMocks() + mockUserId = 'user-1' + }) + + it('toggles preview on hover when inactive', () => { + const comment = createComment() + const { container } = render( + , + ) + const marker = container.querySelector('[data-role="comment-marker"]') as HTMLElement + const hoverTarget = marker.firstElementChild as HTMLElement + + fireEvent.mouseEnter(hoverTarget) + expect(screen.getByTestId('comment-preview')).toBeInTheDocument() + + fireEvent.mouseLeave(hoverTarget) + expect(screen.queryByTestId('comment-preview')).not.toBeInTheDocument() + }) + + it('calls onPositionUpdate after dragging by author', () => { + const comment = createComment({ position_x: 0, position_y: 0 }) + const onClick = vi.fn() + const onPositionUpdate = vi.fn() + const { container } = render( + , + ) + const marker = container.querySelector('[data-role="comment-marker"]') as HTMLElement + + fireEvent.pointerDown(marker, { + pointerId: 1, + button: 0, + clientX: 100, + clientY: 100, + }) + fireEvent.pointerMove(marker, { + pointerId: 1, + clientX: 110, + clientY: 110, + }) + fireEvent.pointerUp(marker, { + pointerId: 1, + clientX: 110, + clientY: 110, + }) + + expect(mockScreenToFlowPosition).toHaveBeenCalledWith({ x: 10, y: 10 }) + expect(onPositionUpdate).toHaveBeenCalledWith({ x: 10, y: 10 }) + expect(onClick).not.toHaveBeenCalled() + }) + + it('calls onClick for non-author clicks', () => { + mockUserId = 'user-2' + const comment = createComment() + const onClick = vi.fn() + const { container } = render( + , + ) + const marker = container.querySelector('[data-role="comment-marker"]') as HTMLElement + + fireEvent.pointerDown(marker, { + pointerId: 1, + button: 0, + clientX: 50, + clientY: 60, + }) + fireEvent.pointerUp(marker, { + pointerId: 1, + clientX: 50, + clientY: 60, + }) + + expect(onClick).toHaveBeenCalledTimes(1) + }) +}) diff --git a/web/app/components/workflow/comment/comment-icon.tsx b/web/app/components/workflow/comment/comment-icon.tsx new file mode 100644 index 0000000000..7f005f3465 --- /dev/null +++ b/web/app/components/workflow/comment/comment-icon.tsx @@ -0,0 +1,269 @@ +'use client' + +import type { FC, PointerEvent as ReactPointerEvent } from 'react' +import type { WorkflowCommentList } from '@/service/workflow-comment' +import { memo, useCallback, useMemo, useRef, useState } from 'react' +import { useReactFlow, useViewport } from 'reactflow' +import { UserAvatarList } from '@/app/components/base/user-avatar-list' +import { useAppContext } from '@/context/app-context' +import CommentPreview from './comment-preview' + +type CommentIconProps = { + comment: WorkflowCommentList + onClick: () => void + isActive?: boolean + onPositionUpdate?: (position: { x: number, y: number }) => void +} + +export const CommentIcon: FC = memo(({ comment, onClick, isActive = false, onPositionUpdate }) => { + const { flowToScreenPosition, screenToFlowPosition } = useReactFlow() + const viewport = useViewport() + const { userProfile } = useAppContext() + const isAuthor = comment.created_by_account?.id === userProfile?.id + const [showPreview, setShowPreview] = useState(false) + const [dragPosition, setDragPosition] = useState<{ x: number, y: number } | null>(null) + const [isDragging, setIsDragging] = useState(false) + const dragStateRef = useRef<{ + offsetX: number + offsetY: number + startX: number + startY: number + hasMoved: boolean + } | null>(null) + + const workflowContainerRect = typeof document !== 'undefined' + ? document.getElementById('workflow-container')?.getBoundingClientRect() + : null + const containerLeft = workflowContainerRect?.left ?? 0 + const containerTop = workflowContainerRect?.top ?? 0 + + const screenPosition = useMemo(() => { + return flowToScreenPosition({ + x: comment.position_x, + y: comment.position_y, + }) + }, [comment.position_x, comment.position_y, viewport.x, viewport.y, viewport.zoom, flowToScreenPosition]) + + const effectiveScreenPosition = dragPosition ?? screenPosition + const canvasPosition = useMemo(() => ({ + x: effectiveScreenPosition.x - containerLeft, + y: effectiveScreenPosition.y - containerTop, + }), [effectiveScreenPosition.x, effectiveScreenPosition.y, containerLeft, containerTop]) + const cursorClass = useMemo(() => { + if (!isAuthor) + return 'cursor-pointer' + if (isActive) + return isDragging ? 'cursor-grabbing' : '' + return isDragging ? 'cursor-grabbing' : 'cursor-pointer' + }, [isActive, isAuthor, isDragging]) + + const handlePointerDown = useCallback((event: ReactPointerEvent) => { + if (event.button !== 0) + return + + event.stopPropagation() + event.preventDefault() + + if (!isAuthor) { + if (event.currentTarget.dataset.role !== 'comment-preview') + setShowPreview(false) + return + } + + dragStateRef.current = { + offsetX: event.clientX - screenPosition.x, + offsetY: event.clientY - screenPosition.y, + startX: event.clientX, + startY: event.clientY, + hasMoved: false, + } + + setDragPosition(screenPosition) + setIsDragging(false) + + if (event.currentTarget.dataset.role !== 'comment-preview') + setShowPreview(false) + + if (event.currentTarget.setPointerCapture) + event.currentTarget.setPointerCapture(event.pointerId) + }, [isAuthor, screenPosition]) + + const handlePointerMove = useCallback((event: ReactPointerEvent) => { + const dragState = dragStateRef.current + if (!dragState) + return + + event.stopPropagation() + event.preventDefault() + + const nextX = event.clientX - dragState.offsetX + const nextY = event.clientY - dragState.offsetY + + if (!dragState.hasMoved) { + const distance = Math.hypot(event.clientX - dragState.startX, event.clientY - dragState.startY) + if (distance > 4) { + dragState.hasMoved = true + setIsDragging(true) + } + } + + setDragPosition({ x: nextX, y: nextY }) + }, []) + + const finishDrag = useCallback((event: ReactPointerEvent) => { + const dragState = dragStateRef.current + if (!dragState) + return false + + if (event.currentTarget.hasPointerCapture?.(event.pointerId)) + event.currentTarget.releasePointerCapture(event.pointerId) + + dragStateRef.current = null + setDragPosition(null) + setIsDragging(false) + return dragState.hasMoved + }, []) + + const handlePointerUp = useCallback((event: ReactPointerEvent) => { + event.stopPropagation() + event.preventDefault() + + const finalScreenPosition = dragPosition ?? screenPosition + const didDrag = finishDrag(event) + + setShowPreview(false) + + if (didDrag) { + if (onPositionUpdate) { + const flowPosition = screenToFlowPosition({ + x: finalScreenPosition.x, + y: finalScreenPosition.y, + }) + onPositionUpdate(flowPosition) + } + } + else if (!isActive) { + onClick() + } + }, [dragPosition, finishDrag, isActive, onClick, onPositionUpdate, screenPosition, screenToFlowPosition]) + + const handlePointerCancel = useCallback((event: ReactPointerEvent) => { + event.stopPropagation() + event.preventDefault() + finishDrag(event) + }, [finishDrag]) + + const handleMouseEnter = useCallback(() => { + if (isActive || isDragging) + return + setShowPreview(true) + }, [isActive, isDragging]) + + const handleMouseLeave = useCallback(() => { + setShowPreview(false) + }, []) + + const participants = useMemo(() => { + const list = comment.participants ?? [] + const author = comment.created_by_account + if (!author) + return [...list] + const rest = list.filter(user => user.id !== author.id) + return [author, ...rest] + }, [comment.created_by_account, comment.participants]) + + // Calculate dynamic width based on number of participants + const participantCount = participants.length + const maxVisible = Math.min(3, participantCount) + const showCount = participantCount > 3 + const avatarSize = 24 + const avatarSpacing = 4 // -space-x-1 is about 4px overlap + + // Width calculation: first avatar + (additional avatars * (size - spacing)) + padding + const dynamicWidth = Math.max(40, // minimum width + 8 + avatarSize + Math.max(0, (showCount ? 2 : maxVisible - 1)) * (avatarSize - avatarSpacing) + 8) + + const pointerEventHandlers = useMemo(() => ({ + onPointerDown: handlePointerDown, + onPointerMove: handlePointerMove, + onPointerUp: handlePointerUp, + onPointerCancel: handlePointerCancel, + }), [handlePointerCancel, handlePointerDown, handlePointerMove, handlePointerUp]) + + return ( + <> +
+
+
+
+
+ +
+
+
+
+
+ + {/* Preview panel */} + {showPreview && !isActive && ( +
setShowPreview(true)} + onMouseLeave={() => setShowPreview(false)} + > + { + setShowPreview(false) + onClick() + }} + /> +
+ )} + + ) +}, (prevProps, nextProps) => { + return ( + prevProps.comment.id === nextProps.comment.id + && prevProps.comment.position_x === nextProps.comment.position_x + && prevProps.comment.position_y === nextProps.comment.position_y + && prevProps.onClick === nextProps.onClick + && prevProps.isActive === nextProps.isActive + && prevProps.onPositionUpdate === nextProps.onPositionUpdate + ) +}) + +CommentIcon.displayName = 'CommentIcon' diff --git a/web/app/components/workflow/comment/comment-input.spec.tsx b/web/app/components/workflow/comment/comment-input.spec.tsx new file mode 100644 index 0000000000..41e583aeed --- /dev/null +++ b/web/app/components/workflow/comment/comment-input.spec.tsx @@ -0,0 +1,109 @@ +import type { FC } from 'react' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { CommentInput } from './comment-input' + +type MentionInputProps = { + value: string + onChange: (value: string) => void + onSubmit: (content: string, mentionedUserIds: string[]) => void + placeholder?: string + disabled?: boolean + autoFocus?: boolean + className?: string +} + +const stableT = (key: string, options?: { ns?: string }) => ( + options?.ns ? `${options.ns}.${key}` : key +) + +let mentionInputProps: MentionInputProps | null = null + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: stableT, + }), +})) + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + userProfile: { + id: 'user-1', + name: 'Alice', + avatar_url: 'avatar', + }, + }), +})) + +vi.mock('@/app/components/base/ui/avatar', () => ({ + Avatar: ({ name }: { name: string }) =>
{name}
, + default: ({ name }: { name: string }) =>
{name}
, +})) + +vi.mock('./mention-input', () => ({ + MentionInput: ((props: MentionInputProps) => { + mentionInputProps = props + return ( + + ) + }) as FC, +})) + +describe('CommentInput', () => { + beforeEach(() => { + vi.clearAllMocks() + mentionInputProps = null + }) + + it('passes translated placeholder to mention input', () => { + render( + , + ) + + expect(mentionInputProps?.placeholder).toBe('workflow.comments.placeholder.add') + expect(mentionInputProps?.autoFocus).toBe(true) + expect(mentionInputProps?.disabled).toBe(false) + }) + + it('calls onCancel when Escape is pressed', () => { + const onCancel = vi.fn() + + render( + , + ) + + fireEvent.keyDown(document, { key: 'Escape' }) + + expect(onCancel).toHaveBeenCalledTimes(1) + }) + + it('forwards mention submit to onSubmit', () => { + const onSubmit = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByTestId('mention-input')) + + expect(onSubmit).toHaveBeenCalledWith('Hello', ['user-2']) + }) +}) diff --git a/web/app/components/workflow/comment/comment-input.tsx b/web/app/components/workflow/comment/comment-input.tsx new file mode 100644 index 0000000000..51cef610c4 --- /dev/null +++ b/web/app/components/workflow/comment/comment-input.tsx @@ -0,0 +1,184 @@ +import type { FC, PointerEvent as ReactPointerEvent } from 'react' +import { cn } from '@langgenius/dify-ui/cn' +import { memo, useCallback, useEffect, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { Avatar } from '@/app/components/base/ui/avatar' +import { useAppContext } from '@/context/app-context' +import { MentionInput } from './mention-input' + +type CommentInputProps = { + position: { x: number, y: number } + onSubmit: (content: string, mentionedUserIds: string[]) => void + onCancel: () => void + autoFocus?: boolean + disabled?: boolean + onPositionChange?: (position: { + pageX: number + pageY: number + elementX: number + elementY: number + }) => void +} + +export const CommentInput: FC = memo(({ + position, + onSubmit, + onCancel, + autoFocus = true, + disabled = false, + onPositionChange, +}) => { + const [content, setContent] = useState('') + const { t } = useTranslation() + const { userProfile } = useAppContext() + const dragStateRef = useRef<{ + pointerId: number | null + startPointerX: number + startPointerY: number + startX: number + startY: number + active: boolean + } & { + endHandler?: (event: PointerEvent) => void + }>({ + pointerId: null, + startPointerX: 0, + startPointerY: 0, + startX: 0, + startY: 0, + active: false, + endHandler: undefined, + }) + + useEffect(() => { + const handleGlobalKeyDown = (e: KeyboardEvent) => { + if (e.key === 'Escape') { + e.preventDefault() + e.stopPropagation() + onCancel() + } + } + + document.addEventListener('keydown', handleGlobalKeyDown, true) + return () => { + document.removeEventListener('keydown', handleGlobalKeyDown, true) + } + }, [onCancel]) + + const handleMentionSubmit = useCallback((content: string, mentionedUserIds: string[]) => { + onSubmit(content, mentionedUserIds) + setContent('') + }, [onSubmit]) + + const handleDragPointerMove = useCallback((event: PointerEvent) => { + const state = dragStateRef.current + if (!state.active || (state.pointerId !== null && event.pointerId !== state.pointerId)) + return + if (!onPositionChange) + return + event.preventDefault() + const deltaX = event.clientX - state.startPointerX + const deltaY = event.clientY - state.startPointerY + onPositionChange({ + pageX: event.clientX, + pageY: event.clientY, + elementX: state.startX + deltaX, + elementY: state.startY + deltaY, + }) + }, [onPositionChange]) + + const stopDragging = useCallback((event?: PointerEvent) => { + const state = dragStateRef.current + if (!state.active) + return + if (event && state.pointerId !== null && event.pointerId !== state.pointerId) + return + state.active = false + state.pointerId = null + window.removeEventListener('pointermove', handleDragPointerMove) + if (state.endHandler) { + window.removeEventListener('pointerup', state.endHandler) + window.removeEventListener('pointercancel', state.endHandler) + state.endHandler = undefined + } + }, [handleDragPointerMove]) + + const handleDragPointerDown = useCallback((event: ReactPointerEvent) => { + if (event.button !== 0) + return + event.stopPropagation() + event.preventDefault() + if (!onPositionChange) + return + const endHandler = (pointerEvent: PointerEvent) => { + stopDragging(pointerEvent) + } + dragStateRef.current = { + pointerId: event.pointerId, + startPointerX: event.clientX, + startPointerY: event.clientY, + startX: position.x, + startY: position.y, + active: true, + endHandler, + } + window.addEventListener('pointermove', handleDragPointerMove, { passive: false }) + window.addEventListener('pointerup', endHandler) + window.addEventListener('pointercancel', endHandler) + }, [handleDragPointerMove, onPositionChange, position.x, position.y, stopDragging]) + + useEffect(() => () => { + stopDragging() + }, [stopDragging]) + + return ( +
+
+
+
+
+ +
+
+
+
+
+ +
+
+
+
+ ) +}) + +CommentInput.displayName = 'CommentInput' diff --git a/web/app/components/workflow/comment/comment-preview.spec.tsx b/web/app/components/workflow/comment/comment-preview.spec.tsx new file mode 100644 index 0000000000..d411c67ecd --- /dev/null +++ b/web/app/components/workflow/comment/comment-preview.spec.tsx @@ -0,0 +1,86 @@ +import type { WorkflowCommentList } from '@/service/workflow-comment' +import { fireEvent, render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import CommentPreview from './comment-preview' + +type UserProfile = WorkflowCommentList['created_by_account'] + +const mockSetHovering = vi.fn() +let capturedUsers: UserProfile[] = [] + +vi.mock('@/app/components/base/user-avatar-list', () => ({ + UserAvatarList: ({ users }: { users: UserProfile[] }) => { + capturedUsers = users + return
{users.map(user => user.id).join(',')}
+ }, +})) + +vi.mock('@/hooks/use-format-time-from-now', () => ({ + useFormatTimeFromNow: () => ({ + formatTimeFromNow: (value: number) => `time:${value}`, + }), +})) + +vi.mock('../store', () => ({ + useStore: (selector: (state: { setCommentPreviewHovering: (value: boolean) => void }) => unknown) => + selector({ setCommentPreviewHovering: mockSetHovering }), +})) + +const createComment = (overrides: Partial = {}): WorkflowCommentList => { + const author = { id: 'user-1', name: 'Alice', email: 'alice@example.com' } + const participant = { id: 'user-2', name: 'Bob', email: 'bob@example.com' } + + return { + id: 'comment-1', + position_x: 0, + position_y: 0, + content: 'Hello', + created_by: author.id, + created_by_account: author, + created_at: 1, + updated_at: 10, + resolved: false, + mention_count: 0, + reply_count: 0, + participants: [author, participant], + ...overrides, + } +} + +describe('CommentPreview', () => { + beforeEach(() => { + vi.clearAllMocks() + capturedUsers = [] + }) + + it('orders participants with author first and formats time', () => { + const comment = createComment() + + render() + + expect(capturedUsers.map(user => user.id)).toEqual(['user-1', 'user-2']) + expect(screen.getByText('Hello')).toBeInTheDocument() + expect(screen.getByText('time:10000')).toBeInTheDocument() + }) + + it('updates hover state on enter and leave', () => { + const comment = createComment() + const { container } = render() + const root = container.firstElementChild as HTMLElement + + fireEvent.mouseEnter(root) + fireEvent.mouseLeave(root) + + expect(mockSetHovering).toHaveBeenCalledWith(true) + expect(mockSetHovering).toHaveBeenCalledWith(false) + }) + + it('clears hover state on unmount', () => { + const comment = createComment() + const { unmount } = render() + + unmount() + + expect(mockSetHovering).toHaveBeenCalledWith(false) + }) +}) diff --git a/web/app/components/workflow/comment/comment-preview.tsx b/web/app/components/workflow/comment/comment-preview.tsx new file mode 100644 index 0000000000..5985ed848b --- /dev/null +++ b/web/app/components/workflow/comment/comment-preview.tsx @@ -0,0 +1,59 @@ +'use client' + +import type { FC } from 'react' +import type { WorkflowCommentList } from '@/service/workflow-comment' +import { memo, useEffect, useMemo } from 'react' +import { UserAvatarList } from '@/app/components/base/user-avatar-list' +import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' +import { useStore } from '../store' + +type CommentPreviewProps = { + comment: WorkflowCommentList + onClick?: () => void +} + +const CommentPreview: FC = ({ comment, onClick }) => { + const { formatTimeFromNow } = useFormatTimeFromNow() + const setCommentPreviewHovering = useStore(s => s.setCommentPreviewHovering) + const participants = useMemo(() => { + const list = comment.participants ?? [] + const author = comment.created_by_account + if (!author) + return [...list] + const rest = list.filter(user => user.id !== author.id) + return [author, ...rest] + }, [comment.created_by_account, comment.participants]) + useEffect(() => () => { + setCommentPreviewHovering(false) + }, [setCommentPreviewHovering]) + + return ( +
setCommentPreviewHovering(true)} + onMouseLeave={() => setCommentPreviewHovering(false)} + > +
+ +
+ +
+
+
{comment.created_by_account.name}
+
+ {formatTimeFromNow(comment.updated_at * 1000)} +
+
+
+ +
{comment.content}
+
+ ) +} + +export default memo(CommentPreview) diff --git a/web/app/components/workflow/comment/cursor.spec.tsx b/web/app/components/workflow/comment/cursor.spec.tsx new file mode 100644 index 0000000000..c6cd64c0aa --- /dev/null +++ b/web/app/components/workflow/comment/cursor.spec.tsx @@ -0,0 +1,56 @@ +import { render, screen } from '@testing-library/react' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { ControlMode } from '../types' +import { CommentCursor } from './cursor' + +const mockState = { + controlMode: ControlMode.Pointer, + isCommentPlacing: false, + mousePosition: { + elementX: 10, + elementY: 20, + }, +} + +vi.mock('@/app/components/base/icons/src/public/other', () => ({ + Comment: (props: { className?: string }) => , +})) + +vi.mock('../store', () => ({ + useStore: (selector: (state: typeof mockState) => unknown) => selector(mockState), +})) + +describe('CommentCursor', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('renders nothing when not in comment mode', () => { + mockState.controlMode = ControlMode.Pointer + + render() + + expect(screen.queryByTestId('comment-icon')).not.toBeInTheDocument() + }) + + it('renders at current mouse position when in comment mode', () => { + mockState.controlMode = ControlMode.Comment + mockState.isCommentPlacing = false + + render() + + const icon = screen.getByTestId('comment-icon') + const container = icon.parentElement as HTMLElement + + expect(container).toHaveStyle({ left: '10px', top: '20px' }) + }) + + it('renders nothing when comment is in placing mode', () => { + mockState.controlMode = ControlMode.Comment + mockState.isCommentPlacing = true + + render() + + expect(screen.queryByTestId('comment-icon')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/workflow/comment/cursor.tsx b/web/app/components/workflow/comment/cursor.tsx new file mode 100644 index 0000000000..02f09f7d92 --- /dev/null +++ b/web/app/components/workflow/comment/cursor.tsx @@ -0,0 +1,29 @@ +import type { FC } from 'react' +import { memo } from 'react' +import { Comment } from '@/app/components/base/icons/src/public/other' +import { useStore } from '../store' +import { ControlMode } from '../types' + +export const CommentCursor: FC = memo(() => { + const controlMode = useStore(s => s.controlMode) + const mousePosition = useStore(s => s.mousePosition) + const isCommentPlacing = useStore(s => s.isCommentPlacing) + + if (controlMode !== ControlMode.Comment || isCommentPlacing) + return null + + return ( +
+ +
+ ) +}) + +CommentCursor.displayName = 'CommentCursor' diff --git a/web/app/components/workflow/comment/mention-input.spec.tsx b/web/app/components/workflow/comment/mention-input.spec.tsx new file mode 100644 index 0000000000..5ff211d1f2 --- /dev/null +++ b/web/app/components/workflow/comment/mention-input.spec.tsx @@ -0,0 +1,151 @@ +import type { UserProfile } from '@/service/workflow-comment' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import { useState } from 'react' +import { MentionInput } from './mention-input' + +const mockFetchMentionableUsers = vi.hoisted(() => vi.fn()) +const mockSetMentionableUsersLoading = vi.hoisted(() => vi.fn()) +const mockSetMentionableUsersCache = vi.hoisted(() => vi.fn()) + +const mentionStoreState = vi.hoisted(() => ({ + mentionableUsersCache: {} as Record, + mentionableUsersLoading: {} as Record, + setMentionableUsersLoading: (appId: string, loading: boolean) => { + mockSetMentionableUsersLoading(appId, loading) + mentionStoreState.mentionableUsersLoading[appId] = loading + }, + setMentionableUsersCache: (appId: string, users: UserProfile[]) => { + mockSetMentionableUsersCache(appId, users) + mentionStoreState.mentionableUsersCache[appId] = users + }, +})) + +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, options?: { ns?: string }) => options?.ns ? `${options.ns}.${key}` : key, + }), +})) + +vi.mock('@/next/navigation', () => ({ + useParams: () => ({ appId: 'app-1' }), +})) + +vi.mock('@/service/workflow-comment', () => ({ + fetchMentionableUsers: (...args: unknown[]) => mockFetchMentionableUsers(...args), +})) + +vi.mock('../store', () => ({ + useStore: (selector: (state: typeof mentionStoreState) => unknown) => selector(mentionStoreState), + useWorkflowStore: () => ({ + getState: () => mentionStoreState, + }), +})) + +vi.mock('@/app/components/base/ui/avatar', () => ({ + Avatar: ({ name }: { name: string }) =>
{name}
, +})) + +const mentionUsers: UserProfile[] = [ + { + id: 'user-2', + name: 'Alice', + email: 'alice@example.com', + avatar_url: 'alice.png', + }, + { + id: 'user-3', + name: 'Bob', + email: 'bob@example.com', + avatar_url: 'bob.png', + }, +] + +function ControlledMentionInput({ + onSubmit, +}: { + onSubmit: (content: string, mentionedUserIds: string[]) => void +}) { + const [value, setValue] = useState('') + return ( + + ) +} + +describe('MentionInput', () => { + beforeEach(() => { + vi.clearAllMocks() + mentionStoreState.mentionableUsersCache = {} + mentionStoreState.mentionableUsersLoading = {} + mockFetchMentionableUsers.mockResolvedValue(mentionUsers) + }) + + it('loads mentionable users when cache is empty', async () => { + render( + , + ) + + await waitFor(() => { + expect(mockFetchMentionableUsers).toHaveBeenCalledWith('app-1') + }) + + expect(mockSetMentionableUsersLoading).toHaveBeenCalledWith('app-1', true) + expect(mockSetMentionableUsersCache).toHaveBeenCalledWith('app-1', mentionUsers) + expect(mockSetMentionableUsersLoading).toHaveBeenCalledWith('app-1', false) + }) + + it('selects a mention and submits with mentioned user ids', async () => { + mentionStoreState.mentionableUsersCache['app-1'] = mentionUsers + const onSubmit = vi.fn() + + render() + + const textarea = screen.getByPlaceholderText('workflow.comments.placeholder.add') as HTMLTextAreaElement + textarea.focus() + textarea.setSelectionRange(4, 4) + fireEvent.change(textarea, { target: { value: '@Ali' } }) + + await waitFor(() => { + expect(screen.getByText('alice@example.com')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByText('alice@example.com')) + fireEvent.change(textarea, { target: { value: '@Alice hi' } }) + fireEvent.keyDown(textarea, { key: 'Enter' }) + + await waitFor(() => { + expect(onSubmit).toHaveBeenCalledWith('@Alice hi', ['user-2']) + }) + }) + + it('supports editing mode cancel and save actions', async () => { + mentionStoreState.mentionableUsersCache['app-1'] = mentionUsers + const onSubmit = vi.fn() + const onCancel = vi.fn() + + render( + , + ) + + fireEvent.click(screen.getByText('common.operation.cancel')) + expect(onCancel).toHaveBeenCalledTimes(1) + + fireEvent.click(screen.getByText('common.operation.save')) + await waitFor(() => { + expect(onSubmit).toHaveBeenCalledWith('updated reply', []) + }) + }) +}) diff --git a/web/app/components/workflow/comment/mention-input.tsx b/web/app/components/workflow/comment/mention-input.tsx new file mode 100644 index 0000000000..ac708b979d --- /dev/null +++ b/web/app/components/workflow/comment/mention-input.tsx @@ -0,0 +1,661 @@ +'use client' + +import type { ReactNode } from 'react' +import type { UserProfile } from '@/service/workflow-comment' +import { cn } from '@langgenius/dify-ui/cn' +import { RiArrowUpLine, RiAtLine, RiLoader2Line } from '@remixicon/react' +import { + forwardRef, + memo, + useCallback, + useEffect, + useImperativeHandle, + useLayoutEffect, + useMemo, + useRef, + useState, +} from 'react' +import { createPortal } from 'react-dom' +import { useTranslation } from 'react-i18next' +import Textarea from 'react-textarea-autosize' +import EnterKey from '@/app/components/base/icons/src/public/common/EnterKey' +import { Avatar } from '@/app/components/base/ui/avatar' +import { Button } from '@/app/components/base/ui/button' +import { useParams } from '@/next/navigation' +import { fetchMentionableUsers } from '@/service/workflow-comment' +import { useStore, useWorkflowStore } from '../store' + +type MentionInputProps = { + value: string + onChange: (value: string) => void + onSubmit: (content: string, mentionedUserIds: string[]) => void + onCancel?: () => void + placeholder?: string + disabled?: boolean + loading?: boolean + className?: string + isEditing?: boolean + autoFocus?: boolean +} + +const MentionInputInner = forwardRef(({ + value, + onChange, + onSubmit, + onCancel, + placeholder, + disabled = false, + loading = false, + className, + isEditing = false, + autoFocus = false, +}, forwardedRef) => { + const params = useParams() + const { t } = useTranslation() + const appId = params.appId as string + const textareaRef = useRef(null) + const highlightContentRef = useRef(null) + const actionContainerRef = useRef(null) + const actionRightRef = useRef(null) + const baseTextareaHeightRef = useRef(null) + + // Expose textarea ref to parent component + useImperativeHandle(forwardedRef, () => textareaRef.current!, []) + + const workflowStore = useWorkflowStore() + const mentionUsersFromStore = useStore(state => ( + appId ? state.mentionableUsersCache[appId] : undefined + )) + const mentionUsers = mentionUsersFromStore ?? [] + + const [showMentionDropdown, setShowMentionDropdown] = useState(false) + const [mentionQuery, setMentionQuery] = useState('') + const [mentionPosition, setMentionPosition] = useState(0) + const [selectedMentionIndex, setSelectedMentionIndex] = useState(0) + const [mentionedUserIds, setMentionedUserIds] = useState([]) + const resolvedPlaceholder = placeholder ?? t('comments.placeholder.add', { ns: 'workflow' }) + const BASE_PADDING = 4 + const [shouldReserveButtonGap, setShouldReserveButtonGap] = useState(isEditing) + const [shouldReserveHorizontalSpace, setShouldReserveHorizontalSpace] = useState(() => !isEditing) + const [paddingRight, setPaddingRight] = useState(() => BASE_PADDING + (isEditing ? 0 : 48)) + const [paddingBottom, setPaddingBottom] = useState(() => BASE_PADDING + (isEditing ? 32 : 0)) + + const mentionNameList = useMemo(() => { + const names = mentionUsers + .map(user => user.name?.trim()) + .filter((name): name is string => Boolean(name)) + + const uniqueNames = Array.from(new Set(names)) + uniqueNames.sort((a, b) => b.length - a.length) + return uniqueNames + }, [mentionUsers]) + + const highlightedValue = useMemo(() => { + if (!value) + return '' + + if (mentionNameList.length === 0) + return value + + const segments: ReactNode[] = [] + let cursor = 0 + let hasMention = false + + while (cursor < value.length) { + let nextMatchStart = -1 + let matchedName = '' + + for (const name of mentionNameList) { + const searchStart = value.indexOf(`@${name}`, cursor) + if (searchStart === -1) + continue + + const previousChar = searchStart > 0 ? value[searchStart - 1] : '' + if (searchStart > 0 && !/\s/.test(previousChar)) + continue + + if ( + nextMatchStart === -1 + || searchStart < nextMatchStart + || (searchStart === nextMatchStart && name.length > matchedName.length) + ) { + nextMatchStart = searchStart + matchedName = name + } + } + + if (nextMatchStart === -1) + break + + if (nextMatchStart > cursor) + segments.push({value.slice(cursor, nextMatchStart)}) + + const mentionEnd = nextMatchStart + matchedName.length + 1 + segments.push( + + {value.slice(nextMatchStart, mentionEnd)} + , + ) + + hasMention = true + cursor = mentionEnd + } + + if (!hasMention) + return value + + if (cursor < value.length) + segments.push({value.slice(cursor)}) + + return segments + }, [value, mentionNameList]) + + const loadMentionableUsers = useCallback(async () => { + if (!appId) + return + + const state = workflowStore.getState() + if (state.mentionableUsersCache[appId] !== undefined) + return + + if (state.mentionableUsersLoading[appId]) + return + + state.setMentionableUsersLoading(appId, true) + try { + const users = await fetchMentionableUsers(appId) + workflowStore.getState().setMentionableUsersCache(appId, users) + } + catch (error) { + console.error('Failed to load mentionable users:', error) + } + finally { + workflowStore.getState().setMentionableUsersLoading(appId, false) + } + }, [appId, workflowStore]) + + useEffect(() => { + loadMentionableUsers() + }, [loadMentionableUsers]) + const syncHighlightScroll = useCallback(() => { + const textarea = textareaRef.current + const highlightContent = highlightContentRef.current + if (!textarea || !highlightContent) + return + + const { scrollTop, scrollLeft } = textarea + highlightContent.style.transform = `translate(${-scrollLeft}px, ${-scrollTop}px)` + }, []) + + const evaluateContentLayout = useCallback(() => { + const textarea = textareaRef.current + if (!textarea) + return + + const extraBottom = Math.max(0, paddingBottom - BASE_PADDING) + const effectiveClientHeight = textarea.clientHeight - extraBottom + + if (baseTextareaHeightRef.current === null) + baseTextareaHeightRef.current = effectiveClientHeight + + const baseHeight = baseTextareaHeightRef.current ?? effectiveClientHeight + const hasMultiline = effectiveClientHeight > baseHeight + 1 + const shouldReserveVertical = isEditing ? true : hasMultiline + + setShouldReserveButtonGap(shouldReserveVertical) + setShouldReserveHorizontalSpace(!hasMultiline) + }, [isEditing, paddingBottom]) + + const updateLayoutPadding = useCallback(() => { + const actionEl = actionContainerRef.current + const rect = actionEl?.getBoundingClientRect() + const rightRect = actionRightRef.current?.getBoundingClientRect() + let actionWidth = 0 + if (rightRect) + actionWidth = Math.ceil(rightRect.width) + else if (rect) + actionWidth = Math.ceil(rect.width) + + const actionHeight = rect ? Math.ceil(rect.height) : 0 + const fallbackWidth = Math.max(0, paddingRight - BASE_PADDING) + const fallbackHeight = Math.max(0, paddingBottom - BASE_PADDING) + const effectiveWidth = actionWidth > 0 ? actionWidth : fallbackWidth + const effectiveHeight = actionHeight > 0 ? actionHeight : fallbackHeight + + const nextRight = BASE_PADDING + (shouldReserveHorizontalSpace ? effectiveWidth : 0) + const nextBottom = BASE_PADDING + (shouldReserveButtonGap ? effectiveHeight : 0) + + setPaddingRight(prev => (prev === nextRight ? prev : nextRight)) + setPaddingBottom(prev => (prev === nextBottom ? prev : nextBottom)) + }, [shouldReserveButtonGap, shouldReserveHorizontalSpace, paddingRight, paddingBottom]) + + const setActionContainerRef = useCallback((node: HTMLDivElement | null) => { + actionContainerRef.current = node + + if (!isEditing) + actionRightRef.current = node + else if (!node) + actionRightRef.current = null + + if (node && typeof window !== 'undefined') + window.requestAnimationFrame(() => updateLayoutPadding()) + }, [isEditing, updateLayoutPadding]) + + const setActionRightRef = useCallback((node: HTMLDivElement | null) => { + actionRightRef.current = node + + if (node && typeof window !== 'undefined') + window.requestAnimationFrame(() => updateLayoutPadding()) + }, [updateLayoutPadding]) + + useLayoutEffect(() => { + syncHighlightScroll() + }, [value, syncHighlightScroll]) + + useLayoutEffect(() => { + Promise.resolve().then(() => { + evaluateContentLayout() + }) + }, [value, evaluateContentLayout]) + + useLayoutEffect(() => { + Promise.resolve().then(() => { + updateLayoutPadding() + }) + }, [updateLayoutPadding, isEditing, shouldReserveButtonGap]) + + useEffect(() => { + const handleResize = () => { + evaluateContentLayout() + updateLayoutPadding() + } + + window.addEventListener('resize', handleResize) + return () => window.removeEventListener('resize', handleResize) + }, [evaluateContentLayout, updateLayoutPadding]) + + useEffect(() => { + Promise.resolve().then(() => { + baseTextareaHeightRef.current = null + evaluateContentLayout() + setShouldReserveHorizontalSpace(!isEditing) + }) + }, [isEditing, evaluateContentLayout]) + + const filteredMentionUsers = useMemo(() => { + if (!mentionQuery) + return mentionUsers + return mentionUsers.filter(user => + user.name.toLowerCase().includes(mentionQuery.toLowerCase()) + || user.email.toLowerCase().includes(mentionQuery.toLowerCase()), + ) + }, [mentionUsers, mentionQuery]) + + const shouldDisableMentionButton = useMemo(() => { + if (showMentionDropdown) + return true + + const textarea = textareaRef.current + if (!textarea) + return false + + const cursorPosition = textarea.selectionStart || 0 + const textBeforeCursor = value.slice(0, cursorPosition) + return /@\w*$/.test(textBeforeCursor) + }, [showMentionDropdown, value]) + + const dropdownPosition = useMemo(() => { + if (!showMentionDropdown || !textareaRef.current) + return { x: 0, y: 0, placement: 'bottom' as const } + + const textareaRect = textareaRef.current.getBoundingClientRect() + const dropdownHeight = 160 // max-h-40 = 10rem = 160px + const viewportHeight = window.innerHeight + const spaceBelow = viewportHeight - textareaRect.bottom + const spaceAbove = textareaRect.top + + const shouldPlaceAbove = spaceBelow < dropdownHeight && spaceAbove > spaceBelow + + return { + x: textareaRect.left, + y: shouldPlaceAbove ? textareaRect.top - 4 : textareaRect.bottom + 4, + placement: shouldPlaceAbove ? 'top' as const : 'bottom' as const, + } + }, [showMentionDropdown]) + + const handleContentChange = useCallback((newValue: string) => { + onChange(newValue) + + setTimeout(() => { + const cursorPosition = textareaRef.current?.selectionStart || 0 + const textBeforeCursor = newValue.slice(0, cursorPosition) + const mentionMatch = textBeforeCursor.match(/@(\w*)$/) + + if (mentionMatch) { + setMentionQuery(mentionMatch[1]) + setMentionPosition(cursorPosition - mentionMatch[0].length) + setShowMentionDropdown(true) + setSelectedMentionIndex(0) + } + else { + setShowMentionDropdown(false) + } + + if (typeof window !== 'undefined') { + window.requestAnimationFrame(() => { + evaluateContentLayout() + syncHighlightScroll() + }) + } + }, 0) + }, [onChange, evaluateContentLayout, syncHighlightScroll]) + + const handleMentionButtonClick = useCallback((e: React.MouseEvent) => { + e.preventDefault() + e.stopPropagation() + + const textarea = textareaRef.current + if (!textarea) + return + + const cursorPosition = textarea.selectionStart || 0 + const textBeforeCursor = value.slice(0, cursorPosition) + + if (showMentionDropdown) + return + + if (/@\w*$/.test(textBeforeCursor)) + return + + const newContent = `${value.slice(0, cursorPosition)}@${value.slice(cursorPosition)}` + + onChange(newContent) + + setTimeout(() => { + const newCursorPos = cursorPosition + 1 + textarea.setSelectionRange(newCursorPos, newCursorPos) + textarea.focus() + + setMentionQuery('') + setMentionPosition(cursorPosition) + setShowMentionDropdown(true) + setSelectedMentionIndex(0) + + if (typeof window !== 'undefined') { + window.requestAnimationFrame(() => { + evaluateContentLayout() + syncHighlightScroll() + }) + } + }, 0) + }, [value, onChange, evaluateContentLayout, syncHighlightScroll, showMentionDropdown]) + + const insertMention = useCallback((user: UserProfile) => { + const textarea = textareaRef.current + if (!textarea) + return + + const beforeMention = value.slice(0, mentionPosition) + const afterMention = value.slice(textarea.selectionStart || 0) + + const needsSpaceBefore = mentionPosition > 0 && !/\s/.test(value[mentionPosition - 1]) + const prefix = needsSpaceBefore ? ' ' : '' + const newContent = `${beforeMention}${prefix}@${user.name} ${afterMention}` + + onChange(newContent) + setShowMentionDropdown(false) + + const newMentionedUserIds = [...mentionedUserIds, user.id] + setMentionedUserIds(newMentionedUserIds) + + setTimeout(() => { + const extraSpace = needsSpaceBefore ? 1 : 0 + const newCursorPos = mentionPosition + extraSpace + user.name.length + 2 // (space) + @ + name + space + textarea.setSelectionRange(newCursorPos, newCursorPos) + textarea.focus() + if (typeof window !== 'undefined') { + window.requestAnimationFrame(() => { + evaluateContentLayout() + syncHighlightScroll() + }) + } + }, 0) + }, [value, mentionPosition, onChange, mentionedUserIds, evaluateContentLayout, syncHighlightScroll]) + + const handleSubmit = useCallback(async (e?: React.MouseEvent) => { + if (e) { + e.preventDefault() + e.stopPropagation() + } + + if (value.trim()) { + try { + await onSubmit(value.trim(), mentionedUserIds) + setMentionedUserIds([]) + setShowMentionDropdown(false) + } + catch (error) { + console.error('Failed to submit', error) + } + } + }, [value, mentionedUserIds, onSubmit]) + + const handleKeyDown = useCallback((e: React.KeyboardEvent) => { + // Ignore key events during IME composition (e.g., Chinese, Japanese input) + if (e.nativeEvent.isComposing) + return + + if (showMentionDropdown) { + if (e.key === 'ArrowDown') { + e.preventDefault() + setSelectedMentionIndex(prev => + prev < filteredMentionUsers.length - 1 ? prev + 1 : 0, + ) + } + else if (e.key === 'ArrowUp') { + e.preventDefault() + setSelectedMentionIndex(prev => + prev > 0 ? prev - 1 : filteredMentionUsers.length - 1, + ) + } + else if (e.key === 'Enter') { + e.preventDefault() + if (filteredMentionUsers[selectedMentionIndex]) + insertMention(filteredMentionUsers[selectedMentionIndex]) + + return + } + else if (e.key === 'Escape') { + e.preventDefault() + setShowMentionDropdown(false) + return + } + } + + if (e.key === 'Enter' && !e.shiftKey && !showMentionDropdown) { + e.preventDefault() + handleSubmit() + } + }, [showMentionDropdown, filteredMentionUsers, selectedMentionIndex, insertMention, handleSubmit]) + + const resetMentionState = useCallback(() => { + setMentionedUserIds([]) + setShowMentionDropdown(false) + setMentionQuery('') + setMentionPosition(0) + setSelectedMentionIndex(0) + }, []) + + useEffect(() => { + if (!value) { + Promise.resolve().then(() => { + resetMentionState() + }) + } + }, [value, resetMentionState]) + + useEffect(() => { + if (autoFocus && textareaRef.current) { + const textarea = textareaRef.current + setTimeout(() => { + textarea.focus() + const length = textarea.value.length + textarea.setSelectionRange(length, length) + }, 0) + } + }, [autoFocus]) + + return ( + <> +
+
+
+ {highlightedValue} + +
+
+