diff --git a/api/.env.example b/api/.env.example index b1ac15d25b..24f496a7b1 100644 --- a/api/.env.example +++ b/api/.env.example @@ -33,6 +33,9 @@ TRIGGER_URL=http://localhost:5001 # The time in seconds after the signature is rejected FILES_ACCESS_TIMEOUT=300 +# Collaboration mode toggle +ENABLE_COLLABORATION_MODE=false + # Access token expiration time in minutes ACCESS_TOKEN_EXPIRE_MINUTES=60 diff --git a/api/app.py b/api/app.py index 99f70f32d5..ee0f91de3b 100644 --- a/api/app.py +++ b/api/app.py @@ -1,3 +1,4 @@ +import os import sys @@ -8,10 +9,16 @@ def is_db_command() -> bool: # create app +celery = None +flask_app = None +socketio_app = None + if is_db_command(): from app_factory import create_migrations_app app = create_migrations_app() + socketio_app = app + flask_app = app else: # Gunicorn and Celery handle monkey patching automatically in production by # specifying the `gevent` worker class. Manual monkey patching is not required here. @@ -22,8 +29,15 @@ else: from app_factory import create_app - app = create_app() - celery = app.extensions["celery"] + socketio_app, flask_app = create_app() + app = flask_app + celery = flask_app.extensions["celery"] if __name__ == "__main__": - app.run(host="0.0.0.0", port=5001) + from gevent import pywsgi + from geventwebsocket.handler import WebSocketHandler + + host = os.environ.get("HOST", "0.0.0.0") + port = int(os.environ.get("PORT", 5001)) + server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler) + server.serve_forever() diff --git a/api/app_factory.py b/api/app_factory.py index 17c376de77..449e67971d 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -31,14 +31,22 @@ def create_flask_app_with_configs() -> DifyApp: return dify_app -def create_app() -> DifyApp: +def create_app() -> tuple[any, DifyApp]: start_time = time.perf_counter() app = create_flask_app_with_configs() initialize_extensions(app) + + import socketio + + from extensions.ext_socketio import sio + + sio.app = app + socketio_app = socketio.WSGIApp(sio, app) + end_time = time.perf_counter() if dify_config.DEBUG: logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2)) - return app + return socketio_app, app def initialize_extensions(app: DifyApp): diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index ff1f983f94..1ba3eb4ffb 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1150,6 +1150,13 @@ class PositionConfig(BaseSettings): return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} +class CollaborationConfig(BaseSettings): + ENABLE_COLLABORATION_MODE: bool = Field( + description="Whether to enable collaboration mode features across the workspace", + default=False, + ) + + class LoginConfig(BaseSettings): ENABLE_EMAIL_CODE_LOGIN: bool = Field( description="whether to enable email code login", @@ -1248,6 +1255,7 @@ class FeatureConfig( WorkflowConfig, WorkflowNodeExecutionConfig, WorkspaceConfig, + CollaborationConfig, LoginConfig, AccountConfig, SwaggerUIConfig, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index ad878fc266..b341b18778 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -58,11 +58,13 @@ from .app import ( mcp_server, message, model_config, + online_user, ops_trace, site, statistic, workflow, workflow_app_log, + workflow_comment, workflow_draft_variable, workflow_run, workflow_statistic, diff --git a/api/controllers/console/app/online_user.py b/api/controllers/console/app/online_user.py new file mode 100644 index 0000000000..ae829681ce --- /dev/null +++ b/api/controllers/console/app/online_user.py @@ -0,0 +1,339 @@ +import json +import time + +from werkzeug.wrappers import Request as WerkzeugRequest + +from extensions.ext_redis import redis_client +from extensions.ext_socketio import sio +from libs.passport import PassportService +from libs.token import extract_access_token +from services.account_service import AccountService + +SESSION_STATE_TTL_SECONDS = 3600 +WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:" +WORKFLOW_LEADER_PREFIX = "workflow_leader:" +WS_SID_MAP_PREFIX = "ws_sid_map:" + + +def _workflow_key(workflow_id: str) -> str: + return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}" + + +def _leader_key(workflow_id: str) -> str: + return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}" + + +def _sid_key(sid: str) -> str: + return f"{WS_SID_MAP_PREFIX}{sid}" + + +def _refresh_session_state(workflow_id: str, sid: str) -> None: + """ + Refresh TTLs for workflow + session keys so healthy sessions do not linger forever after crashes. + """ + workflow_key = _workflow_key(workflow_id) + sid_key = _sid_key(sid) + if redis_client.exists(workflow_key): + redis_client.expire(workflow_key, SESSION_STATE_TTL_SECONDS) + if redis_client.exists(sid_key): + redis_client.expire(sid_key, SESSION_STATE_TTL_SECONDS) + + +@sio.on("connect") +def socket_connect(sid, environ, auth): + """ + WebSocket connect event, do authentication here. + """ + token = None + if auth and isinstance(auth, dict): + token = auth.get("token") + + if not token: + try: + request_environ = WerkzeugRequest(environ) + token = extract_access_token(request_environ) + except Exception: + token = None + + if not token: + return False + + try: + decoded = PassportService().verify(token) + user_id = decoded.get("user_id") + if not user_id: + return False + + with sio.app.app_context(): + user = AccountService.load_logged_in_account(account_id=user_id) + if not user: + return False + + sio.save_session(sid, {"user_id": user.id, "username": user.name, "avatar": user.avatar}) + + return True + + except Exception: + return False + + +@sio.on("user_connect") +def handle_user_connect(sid, data): + """ + Handle user connect event. Each session (tab) is treated as an independent collaborator. + """ + + workflow_id = data.get("workflow_id") + if not workflow_id: + return {"msg": "workflow_id is required"}, 400 + + session = sio.get_session(sid) + user_id = session.get("user_id") + + if not user_id: + return {"msg": "unauthorized"}, 401 + + # Each session is stored independently with sid as key + session_info = { + "user_id": user_id, + "username": session.get("username", "Unknown"), + "avatar": session.get("avatar", None), + "sid": sid, + "connected_at": int(time.time()), # Add timestamp to differentiate tabs + } + + workflow_key = _workflow_key(workflow_id) + # Store session info with sid as key + redis_client.hset(workflow_key, sid, json.dumps(session_info)) + redis_client.set( + _sid_key(sid), + json.dumps({"workflow_id": workflow_id, "user_id": user_id}), + ex=SESSION_STATE_TTL_SECONDS, + ) + _refresh_session_state(workflow_id, sid) + + # Leader election: first session becomes the leader + leader_sid = get_or_set_leader(workflow_id, sid) + is_leader = leader_sid == sid + + sio.enter_room(sid, workflow_id) + broadcast_online_users(workflow_id) + + # Notify this session of their leader status + sio.emit("status", {"isLeader": is_leader}, room=sid) + + return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader} + + +@sio.on("disconnect") +def handle_disconnect(sid): + """ + Handle session disconnect event. Remove the specific session from online users. + """ + mapping = redis_client.get(_sid_key(sid)) + if mapping: + data = json.loads(mapping) + workflow_id = data["workflow_id"] + + # Remove this specific session + redis_client.hdel(_workflow_key(workflow_id), sid) + redis_client.delete(_sid_key(sid)) + + # Handle leader re-election if the leader session disconnected + handle_leader_disconnect(workflow_id, sid) + + broadcast_online_users(workflow_id) + + +def _clear_session_state(workflow_id: str, sid: str) -> None: + redis_client.hdel(_workflow_key(workflow_id), sid) + redis_client.delete(_sid_key(sid)) + + +def _is_session_active(workflow_id: str, sid: str) -> bool: + if not sid: + return False + + try: + if not sio.manager.is_connected(sid, "/"): + return False + except AttributeError: + return False + + if not redis_client.hexists(_workflow_key(workflow_id), sid): + return False + + if not redis_client.exists(_sid_key(sid)): + return False + + return True + + +def get_or_set_leader(workflow_id: str, sid: str) -> str: + """ + Get current leader session or set this session as leader if no valid leader exists. + Returns the leader session id (sid). + """ + raw_leader = redis_client.get(_leader_key(workflow_id)) + current_leader = raw_leader.decode("utf-8") if isinstance(raw_leader, bytes) else raw_leader + leader_replaced = False + + if current_leader and not _is_session_active(workflow_id, current_leader): + _clear_session_state(workflow_id, current_leader) + redis_client.delete(_leader_key(workflow_id)) + current_leader = None + leader_replaced = True + + if not current_leader: + redis_client.set(_leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS) # Expire in 1 hour + if leader_replaced: + broadcast_leader_change(workflow_id, sid) + return sid + + return current_leader + + +def handle_leader_disconnect(workflow_id, disconnected_sid): + """ + Handle leader re-election when a session disconnects. + If the disconnected session was the leader, elect a new leader from remaining sessions. + """ + current_leader = redis_client.get(_leader_key(workflow_id)) + + if current_leader: + current_leader = current_leader.decode("utf-8") if isinstance(current_leader, bytes) else current_leader + + if current_leader == disconnected_sid: + # Leader session disconnected, elect a new leader + sessions_json = redis_client.hgetall(_workflow_key(workflow_id)) + + if sessions_json: + # Get the first remaining session as new leader + new_leader_sid = list(sessions_json.keys())[0] + if isinstance(new_leader_sid, bytes): + new_leader_sid = new_leader_sid.decode("utf-8") + + redis_client.set(_leader_key(workflow_id), new_leader_sid, ex=SESSION_STATE_TTL_SECONDS) + + # Notify all sessions about the new leader + broadcast_leader_change(workflow_id, new_leader_sid) + else: + # No sessions left, remove leader + redis_client.delete(_leader_key(workflow_id)) + + +def broadcast_leader_change(workflow_id, new_leader_sid): + """ + Broadcast leader change to all sessions in the workflow. + """ + sessions_json = redis_client.hgetall(_workflow_key(workflow_id)) + + for sid, session_info_json in sessions_json.items(): + try: + sid_str = sid.decode("utf-8") if isinstance(sid, bytes) else sid + is_leader = sid_str == new_leader_sid + # Emit to each session whether they are the new leader + sio.emit("status", {"isLeader": is_leader}, room=sid_str) + except Exception: + continue + + +def get_current_leader(workflow_id): + """ + Get the current leader for a workflow. + """ + leader = redis_client.get(_leader_key(workflow_id)) + return leader.decode("utf-8") if leader and isinstance(leader, bytes) else leader + + +def broadcast_online_users(workflow_id): + """ + Broadcast online users to the workflow room. + Each session is shown as a separate user (even if same person has multiple tabs). + """ + sessions_json = redis_client.hgetall(_workflow_key(workflow_id)) + users = [] + + for sid, session_info_json in sessions_json.items(): + try: + session_info = json.loads(session_info_json) + # Each session appears as a separate "user" in the UI + users.append( + { + "user_id": session_info["user_id"], + "username": session_info["username"], + "avatar": session_info.get("avatar"), + "sid": session_info["sid"], + "connected_at": session_info.get("connected_at"), + } + ) + except Exception: + continue + + # Sort by connection time to maintain consistent order + users.sort(key=lambda x: x.get("connected_at") or 0) + + # Get current leader session + leader_sid = get_current_leader(workflow_id) + + sio.emit("online_users", {"workflow_id": workflow_id, "users": users, "leader": leader_sid}, room=workflow_id) + + +@sio.on("collaboration_event") +def handle_collaboration_event(sid, data): + """ + Handle general collaboration events, include: + 1. mouse_move + 2. vars_and_features_update + 3. sync_request (ask leader to update graph) + 4. app_state_update + 5. mcp_server_update + 6. workflow_update + 7. comments_update + 8. node_panel_presence + + """ + mapping = redis_client.get(_sid_key(sid)) + + if not mapping: + return {"msg": "unauthorized"}, 401 + + mapping_data = json.loads(mapping) + workflow_id = mapping_data["workflow_id"] + user_id = mapping_data["user_id"] + _refresh_session_state(workflow_id, sid) + + event_type = data.get("type") + event_data = data.get("data") + timestamp = data.get("timestamp", int(time.time())) + + if not event_type: + return {"msg": "invalid event type"}, 400 + + sio.emit( + "collaboration_update", + {"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp}, + room=workflow_id, + skip_sid=sid, + ) + + return {"msg": "event_broadcasted"} + + +@sio.on("graph_event") +def handle_graph_event(sid, data): + """ + Handle graph events - simple broadcast relay. + """ + mapping = redis_client.get(_sid_key(sid)) + + if not mapping: + return {"msg": "unauthorized"}, 401 + + mapping_data = json.loads(mapping) + workflow_id = mapping_data["workflow_id"] + _refresh_session_state(workflow_id, sid) + + sio.emit("graph_update", data, room=workflow_id, skip_sid=sid) + + return {"msg": "graph_update_broadcasted"} diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 8c451cd08c..bfd3edb41d 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -9,6 +9,7 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services +from configs import dify_config from controllers.console import api, console_ns from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model @@ -31,7 +32,9 @@ from core.trigger.debug.event_selectors import ( from core.workflow.enums import NodeType from core.workflow.graph_engine.manager import GraphEngineManager from extensions.ext_database import db +from extensions.ext_redis import redis_client from factories import file_factory, variable_factory +from fields.online_user_fields import online_user_list_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper @@ -144,6 +147,7 @@ class DraftWorkflowApi(Resource): .add_argument("hash", type=str, required=False, location="json") .add_argument("environment_variables", type=list, required=True, location="json") .add_argument("conversation_variables", type=list, required=False, location="json") + .add_argument("force_upload", type=bool, required=False, default=False, location="json") ) args = parser.parse_args() elif "text/plain" in content_type: @@ -161,6 +165,7 @@ class DraftWorkflowApi(Resource): "hash": data.get("hash"), "environment_variables": data.get("environment_variables"), "conversation_variables": data.get("conversation_variables"), + "force_upload": data.get("force_upload", False), } except json.JSONDecodeError: return {"message": "Invalid JSON data"}, 400 @@ -185,6 +190,7 @@ class DraftWorkflowApi(Resource): account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, + force_upload=args.get("force_upload", False), ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() @@ -756,6 +762,46 @@ class ConvertToWorkflowApi(Resource): } +@console_ns.route("/apps//workflows/draft/config") +class WorkflowConfigApi(Resource): + """Resource for workflow configuration.""" + + @api.doc("get_workflow_config") + @api.doc(description="Get workflow configuration") + @api.doc(params={"app_id": "Application ID"}) + @api.response(200, "Workflow configuration retrieved successfully") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def get(self, app_model: App): + return { + "parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT, + } + + +@console_ns.route("/apps//workflows/draft/features") +class WorkflowFeaturesApi(Resource): + """Update draft workflow features.""" + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + current_user, _ = current_account_with_tenant() + + parser = reqparse.RequestParser().add_argument("features", type=dict, required=True, location="json") + args = parser.parse_args() + + features = args.get("features") + + workflow_service = WorkflowService() + workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user) + + return {"result": "success"} + + @console_ns.route("/apps//workflows") class PublishedAllWorkflowApi(Resource): @api.doc("get_all_published_workflows") @@ -1168,3 +1214,30 @@ class DraftWorkflowTriggerRunAllApi(Resource): "status": "error", } ), 400 + + +@console_ns.route("/apps/workflows/online-users") +class WorkflowOnlineUsersApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(online_user_list_fields) + def get(self): + parser = reqparse.RequestParser().add_argument("workflow_ids", type=str, required=True, location="args") + args = parser.parse_args() + + workflow_ids = [workflow_id.strip() for workflow_id in args["workflow_ids"].split(",")] + + results = [] + for workflow_id in workflow_ids: + users_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}") + + users = [] + for _, user_info_json in users_json.items(): + try: + users.append(json.loads(user_info_json)) + except Exception: + continue + results.append({"workflow_id": workflow_id, "users": users}) + + return {"data": results} diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py new file mode 100644 index 0000000000..4e3a311de2 --- /dev/null +++ b/api/controllers/console/app/workflow_comment.py @@ -0,0 +1,240 @@ +import logging + +from flask_restx import Resource, fields, marshal_with, reqparse + +from controllers.console import api +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, setup_required +from fields.member_fields import account_with_role_fields +from fields.workflow_comment_fields import ( + workflow_comment_basic_fields, + workflow_comment_create_fields, + workflow_comment_detail_fields, + workflow_comment_reply_create_fields, + workflow_comment_reply_update_fields, + workflow_comment_resolve_fields, + workflow_comment_update_fields, +) +from libs.login import current_user, login_required +from models import App +from services.account_service import TenantService +from services.workflow_comment_service import WorkflowCommentService + +logger = logging.getLogger(__name__) + + +class WorkflowCommentListApi(Resource): + """API for listing and creating workflow comments.""" + + @login_required + @setup_required + @account_initialization_required + @get_app_model + @marshal_with(workflow_comment_basic_fields, envelope="data") + def get(self, app_model: App): + """Get all comments for a workflow.""" + comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id) + + return comments + + @login_required + @setup_required + @account_initialization_required + @get_app_model + @marshal_with(workflow_comment_create_fields) + def post(self, app_model: App): + """Create a new workflow comment.""" + parser = reqparse.RequestParser() + parser.add_argument("position_x", type=float, required=True, location="json") + parser.add_argument("position_y", type=float, required=True, location="json") + parser.add_argument("content", type=str, required=True, location="json") + parser.add_argument("mentioned_user_ids", type=list, location="json", default=[]) + args = parser.parse_args() + + result = WorkflowCommentService.create_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + created_by=current_user.id, + content=args.content, + position_x=args.position_x, + position_y=args.position_y, + mentioned_user_ids=args.mentioned_user_ids, + ) + + return result, 201 + + +class WorkflowCommentDetailApi(Resource): + """API for managing individual workflow comments.""" + + @login_required + @setup_required + @account_initialization_required + @get_app_model + @marshal_with(workflow_comment_detail_fields) + def get(self, app_model: App, comment_id: str): + """Get a specific workflow comment.""" + comment = WorkflowCommentService.get_comment( + tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id + ) + + return comment + + @login_required + @setup_required + @account_initialization_required + @get_app_model + @marshal_with(workflow_comment_update_fields) + def put(self, app_model: App, comment_id: str): + """Update a workflow comment.""" + parser = reqparse.RequestParser() + parser.add_argument("content", type=str, required=True, location="json") + parser.add_argument("position_x", type=float, required=False, location="json") + parser.add_argument("position_y", type=float, required=False, location="json") + parser.add_argument("mentioned_user_ids", type=list, location="json", default=[]) + args = parser.parse_args() + + result = WorkflowCommentService.update_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + content=args.content, + position_x=args.position_x, + position_y=args.position_y, + mentioned_user_ids=args.mentioned_user_ids, + ) + + return result + + @login_required + @setup_required + @account_initialization_required + @get_app_model + def delete(self, app_model: App, comment_id: str): + """Delete a workflow comment.""" + WorkflowCommentService.delete_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + ) + + return {"result": "success"}, 204 + + +class WorkflowCommentResolveApi(Resource): + """API for resolving and reopening workflow comments.""" + + @login_required + @setup_required + @account_initialization_required + @get_app_model + @marshal_with(workflow_comment_resolve_fields) + def post(self, app_model: App, comment_id: str): + """Resolve a workflow comment.""" + comment = WorkflowCommentService.resolve_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + ) + + return comment + + +class WorkflowCommentReplyApi(Resource): + """API for managing comment replies.""" + + @login_required + @setup_required + @account_initialization_required + @get_app_model + @marshal_with(workflow_comment_reply_create_fields) + def post(self, app_model: App, comment_id: str): + """Add a reply to a workflow comment.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + parser = reqparse.RequestParser() + parser.add_argument("content", type=str, required=True, location="json") + parser.add_argument("mentioned_user_ids", type=list, location="json", default=[]) + args = parser.parse_args() + + result = WorkflowCommentService.create_reply( + comment_id=comment_id, + content=args.content, + created_by=current_user.id, + mentioned_user_ids=args.mentioned_user_ids, + ) + + return result, 201 + + +class WorkflowCommentReplyDetailApi(Resource): + """API for managing individual comment replies.""" + + @login_required + @setup_required + @account_initialization_required + @get_app_model + @marshal_with(workflow_comment_reply_update_fields) + def put(self, app_model: App, comment_id: str, reply_id: str): + """Update a comment reply.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + parser = reqparse.RequestParser() + parser.add_argument("content", type=str, required=True, location="json") + parser.add_argument("mentioned_user_ids", type=list, location="json", default=[]) + args = parser.parse_args() + + reply = WorkflowCommentService.update_reply( + reply_id=reply_id, user_id=current_user.id, content=args.content, mentioned_user_ids=args.mentioned_user_ids + ) + + return reply + + @login_required + @setup_required + @account_initialization_required + @get_app_model + def delete(self, app_model: App, comment_id: str, reply_id: str): + """Delete a comment reply.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id) + + return {"result": "success"}, 204 + + +class WorkflowCommentMentionUsersApi(Resource): + """API for getting mentionable users for workflow comments.""" + + @login_required + @setup_required + @account_initialization_required + @get_app_model + @marshal_with({"users": fields.List(fields.Nested(account_with_role_fields))}) + def get(self, app_model: App): + """Get all users in current tenant for mentions.""" + members = TenantService.get_tenant_members(current_user.current_tenant) + return {"users": members} + + +# Register API routes +api.add_resource(WorkflowCommentListApi, "/apps//workflow/comments") +api.add_resource(WorkflowCommentDetailApi, "/apps//workflow/comments/") +api.add_resource(WorkflowCommentResolveApi, "/apps//workflow/comments//resolve") +api.add_resource(WorkflowCommentReplyApi, "/apps//workflow/comments//replies") +api.add_resource( + WorkflowCommentReplyDetailApi, "/apps//workflow/comments//replies/" +) +api.add_resource(WorkflowCommentMentionUsersApi, "/apps//workflow/comments/mention-users") diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 0722eb40d2..694f413d92 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -19,8 +19,8 @@ from core.variables.segments import ArrayFileSegment, FileSegment, Segment from core.variables.types import SegmentType from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db +from factories import variable_factory from factories.file_factory import build_from_mapping, build_from_mappings -from factories.variable_factory import build_segment_with_type from libs.login import current_user, login_required from models import Account, App, AppMode from models.workflow import WorkflowDraftVariable @@ -355,7 +355,7 @@ class VariableApi(Resource): if len(raw_value) > 0 and not isinstance(raw_value[0], dict): raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id) - new_value = build_segment_with_type(variable.value_type, raw_value) + new_value = variable_factory.build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() return variable @@ -448,8 +448,35 @@ class ConversationVariableCollectionApi(Resource): db.session.commit() return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=AppMode.ADVANCED_CHAT) + def post(self, app_model: App): + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("conversation_variables", type=list, required=True, location="json") + args = parser.parse_args() + + workflow_service = WorkflowService() + + conversation_variables_list = args.get("conversation_variables") or [] + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + + workflow_service.update_draft_workflow_conversation_variables( + app_model=app_model, + account=current_user, + conversation_variables=conversation_variables, + ) + + return {"result": "success"} + -@console_ns.route("/apps//workflows/draft/system-variables") class SystemVariableCollectionApi(Resource): @api.doc("get_system_variables") @api.doc(description="Get system variables for workflow") @@ -499,3 +526,44 @@ class EnvironmentVariableCollectionApi(Resource): ) return {"items": env_vars_list} + + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("environment_variables", type=list, required=True, location="json") + args = parser.parse_args() + + workflow_service = WorkflowService() + + environment_variables_list = args.get("environment_variables") or [] + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + + workflow_service.update_draft_workflow_environment_variables( + app_model=app_model, + account=current_user, + environment_variables=environment_variables, + ) + + return {"result": "success"} + + +api.add_resource( + WorkflowVariableCollectionApi, + "/apps//workflows/draft/variables", +) +api.add_resource(NodeVariableCollectionApi, "/apps//workflows/draft/nodes//variables") +api.add_resource(VariableApi, "/apps//workflows/draft/variables/") +api.add_resource(VariableResetApi, "/apps//workflows/draft/variables//reset") + +api.add_resource(ConversationVariableCollectionApi, "/apps//workflows/draft/conversation-variables") +api.add_resource(SystemVariableCollectionApi, "/apps//workflows/draft/system-variables") +api.add_resource(EnvironmentVariableCollectionApi, "/apps//workflows/draft/environment-variables") diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 499a52370f..3543f12c16 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -32,6 +32,7 @@ from controllers.console.wraps import ( only_edition_cloud, setup_required, ) +from core.file import helpers as file_helpers from extensions.ext_database import db from fields.member_fields import account_fields from libs.datetime_utils import naive_utc_now @@ -128,6 +129,17 @@ class AccountNameApi(Resource): @console_ns.route("/account/avatar") class AccountAvatarApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("avatar", type=str, required=True, location="args") + args = parser.parse_args() + + avatar_url = file_helpers.get_signed_file_url(args["avatar"]) + return {"avatar_url": avatar_url} + @setup_required @login_required @account_initialization_required diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index d289cde9e4..2f9eaf145f 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -289,7 +289,8 @@ class OracleVector(BaseVector): words = pseg.cut(query) current_entity = "" for word, pos in words: - if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名 + # nr: person name, ns: place name, nt: organization name + if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: current_entity += word else: if current_entity: diff --git a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py index 86b6ace3f6..282e83865d 100644 --- a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py +++ b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py @@ -213,7 +213,7 @@ class VastbaseVector(BaseVector): with self._get_cursor() as cur: cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension)) - # Vastbase 支持的向量维度取值范围为 [1,16000] + # Vastbase supports vector dimensions in range [1, 16000] if dimension <= 16000: cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) redis_client.set(collection_exist_cache_key, 1, ex=3600) diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 6313085e64..7c1edcd075 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -71,14 +71,16 @@ elif [[ "${MODE}" == "beat" ]]; then exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} else if [[ "${DEBUG}" == "true" ]]; then - exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug + export HOST=${DIFY_BIND_ADDRESS:-0.0.0.0} + export PORT=${DIFY_PORT:-5001} + exec python -m app else exec gunicorn \ --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ --workers ${SERVER_WORKER_AMOUNT:-1} \ - --worker-class ${SERVER_WORKER_CLASS:-gevent} \ + --worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \ --worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \ --timeout ${GUNICORN_TIMEOUT:-200} \ - app:app + app:socketio_app fi fi diff --git a/api/extensions/ext_socketio.py b/api/extensions/ext_socketio.py new file mode 100644 index 0000000000..d898f795b1 --- /dev/null +++ b/api/extensions/ext_socketio.py @@ -0,0 +1,3 @@ +import socketio + +sio = socketio.Server(async_mode="gevent", cors_allowed_origins="*") diff --git a/api/fields/online_user_fields.py b/api/fields/online_user_fields.py new file mode 100644 index 0000000000..8fe0dc6a64 --- /dev/null +++ b/api/fields/online_user_fields.py @@ -0,0 +1,17 @@ +from flask_restx import fields + +online_user_partial_fields = { + "user_id": fields.String, + "username": fields.String, + "avatar": fields.String, + "sid": fields.String, +} + +workflow_online_users_fields = { + "workflow_id": fields.String, + "users": fields.List(fields.Nested(online_user_partial_fields)), +} + +online_user_list_fields = { + "data": fields.List(fields.Nested(workflow_online_users_fields)), +} diff --git a/api/fields/workflow_comment_fields.py b/api/fields/workflow_comment_fields.py new file mode 100644 index 0000000000..c708dd3460 --- /dev/null +++ b/api/fields/workflow_comment_fields.py @@ -0,0 +1,96 @@ +from flask_restx import fields + +from libs.helper import AvatarUrlField, TimestampField + +# basic account fields for comments +account_fields = { + "id": fields.String, + "name": fields.String, + "email": fields.String, + "avatar_url": AvatarUrlField, +} + +# Comment mention fields +workflow_comment_mention_fields = { + "mentioned_user_id": fields.String, + "mentioned_user_account": fields.Nested(account_fields, allow_null=True), + "reply_id": fields.String, +} + +# Comment reply fields +workflow_comment_reply_fields = { + "id": fields.String, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, +} + +# Basic comment fields (for list views) +workflow_comment_basic_fields = { + "id": fields.String, + "position_x": fields.Float, + "position_y": fields.Float, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, + "updated_at": TimestampField, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, + "resolved_by_account": fields.Nested(account_fields, allow_null=True), + "reply_count": fields.Integer, + "mention_count": fields.Integer, + "participants": fields.List(fields.Nested(account_fields)), +} + +# Detailed comment fields (for single comment view) +workflow_comment_detail_fields = { + "id": fields.String, + "position_x": fields.Float, + "position_y": fields.Float, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, + "updated_at": TimestampField, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, + "resolved_by_account": fields.Nested(account_fields, allow_null=True), + "replies": fields.List(fields.Nested(workflow_comment_reply_fields)), + "mentions": fields.List(fields.Nested(workflow_comment_mention_fields)), +} + +# Comment creation response fields (simplified) +workflow_comment_create_fields = { + "id": fields.String, + "created_at": TimestampField, +} + +# Comment update response fields (simplified) +workflow_comment_update_fields = { + "id": fields.String, + "updated_at": TimestampField, +} + +# Comment resolve response fields +workflow_comment_resolve_fields = { + "id": fields.String, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, +} + +# Reply creation response fields (simplified) +workflow_comment_reply_create_fields = { + "id": fields.String, + "created_at": TimestampField, +} + +# Reply update response fields +workflow_comment_reply_update_fields = { + "id": fields.String, + "updated_at": TimestampField, +} diff --git a/api/migrations/versions/2025_09_18_1726-227822d22895_add_workflow_comments_table.py b/api/migrations/versions/2025_09_18_1726-227822d22895_add_workflow_comments_table.py new file mode 100644 index 0000000000..b9f91ca141 --- /dev/null +++ b/api/migrations/versions/2025_09_18_1726-227822d22895_add_workflow_comments_table.py @@ -0,0 +1,90 @@ +"""Add workflow comments table + +Revision ID: 227822d22895 +Revises: 68519ad5cd18 +Create Date: 2025-08-22 17:26:15.255980 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '227822d22895' +down_revision = '68519ad5cd18' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow_comments', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('position_x', sa.Float(), nullable=False), + sa.Column('position_y', sa.Float(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('resolved', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('resolved_at', sa.DateTime(), nullable=True), + sa.Column('resolved_by', models.types.StringUUID(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_comments_pkey') + ) + with op.batch_alter_table('workflow_comments', schema=None) as batch_op: + batch_op.create_index('workflow_comments_app_idx', ['tenant_id', 'app_id'], unique=False) + batch_op.create_index('workflow_comments_created_at_idx', ['created_at'], unique=False) + + op.create_table('workflow_comment_replies', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('comment_id', models.types.StringUUID(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_replies_comment_id_fkey'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name='workflow_comment_replies_pkey') + ) + with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op: + batch_op.create_index('comment_replies_comment_idx', ['comment_id'], unique=False) + batch_op.create_index('comment_replies_created_at_idx', ['created_at'], unique=False) + + op.create_table('workflow_comment_mentions', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('comment_id', models.types.StringUUID(), nullable=False), + sa.Column('reply_id', models.types.StringUUID(), nullable=True), + sa.Column('mentioned_user_id', models.types.StringUUID(), nullable=False), + sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_mentions_comment_id_fkey'), ondelete='CASCADE'), + sa.ForeignKeyConstraint(['reply_id'], ['workflow_comment_replies.id'], name=op.f('workflow_comment_mentions_reply_id_fkey'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name='workflow_comment_mentions_pkey') + ) + with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op: + batch_op.create_index('comment_mentions_comment_idx', ['comment_id'], unique=False) + batch_op.create_index('comment_mentions_reply_idx', ['reply_id'], unique=False) + batch_op.create_index('comment_mentions_user_idx', ['mentioned_user_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op: + batch_op.drop_index('comment_mentions_user_idx') + batch_op.drop_index('comment_mentions_reply_idx') + batch_op.drop_index('comment_mentions_comment_idx') + + op.drop_table('workflow_comment_mentions') + with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op: + batch_op.drop_index('comment_replies_created_at_idx') + batch_op.drop_index('comment_replies_comment_idx') + + op.drop_table('workflow_comment_replies') + with op.batch_alter_table('workflow_comments', schema=None) as batch_op: + batch_op.drop_index('workflow_comments_created_at_idx') + batch_op.drop_index('workflow_comments_app_idx') + + op.drop_table('workflow_comments') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 906bc3198e..b94f26fa11 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -9,6 +9,11 @@ from .account import ( TenantStatus, ) from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from .comment import ( + WorkflowComment, + WorkflowCommentMention, + WorkflowCommentReply, +) from .dataset import ( AppDatasetJoin, Dataset, @@ -195,6 +200,9 @@ __all__ = [ "Workflow", "WorkflowAppLog", "WorkflowAppLogCreatedFrom", + "WorkflowComment", + "WorkflowCommentMention", + "WorkflowCommentReply", "WorkflowNodeExecutionModel", "WorkflowNodeExecutionOffload", "WorkflowNodeExecutionTriggeredFrom", diff --git a/api/models/comment.py b/api/models/comment.py new file mode 100644 index 0000000000..059f974dc7 --- /dev/null +++ b/api/models/comment.py @@ -0,0 +1,189 @@ +"""Workflow comment models.""" + +from datetime import datetime +from typing import TYPE_CHECKING, Optional + +from sqlalchemy import Index, func +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .account import Account +from .base import Base +from .engine import db +from .types import StringUUID + +if TYPE_CHECKING: + pass + + +class WorkflowComment(Base): + """Workflow comment model for canvas commenting functionality. + + Comments are associated with apps rather than specific workflow versions, + since an app has only one draft workflow at a time and comments should persist + across workflow version changes. + + Attributes: + id: Comment ID + tenant_id: Workspace ID + app_id: App ID (primary association, comments belong to apps) + position_x: X coordinate on canvas + position_y: Y coordinate on canvas + content: Comment content + created_by: Creator account ID + created_at: Creation time + updated_at: Last update time + resolved: Whether comment is resolved + resolved_at: Resolution time + resolved_by: Resolver account ID + """ + + __tablename__ = "workflow_comments" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"), + Index("workflow_comments_app_idx", "tenant_id", "app_id"), + Index("workflow_comments_created_at_idx", "created_at"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + position_x: Mapped[float] = mapped_column(db.Float) + position_y: Mapped[float] = mapped_column(db.Float) + content: Mapped[str] = mapped_column(db.Text, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + resolved_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + resolved_by: Mapped[Optional[str]] = mapped_column(StringUUID) + + # Relationships + replies: Mapped[list["WorkflowCommentReply"]] = relationship( + "WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan" + ) + mentions: Mapped[list["WorkflowCommentMention"]] = relationship( + "WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan" + ) + + @property + def created_by_account(self): + """Get creator account.""" + return db.session.get(Account, self.created_by) + + @property + def resolved_by_account(self): + """Get resolver account.""" + if self.resolved_by: + return db.session.get(Account, self.resolved_by) + return None + + @property + def reply_count(self): + """Get reply count.""" + return len(self.replies) + + @property + def mention_count(self): + """Get mention count.""" + return len(self.mentions) + + @property + def participants(self): + """Get all participants (creator + repliers + mentioned users).""" + participant_ids = set() + + # Add comment creator + participant_ids.add(self.created_by) + + # Add reply creators + participant_ids.update(reply.created_by for reply in self.replies) + + # Add mentioned users + participant_ids.update(mention.mentioned_user_id for mention in self.mentions) + + # Get account objects + participants = [] + for user_id in participant_ids: + account = db.session.get(Account, user_id) + if account: + participants.append(account) + + return participants + + +class WorkflowCommentReply(Base): + """Workflow comment reply model. + + Attributes: + id: Reply ID + comment_id: Parent comment ID + content: Reply content + created_by: Creator account ID + created_at: Creation time + """ + + __tablename__ = "workflow_comment_replies" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"), + Index("comment_replies_comment_idx", "comment_id"), + Index("comment_replies_created_at_idx", "created_at"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + comment_id: Mapped[str] = mapped_column( + StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + ) + content: Mapped[str] = mapped_column(db.Text, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + # Relationships + comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies") + + @property + def created_by_account(self): + """Get creator account.""" + return db.session.get(Account, self.created_by) + + +class WorkflowCommentMention(Base): + """Workflow comment mention model. + + Mentions are only for internal accounts since end users + cannot access workflow canvas and commenting features. + + Attributes: + id: Mention ID + comment_id: Parent comment ID + mentioned_user_id: Mentioned account ID + """ + + __tablename__ = "workflow_comment_mentions" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"), + Index("comment_mentions_comment_idx", "comment_id"), + Index("comment_mentions_reply_idx", "reply_id"), + Index("comment_mentions_user_idx", "mentioned_user_id"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + comment_id: Mapped[str] = mapped_column( + StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + ) + reply_id: Mapped[Optional[str]] = mapped_column( + StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True + ) + mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + # Relationships + comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions") + reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply") + + @property + def mentioned_user_account(self): + """Get mentioned account.""" + return db.session.get(Account, self.mentioned_user_id) diff --git a/api/models/workflow.py b/api/models/workflow.py index 4eff16dda2..5420bc2a4d 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -392,7 +392,7 @@ class Workflow(Base): :return: hash """ - entity = {"graph": self.graph_dict, "features": self.features_dict} + entity = {"graph": self.graph_dict} return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) diff --git a/api/pyproject.toml b/api/pyproject.toml index 5d72b18204..70a6786b66 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "flask-orjson~=2.0.0", "flask-sqlalchemy~=3.1.1", "gevent~=25.9.1", + "gevent-websocket~=0.10.1", "gmpy2~=2.2.1", "google-api-core==2.18.0", "google-api-python-client==2.90.0", @@ -69,6 +70,7 @@ dependencies = [ "pypdfium2==4.30.0", "python-docx~=1.1.0", "python-dotenv==1.0.1", + "python-socketio~=5.13.0", "pyyaml~=6.0.1", "readabilipy~=0.3.0", "redis[hiredis]~=6.1.0", diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 44bea57769..a3a0930b00 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -152,6 +152,7 @@ class SystemFeatureModel(BaseModel): enable_email_code_login: bool = False enable_email_password_login: bool = True enable_social_oauth_login: bool = False + enable_collaboration_mode: bool = False is_allow_register: bool = False is_allow_create_workspace: bool = False is_email_setup: bool = False @@ -213,6 +214,7 @@ class FeatureService: system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN + system_features.enable_collaboration_mode = dify_config.ENABLE_COLLABORATION_MODE system_features.is_allow_register = dify_config.ALLOW_REGISTER system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" diff --git a/api/services/workflow_comment_service.py b/api/services/workflow_comment_service.py new file mode 100644 index 0000000000..4b5fbf7a05 --- /dev/null +++ b/api/services/workflow_comment_service.py @@ -0,0 +1,311 @@ +import logging +from typing import Optional + +from sqlalchemy import desc, select +from sqlalchemy.orm import Session, selectinload +from werkzeug.exceptions import Forbidden, NotFound + +from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now +from libs.helper import uuid_value +from models import WorkflowComment, WorkflowCommentMention, WorkflowCommentReply + +logger = logging.getLogger(__name__) + + +class WorkflowCommentService: + """Service for managing workflow comments.""" + + @staticmethod + def _validate_content(content: str) -> None: + if len(content.strip()) == 0: + raise ValueError("Comment content cannot be empty") + + if len(content) > 1000: + raise ValueError("Comment content cannot exceed 1000 characters") + + @staticmethod + def get_comments(tenant_id: str, app_id: str) -> list[WorkflowComment]: + """Get all comments for a workflow.""" + with Session(db.engine) as session: + # Get all comments with eager loading + stmt = ( + select(WorkflowComment) + .options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions)) + .where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id) + .order_by(desc(WorkflowComment.created_at)) + ) + + comments = session.scalars(stmt).all() + return comments + + @staticmethod + def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session = None) -> WorkflowComment: + """Get a specific comment.""" + + def _get_comment(session: Session) -> WorkflowComment: + stmt = ( + select(WorkflowComment) + .options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions)) + .where( + WorkflowComment.id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + ) + comment = session.scalar(stmt) + + if not comment: + raise NotFound("Comment not found") + + return comment + + if session is not None: + return _get_comment(session) + else: + with Session(db.engine, expire_on_commit=False) as session: + return _get_comment(session) + + @staticmethod + def create_comment( + tenant_id: str, + app_id: str, + created_by: str, + content: str, + position_x: float, + position_y: float, + mentioned_user_ids: Optional[list[str]] = None, + ) -> WorkflowComment: + """Create a new workflow comment.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine) as session: + comment = WorkflowComment( + tenant_id=tenant_id, + app_id=app_id, + position_x=position_x, + position_y=position_y, + content=content, + created_by=created_by, + ) + + session.add(comment) + session.flush() # Get the comment ID for mentions + + # Create mentions if specified + mentioned_user_ids = mentioned_user_ids or [] + for user_id in mentioned_user_ids: + if isinstance(user_id, str) and uuid_value(user_id): + mention = WorkflowCommentMention( + comment_id=comment.id, + reply_id=None, # This is a comment mention, not reply mention + mentioned_user_id=user_id, + ) + session.add(mention) + + session.commit() + + # Return only what we need - id and created_at + return {"id": comment.id, "created_at": comment.created_at} + + @staticmethod + def update_comment( + tenant_id: str, + app_id: str, + comment_id: str, + user_id: str, + content: str, + position_x: Optional[float] = None, + position_y: Optional[float] = None, + mentioned_user_ids: Optional[list[str]] = None, + ) -> dict: + """Update a workflow comment.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + # Get comment with validation + stmt = select(WorkflowComment).where( + WorkflowComment.id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + comment = session.scalar(stmt) + + if not comment: + raise NotFound("Comment not found") + + # Only the creator can update the comment + if comment.created_by != user_id: + raise Forbidden("Only the comment creator can update it") + + # Update comment fields + comment.content = content + if position_x is not None: + comment.position_x = position_x + if position_y is not None: + comment.position_y = position_y + + # Update mentions - first remove existing mentions for this comment only (not replies) + existing_mentions = session.scalars( + select(WorkflowCommentMention).where( + WorkflowCommentMention.comment_id == comment.id, + WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions + ) + ).all() + for mention in existing_mentions: + session.delete(mention) + + # Add new mentions + mentioned_user_ids = mentioned_user_ids or [] + for user_id_str in mentioned_user_ids: + if isinstance(user_id_str, str) and uuid_value(user_id_str): + mention = WorkflowCommentMention( + comment_id=comment.id, + reply_id=None, # This is a comment mention + mentioned_user_id=user_id_str, + ) + session.add(mention) + + session.commit() + + return {"id": comment.id, "updated_at": comment.updated_at} + + @staticmethod + def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None: + """Delete a workflow comment.""" + with Session(db.engine, expire_on_commit=False) as session: + comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session) + + # Only the creator can delete the comment + if comment.created_by != user_id: + raise Forbidden("Only the comment creator can delete it") + + # Delete associated mentions (both comment and reply mentions) + mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id) + ).all() + for mention in mentions: + session.delete(mention) + + # Delete associated replies + replies = session.scalars( + select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id) + ).all() + for reply in replies: + session.delete(reply) + + session.delete(comment) + session.commit() + + @staticmethod + def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment: + """Resolve a workflow comment.""" + with Session(db.engine, expire_on_commit=False) as session: + comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session) + if comment.resolved: + return comment + + comment.resolved = True + comment.resolved_at = naive_utc_now() + comment.resolved_by = user_id + session.commit() + + return comment + + @staticmethod + def create_reply( + comment_id: str, content: str, created_by: str, mentioned_user_ids: Optional[list[str]] = None + ) -> dict: + """Add a reply to a workflow comment.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + # Check if comment exists + comment = session.get(WorkflowComment, comment_id) + if not comment: + raise NotFound("Comment not found") + + reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by) + + session.add(reply) + session.flush() # Get the reply ID for mentions + + # Create mentions if specified + mentioned_user_ids = mentioned_user_ids or [] + for user_id in mentioned_user_ids: + if isinstance(user_id, str) and uuid_value(user_id): + # Create mention linking to specific reply + mention = WorkflowCommentMention( + comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id + ) + session.add(mention) + + session.commit() + + return {"id": reply.id, "created_at": reply.created_at} + + @staticmethod + def update_reply( + reply_id: str, user_id: str, content: str, mentioned_user_ids: Optional[list[str]] = None + ) -> WorkflowCommentReply: + """Update a comment reply.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + reply = session.get(WorkflowCommentReply, reply_id) + if not reply: + raise NotFound("Reply not found") + + # Only the creator can update the reply + if reply.created_by != user_id: + raise Forbidden("Only the reply creator can update it") + + reply.content = content + + # Update mentions - first remove existing mentions for this reply + existing_mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id) + ).all() + for mention in existing_mentions: + session.delete(mention) + + # Add mentions + mentioned_user_ids = mentioned_user_ids or [] + for user_id_str in mentioned_user_ids: + if isinstance(user_id_str, str) and uuid_value(user_id_str): + mention = WorkflowCommentMention( + comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str + ) + session.add(mention) + + session.commit() + session.refresh(reply) # Refresh to get updated timestamp + + return {"id": reply.id, "updated_at": reply.updated_at} + + @staticmethod + def delete_reply(reply_id: str, user_id: str) -> None: + """Delete a comment reply.""" + with Session(db.engine, expire_on_commit=False) as session: + reply = session.get(WorkflowCommentReply, reply_id) + if not reply: + raise NotFound("Reply not found") + + # Only the creator can delete the reply + if reply.created_by != user_id: + raise Forbidden("Only the reply creator can delete it") + + # Delete associated mentions first + mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id) + ).all() + for mention in mentions: + session.delete(mention) + + session.delete(reply) + session.commit() + + @staticmethod + def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment: + """Validate that a comment belongs to the specified tenant and app.""" + return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index b6d64d95da..ae9cb5825d 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -200,15 +200,17 @@ class WorkflowService: account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], + force_upload: bool = False, ) -> Workflow: """ Sync draft workflow + :param force_upload: Skip hash validation when True (for restore operations) :raises WorkflowHashNotEqualError """ # fetch draft workflow by app_model workflow = self.get_draft_workflow(app_model=app_model) - if workflow and workflow.unique_hash != unique_hash: + if workflow and workflow.unique_hash != unique_hash and not force_upload: raise WorkflowHashNotEqualError() # validate features structure @@ -249,6 +251,78 @@ class WorkflowService: # return draft workflow return workflow + def update_draft_workflow_environment_variables( + self, + *, + app_model: App, + environment_variables: Sequence[Variable], + account: Account, + ): + """ + Update draft workflow environment variables + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + workflow.environment_variables = environment_variables + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + + def update_draft_workflow_conversation_variables( + self, + *, + app_model: App, + conversation_variables: Sequence[Variable], + account: Account, + ): + """ + Update draft workflow conversation variables + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + workflow.conversation_variables = conversation_variables + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + + def update_draft_workflow_features( + self, + *, + app_model: App, + features: dict, + account: Account, + ): + """ + Update draft workflow features + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + # validate features structure + self.validate_features_structure(app_model=app_model, features=features) + + workflow.features = json.dumps(features) + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + def publish_workflow( self, *, diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py index 40380b09d2..8dca879e27 100644 --- a/api/tests/test_containers_integration_tests/services/test_feature_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -268,6 +268,7 @@ class TestFeatureService: mock_config.ENABLE_EMAIL_CODE_LOGIN = True mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ENABLE_COLLABORATION_MODE = True mock_config.ALLOW_REGISTER = False mock_config.ALLOW_CREATE_WORKSPACE = False mock_config.MAIL_TYPE = "smtp" @@ -292,6 +293,7 @@ class TestFeatureService: # Verify authentication settings assert result.enable_email_code_login is True assert result.enable_email_password_login is False + assert result.enable_collaboration_mode is True assert result.is_allow_register is False assert result.is_allow_create_workspace is False @@ -341,6 +343,7 @@ class TestFeatureService: mock_config.ENABLE_EMAIL_CODE_LOGIN = True mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ENABLE_COLLABORATION_MODE = False mock_config.ALLOW_REGISTER = True mock_config.ALLOW_CREATE_WORKSPACE = True mock_config.MAIL_TYPE = "smtp" @@ -362,6 +365,7 @@ class TestFeatureService: assert result.enable_email_code_login is True assert result.enable_email_password_login is True assert result.enable_social_oauth_login is False + assert result.enable_collaboration_mode is False assert result.is_allow_register is True assert result.is_allow_create_workspace is True assert result.is_email_setup is True diff --git a/api/uv.lock b/api/uv.lock index db4827e143..5ba0eda03e 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -562,6 +562,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/57/f4/a69c20ee4f660081a7dedb1ac57f29be9378e04edfcb90c526b923d4bebc/beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a", size = 142979, upload-time = "2023-04-07T15:02:50.77Z" }, ] +[[package]] +name = "bidict" +version = "0.23.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/6e/026678aa5a830e07cd9498a05d3e7e650a4f56a42f267a53d22bcda1bdc9/bidict-0.23.1.tar.gz", hash = "sha256:03069d763bc387bbd20e7d49914e75fc4132a41937fa3405417e1a5a2d006d71", size = 29093, upload-time = "2024-02-18T19:09:05.748Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/37/e8730c3587a65eb5645d4aba2d27aae48e8003614d6aaf15dda67f702f1f/bidict-0.23.1-py3-none-any.whl", hash = "sha256:5dae8d4d79b552a71cbabc7deb25dfe8ce710b17ff41711e13010ead2abfc3e5", size = 32764, upload-time = "2024-02-18T19:09:04.156Z" }, +] + [[package]] name = "billiard" version = "4.2.2" @@ -1322,6 +1331,7 @@ dependencies = [ { name = "flask-restx" }, { name = "flask-sqlalchemy" }, { name = "gevent" }, + { name = "gevent-websocket" }, { name = "gmpy2" }, { name = "google-api-core" }, { name = "google-api-python-client" }, @@ -1370,6 +1380,7 @@ dependencies = [ { name = "pypdfium2" }, { name = "python-docx" }, { name = "python-dotenv" }, + { name = "python-socketio" }, { name = "pyyaml" }, { name = "readabilipy" }, { name = "redis", extra = ["hiredis"] }, @@ -1516,6 +1527,7 @@ requires-dist = [ { name = "flask-restx", specifier = "~=1.3.0" }, { name = "flask-sqlalchemy", specifier = "~=3.1.1" }, { name = "gevent", specifier = "~=25.9.1" }, + { name = "gevent-websocket", specifier = "~=0.10.1" }, { name = "gmpy2", specifier = "~=2.2.1" }, { name = "google-api-core", specifier = "==2.18.0" }, { name = "google-api-python-client", specifier = "==2.90.0" }, @@ -1564,6 +1576,7 @@ requires-dist = [ { name = "pypdfium2", specifier = "==4.30.0" }, { name = "python-docx", specifier = "~=1.1.0" }, { name = "python-dotenv", specifier = "==1.0.1" }, + { name = "python-socketio", specifier = "~=5.13.0" }, { name = "pyyaml", specifier = "~=6.0.1" }, { name = "readabilipy", specifier = "~=0.3.0" }, { name = "redis", extras = ["hiredis"], specifier = "~=6.1.0" }, @@ -2119,6 +2132,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/98/caf06d5d22a7c129c1fb2fc1477306902a2c8ddfd399cd26bbbd4caf2141/gevent-25.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:4acd6bcd5feabf22c7c5174bd3b9535ee9f088d2bbce789f740ad8d6554b18f3", size = 1682837, upload-time = "2025-09-17T19:48:47.318Z" }, ] +[[package]] +name = "gevent-websocket" +version = "0.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gevent" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/d2/6fa19239ff1ab072af40ebf339acd91fb97f34617c2ee625b8e34bf42393/gevent-websocket-0.10.1.tar.gz", hash = "sha256:7eaef32968290c9121f7c35b973e2cc302ffb076d018c9068d2f5ca8b2d85fb0", size = 18366, upload-time = "2017-03-12T22:46:05.68Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/84/2dc373eb6493e00c884cc11e6c059ec97abae2678d42f06bf780570b0193/gevent_websocket-0.10.1-py3-none-any.whl", hash = "sha256:17b67d91282f8f4c973eba0551183fc84f56f1c90c8f6b6b30256f31f66f5242", size = 22987, upload-time = "2017-03-12T22:46:03.611Z" }, +] + [[package]] name = "gitdb" version = "4.0.12" @@ -5078,6 +5103,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863, upload-time = "2024-01-23T06:32:58.246Z" }, ] +[[package]] +name = "python-engineio" +version = "4.12.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "simple-websocket" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/d8/63e5535ab21dc4998ba1cfe13690ccf122883a38f025dca24d6e56c05eba/python_engineio-4.12.3.tar.gz", hash = "sha256:35633e55ec30915e7fc8f7e34ca8d73ee0c080cec8a8cd04faf2d7396f0a7a7a", size = 91910, upload-time = "2025-09-28T06:31:36.765Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d8/f0/c5aa0a69fd9326f013110653543f36ece4913c17921f3e1dbd78e1b423ee/python_engineio-4.12.3-py3-none-any.whl", hash = "sha256:7c099abb2a27ea7ab429c04da86ab2d82698cdd6c52406cb73766fe454feb7e1", size = 59637, upload-time = "2025-09-28T06:31:35.354Z" }, +] + [[package]] name = "python-http-client" version = "3.3.7" @@ -5134,6 +5171,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/4f/00be2196329ebbff56ce564aa94efb0fbc828d00de250b1980de1a34ab49/python_pptx-1.0.2-py3-none-any.whl", hash = "sha256:160838e0b8565a8b1f67947675886e9fea18aa5e795db7ae531606d68e785cba", size = 472788, upload-time = "2024-08-07T17:33:28.192Z" }, ] +[[package]] +name = "python-socketio" +version = "5.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "bidict" }, + { name = "python-engineio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/21/1a/396d50ccf06ee539fa758ce5623b59a9cb27637fc4b2dc07ed08bf495e77/python_socketio-5.13.0.tar.gz", hash = "sha256:ac4e19a0302ae812e23b712ec8b6427ca0521f7c582d6abb096e36e24a263029", size = 121125, upload-time = "2025-04-12T15:46:59.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/32/b4fb8585d1be0f68bde7e110dffbcf354915f77ad8c778563f0ad9655c02/python_socketio-5.13.0-py3-none-any.whl", hash = "sha256:51f68d6499f2df8524668c24bcec13ba1414117cfb3a90115c559b601ab10caf", size = 77800, upload-time = "2025-04-12T15:46:58.412Z" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -5639,6 +5689,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, ] +[[package]] +name = "simple-websocket" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wsproto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b0/d4/bfa032f961103eba93de583b161f0e6a5b63cebb8f2c7d0c6e6efe1e3d2e/simple_websocket-1.1.0.tar.gz", hash = "sha256:7939234e7aa067c534abdab3a9ed933ec9ce4691b0713c78acb195560aa52ae4", size = 17300, upload-time = "2024-10-10T22:39:31.412Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/59/0782e51887ac6b07ffd1570e0364cf901ebc36345fea669969d2084baebb/simple_websocket-1.1.0-py3-none-any.whl", hash = "sha256:4af6069630a38ed6c561010f0e11a5bc0d4ca569b36306eb257cd9a192497c8c", size = 13842, upload-time = "2024-10-10T22:39:29.645Z" }, +] + [[package]] name = "six" version = "1.17.0" @@ -7038,6 +7100,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/f6/a933bd70f98e9cf3e08167fc5cd7aaaca49147e48411c0bd5ae701bb2194/wrapt-1.17.3-py3-none-any.whl", hash = "sha256:7171ae35d2c33d326ac19dd8facb1e82e5fd04ef8c6c0e394d7af55a55051c22", size = 23591, upload-time = "2025-08-12T05:53:20.674Z" }, ] +[[package]] +name = "wsproto" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/4a/44d3c295350d776427904d73c189e10aeae66d7f555bb2feee16d1e4ba5a/wsproto-1.2.0.tar.gz", hash = "sha256:ad565f26ecb92588a3e43bc3d96164de84cd9902482b130d0ddbaa9664a85065", size = 53425, upload-time = "2022-08-23T19:58:21.447Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/58/e860788190eba3bcce367f74d29c4675466ce8dddfba85f7827588416f01/wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736", size = 24226, upload-time = "2022-08-23T19:58:19.96Z" }, +] + [[package]] name = "xinference-client" version = "1.2.2" diff --git a/docker/.env.example b/docker/.env.example index 519f4aa3e0..f39b85db76 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -127,6 +127,10 @@ MIGRATION_ENABLED=true # The default value is 300 seconds. FILES_ACCESS_TIMEOUT=300 +# Collaboration mode toggle +# To open collaboration features, you also need to set SERVER_WORKER_CLASS=geventwebsocket.gunicorn.workers.GeventWebSocketWorker +ENABLE_COLLABORATION_MODE=false + # Access token expiration time in minutes ACCESS_TOKEN_EXPIRE_MINUTES=60 @@ -160,6 +164,7 @@ SERVER_WORKER_AMOUNT=1 # Modifying it may also decrease throughput. # # It is strongly discouraged to change this parameter. +# If enable collaboration mode, it must be set to geventwebsocket.gunicorn.workers.GeventWebSocketWorker SERVER_WORKER_CLASS=gevent # Default number of worker connections, the default is 10. diff --git a/docker/nginx/conf.d/default.conf.template b/docker/nginx/conf.d/default.conf.template index 1d63c1b97d..94a748290f 100644 --- a/docker/nginx/conf.d/default.conf.template +++ b/docker/nginx/conf.d/default.conf.template @@ -14,6 +14,14 @@ server { include proxy.conf; } + location /socket.io/ { + proxy_pass http://api:5001; + include proxy.conf; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_cache_bypass $http_upgrade; + } + location /v1 { proxy_pass http://api:5001; include proxy.conf; diff --git a/docker/nginx/proxy.conf.template b/docker/nginx/proxy.conf.template index 117f806146..3c39e507ff 100644 --- a/docker/nginx/proxy.conf.template +++ b/docker/nginx/proxy.conf.template @@ -5,7 +5,7 @@ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; proxy_set_header X-Forwarded-Port $server_port; proxy_http_version 1.1; -proxy_set_header Connection ""; +# proxy_set_header Connection ""; proxy_buffering off; proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT}; proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT}; diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx index 7e592729a5..b7292b0acb 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import React, { useMemo } from 'react' +import React, { useEffect, useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import AppCard from '@/app/components/app/overview/app-card' @@ -24,6 +24,8 @@ import { useStore as useAppStore } from '@/app/components/app/store' import { useAppWorkflow } from '@/service/use-workflow' import type { BlockEnum } from '@/app/components/workflow/types' import { isTriggerNode } from '@/app/components/workflow/types' +import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager' +import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' export type ICardViewProps = { appId: string @@ -63,15 +65,44 @@ const CardView: FC = ({ appId, isInPanel, className }) => { message ||= (type === 'success' ? 'modifiedSuccessfully' : 'modifiedUnsuccessfully') - if (type === 'success') + if (type === 'success') { updateAppDetail() + // Emit collaboration event to notify other clients of app state changes + const socket = webSocketClient.getSocket(appId) + if (socket) { + socket.emit('collaboration_event', { + type: 'app_state_update', + data: { timestamp: Date.now() }, + timestamp: Date.now(), + }) + } + } + notify({ type, message: t(`common.actionMsg.${message}`), }) } + // Listen for collaborative app state updates from other clients + useEffect(() => { + if (!appId) return + + const unsubscribe = collaborationManager.onAppStateUpdate(async (update: any) => { + try { + console.log('Received app state update from collaboration:', update) + // Update app detail when other clients modify app state + await updateAppDetail() + } + catch (error) { + console.error('app state update failed:', error) + } + }) + + return unsubscribe + }, [appId]) + const onChangeSiteStatus = async (value: boolean) => { const [err] = await asyncRunSafe( updateAppSiteStatus({ diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index c2bda8d8fc..4384673f02 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -1,7 +1,7 @@ import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' import { useContext } from 'use-context-selector' -import React, { useCallback, useState } from 'react' +import React, { useCallback, useEffect, useState } from 'react' import { RiDeleteBinLine, RiEditLine, @@ -16,7 +16,7 @@ import { useStore as useAppStore } from '@/app/components/app/store' import { ToastContext } from '@/app/components/base/toast' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' -import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' +import { copyApp, deleteApp, exportAppConfig, fetchAppDetail, updateAppInfo } from '@/service/apps' import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' @@ -31,6 +31,8 @@ import AppOperations from './app-operations' import dynamic from 'next/dynamic' import cn from '@/utils/classnames' import { AppModeEnum } from '@/types/app' +import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager' +import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { ssr: false, @@ -74,6 +76,19 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx const [secretEnvList, setSecretEnvList] = useState([]) const [showExportWarning, setShowExportWarning] = useState(false) + const emitAppMetaUpdate = useCallback(() => { + if (!appDetail?.id) + return + const socket = webSocketClient.getSocket(appDetail.id) + if (socket) { + socket.emit('collaboration_event', { + type: 'app_meta_update', + data: { timestamp: Date.now() }, + timestamp: Date.now(), + }) + } + }, [appDetail?.id]) + const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, icon_type, @@ -102,11 +117,12 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx message: t('app.editDone'), }) setAppDetail(app) + emitAppMetaUpdate() } catch { notify({ type: 'error', message: t('app.editFailed') }) } - }, [appDetail, notify, setAppDetail, t]) + }, [appDetail, notify, setAppDetail, t, emitAppMetaUpdate]) const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => { if (!appDetail) @@ -203,6 +219,23 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx setShowConfirmDelete(false) }, [appDetail, notify, onPlanInfoChanged, replace, setAppDetail, t]) + useEffect(() => { + if (!appDetail?.id) + return + + const unsubscribe = collaborationManager.onAppMetaUpdate(async () => { + try { + const res = await fetchAppDetail({ url: '/apps', id: appDetail.id }) + setAppDetail({ ...res }) + } + catch (error) { + console.error('failed to refresh app detail from collaboration update:', error) + } + }) + + return unsubscribe + }, [appDetail?.id, setAppDetail]) + const { isCurrentWorkspaceEditor } = useAppContext() if (!appDetail) diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 64ce869c5d..d153d76c68 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -42,6 +42,9 @@ import type { InputVar } from '@/app/components/workflow/types' import { appDefaultIconBackground } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' +import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager' +import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' +import { useInvalidateAppWorkflow } from '@/service/use-workflow' import { AccessMode } from '@/models/access-control' import { useAppWhiteListSubjects, useGetUserCanAccessApp } from '@/service/access-control' import { fetchAppDetailDirect } from '@/service/apps' @@ -148,6 +151,7 @@ const AppPublisher = ({ const { data: userCanAccessApp, isLoading: isGettingUserCanAccessApp, refetch } = useGetUserCanAccessApp({ appId: appDetail?.id, enabled: false }) const { data: appAccessSubjects, isLoading: isGettingAppWhiteListSubjects } = useAppWhiteListSubjects(appDetail?.id, open && systemFeatures.webapp_auth.enabled && appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS) + const invalidateAppWorkflow = useInvalidateAppWorkflow() const noAccessPermission = useMemo(() => systemFeatures.webapp_auth.enabled && appDetail && appDetail.access_mode !== AccessMode.EXTERNAL_MEMBERS && !userCanAccessApp?.result, [systemFeatures, appDetail, userCanAccessApp]) const disabledFunctionButton = useMemo(() => (!publishedAt || missingStartNode || noAccessPermission), [publishedAt, missingStartNode, noAccessPermission]) @@ -182,11 +186,27 @@ const AppPublisher = ({ try { await onPublish?.(params) setPublished(true) + + const appId = appDetail?.id + const socket = appId ? webSocketClient.getSocket(appId) : null + if (appId) + invalidateAppWorkflow(appId) + if (socket) { + const timestamp = Date.now() + socket.emit('collaboration_event', { + type: 'app_publish_update', + data: { + action: 'published', + timestamp, + }, + timestamp, + }) + } } catch { setPublished(false) } - }, [onPublish]) + }, [appDetail?.id, onPublish, invalidateAppWorkflow]) const handleRestore = useCallback(async () => { try { @@ -243,6 +263,18 @@ const AppPublisher = ({ handlePublish() }, { exactMatch: true, useCapture: true }) + useEffect(() => { + const appId = appDetail?.id + if (!appId) return + + const unsubscribe = collaborationManager.onAppPublishUpdate((update: any) => { + if (update?.data?.action === 'published') + invalidateAppWorkflow(appId) + }) + + return unsubscribe + }, [appDetail?.id, invalidateAppWorkflow]) + const hasPublishedVersion = !!publishedAt const workflowToolDisabled = !hasPublishedVersion || !workflowToolAvailable const workflowToolMessage = workflowToolDisabled ? t('workflow.common.workflowAsToolDisabledHint') : undefined diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index 564eb493e5..31c196f0df 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -32,6 +32,8 @@ import { useGlobalPublicStore } from '@/context/global-public-context' import { formatTime } from '@/utils/time' import { useGetUserCanAccessApp } from '@/service/access-control' import dynamic from 'next/dynamic' +import { UserAvatarList } from '@/app/components/base/user-avatar-list' +import type { WorkflowOnlineUser } from '@/models/app' const EditAppModal = dynamic(() => import('@/app/components/explore/create-app-modal'), { ssr: false, @@ -55,9 +57,10 @@ const AccessControl = dynamic(() => import('@/app/components/app/app-access-cont export type AppCardProps = { app: App onRefresh?: () => void + onlineUsers?: WorkflowOnlineUser[] } -const AppCard = ({ app, onRefresh }: AppCardProps) => { +const AppCard = ({ app, onRefresh, onlineUsers = [] }: AppCardProps) => { const { t } = useTranslation() const { notify } = useContext(ToastContext) const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) @@ -333,6 +336,19 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { return `${t('datasetDocuments.segment.editedAt')} ${timeText}` }, [app.updated_at, app.created_at]) + const onlineUserAvatars = useMemo(() => { + if (!onlineUsers.length) + return [] + + return onlineUsers + .map(user => ({ + id: user.user_id || user.sid || '', + name: user.username || 'User', + avatar_url: user.avatar || undefined, + })) + .filter(user => !!user.id) + }, [onlineUsers]) + return ( <>
{ }
+
+ {onlineUserAvatars.length > 0 && ( + + )} +
{ }, ) + const apps = useMemo(() => data?.flatMap(page => page.data) ?? [], [data]) + + const workflowIds = useMemo(() => { + const ids = new Set() + apps.forEach((appItem) => { + const workflowId = appItem.id + if (!workflowId) + return + + if (appItem.mode === 'workflow' || appItem.mode === 'advanced-chat') + ids.add(workflowId) + }) + return Array.from(ids) + }, [apps]) + + const { data: onlineUsersByWorkflow, mutate: refreshOnlineUsers } = useSWR>( + workflowIds.length ? { workflowIds } : null, + fetchWorkflowOnlineUsers, + ) + + useEffect(() => { + const timer = window.setInterval(() => { + mutate() + if (workflowIds.length) + refreshOnlineUsers() + }, 10000) + + return () => window.clearInterval(timer) + }, [workflowIds.join(','), mutate, refreshOnlineUsers]) + const anchorRef = useRef(null) const options = [ { value: 'all', text: t('app.types.all'), icon: }, @@ -222,7 +253,12 @@ const List = () => { {isCurrentWorkspaceEditor && } {data.map(({ data: apps }) => apps.map(app => ( - + )))}
:
diff --git a/web/app/components/base/avatar/index.tsx b/web/app/components/base/avatar/index.tsx index 89019a19b0..8226320c2a 100644 --- a/web/app/components/base/avatar/index.tsx +++ b/web/app/components/base/avatar/index.tsx @@ -9,6 +9,7 @@ export type AvatarProps = { className?: string textClassName?: string onError?: (x: boolean) => void + backgroundColor?: string } const Avatar = ({ name, @@ -17,9 +18,18 @@ const Avatar = ({ className, textClassName, onError, + backgroundColor, }: AvatarProps) => { - const avatarClassName = 'shrink-0 flex items-center rounded-full bg-primary-600' - const style = { width: `${size}px`, height: `${size}px`, fontSize: `${size}px`, lineHeight: `${size}px` } + const avatarClassName = backgroundColor + ? 'shrink-0 flex items-center rounded-full' + : 'shrink-0 flex items-center rounded-full bg-primary-600' + const style = { + width: `${size}px`, + height: `${size}px`, + fontSize: `${size}px`, + lineHeight: `${size}px`, + ...(backgroundColor && !avatar ? { backgroundColor } : {}), + } const [imgError, setImgError] = useState(false) const handleError = () => { @@ -35,14 +45,18 @@ const Avatar = ({ if (avatar && !imgError) { return ( - {name} onError?.(false)} - /> + > + {name} onError?.(false)} + /> + ) } diff --git a/web/app/components/base/content-dialog/index.tsx b/web/app/components/base/content-dialog/index.tsx index 5efab57a40..588ef67bc1 100644 --- a/web/app/components/base/content-dialog/index.tsx +++ b/web/app/components/base/content-dialog/index.tsx @@ -15,11 +15,12 @@ const ContentDialog = ({ onClose, children, }: ContentDialogProps) => { + // z-[70]: Ensures dialog appears above workflow operators (z-[60]) and other UI elements return (
+ + + diff --git a/web/app/components/base/icons/assets/public/other/comment.svg b/web/app/components/base/icons/assets/public/other/comment.svg new file mode 100644 index 0000000000..7f48f22fbd --- /dev/null +++ b/web/app/components/base/icons/assets/public/other/comment.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/web/app/components/base/icons/src/public/common/EnterKey.json b/web/app/components/base/icons/src/public/common/EnterKey.json new file mode 100644 index 0000000000..17c8e645ae --- /dev/null +++ b/web/app/components/base/icons/src/public/common/EnterKey.json @@ -0,0 +1,36 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "16", + "height": "16", + "viewBox": "0 0 16 16", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M0 4C0 1.79086 1.79086 0 4 0H12C14.2091 0 16 1.79086 16 4V12C16 14.2091 14.2091 16 12 16H4C1.79086 16 0 14.2091 0 12V4Z", + "fill": "white", + "fill-opacity": "0.12" + }, + "children": [] + }, + { + "type": "element", + "name": "path", + "attributes": { + "d": "M3.42756 8.7358V7.62784H10.8764C11.2003 7.62784 11.4957 7.5483 11.7628 7.3892C12.0298 7.23011 12.2415 7.01705 12.3977 6.75C12.5568 6.48295 12.6364 6.1875 12.6364 5.86364C12.6364 5.53977 12.5568 5.24574 12.3977 4.98153C12.2386 4.71449 12.0256 4.50142 11.7585 4.34233C11.4943 4.18324 11.2003 4.10369 10.8764 4.10369H10.3991V3H10.8764C11.4048 3 11.8849 3.12926 12.3168 3.38778C12.7486 3.64631 13.0938 3.99148 13.3523 4.4233C13.6108 4.85511 13.7401 5.33523 13.7401 5.86364C13.7401 6.25852 13.6648 6.62926 13.5142 6.97585C13.3665 7.32244 13.1619 7.62784 12.9006 7.89205C12.6392 8.15625 12.3352 8.36364 11.9886 8.5142C11.642 8.66193 11.2713 8.7358 10.8764 8.7358H3.42756ZM6.16761 12.0554L2.29403 8.18182L6.16761 4.30824L6.9304 5.07102L3.81534 8.18182L6.9304 11.2926L6.16761 12.0554Z", + "fill": "white" + }, + "children": [] + } + ] + }, + "name": "EnterKey" +} diff --git a/web/app/components/base/icons/src/public/common/EnterKey.tsx b/web/app/components/base/icons/src/public/common/EnterKey.tsx new file mode 100644 index 0000000000..5365f48344 --- /dev/null +++ b/web/app/components/base/icons/src/public/common/EnterKey.tsx @@ -0,0 +1,20 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './EnterKey.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconData } from '@/app/components/base/icons/IconBase' + +const Icon = ( + { + ref, + ...props + }: React.SVGProps & { + ref?: React.RefObject>; + }, +) => + +Icon.displayName = 'EnterKey' + +export default Icon diff --git a/web/app/components/base/icons/src/public/common/index.ts b/web/app/components/base/icons/src/public/common/index.ts index e672e52613..f4769d42de 100644 --- a/web/app/components/base/icons/src/public/common/index.ts +++ b/web/app/components/base/icons/src/public/common/index.ts @@ -1,6 +1,7 @@ export { default as D } from './D' export { default as DiagonalDividingLine } from './DiagonalDividingLine' export { default as Dify } from './Dify' +export { default as EnterKey } from './EnterKey' export { default as Gdpr } from './Gdpr' export { default as Github } from './Github' export { default as Highlight } from './Highlight' diff --git a/web/app/components/base/icons/src/public/other/Comment.json b/web/app/components/base/icons/src/public/other/Comment.json new file mode 100644 index 0000000000..c4865a010c --- /dev/null +++ b/web/app/components/base/icons/src/public/other/Comment.json @@ -0,0 +1,26 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "xmlns": "http://www.w3.org/2000/svg", + "width": "14", + "height": "12", + "viewBox": "0 0 14 12", + "fill": "none" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M12.3334 4C12.3334 2.52725 11.1395 1.33333 9.66671 1.33333H4.33337C2.86062 1.33333 1.66671 2.52724 1.66671 4V10.6667H9.66671C11.1395 10.6667 12.3334 9.47274 12.3334 8V4ZM7.66671 6.66667V8H4.33337V6.66667H7.66671ZM9.66671 4V5.33333H4.33337V4H9.66671ZM13.6667 8C13.6667 10.2091 11.8758 12 9.66671 12H0.333374V4C0.333374 1.79086 2.12424 0 4.33337 0H9.66671C11.8758 0 13.6667 1.79086 13.6667 4V8Z", + "fill": "currentColor" + }, + "children": [] + } + ] + }, + "name": "Comment" +} diff --git a/web/app/components/base/icons/src/public/other/Comment.tsx b/web/app/components/base/icons/src/public/other/Comment.tsx new file mode 100644 index 0000000000..85c2559b76 --- /dev/null +++ b/web/app/components/base/icons/src/public/other/Comment.tsx @@ -0,0 +1,20 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './Comment.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconData } from '@/app/components/base/icons/IconBase' + +const Icon = ( + { + ref, + ...props + }: React.SVGProps & { + ref?: React.RefObject>; + }, +) => + +Icon.displayName = 'Comment' + +export default Icon diff --git a/web/app/components/base/icons/src/public/other/index.ts b/web/app/components/base/icons/src/public/other/index.ts index a7558ca0ab..3637525202 100644 --- a/web/app/components/base/icons/src/public/other/index.ts +++ b/web/app/components/base/icons/src/public/other/index.ts @@ -1,4 +1,5 @@ export { default as Icon3Dots } from './Icon3Dots' +export { default as Comment } from './Comment' export { default as DefaultToolIcon } from './DefaultToolIcon' export { default as Message3Fill } from './Message3Fill' export { default as RowStruct } from './RowStruct' diff --git a/web/app/components/base/prompt-editor/index.tsx b/web/app/components/base/prompt-editor/index.tsx index 50fdc1f920..0b73a7b8c9 100644 --- a/web/app/components/base/prompt-editor/index.tsx +++ b/web/app/components/base/prompt-editor/index.tsx @@ -2,6 +2,7 @@ import type { FC } from 'react' import React, { useEffect } from 'react' +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import type { EditorState, } from 'lexical' @@ -80,6 +81,29 @@ import { import { useEventEmitterContextContext } from '@/context/event-emitter' import cn from '@/utils/classnames' +const ValueSyncPlugin: FC<{ value?: string }> = ({ value }) => { + const [editor] = useLexicalComposerContext() + + useEffect(() => { + if (value === undefined) + return + + const incomingValue = value ?? '' + const shouldUpdate = editor.getEditorState().read(() => { + const currentText = $getRoot().getChildren().map(node => node.getTextContent()).join('\n') + return currentText !== incomingValue + }) + + if (!shouldUpdate) + return + + const editorState = editor.parseEditorState(textToEditorState(incomingValue)) + editor.setEditorState(editorState) + }, [editor, value]) + + return null +} + export type PromptEditorProps = { instanceId?: string compact?: boolean @@ -293,6 +317,7 @@ const PromptEditor: FC = ({ ) } + diff --git a/web/app/components/base/user-avatar-list/index.tsx b/web/app/components/base/user-avatar-list/index.tsx new file mode 100644 index 0000000000..b0d1989521 --- /dev/null +++ b/web/app/components/base/user-avatar-list/index.tsx @@ -0,0 +1,77 @@ +import type { FC } from 'react' +import { memo } from 'react' +import { getUserColor } from '@/app/components/workflow/collaboration/utils/user-color' +import { useAppContext } from '@/context/app-context' +import Avatar from '@/app/components/base/avatar' + +type User = { + id: string + name: string + avatar_url?: string | null +} + +type UserAvatarListProps = { + users: User[] + maxVisible?: number + size?: number + className?: string + showCount?: boolean +} + +export const UserAvatarList: FC = memo(({ + users, + maxVisible = 3, + size = 24, + className = '', + showCount = true, +}) => { + const { userProfile } = useAppContext() + if (!users.length) return null + + const shouldShowCount = showCount && users.length > maxVisible + const actualMaxVisible = shouldShowCount ? Math.max(1, maxVisible - 1) : maxVisible + const visibleUsers = users.slice(0, actualMaxVisible) + const remainingCount = users.length - actualMaxVisible + + const currentUserId = userProfile?.id + + return ( +
+ {visibleUsers.map((user, index) => { + const isCurrentUser = user.id === currentUserId + const userColor = isCurrentUser ? undefined : getUserColor(user.id) + return ( +
+ +
+ ) + }, + + )} + {shouldShowCount && remainingCount > 0 && ( +
+ +{remainingCount} +
+ )} +
+ ) +}) + +UserAvatarList.displayName = 'UserAvatarList' diff --git a/web/app/components/datasets/hit-testing/index.tsx b/web/app/components/datasets/hit-testing/index.tsx index ffda65e671..8bc33371cd 100644 --- a/web/app/components/datasets/hit-testing/index.tsx +++ b/web/app/components/datasets/hit-testing/index.tsx @@ -49,7 +49,7 @@ const HitTestingPage: FC = ({ datasetId }: Props) => { const media = useBreakpoints() const isMobile = media === MediaType.mobile - const [hitResult, setHitResult] = useState() // 初始化记录为空数组 + const [hitResult, setHitResult] = useState() // Initialize records as empty array const [externalHitResult, setExternalHitResult] = useState() const [submitLoading, setSubmitLoading] = useState(false) const [text, setText] = useState('') diff --git a/web/app/components/header/account-setting/menu-dialog.tsx b/web/app/components/header/account-setting/menu-dialog.tsx index ad3a1e7109..c0f89c3c7d 100644 --- a/web/app/components/header/account-setting/menu-dialog.tsx +++ b/web/app/components/header/account-setting/menu-dialog.tsx @@ -35,7 +35,7 @@ const MenuDialog = ({ return ( - +
diff --git a/web/app/components/tools/mcp/mcp-server-modal.tsx b/web/app/components/tools/mcp/mcp-server-modal.tsx index 11af81ec1a..659db2d737 100644 --- a/web/app/components/tools/mcp/mcp-server-modal.tsx +++ b/web/app/components/tools/mcp/mcp-server-modal.tsx @@ -16,6 +16,7 @@ import { useUpdateMCPServer, } from '@/service/use-tools' import cn from '@/utils/classnames' +import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager' export type ModalProps = { appID: string @@ -59,6 +60,21 @@ const MCPServerModal = ({ return res } + const emitMcpServerUpdate = (action: 'created' | 'updated') => { + const socket = webSocketClient.getSocket(appID) + if (!socket) return + + const timestamp = Date.now() + socket.emit('collaboration_event', { + type: 'mcp_server_update', + data: { + action, + timestamp, + }, + timestamp, + }) + } + const submit = async () => { if (!data) { const payload: any = { @@ -71,6 +87,7 @@ const MCPServerModal = ({ await createMCPServer(payload) invalidateMCPServerDetail(appID) + emitMcpServerUpdate('created') onHide() } else { @@ -83,6 +100,7 @@ const MCPServerModal = ({ payload.description = description await updateMCPServer(payload) invalidateMCPServerDetail(appID) + emitMcpServerUpdate('updated') onHide() } } @@ -92,6 +110,7 @@ const MCPServerModal = ({ isShow={show} onClose={onHide} className={cn('relative !max-w-[520px] !p-0')} + highPriority >
diff --git a/web/app/components/tools/mcp/mcp-service-card.tsx b/web/app/components/tools/mcp/mcp-service-card.tsx index 1f40b1e4b3..de027ce453 100644 --- a/web/app/components/tools/mcp/mcp-service-card.tsx +++ b/web/app/components/tools/mcp/mcp-service-card.tsx @@ -27,6 +27,8 @@ import { BlockEnum } from '@/app/components/workflow/types' import cn from '@/utils/classnames' import { fetchAppDetail } from '@/service/apps' import { useDocLink } from '@/context/i18n' +import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager' +import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' export type IAppCardProps = { appInfo: AppDetailResponse & Partial @@ -97,6 +99,19 @@ function MCPServiceCard({ const onGenCode = async () => { await refreshMCPServerCode(detail?.id || '') invalidateMCPServerDetail(appId) + + // Emit collaboration event to notify other clients of MCP server changes + const socket = webSocketClient.getSocket(appId) + if (socket) { + socket.emit('collaboration_event', { + type: 'mcp_server_update', + data: { + action: 'codeRegenerated', + timestamp: Date.now(), + }, + timestamp: Date.now(), + }) + } } const onChangeStatus = async (state: boolean) => { @@ -126,6 +141,20 @@ function MCPServiceCard({ }) invalidateMCPServerDetail(appId) } + + // Emit collaboration event to notify other clients of MCP server status change + const socket = webSocketClient.getSocket(appId) + if (socket) { + socket.emit('collaboration_event', { + type: 'mcp_server_update', + data: { + action: 'statusChanged', + status: state ? 'active' : 'inactive', + timestamp: Date.now(), + }, + timestamp: Date.now(), + }) + } } const handleServerModalHide = () => { @@ -138,6 +167,23 @@ function MCPServiceCard({ setActivated(serverActivated) }, [serverActivated]) + // Listen for collaborative MCP server updates from other clients + useEffect(() => { + if (!appId) return + + const unsubscribe = collaborationManager.onMcpServerUpdate(async (update: any) => { + try { + console.log('Received MCP server update from collaboration:', update) + invalidateMCPServerDetail(appId) + } + catch (error) { + console.error('MCP server update failed:', error) + } + }) + + return unsubscribe + }, [appId, invalidateMCPServerDetail]) + if (!currentWorkflow && isAdvancedApp) return null diff --git a/web/app/components/workflow-app/components/workflow-main.tsx b/web/app/components/workflow-app/components/workflow-main.tsx index e90b2904c9..d63c717947 100644 --- a/web/app/components/workflow-app/components/workflow-main.tsx +++ b/web/app/components/workflow-app/components/workflow-main.tsx @@ -1,11 +1,18 @@ import { useCallback, + useEffect, useMemo, + useRef, + useState, } from 'react' import { useFeaturesStore } from '@/app/components/base/features/hooks' +import type { Features as FeaturesData } from '@/app/components/base/features/types' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' +import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' import { WorkflowWithInnerContext } from '@/app/components/workflow' import type { WorkflowProps } from '@/app/components/workflow' import WorkflowChildren from './workflow-children' + import { useAvailableNodesMetaData, useConfigsMap, @@ -18,7 +25,12 @@ import { useWorkflowRun, useWorkflowStartRun, } from '../hooks' -import { useWorkflowStore } from '@/app/components/workflow/store' +import { useWorkflowUpdate } from '@/app/components/workflow/hooks/use-workflow-interactions' +import { useStore, useWorkflowStore } from '@/app/components/workflow/store' +import { useCollaboration } from '@/app/components/workflow/collaboration' +import { collaborationManager } from '@/app/components/workflow/collaboration' +import { fetchWorkflowDraft } from '@/service/workflow' +import { useReactFlow, useStoreApi } from 'reactflow' type WorkflowMainProps = Pick const WorkflowMain = ({ @@ -28,6 +40,43 @@ const WorkflowMain = ({ }: WorkflowMainProps) => { const featuresStore = useFeaturesStore() const workflowStore = useWorkflowStore() + const appId = useStore(s => s.appId) + const containerRef = useRef(null) + const reactFlow = useReactFlow() + + const store = useStoreApi() + const { + startCursorTracking, + stopCursorTracking, + onlineUsers, + cursors, + isConnected, + isEnabled: isCollaborationEnabled, + } = useCollaboration(appId || '', store) + const [myUserId, setMyUserId] = useState(null) + + useEffect(() => { + if (isCollaborationEnabled && isConnected) + setMyUserId('current-user') + else + setMyUserId(null) + }, [isCollaborationEnabled, isConnected]) + + const filteredCursors = Object.fromEntries( + Object.entries(cursors).filter(([userId]) => userId !== myUserId), + ) + + useEffect(() => { + if (!isCollaborationEnabled) + return + + if (containerRef.current) + startCursorTracking(containerRef as React.RefObject, reactFlow) + + return () => { + stopCursorTracking() + } + }, [startCursorTracking, stopCursorTracking, reactFlow, isCollaborationEnabled]) const handleWorkflowDataUpdate = useCallback((payload: any) => { const { @@ -38,7 +87,33 @@ const WorkflowMain = ({ if (features && featuresStore) { const { setFeatures } = featuresStore.getState() - setFeatures(features) + const transformedFeatures: FeaturesData = { + file: { + image: { + enabled: !!features.file_upload?.image?.enabled, + number_limits: features.file_upload?.image?.number_limits || 3, + transfer_methods: features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + }, + enabled: !!(features.file_upload?.enabled || features.file_upload?.image?.enabled), + allowed_file_types: features.file_upload?.allowed_file_types || [SupportUploadFileTypes.image], + allowed_file_extensions: features.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`), + allowed_file_upload_methods: features.file_upload?.allowed_file_upload_methods || features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + number_limits: features.file_upload?.number_limits || features.file_upload?.image?.number_limits || 3, + }, + opening: { + enabled: !!features.opening_statement, + opening_statement: features.opening_statement, + suggested_questions: features.suggested_questions, + }, + suggested: features.suggested_questions_after_answer || { enabled: false }, + speech2text: features.speech_to_text || { enabled: false }, + text2speech: features.text_to_speech || { enabled: false }, + citation: features.retriever_resource || { enabled: false }, + moderation: features.sensitive_word_avoidance || { enabled: false }, + annotationReply: features.annotation_reply || { enabled: false }, + } + + setFeatures(transformedFeatures) } if (conversation_variables) { const { setConversationVariables } = workflowStore.getState() @@ -55,6 +130,7 @@ const WorkflowMain = ({ syncWorkflowDraftWhenPageClose, } = useNodesSyncDraft() const { handleRefreshWorkflowDraft } = useWorkflowRefreshDraft() + const { handleUpdateWorkflowCanvas } = useWorkflowUpdate() const { handleBackupDraft, handleLoadBackupDraft, @@ -62,6 +138,63 @@ const WorkflowMain = ({ handleRun, handleStopRun, } = useWorkflowRun() + + useEffect(() => { + if (!appId || !isCollaborationEnabled) return + + const unsubscribe = collaborationManager.onVarsAndFeaturesUpdate(async (update: any) => { + try { + const response = await fetchWorkflowDraft(`/apps/${appId}/workflows/draft`) + handleWorkflowDataUpdate(response) + } + catch (error) { + console.error('workflow vars and features update failed:', error) + } + }) + + return unsubscribe + }, [appId, handleWorkflowDataUpdate, isCollaborationEnabled]) + + // Listen for workflow updates from other users + useEffect(() => { + if (!appId || !isCollaborationEnabled) return + + const unsubscribe = collaborationManager.onWorkflowUpdate(async () => { + console.log('Received workflow update from collaborator, fetching latest workflow data') + try { + const response = await fetchWorkflowDraft(`/apps/${appId}/workflows/draft`) + + // Handle features, variables etc. + handleWorkflowDataUpdate(response) + + // Update workflow canvas (nodes, edges, viewport) + if (response.graph) { + handleUpdateWorkflowCanvas({ + nodes: response.graph.nodes || [], + edges: response.graph.edges || [], + viewport: response.graph.viewport || { x: 0, y: 0, zoom: 1 }, + }) + } + } + catch (error) { + console.error('Failed to fetch updated workflow:', error) + } + }) + + return unsubscribe + }, [appId, handleWorkflowDataUpdate, handleUpdateWorkflowCanvas, isCollaborationEnabled]) + + // Listen for sync requests from other users (only processed by leader) + useEffect(() => { + if (!appId || !isCollaborationEnabled) return + + const unsubscribe = collaborationManager.onSyncRequest(() => { + console.log('Leader received sync request, performing sync') + doSyncWorkflowDraft() + }) + + return unsubscribe + }, [appId, doSyncWorkflowDraft, isCollaborationEnabled]) const { handleStartWorkflowRun, handleWorkflowStartRunInChatflow, @@ -79,6 +212,7 @@ const WorkflowMain = ({ } = useDSL() const configsMap = useConfigsMap() + const { fetchInspectVars } = useSetWorkflowVarsWithValue({ ...configsMap, }) @@ -176,15 +310,23 @@ const WorkflowMain = ({ ]) return ( - - - + + + +
) } diff --git a/web/app/components/workflow-app/components/workflow-panel.tsx b/web/app/components/workflow-app/components/workflow-panel.tsx index 6e0504710e..693e6e78cb 100644 --- a/web/app/components/workflow-app/components/workflow-panel.tsx +++ b/web/app/components/workflow-app/components/workflow-panel.tsx @@ -7,6 +7,7 @@ import { useStore } from '@/app/components/workflow/store' import { useIsChatMode, } from '../hooks' +import CommentsPanel from '@/app/components/workflow/panel/comments-panel' import { useStore as useAppStore } from '@/app/components/app/store' import type { PanelProps } from '@/app/components/workflow/panel' import Panel from '@/app/components/workflow/panel' @@ -67,6 +68,7 @@ const WorkflowPanelOnRight = () => { const showDebugAndPreviewPanel = useStore(s => s.showDebugAndPreviewPanel) const showChatVariablePanel = useStore(s => s.showChatVariablePanel) const showGlobalVariablePanel = useStore(s => s.showGlobalVariablePanel) + const controlMode = useStore(s => s.controlMode) return ( <> @@ -100,6 +102,7 @@ const WorkflowPanelOnRight = () => { ) } + {controlMode === 'comment' && } ) } diff --git a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts index 56d9021feb..beac3d327e 100644 --- a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts @@ -1,13 +1,16 @@ import { useCallback } from 'react' import { produce } from 'immer' import { useStoreApi } from 'reactflow' +import { useParams } from 'next/navigation' import { useWorkflowStore } from '@/app/components/workflow/store' import { useNodesReadOnly } from '@/app/components/workflow/hooks/use-workflow' import { useSerialAsyncCallback } from '@/app/components/workflow/hooks/use-serial-async-callback' -import { syncWorkflowDraft } from '@/service/workflow' +import { type WorkflowDraftFeaturesPayload, syncWorkflowDraft } from '@/service/workflow' import { useFeaturesStore } from '@/app/components/base/features/hooks' import { API_PREFIX } from '@/config' import { useWorkflowRefreshDraft } from '.' +import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' +import { useGlobalPublicStore } from '@/context/global-public-context' export const useNodesSyncDraft = () => { const store = useStoreApi() @@ -15,6 +18,8 @@ export const useNodesSyncDraft = () => { const featuresStore = useFeaturesStore() const { getNodesReadOnly } = useNodesReadOnly() const { handleRefreshWorkflowDraft } = useWorkflowRefreshDraft() + const params = useParams() + const isCollaborationEnabled = useGlobalPublicStore(s => s.systemFeatures.enable_collaboration_mode) const getPostParams = useCallback(() => { const { @@ -52,7 +57,16 @@ export const useNodesSyncDraft = () => { }) }) }) - const viewport = { x, y, zoom } + const featuresPayload: WorkflowDraftFeaturesPayload = { + opening_statement: features.opening?.enabled ? (features.opening?.opening_statement || '') : '', + suggested_questions: features.opening?.enabled ? (features.opening?.suggested_questions || []) : [], + suggested_questions_after_answer: features.suggested, + text_to_speech: features.text2speech, + speech_to_text: features.speech2text, + retriever_resource: features.citation, + sensitive_word_avoidance: features.moderation, + file_upload: features.file, + } return { url: `/apps/${appId}/workflows/draft`, @@ -60,33 +74,44 @@ export const useNodesSyncDraft = () => { graph: { nodes: producedNodes, edges: producedEdges, - viewport, - }, - features: { - opening_statement: features.opening?.enabled ? (features.opening?.opening_statement || '') : '', - suggested_questions: features.opening?.enabled ? (features.opening?.suggested_questions || []) : [], - suggested_questions_after_answer: features.suggested, - text_to_speech: features.text2speech, - speech_to_text: features.speech2text, - retriever_resource: features.citation, - sensitive_word_avoidance: features.moderation, - file_upload: features.file, + viewport: { + x, + y, + zoom, + }, }, + features: featuresPayload, environment_variables: environmentVariables, conversation_variables: conversationVariables, hash: syncWorkflowDraftHash, + _is_collaborative: isCollaborationEnabled, }, } - }, [store, featuresStore, workflowStore]) + }, [store, featuresStore, workflowStore, isCollaborationEnabled]) const syncWorkflowDraftWhenPageClose = useCallback(() => { if (getNodesReadOnly()) return + + // Check leader status at sync time + const currentIsLeader = isCollaborationEnabled ? collaborationManager.getIsLeader() : true + + // Only allow leader to sync data + if (isCollaborationEnabled && !currentIsLeader) { + console.log('Not leader, skipping sync on page close') + return + } + const postParams = getPostParams() - if (postParams) - navigator.sendBeacon(`${API_PREFIX}${postParams.url}`, JSON.stringify(postParams.params)) - }, [getPostParams, getNodesReadOnly]) + if (postParams) { + console.log('Leader syncing workflow draft on page close') + navigator.sendBeacon( + `${API_PREFIX}/apps/${params.appId}/workflows/draft`, + JSON.stringify(postParams.params), + ) + } + }, [getPostParams, params.appId, getNodesReadOnly, isCollaborationEnabled]) const performSync = useCallback(async ( notRefreshWhenSyncError?: boolean, @@ -95,9 +120,24 @@ export const useNodesSyncDraft = () => { onError?: () => void onSettled?: () => void }, + forceUpload?: boolean, ) => { if (getNodesReadOnly()) return + + // Check leader status at sync time + const currentIsLeader = isCollaborationEnabled ? collaborationManager.getIsLeader() : true + + // If not leader and not forcing upload, request the leader to sync + if (isCollaborationEnabled && !currentIsLeader && !forceUpload) { + console.log('Not leader, requesting leader to sync workflow draft') + if (isCollaborationEnabled) + collaborationManager.emitSyncRequest() + callback?.onSettled?.() + return + } + + console.log(forceUpload ? 'Force uploading workflow draft' : 'Leader performing workflow draft sync') const postParams = getPostParams() if (postParams) { @@ -105,17 +145,31 @@ export const useNodesSyncDraft = () => { setSyncWorkflowDraftHash, setDraftUpdatedAt, } = workflowStore.getState() + + // Add force_upload parameter if needed + const finalParams = { + ...postParams.params, + ...(forceUpload && { force_upload: true }), + } + try { - const res = await syncWorkflowDraft(postParams) + const res = await syncWorkflowDraft({ + url: postParams.url, + params: finalParams, + }) setSyncWorkflowDraftHash(res.hash) setDraftUpdatedAt(res.updated_at) + console.log('Leader successfully synced workflow draft') callback?.onSuccess?.() } catch (error: any) { + console.error('Leader failed to sync workflow draft:', error) if (error && error.json && !error.bodyUsed) { error.json().then((err: any) => { - if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError) + if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError) { + console.error('draft_workflow_not_sync', err) handleRefreshWorkflowDraft() + } }) } callback?.onError?.() @@ -124,7 +178,7 @@ export const useNodesSyncDraft = () => { callback?.onSettled?.() } } - }, [workflowStore, getPostParams, getNodesReadOnly, handleRefreshWorkflowDraft]) + }, [workflowStore, getPostParams, getNodesReadOnly, handleRefreshWorkflowDraft, isCollaborationEnabled]) const doSyncWorkflowDraft = useSerialAsyncCallback(performSync, getNodesReadOnly) diff --git a/web/app/components/workflow-app/index.tsx b/web/app/components/workflow-app/index.tsx index fcd247ef22..d5242c76e4 100644 --- a/web/app/components/workflow-app/index.tsx +++ b/web/app/components/workflow-app/index.tsx @@ -30,6 +30,7 @@ import { import type { InjectWorkflowStoreSliceFn } from '@/app/components/workflow/store' import { createWorkflowSlice } from './store/workflow/workflow-slice' import WorkflowAppMain from './components/workflow-main' +import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' import { useSearchParams } from 'next/navigation' import { fetchRunDetail } from '@/service/log' @@ -83,15 +84,20 @@ const WorkflowAppWithAdditionalContext = () => { }, [workflowStore]) const nodesData = useMemo(() => { - if (data) - return initialNodes(data.graph.nodes, data.graph.edges) - + if (data) { + const processedNodes = initialNodes(data.graph.nodes, data.graph.edges) + collaborationManager.setNodes([], processedNodes) + return processedNodes + } return [] }, [data]) - const edgesData = useMemo(() => { - if (data) - return initialEdges(data.graph.edges, data.graph.nodes) + const edgesData = useMemo(() => { + if (data) { + const processedEdges = initialEdges(data.graph.edges, data.graph.nodes) + collaborationManager.setEdges([], processedEdges) + return processedEdges + } return [] }, [data]) diff --git a/web/app/components/workflow/candidate-node.tsx b/web/app/components/workflow/candidate-node.tsx index 54daf13ebc..e2e70add69 100644 --- a/web/app/components/workflow/candidate-node.tsx +++ b/web/app/components/workflow/candidate-node.tsx @@ -4,7 +4,6 @@ import { import { produce } from 'immer' import { useReactFlow, - useStoreApi, useViewport, } from 'reactflow' import { useEventListener } from 'ahooks' @@ -19,9 +18,9 @@ import CustomNode from './nodes' import CustomNoteNode from './note-node' import { CUSTOM_NOTE_NODE } from './note-node/constants' import { BlockEnum } from './types' +import { useCollaborativeWorkflow } from '@/app/components/workflow/hooks/use-collaborative-workflow' const CandidateNode = () => { - const store = useStoreApi() const reactflow = useReactFlow() const workflowStore = useWorkflowStore() const candidateNode = useStore(s => s.candidateNode) @@ -31,18 +30,15 @@ const CandidateNode = () => { const { saveStateToHistory } = useWorkflowHistory() const { handleSyncWorkflowDraft } = useNodesSyncDraft() const autoGenerateWebhookUrl = useAutoGenerateWebhookUrl() + const collaborativeWorkflow = useCollaborativeWorkflow() useEventListener('click', (e) => { const { candidateNode, mousePosition } = workflowStore.getState() if (candidateNode) { e.preventDefault() - const { - getNodes, - setNodes, - } = store.getState() + const { nodes, setNodes } = collaborativeWorkflow.getState() const { screenToFlowPosition } = reactflow - const nodes = getNodes() const { x, y } = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY }) const newNodes = produce(nodes, (draft) => { draft.push({ diff --git a/web/app/components/workflow/collaboration/components/user-cursors.tsx b/web/app/components/workflow/collaboration/components/user-cursors.tsx new file mode 100644 index 0000000000..fff4d5bb8c --- /dev/null +++ b/web/app/components/workflow/collaboration/components/user-cursors.tsx @@ -0,0 +1,78 @@ +import type { FC } from 'react' +import { useViewport } from 'reactflow' +import type { CursorPosition, OnlineUser } from '@/app/components/workflow/collaboration/types' +import { getUserColor } from '../utils/user-color' + +type UserCursorsProps = { + cursors: Record + myUserId: string | null + onlineUsers: OnlineUser[] +} + +const UserCursors: FC = ({ + cursors, + myUserId, + onlineUsers, +}) => { + const viewport = useViewport() + + const convertToScreenCoordinates = (cursor: CursorPosition) => { + // Convert world coordinates to screen coordinates using current viewport + const screenX = cursor.x * viewport.zoom + viewport.x + const screenY = cursor.y * viewport.zoom + viewport.y + + return { x: screenX, y: screenY } + } + return ( + <> + {Object.entries(cursors || {}).map(([userId, cursor]) => { + if (userId === myUserId) + return null + + const userInfo = onlineUsers.find(user => user.user_id === userId) + const userName = userInfo?.username || `User ${userId.slice(-4)}` + const userColor = getUserColor(userId) + const screenPos = convertToScreenCoordinates(cursor) + + return ( +
+ + + + +
+ {userName} +
+
+ ) + })} + + ) +} + +export default UserCursors diff --git a/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.merge-behavior.test.ts b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.merge-behavior.test.ts new file mode 100644 index 0000000000..7893ed74a4 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.merge-behavior.test.ts @@ -0,0 +1,239 @@ +import { LoroDoc } from 'loro-crdt' +import { CollaborationManager } from '../collaboration-manager' +import type { Node } from '@/app/components/workflow/types' +import { BlockEnum } from '@/app/components/workflow/types' + +const NODE_ID = 'node-1' +const LLM_NODE_ID = 'llm-node' +const PARAM_NODE_ID = 'parameter-node' + +const createNode = (variables: string[]): Node => ({ + id: NODE_ID, + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.Start, + title: 'Start', + desc: '', + variables: variables.map(name => ({ + variable: name, + label: name, + type: 'text-input', + required: true, + default: '', + max_length: 48, + placeholder: '', + options: [], + hint: '', + })), + }, +}) + +const createLLMNode = (templates: Array<{ id: string; role: string; text: string }>): Node => ({ + id: LLM_NODE_ID, + type: 'custom', + position: { x: 200, y: 200 }, + data: { + type: BlockEnum.LLM, + title: 'LLM', + selected: false, + model: { + mode: 'chat', + name: 'gemini-2.5-pro', + provider: 'langgenius/gemini/google', + completion_params: { + temperature: 0.7, + }, + }, + context: { + enabled: false, + variable_selector: [], + }, + vision: { + enabled: false, + }, + prompt_template: templates, + }, +}) + +const createParameterExtractorNode = (parameters: Array<{ description: string; name: string; required: boolean; type: string }>): Node => ({ + id: PARAM_NODE_ID, + type: 'custom', + position: { x: 400, y: 120 }, + data: { + type: BlockEnum.ParameterExtractor, + title: 'ParameterExtractor', + selected: true, + model: { + mode: 'chat', + name: '', + provider: '', + completion_params: { + temperature: 0.7, + }, + }, + query: [], + reasoning_mode: 'prompt', + parameters, + vision: { + enabled: false, + }, + }, +}) + +const getManager = (doc: LoroDoc) => { + const manager = new CollaborationManager() + ;(manager as any).doc = doc + ;(manager as any).nodesMap = doc.getMap('nodes') + ;(manager as any).edgesMap = doc.getMap('edges') + return manager +} + +const deepClone = (value: T): T => JSON.parse(JSON.stringify(value)) + +const exportNodes = (manager: CollaborationManager) => manager.getNodes() + +describe('Loro merge behavior smoke test', () => { + it('inspects concurrent edits after merge', () => { + const docA = new LoroDoc() + const managerA = getManager(docA) + managerA.syncNodes([], [createNode(['a'])]) + + const snapshot = docA.export({ mode: 'snapshot' }) + + const docB = LoroDoc.fromSnapshot(snapshot) + const managerB = getManager(docB) + + managerA.syncNodes([createNode(['a'])], [createNode(['a', 'b'])]) + managerB.syncNodes([createNode(['a'])], [createNode(['a', 'c'])]) + + const updateForA = docB.export({ mode: 'update', from: docA.version() }) + docA.import(updateForA) + + const updateForB = docA.export({ mode: 'update', from: docB.version() }) + docB.import(updateForB) + + const finalA = exportNodes(managerA) + const finalB = exportNodes(managerB) + + console.log('Final nodes on docA:', JSON.stringify(finalA, null, 2)) + + console.log('Final nodes on docB:', JSON.stringify(finalB, null, 2)) + expect(finalA.length).toBe(1) + expect(finalB.length).toBe(1) + }) + + it('merges prompt template insertions and edits across replicas', () => { + const baseTemplate = [ + { + id: 'system-1', + role: 'system', + text: 'base instruction', + }, + ] + + const docA = new LoroDoc() + const managerA = getManager(docA) + managerA.syncNodes([], [createLLMNode(deepClone(baseTemplate))]) + + const snapshot = docA.export({ mode: 'snapshot' }) + const docB = LoroDoc.fromSnapshot(snapshot) + const managerB = getManager(docB) + + const additionTemplate = [ + ...baseTemplate, + { + id: 'user-1', + role: 'user', + text: 'hello from docA', + }, + ] + managerA.syncNodes([createLLMNode(deepClone(baseTemplate))], [createLLMNode(deepClone(additionTemplate))]) + + const editedTemplate = [ + { + id: 'system-1', + role: 'system', + text: 'updated by docB', + }, + ] + managerB.syncNodes([createLLMNode(deepClone(baseTemplate))], [createLLMNode(deepClone(editedTemplate))]) + + const updateForA = docB.export({ mode: 'update', from: docA.version() }) + docA.import(updateForA) + + const updateForB = docA.export({ mode: 'update', from: docB.version() }) + docB.import(updateForB) + + const finalA = exportNodes(managerA).find(node => node.id === LLM_NODE_ID) + const finalB = exportNodes(managerB).find(node => node.id === LLM_NODE_ID) + + expect(finalA).toBeDefined() + expect(finalB).toBeDefined() + + const expectedTemplates = [ + { + id: 'system-1', + role: 'system', + text: 'updated by docB', + }, + { + id: 'user-1', + role: 'user', + text: 'hello from docA', + }, + ] + + expect((finalA!.data as any).prompt_template).toEqual(expectedTemplates) + expect((finalB!.data as any).prompt_template).toEqual(expectedTemplates) + }) + + it('converges when parameter lists are edited concurrently', () => { + const baseParameters = [ + { description: 'bb', name: 'aa', required: false, type: 'string' }, + { description: 'dd', name: 'cc', required: false, type: 'string' }, + ] + + const docA = new LoroDoc() + const managerA = getManager(docA) + managerA.syncNodes([], [createParameterExtractorNode(deepClone(baseParameters))]) + + const snapshot = docA.export({ mode: 'snapshot' }) + const docB = LoroDoc.fromSnapshot(snapshot) + const managerB = getManager(docB) + + const docAUpdate = [ + { description: 'bb updated by A', name: 'aa', required: true, type: 'string' }, + { description: 'dd', name: 'cc', required: false, type: 'string' }, + { description: 'new from A', name: 'ee', required: false, type: 'number' }, + ] + managerA.syncNodes([createParameterExtractorNode(deepClone(baseParameters))], [createParameterExtractorNode(deepClone(docAUpdate))]) + + const docBUpdate = [ + { description: 'bb', name: 'aa', required: false, type: 'string' }, + { description: 'dd updated by B', name: 'cc', required: true, type: 'string' }, + ] + managerB.syncNodes([createParameterExtractorNode(deepClone(baseParameters))], [createParameterExtractorNode(deepClone(docBUpdate))]) + + const updateForA = docB.export({ mode: 'update', from: docA.version() }) + docA.import(updateForA) + + const updateForB = docA.export({ mode: 'update', from: docB.version() }) + docB.import(updateForB) + + const finalA = exportNodes(managerA).find(node => node.id === PARAM_NODE_ID) + const finalB = exportNodes(managerB).find(node => node.id === PARAM_NODE_ID) + + expect(finalA).toBeDefined() + expect(finalB).toBeDefined() + + const expectedParameters = [ + { description: 'bb updated by A', name: 'aa', required: true, type: 'string' }, + { description: 'dd updated by B', name: 'cc', required: true, type: 'string' }, + { description: 'new from A', name: 'ee', required: false, type: 'number' }, + ] + + expect((finalA!.data as any).parameters).toEqual(expectedParameters) + expect((finalB!.data as any).parameters).toEqual(expectedParameters) + }) +}) diff --git a/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.test.ts b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.test.ts new file mode 100644 index 0000000000..e15e7e17e3 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/__tests__/collaboration-manager.test.ts @@ -0,0 +1,659 @@ +import { LoroDoc } from 'loro-crdt' +import { CollaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager' +import { BlockEnum } from '@/app/components/workflow/types' +import type { Edge, Node } from '@/app/components/workflow/types' +import type { NodePanelPresenceMap, NodePanelPresenceUser } from '@/app/components/workflow/collaboration/types/collaboration' + +const NODE_ID = '1760342909316' + +type WorkflowVariable = { + default: string + hint: string + label: string + max_length: number + options: string[] + placeholder: string + required: boolean + type: string + variable: string +} + +type PromptTemplateItem = { + id: string + role: string + text: string +} + +type ParameterItem = { + description: string + name: string + required: boolean + type: string +} + +const createVariable = (name: string, overrides: Partial = {}): WorkflowVariable => ({ + default: '', + hint: '', + label: name, + max_length: 48, + options: [], + placeholder: '', + required: true, + type: 'text-input', + variable: name, + ...overrides, +}) + +const deepClone = (value: T): T => JSON.parse(JSON.stringify(value)) + +const createNodeSnapshot = (variableNames: string[]): Node<{ variables: WorkflowVariable[] }> => ({ + id: NODE_ID, + type: 'custom', + position: { x: 0, y: 24 }, + positionAbsolute: { x: 0, y: 24 }, + height: 88, + width: 242, + selected: true, + selectable: true, + draggable: true, + sourcePosition: 'right', + targetPosition: 'left', + data: { + selected: true, + title: '开始', + desc: '', + type: BlockEnum.Start, + variables: variableNames.map(createVariable), + }, +}) + +const LLM_NODE_ID = 'llm-node' +const PARAM_NODE_ID = 'param-extractor-node' + +const createLLMNodeSnapshot = (promptTemplates: PromptTemplateItem[]): Node => ({ + id: LLM_NODE_ID, + type: 'custom', + position: { x: 200, y: 120 }, + positionAbsolute: { x: 200, y: 120 }, + height: 320, + width: 460, + selected: false, + selectable: true, + draggable: true, + sourcePosition: 'right', + targetPosition: 'left', + data: { + type: 'llm', + title: 'LLM', + selected: false, + context: { + enabled: false, + variable_selector: [], + }, + model: { + mode: 'chat', + name: 'gemini-2.5-pro', + provider: 'langgenius/gemini/google', + completion_params: { + temperature: 0.7, + }, + }, + vision: { + enabled: false, + }, + prompt_template: promptTemplates, + }, +}) + +const createParameterExtractorNodeSnapshot = (parameters: ParameterItem[]): Node => ({ + id: PARAM_NODE_ID, + type: 'custom', + position: { x: 420, y: 220 }, + positionAbsolute: { x: 420, y: 220 }, + height: 260, + width: 420, + selected: true, + selectable: true, + draggable: true, + sourcePosition: 'right', + targetPosition: 'left', + data: { + type: 'parameter-extractor', + title: '参数提取器', + selected: true, + model: { + mode: 'chat', + name: '', + provider: '', + completion_params: { + temperature: 0.7, + }, + }, + reasoning_mode: 'prompt', + parameters, + query: [], + vision: { + enabled: false, + }, + }, +}) + +const getVariables = (node: Node): string[] => { + const variables = (node.data as any)?.variables ?? [] + return variables.map((item: WorkflowVariable) => item.variable) +} + +const getVariableObject = (node: Node, name: string): WorkflowVariable | undefined => { + const variables = (node.data as any)?.variables ?? [] + return variables.find((item: WorkflowVariable) => item.variable === name) +} + +const getPromptTemplates = (node: Node): PromptTemplateItem[] => { + return ((node.data as any)?.prompt_template ?? []) as PromptTemplateItem[] +} + +const getParameters = (node: Node): ParameterItem[] => { + return ((node.data as any)?.parameters ?? []) as ParameterItem[] +} + +describe('CollaborationManager syncNodes', () => { + let manager: CollaborationManager + + beforeEach(() => { + manager = new CollaborationManager() + // Bypass private guards for targeted unit testing + const doc = new LoroDoc() + ;(manager as any).doc = doc + ;(manager as any).nodesMap = doc.getMap('nodes') + ;(manager as any).edgesMap = doc.getMap('edges') + + const initialNode = createNodeSnapshot(['a']) + ;(manager as any).syncNodes([], [deepClone(initialNode)]) + }) + + it('updates collaborators map when a single client adds a variable', () => { + const base = [createNodeSnapshot(['a'])] + const next = [createNodeSnapshot(['a', 'b'])] + + ;(manager as any).syncNodes(base, next) + + const stored = (manager.getNodes() as Node[]).find(node => node.id === NODE_ID) + expect(stored).toBeDefined() + expect(getVariables(stored!)).toEqual(['a', 'b']) + }) + + it('applies the latest parallel additions derived from the same base snapshot', () => { + const base = [createNodeSnapshot(['a'])] + const userA = [createNodeSnapshot(['a', 'b'])] + const userB = [createNodeSnapshot(['a', 'c'])] + + ;(manager as any).syncNodes(base, userA) + + const afterUserA = (manager.getNodes() as Node[]).find(node => node.id === NODE_ID) + expect(getVariables(afterUserA!)).toEqual(['a', 'b']) + + ;(manager as any).syncNodes(base, userB) + + const finalNode = (manager.getNodes() as Node[]).find(node => node.id === NODE_ID) + const finalVariables = getVariables(finalNode!) + + expect(finalVariables).toEqual(['a', 'c']) + }) + + it('prefers the incoming mutation when the same variable is edited concurrently', () => { + const base = [createNodeSnapshot(['a'])] + const userA = [ + { + ...createNodeSnapshot(['a']), + data: { + ...createNodeSnapshot(['a']).data, + variables: [ + createVariable('a', { label: 'A from userA', hint: 'hintA' }), + ], + }, + }, + ] + const userB = [ + { + ...createNodeSnapshot(['a']), + data: { + ...createNodeSnapshot(['a']).data, + variables: [ + createVariable('a', { label: 'A from userB', hint: 'hintB' }), + ], + }, + }, + ] + + ;(manager as any).syncNodes(base, userA) + ;(manager as any).syncNodes(base, userB) + + const finalNode = (manager.getNodes() as Node[]).find(node => node.id === NODE_ID) + const finalVariable = getVariableObject(finalNode!, 'a') + + expect(finalVariable?.label).toBe('A from userB') + expect(finalVariable?.hint).toBe('hintB') + }) + + it('reflects the last writer when concurrent removal and edits happen', () => { + const base = [createNodeSnapshot(['a', 'b'])] + ;(manager as any).syncNodes([], [deepClone(base[0])]) + const userA = [ + { + ...createNodeSnapshot(['a']), + data: { + ...createNodeSnapshot(['a']).data, + variables: [ + createVariable('a', { label: 'A after deletion' }), + ], + }, + }, + ] + const userB = [ + { + ...createNodeSnapshot(['a', 'b']), + data: { + ...createNodeSnapshot(['a']).data, + variables: [ + createVariable('a'), + createVariable('b', { label: 'B edited but should vanish' }), + ], + }, + }, + ] + + ;(manager as any).syncNodes(base, userA) + ;(manager as any).syncNodes(base, userB) + + const finalNode = (manager.getNodes() as Node[]).find(node => node.id === NODE_ID) + const finalVariables = getVariables(finalNode!) + expect(finalVariables).toEqual(['a', 'b']) + expect(getVariableObject(finalNode!, 'b')).toBeDefined() + }) + + it('synchronizes prompt_template list updates across collaborators', () => { + const promptManager = new CollaborationManager() + const doc = new LoroDoc() + ;(promptManager as any).doc = doc + ;(promptManager as any).nodesMap = doc.getMap('nodes') + ;(promptManager as any).edgesMap = doc.getMap('edges') + + const baseTemplate = [ + { + id: 'abcfa5f9-3c44-4252-aeba-4b6eaf0acfc4', + role: 'system', + text: 'avc', + }, + ] + + const baseNode = createLLMNodeSnapshot(baseTemplate) + ;(promptManager as any).syncNodes([], [deepClone(baseNode)]) + + const updatedTemplates = [ + ...baseTemplate, + { + id: 'user-1', + role: 'user', + text: 'hello world', + }, + ] + + const updatedNode = createLLMNodeSnapshot(updatedTemplates) + ;(promptManager as any).syncNodes([deepClone(baseNode)], [deepClone(updatedNode)]) + + const stored = (promptManager.getNodes() as Node[]).find(node => node.id === LLM_NODE_ID) + expect(stored).toBeDefined() + + const storedTemplates = getPromptTemplates(stored!) + expect(storedTemplates).toHaveLength(2) + expect(storedTemplates[0]).toEqual(baseTemplate[0]) + expect(storedTemplates[1]).toEqual(updatedTemplates[1]) + + const editedTemplates = [ + { + id: 'abcfa5f9-3c44-4252-aeba-4b6eaf0acfc4', + role: 'system', + text: 'updated system prompt', + }, + ] + const editedNode = createLLMNodeSnapshot(editedTemplates) + + ;(promptManager as any).syncNodes([deepClone(updatedNode)], [deepClone(editedNode)]) + + const final = (promptManager.getNodes() as Node[]).find(node => node.id === LLM_NODE_ID) + const finalTemplates = getPromptTemplates(final!) + expect(finalTemplates).toHaveLength(1) + expect(finalTemplates[0].text).toBe('updated system prompt') + }) + + it('keeps parameter list in sync when nodes add, edit, or remove parameters', () => { + const parameterManager = new CollaborationManager() + const doc = new LoroDoc() + ;(parameterManager as any).doc = doc + ;(parameterManager as any).nodesMap = doc.getMap('nodes') + ;(parameterManager as any).edgesMap = doc.getMap('edges') + + const baseParameters: ParameterItem[] = [ + { description: 'bb', name: 'aa', required: false, type: 'string' }, + { description: 'dd', name: 'cc', required: false, type: 'string' }, + ] + + const baseNode = createParameterExtractorNodeSnapshot(baseParameters) + ;(parameterManager as any).syncNodes([], [deepClone(baseNode)]) + + const updatedParameters: ParameterItem[] = [ + ...baseParameters, + { description: 'ff', name: 'ee', required: true, type: 'number' }, + ] + + const updatedNode = createParameterExtractorNodeSnapshot(updatedParameters) + ;(parameterManager as any).syncNodes([deepClone(baseNode)], [deepClone(updatedNode)]) + + const stored = (parameterManager.getNodes() as Node[]).find(node => node.id === PARAM_NODE_ID) + expect(stored).toBeDefined() + expect(getParameters(stored!)).toEqual(updatedParameters) + + const editedParameters: ParameterItem[] = [ + { description: 'bb edited', name: 'aa', required: true, type: 'string' }, + ] + const editedNode = createParameterExtractorNodeSnapshot(editedParameters) + + ;(parameterManager as any).syncNodes([deepClone(updatedNode)], [deepClone(editedNode)]) + + const final = (parameterManager.getNodes() as Node[]).find(node => node.id === PARAM_NODE_ID) + expect(getParameters(final!)).toEqual(editedParameters) + }) + + it('handles nodes without data gracefully', () => { + const emptyNode: Node = { + id: 'empty-node', + type: 'custom', + position: { x: 0, y: 0 }, + data: undefined as any, + } + + ;(manager as any).syncNodes([], [deepClone(emptyNode)]) + + const stored = (manager.getNodes() as Node[]).find(node => node.id === 'empty-node') + expect(stored).toBeDefined() + expect(stored?.data).toEqual({}) + }) + + it('preserves CRDT list instances when synchronizing parsed state back into the manager', () => { + const promptManager = new CollaborationManager() + const doc = new LoroDoc() + ;(promptManager as any).doc = doc + ;(promptManager as any).nodesMap = doc.getMap('nodes') + ;(promptManager as any).edgesMap = doc.getMap('edges') + + const base = createLLMNodeSnapshot([ + { id: 'system', role: 'system', text: 'base' }, + ]) + ;(promptManager as any).syncNodes([], [deepClone(base)]) + + const storedBefore = promptManager.getNodes().find(node => node.id === LLM_NODE_ID) + const firstTemplate = (storedBefore?.data as any).prompt_template?.[0] + expect(firstTemplate?.text).toBe('base') + + // simulate consumer mutating the plain JSON array and syncing back + const mutatedNode = deepClone(storedBefore!) + mutatedNode.data.prompt_template.push({ + id: 'user', + role: 'user', + text: 'mutated', + }) + + ;(promptManager as any).syncNodes([storedBefore], [mutatedNode]) + + const storedAfter = promptManager.getNodes().find(node => node.id === LLM_NODE_ID) + const templatesAfter = (storedAfter?.data as any).prompt_template + expect(Array.isArray(templatesAfter)).toBe(true) + expect(templatesAfter).toHaveLength(2) + }) + + it('reuses CRDT list when syncing parameters repeatedly', () => { + const parameterManager = new CollaborationManager() + const doc = new LoroDoc() + ;(parameterManager as any).doc = doc + ;(parameterManager as any).nodesMap = doc.getMap('nodes') + ;(parameterManager as any).edgesMap = doc.getMap('edges') + + const initialParameters: ParameterItem[] = [ + { description: 'desc', name: 'param', required: false, type: 'string' }, + ] + const node = createParameterExtractorNodeSnapshot(initialParameters) + ;(parameterManager as any).syncNodes([], [deepClone(node)]) + + const stored = parameterManager.getNodes().find(n => n.id === PARAM_NODE_ID)! + const mutatedNode = deepClone(stored) + mutatedNode.data.parameters[0].description = 'updated' + + ;(parameterManager as any).syncNodes([stored], [mutatedNode]) + + const storedAfter = parameterManager.getNodes().find(n => n.id === PARAM_NODE_ID)! + const params = (storedAfter.data as any).parameters + expect(params).toHaveLength(1) + expect(params[0].description).toBe('updated') + }) + + it('filters out transient/private data keys while keeping allowlisted ones', () => { + const nodeWithPrivate: Node = { + id: 'private-node', + type: 'custom', + position: { x: 0, y: 0 }, + data: { + type: BlockEnum.Start, + _foo: 'should disappear', + _children: ['child-a'], + selected: true, + variables: [], + }, + } + + ;(manager as any).syncNodes([], [deepClone(nodeWithPrivate)]) + + const stored = (manager.getNodes() as Node[]).find(node => node.id === 'private-node')! + expect((stored.data as any)._foo).toBeUndefined() + expect((stored.data as any)._children).toEqual(['child-a']) + expect((stored.data as any).selected).toBeUndefined() + }) + + it('removes list fields when they are omitted in the update snapshot', () => { + const baseNode = createNodeSnapshot(['alpha']) + ;(manager as any).syncNodes([], [deepClone(baseNode)]) + + const withoutVariables: Node = { + ...deepClone(baseNode), + data: { + ...deepClone(baseNode).data, + }, + } + delete (withoutVariables.data as any).variables + + ;(manager as any).syncNodes([deepClone(baseNode)], [withoutVariables]) + + const stored = (manager.getNodes() as Node[]).find(node => node.id === NODE_ID)! + expect((stored.data as any).variables).toBeUndefined() + }) + + it('treats non-array list inputs as empty lists during synchronization', () => { + const promptManager = new CollaborationManager() + const doc = new LoroDoc() + ;(promptManager as any).doc = doc + ;(promptManager as any).nodesMap = doc.getMap('nodes') + ;(promptManager as any).edgesMap = doc.getMap('edges') + + const nodeWithInvalidTemplate = createLLMNodeSnapshot([] as any) + ;(promptManager as any).syncNodes([], [deepClone(nodeWithInvalidTemplate)]) + + const mutated = deepClone(nodeWithInvalidTemplate) + ;(mutated.data as any).prompt_template = 'not-an-array' + + ;(promptManager as any).syncNodes([deepClone(nodeWithInvalidTemplate)], [mutated]) + + const stored = promptManager.getNodes().find(node => node.id === LLM_NODE_ID)! + expect(Array.isArray((stored.data as any).prompt_template)).toBe(true) + expect((stored.data as any).prompt_template).toHaveLength(0) + }) + + it('updates edges map when edges are added, modified, and removed', () => { + const edgeManager = new CollaborationManager() + const doc = new LoroDoc() + ;(edgeManager as any).doc = doc + ;(edgeManager as any).nodesMap = doc.getMap('nodes') + ;(edgeManager as any).edgesMap = doc.getMap('edges') + + const edge: Edge = { + id: 'edge-1', + source: 'node-a', + target: 'node-b', + type: 'default', + data: { label: 'initial' }, + } as Edge + + ;(edgeManager as any).setEdges([], [edge]) + expect(edgeManager.getEdges()).toHaveLength(1) + expect((edgeManager.getEdges()[0].data as any).label).toBe('initial') + + const updatedEdge: Edge = { + ...edge, + data: { label: 'updated' }, + } + ;(edgeManager as any).setEdges([edge], [updatedEdge]) + expect(edgeManager.getEdges()).toHaveLength(1) + expect((edgeManager.getEdges()[0].data as any).label).toBe('updated') + + ;(edgeManager as any).setEdges([updatedEdge], []) + expect(edgeManager.getEdges()).toHaveLength(0) + }) +}) + +describe('CollaborationManager public API wrappers', () => { + let manager: CollaborationManager + const baseNodes: Node[] = [] + const updatedNodes: Node[] = [ + { id: 'new-node', type: 'custom', position: { x: 0, y: 0 }, data: {} } as Node, + ] + const baseEdges: Edge[] = [] + const updatedEdges: Edge[] = [ + { id: 'edge-1', source: 'source', target: 'target', type: 'default', data: {} } as Edge, + ] + + beforeEach(() => { + manager = new CollaborationManager() + }) + + it('setNodes delegates to syncNodes and commits the CRDT document', () => { + const commit = jest.fn() + ;(manager as any).doc = { commit } + const syncSpy = jest.spyOn(manager as any, 'syncNodes').mockImplementation(() => undefined) + + manager.setNodes(baseNodes, updatedNodes) + + expect(syncSpy).toHaveBeenCalledWith(baseNodes, updatedNodes) + expect(commit).toHaveBeenCalled() + syncSpy.mockRestore() + }) + + it('setNodes skips syncing when undo/redo replay is running', () => { + const commit = jest.fn() + ;(manager as any).doc = { commit } + ;(manager as any).isUndoRedoInProgress = true + const syncSpy = jest.spyOn(manager as any, 'syncNodes').mockImplementation(() => undefined) + + manager.setNodes(baseNodes, updatedNodes) + + expect(syncSpy).not.toHaveBeenCalled() + expect(commit).not.toHaveBeenCalled() + syncSpy.mockRestore() + }) + + it('setEdges delegates to syncEdges and commits the CRDT document', () => { + const commit = jest.fn() + ;(manager as any).doc = { commit } + const syncSpy = jest.spyOn(manager as any, 'syncEdges').mockImplementation(() => undefined) + + manager.setEdges(baseEdges, updatedEdges) + + expect(syncSpy).toHaveBeenCalledWith(baseEdges, updatedEdges) + expect(commit).toHaveBeenCalled() + syncSpy.mockRestore() + }) + + it('disconnect tears down the collaboration state only when last connection closes', () => { + const forceSpy = jest.spyOn(manager as any, 'forceDisconnect').mockImplementation(() => undefined) + ;(manager as any).activeConnections.add('conn-a') + ;(manager as any).activeConnections.add('conn-b') + + manager.disconnect('conn-a') + expect(forceSpy).not.toHaveBeenCalled() + + manager.disconnect('conn-b') + expect(forceSpy).toHaveBeenCalledTimes(1) + forceSpy.mockRestore() + }) + + it('applyNodePanelPresenceUpdate keeps a client visible on a single node at a time', () => { + const updates: NodePanelPresenceMap[] = [] + manager.onNodePanelPresenceUpdate((presence) => { + updates.push(presence) + }) + + const user: NodePanelPresenceUser = { userId: 'user-1', username: 'Dana' } + + ;(manager as any).applyNodePanelPresenceUpdate({ + nodeId: 'node-a', + action: 'open', + user, + clientId: 'client-1', + timestamp: 100, + }) + + ;(manager as any).applyNodePanelPresenceUpdate({ + nodeId: 'node-b', + action: 'open', + user, + clientId: 'client-1', + timestamp: 200, + }) + + const finalSnapshot = updates[updates.length - 1]! + expect(finalSnapshot).toEqual({ + 'node-b': { + 'client-1': { + userId: 'user-1', + username: 'Dana', + clientId: 'client-1', + timestamp: 200, + }, + }, + }) + }) + + it('applyNodePanelPresenceUpdate clears node entries when last viewer closes the panel', () => { + const updates: NodePanelPresenceMap[] = [] + manager.onNodePanelPresenceUpdate((presence) => { + updates.push(presence) + }) + + const user: NodePanelPresenceUser = { userId: 'user-2', username: 'Kai' } + + ;(manager as any).applyNodePanelPresenceUpdate({ + nodeId: 'node-a', + action: 'open', + user, + clientId: 'client-9', + timestamp: 300, + }) + + ;(manager as any).applyNodePanelPresenceUpdate({ + nodeId: 'node-a', + action: 'close', + user, + clientId: 'client-9', + timestamp: 301, + }) + + expect(updates[updates.length - 1]).toEqual({}) + }) +}) diff --git a/web/app/components/workflow/collaboration/core/__tests__/crdt-provider.test.ts b/web/app/components/workflow/collaboration/core/__tests__/crdt-provider.test.ts new file mode 100644 index 0000000000..b5fc3ca69e --- /dev/null +++ b/web/app/components/workflow/collaboration/core/__tests__/crdt-provider.test.ts @@ -0,0 +1,121 @@ +import type { Socket } from 'socket.io-client' +import { CRDTProvider } from '../crdt-provider' + +type FakeDoc = { + export: jest.Mock + import: jest.Mock + subscribe: jest.Mock void]> + trigger: (event: any) => void +} + +const createFakeDoc = (): FakeDoc => { + let handler: ((payload: any) => void) | null = null + + return { + export: jest.fn(() => new Uint8Array([1, 2, 3])), + import: jest.fn(), + subscribe: jest.fn((cb: (payload: any) => void) => { + handler = cb + }), + trigger: (event: any) => { + handler?.(event) + }, + } +} + +const createMockSocket = () => { + const handlers = new Map void>() + + const socket: any = { + emit: jest.fn(), + on: jest.fn((event: string, handler: (...args: any[]) => void) => { + handlers.set(event, handler) + }), + off: jest.fn((event: string) => { + handlers.delete(event) + }), + trigger: (event: string, ...args: any[]) => { + const handler = handlers.get(event) + if (handler) + handler(...args) + }, + } + + return socket as Socket & { trigger: (event: string, ...args: any[]) => void } +} + +describe('CRDTProvider', () => { + it('emits graph_event when local changes happen', () => { + const doc = createFakeDoc() + const socket = createMockSocket() + + const provider = new CRDTProvider(socket, doc as unknown as any) + expect(provider).toBeInstanceOf(CRDTProvider) + + doc.trigger({ by: 'local' }) + + expect(socket.emit).toHaveBeenCalledWith( + 'graph_event', + expect.any(Uint8Array), + ) + expect(doc.export).toHaveBeenCalledWith({ mode: 'update' }) + }) + + it('ignores non-local events', () => { + const doc = createFakeDoc() + const socket = createMockSocket() + + const provider = new CRDTProvider(socket, doc as unknown as any) + + doc.trigger({ by: 'remote' }) + + expect(socket.emit).not.toHaveBeenCalled() + provider.destroy() + }) + + it('imports remote updates on graph_update', () => { + const doc = createFakeDoc() + const socket = createMockSocket() + + const provider = new CRDTProvider(socket, doc as unknown as any) + + const payload = new Uint8Array([9, 9, 9]) + socket.trigger('graph_update', payload) + + expect(doc.import).toHaveBeenCalledWith(expect.any(Uint8Array)) + expect(Array.from(doc.import.mock.calls[0][0])).toEqual([9, 9, 9]) + provider.destroy() + }) + + it('removes graph_update listener on destroy', () => { + const doc = createFakeDoc() + const socket = createMockSocket() + + const provider = new CRDTProvider(socket, doc as unknown as any) + provider.destroy() + + expect(socket.off).toHaveBeenCalledWith('graph_update') + }) + + it('logs an error when graph_update import fails but continues operating', () => { + const doc = createFakeDoc() + const socket = createMockSocket() + doc.import.mockImplementation(() => { + throw new Error('boom') + }) + + const provider = new CRDTProvider(socket, doc as unknown as any) + + const errorSpy = jest.spyOn(console, 'error').mockImplementation(() => undefined) + + socket.trigger('graph_update', new Uint8Array([1])) + expect(errorSpy).toHaveBeenCalledWith('Error importing graph update:', expect.any(Error)) + + doc.import.mockReset() + socket.trigger('graph_update', new Uint8Array([2, 3])) + expect(doc.import).toHaveBeenCalled() + + provider.destroy() + errorSpy.mockRestore() + }) +}) diff --git a/web/app/components/workflow/collaboration/core/__tests__/event-emitter.test.ts b/web/app/components/workflow/collaboration/core/__tests__/event-emitter.test.ts new file mode 100644 index 0000000000..b57350209e --- /dev/null +++ b/web/app/components/workflow/collaboration/core/__tests__/event-emitter.test.ts @@ -0,0 +1,93 @@ +import { EventEmitter } from '../event-emitter' + +describe('EventEmitter', () => { + it('registers and invokes handlers via on/emit', () => { + const emitter = new EventEmitter() + const handler = jest.fn() + + emitter.on('test', handler) + emitter.emit('test', { value: 42 }) + + expect(handler).toHaveBeenCalledWith({ value: 42 }) + }) + + it('removes specific handler with off', () => { + const emitter = new EventEmitter() + const handlerA = jest.fn() + const handlerB = jest.fn() + + emitter.on('test', handlerA) + emitter.on('test', handlerB) + + emitter.off('test', handlerA) + emitter.emit('test', 'payload') + + expect(handlerA).not.toHaveBeenCalled() + expect(handlerB).toHaveBeenCalledWith('payload') + }) + + it('clears all listeners when off is called without handler', () => { + const emitter = new EventEmitter() + const handlerA = jest.fn() + const handlerB = jest.fn() + + emitter.on('trigger', handlerA) + emitter.on('trigger', handlerB) + + emitter.off('trigger') + emitter.emit('trigger', 'payload') + + expect(handlerA).not.toHaveBeenCalled() + expect(handlerB).not.toHaveBeenCalled() + expect(emitter.getListenerCount('trigger')).toBe(0) + }) + + it('removeAllListeners clears every registered event', () => { + const emitter = new EventEmitter() + emitter.on('one', jest.fn()) + emitter.on('two', jest.fn()) + + emitter.removeAllListeners() + + expect(emitter.getListenerCount('one')).toBe(0) + expect(emitter.getListenerCount('two')).toBe(0) + }) + + it('returns an unsubscribe function from on', () => { + const emitter = new EventEmitter() + const handler = jest.fn() + + const unsubscribe = emitter.on('detach', handler) + unsubscribe() + + emitter.emit('detach', 'value') + + expect(handler).not.toHaveBeenCalled() + }) + + it('continues emitting when a handler throws', () => { + const emitter = new EventEmitter() + const errorHandler = jest + .spyOn(console, 'error') + .mockImplementation() + + const failingHandler = jest.fn(() => { + throw new Error('boom') + }) + const succeedingHandler = jest.fn() + + emitter.on('safe', failingHandler) + emitter.on('safe', succeedingHandler) + + emitter.emit('safe', 7) + + expect(failingHandler).toHaveBeenCalledWith(7) + expect(succeedingHandler).toHaveBeenCalledWith(7) + expect(errorHandler).toHaveBeenCalledWith( + expect.stringContaining('Error in event handler for safe:'), + expect.any(Error), + ) + + errorHandler.mockRestore() + }) +}) diff --git a/web/app/components/workflow/collaboration/core/__tests__/websocket-manager.test.ts b/web/app/components/workflow/collaboration/core/__tests__/websocket-manager.test.ts new file mode 100644 index 0000000000..c7b26d1bd3 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/__tests__/websocket-manager.test.ts @@ -0,0 +1,165 @@ +import type { Socket } from 'socket.io-client' + +const ioMock = jest.fn() + +jest.mock('socket.io-client', () => ({ + io: (...args: any[]) => ioMock(...args), +})) + +const createMockSocket = (id: string): Socket & { + trigger: (event: string, ...args: any[]) => void +} => { + const handlers = new Map void>() + + const socket: any = { + id, + connected: true, + emit: jest.fn(), + disconnect: jest.fn(() => { + socket.connected = false + }), + on: jest.fn((event: string, handler: (...args: any[]) => void) => { + handlers.set(event, handler) + }), + trigger: (event: string, ...args: any[]) => { + const handler = handlers.get(event) + if (handler) + handler(...args) + }, + } + + return socket as Socket & { trigger: (event: string, ...args: any[]) => void } +} + +describe('WebSocketClient', () => { + let originalWindow: typeof window | undefined + + beforeEach(() => { + jest.resetModules() + ioMock.mockReset() + originalWindow = globalThis.window + }) + + afterEach(() => { + if (originalWindow) + globalThis.window = originalWindow + else + delete (globalThis as any).window + }) + + it('connects with fallback url and registers base listeners when window is undefined', async () => { + delete (globalThis as any).window + + const mockSocket = createMockSocket('socket-fallback') + ioMock.mockImplementation(() => mockSocket) + + const { WebSocketClient } = await import('../websocket-manager') + const client = new WebSocketClient() + const socket = client.connect('app-1') + + expect(ioMock).toHaveBeenCalledWith( + 'ws://localhost:5001', + expect.objectContaining({ + path: '/socket.io', + transports: ['websocket'], + withCredentials: true, + }), + ) + expect(socket).toBe(mockSocket) + expect(mockSocket.on).toHaveBeenCalledWith('connect', expect.any(Function)) + expect(mockSocket.on).toHaveBeenCalledWith('disconnect', expect.any(Function)) + expect(mockSocket.on).toHaveBeenCalledWith('connect_error', expect.any(Function)) + }) + + it('reuses existing connected socket and avoids duplicate connections', async () => { + const mockSocket = createMockSocket('socket-reuse') + ioMock.mockImplementation(() => mockSocket) + + const { WebSocketClient } = await import('../websocket-manager') + const client = new WebSocketClient() + + const first = client.connect('app-reuse') + const second = client.connect('app-reuse') + + expect(ioMock).toHaveBeenCalledTimes(1) + expect(second).toBe(first) + }) + + it('attaches auth token from localStorage and emits user_connect on connect', async () => { + const mockSocket = createMockSocket('socket-auth') + ioMock.mockImplementation((url, options) => { + expect(options.auth).toEqual({ token: 'secret-token' }) + return mockSocket + }) + + globalThis.window = { + location: { protocol: 'https:', host: 'example.com' }, + localStorage: { + getItem: jest.fn(() => 'secret-token'), + }, + } as unknown as typeof window + + const { WebSocketClient } = await import('../websocket-manager') + const client = new WebSocketClient() + client.connect('app-auth') + + const connectHandler = mockSocket.on.mock.calls.find(call => call[0] === 'connect')?.[1] as () => void + expect(connectHandler).toBeDefined() + connectHandler() + + expect(mockSocket.emit).toHaveBeenCalledWith('user_connect', { workflow_id: 'app-auth' }) + }) + + it('disconnects a specific app and clears internal maps', async () => { + const mockSocket = createMockSocket('socket-disconnect-one') + ioMock.mockImplementation(() => mockSocket) + + const { WebSocketClient } = await import('../websocket-manager') + const client = new WebSocketClient() + client.connect('app-disconnect') + + expect(client.isConnected('app-disconnect')).toBe(true) + client.disconnect('app-disconnect') + + expect(mockSocket.disconnect).toHaveBeenCalled() + expect(client.getSocket('app-disconnect')).toBeNull() + expect(client.isConnected('app-disconnect')).toBe(false) + }) + + it('disconnects all apps when no id is provided', async () => { + const socketA = createMockSocket('socket-a') + const socketB = createMockSocket('socket-b') + ioMock.mockImplementationOnce(() => socketA).mockImplementationOnce(() => socketB) + + const { WebSocketClient } = await import('../websocket-manager') + const client = new WebSocketClient() + client.connect('app-a') + client.connect('app-b') + + client.disconnect() + + expect(socketA.disconnect).toHaveBeenCalled() + expect(socketB.disconnect).toHaveBeenCalled() + expect(client.getConnectedApps()).toEqual([]) + }) + + it('reports connected apps, sockets, and debug info correctly', async () => { + const socketA = createMockSocket('socket-debug-a') + const socketB = createMockSocket('socket-debug-b') + socketB.connected = false + ioMock.mockImplementationOnce(() => socketA).mockImplementationOnce(() => socketB) + + const { WebSocketClient } = await import('../websocket-manager') + const client = new WebSocketClient() + client.connect('app-a') + client.connect('app-b') + + expect(client.getConnectedApps()).toEqual(['app-a']) + + const debugInfo = client.getDebugInfo() + expect(debugInfo).toMatchObject({ + 'app-a': { connected: true, socketId: 'socket-debug-a' }, + 'app-b': { connected: false, socketId: 'socket-debug-b' }, + }) + }) +}) diff --git a/web/app/components/workflow/collaboration/core/collaboration-manager.ts b/web/app/components/workflow/collaboration/core/collaboration-manager.ts new file mode 100644 index 0000000000..669601e819 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/collaboration-manager.ts @@ -0,0 +1,1073 @@ +import { LoroDoc, LoroList, LoroMap, UndoManager } from 'loro-crdt' +import { cloneDeep, isEqual } from 'lodash-es' +import type { Socket } from 'socket.io-client' +import { emitWithAuthGuard, webSocketClient } from './websocket-manager' +import { CRDTProvider } from './crdt-provider' +import { EventEmitter } from './event-emitter' +import type { + CommonNodeType, + Edge, + Node, +} from '../../types' +import type { + CollaborationState, + CursorPosition, + NodePanelPresenceMap, + NodePanelPresenceUser, + OnlineUser, +} from '../types/collaboration' + +type NodePanelPresenceEventData = { + nodeId: string + action: 'open' | 'close' + user: NodePanelPresenceUser + clientId: string + timestamp?: number +} + +export class CollaborationManager { + private doc: LoroDoc | null = null + private undoManager: UndoManager | null = null + private provider: CRDTProvider | null = null + private nodesMap: any = null + private edgesMap: any = null + private eventEmitter = new EventEmitter() + private currentAppId: string | null = null + private reactFlowStore: any = null + private isLeader = false + private leaderId: string | null = null + private cursors: Record = {} + private nodePanelPresence: NodePanelPresenceMap = {} + private activeConnections = new Set() + private isUndoRedoInProgress = false + private pendingInitialSync = false + private rejoinInProgress = false + + private getActiveSocket(): Socket | null { + if (!this.currentAppId) + return null + return webSocketClient.getSocket(this.currentAppId) + } + + private handleSessionUnauthorized = (): void => { + if (this.rejoinInProgress) + return + if (!this.currentAppId) + return + + const socket = this.getActiveSocket() + if (!socket) + return + + this.rejoinInProgress = true + console.warn('Collaboration session expired, attempting to rejoin workflow.') + emitWithAuthGuard( + socket, + 'user_connect', + { workflow_id: this.currentAppId }, + { + onAck: () => { + this.rejoinInProgress = false + }, + onUnauthorized: () => { + this.rejoinInProgress = false + console.error('Rejoin failed due to authorization error, forcing disconnect.') + this.forceDisconnect() + }, + }, + ) + } + + private sendCollaborationEvent(payload: any): void { + const socket = this.getActiveSocket() + if (!socket) + return + + emitWithAuthGuard(socket, 'collaboration_event', payload, { onUnauthorized: this.handleSessionUnauthorized }) + } + + private sendGraphEvent(payload: any): void { + const socket = this.getActiveSocket() + if (!socket) + return + + emitWithAuthGuard(socket, 'graph_event', payload, { onUnauthorized: this.handleSessionUnauthorized }) + } + + private getNodeContainer(nodeId: string): LoroMap { + if (!this.nodesMap) + throw new Error('Nodes map not initialized') + + let container = this.nodesMap.get(nodeId) as any + + if (!container || typeof container.kind !== 'function' || container.kind() !== 'Map') { + const previousValue = container + const newContainer = this.nodesMap.setContainer(nodeId, new LoroMap()) + container = typeof newContainer.getAttached === 'function' ? newContainer.getAttached() ?? newContainer : newContainer + if (previousValue && typeof previousValue === 'object') + this.populateNodeContainer(container, previousValue as Node) + } + else { + container = typeof container.getAttached === 'function' ? container.getAttached() ?? container : container + } + + return container + } + + private ensureDataContainer(nodeContainer: LoroMap): LoroMap { + let dataContainer = nodeContainer.get('data') as any + + if (!dataContainer || typeof dataContainer.kind !== 'function' || dataContainer.kind() !== 'Map') + dataContainer = nodeContainer.setContainer('data', new LoroMap()) + + return typeof dataContainer.getAttached === 'function' ? dataContainer.getAttached() ?? dataContainer : dataContainer + } + + private ensureList(nodeContainer: LoroMap, key: string): LoroList { + const dataContainer = this.ensureDataContainer(nodeContainer) + let list = dataContainer.get(key) as any + + if (!list || typeof list.kind !== 'function' || list.kind() !== 'List') + list = dataContainer.setContainer(key, new LoroList()) + + return typeof list.getAttached === 'function' ? list.getAttached() ?? list : list + } + + private exportNode(nodeId: string): Node { + const container = this.getNodeContainer(nodeId) + const json = container.toJSON() as any + return { + ...json, + data: json.data || {}, + } + } + + private populateNodeContainer(container: LoroMap, node: Node): void { + const listFields = new Set(['variables', 'prompt_template', 'parameters']) + container.set('id', node.id) + container.set('type', node.type) + container.set('position', cloneDeep(node.position)) + container.set('sourcePosition', node.sourcePosition) + container.set('targetPosition', node.targetPosition) + + if (node.width === undefined) container.delete('width') + else container.set('width', node.width) + + if (node.height === undefined) container.delete('height') + else container.set('height', node.height) + + if (node.selected === undefined) container.delete('selected') + else container.set('selected', node.selected) + + const optionalProps: Array = [ + 'parentId', + 'positionAbsolute', + 'extent', + 'zIndex', + 'draggable', + 'selectable', + 'dragHandle', + 'dragging', + 'connectable', + 'expandParent', + 'focusable', + 'hidden', + 'style', + 'className', + 'ariaLabel', + 'resizing', + 'deletable', + ] + + optionalProps.forEach((prop) => { + const value = node[prop] + if (value === undefined) + container.delete(prop as string) + else + container.set(prop as string, cloneDeep(value as any)) + }) + + const dataContainer = this.ensureDataContainer(container) + const handledKeys = new Set() + + Object.entries(node.data || {}).forEach(([key, value]) => { + if (!this.shouldSyncDataKey(key)) return + handledKeys.add(key) + + if (listFields.has(key)) + this.syncList(container, key, Array.isArray(value) ? value : []) + else + dataContainer.set(key, cloneDeep(value)) + }) + + const existingData = dataContainer.toJSON() || {} + Object.keys(existingData).forEach((key) => { + if (!this.shouldSyncDataKey(key)) return + if (handledKeys.has(key)) return + + dataContainer.delete(key) + }) + } + + private shouldSyncDataKey(key: string): boolean { + const syncDataAllowList = new Set(['_children', '_connectedSourceHandleIds', '_connectedTargetHandleIds', '_targetBranches']) + return (syncDataAllowList.has(key) || !key.startsWith('_')) && key !== 'selected' + } + + private syncList(nodeContainer: LoroMap, key: string, desired: any[]): void { + const list = this.ensureList(nodeContainer, key) + const current = list.toJSON() as any[] + const target = Array.isArray(desired) ? desired : [] + const minLength = Math.min(current.length, target.length) + + for (let i = 0; i < minLength; i += 1) { + if (!isEqual(current[i], target[i])) { + list.delete(i, 1) + list.insert(i, cloneDeep(target[i])) + } + } + + if (current.length > target.length) { + list.delete(target.length, current.length - target.length) + } + else if (target.length > current.length) { + for (let i = current.length; i < target.length; i += 1) + list.insert(i, cloneDeep(target[i])) + } + } + + private getNodePanelPresenceSnapshot(): NodePanelPresenceMap { + const snapshot: NodePanelPresenceMap = {} + Object.entries(this.nodePanelPresence).forEach(([nodeId, viewers]) => { + snapshot[nodeId] = { ...viewers } + }) + return snapshot + } + + private applyNodePanelPresenceUpdate(update: NodePanelPresenceEventData): void { + const { nodeId, action, clientId, user, timestamp } = update + + if (action === 'open') { + // ensure a client only appears on a single node at a time + Object.entries(this.nodePanelPresence).forEach(([id, viewers]) => { + if (viewers[clientId]) { + delete viewers[clientId] + if (Object.keys(viewers).length === 0) + delete this.nodePanelPresence[id] + } + }) + + if (!this.nodePanelPresence[nodeId]) + this.nodePanelPresence[nodeId] = {} + + this.nodePanelPresence[nodeId][clientId] = { + ...user, + clientId, + timestamp: timestamp || Date.now(), + } + } + else { + const viewers = this.nodePanelPresence[nodeId] + if (viewers) { + delete viewers[clientId] + if (Object.keys(viewers).length === 0) + delete this.nodePanelPresence[nodeId] + } + } + + this.eventEmitter.emit('nodePanelPresence', this.getNodePanelPresenceSnapshot()) + } + + private cleanupNodePanelPresence(activeClientIds: Set, activeUserIds: Set): void { + let hasChanges = false + + Object.entries(this.nodePanelPresence).forEach(([nodeId, viewers]) => { + Object.keys(viewers).forEach((clientId) => { + const viewer = viewers[clientId] + const clientActive = activeClientIds.has(clientId) + const userActive = viewer?.userId ? activeUserIds.has(viewer.userId) : false + + if (!clientActive && !userActive) { + delete viewers[clientId] + hasChanges = true + } + }) + + if (Object.keys(viewers).length === 0) + delete this.nodePanelPresence[nodeId] + }) + + if (hasChanges) + this.eventEmitter.emit('nodePanelPresence', this.getNodePanelPresenceSnapshot()) + } + + init = (appId: string, reactFlowStore: any): void => { + if (!reactFlowStore) { + console.warn('CollaborationManager.init called without reactFlowStore, deferring to connect()') + return + } + this.connect(appId, reactFlowStore) + } + + setNodes = (oldNodes: Node[], newNodes: Node[]): void => { + if (!this.doc) return + + // Don't track operations during undo/redo to prevent loops + if (this.isUndoRedoInProgress) { + console.log('Skipping setNodes during undo/redo') + return + } + + console.log('Setting nodes with tracking') + this.syncNodes(oldNodes, newNodes) + this.doc.commit() + } + + setEdges = (oldEdges: Edge[], newEdges: Edge[]): void => { + if (!this.doc) return + + // Don't track operations during undo/redo to prevent loops + if (this.isUndoRedoInProgress) { + console.log('Skipping setEdges during undo/redo') + return + } + + console.log('Setting edges with tracking') + this.syncEdges(oldEdges, newEdges) + this.doc.commit() + } + + destroy = (): void => { + this.disconnect() + } + + async connect(appId: string, reactFlowStore?: any): Promise { + const connectionId = Math.random().toString(36).substring(2, 11) + + this.activeConnections.add(connectionId) + + if (this.currentAppId === appId && this.doc) { + // Already connected to the same app, only update store if provided and we don't have one + if (reactFlowStore && !this.reactFlowStore) + this.reactFlowStore = reactFlowStore + + return connectionId + } + + // Only disconnect if switching to a different app + if (this.currentAppId && this.currentAppId !== appId) + this.forceDisconnect() + + this.currentAppId = appId + // Only set store if provided + if (reactFlowStore) + this.reactFlowStore = reactFlowStore + + const socket = webSocketClient.connect(appId) + + // Setup event listeners BEFORE any other operations + this.setupSocketEventListeners(socket) + + this.doc = new LoroDoc() + this.nodesMap = this.doc.getMap('nodes') + this.edgesMap = this.doc.getMap('edges') + + // Initialize UndoManager for collaborative undo/redo + this.undoManager = new UndoManager(this.doc, { + maxUndoSteps: 100, + mergeInterval: 500, // Merge operations within 500ms + excludeOriginPrefixes: [], // Don't exclude anything - let UndoManager track all local operations + onPush: (isUndo, range, event) => { + console.log('UndoManager onPush:', { isUndo, range, event }) + // Store current selection state when an operation is pushed + const selectedNode = this.reactFlowStore?.getState().getNodes().find((n: Node) => n.data?.selected) + + // Emit event to update UI button states when new operation is pushed + setTimeout(() => { + this.eventEmitter.emit('undoRedoStateChange', { + canUndo: this.undoManager?.canUndo() || false, + canRedo: this.undoManager?.canRedo() || false, + }) + }, 0) + + return { + value: { + selectedNodeId: selectedNode?.id || null, + timestamp: Date.now(), + }, + cursors: [], + } + }, + onPop: (isUndo, value, counterRange) => { + console.log('UndoManager onPop:', { isUndo, value, counterRange }) + // Restore selection state when undoing/redoing + if (value?.value && typeof value.value === 'object' && 'selectedNodeId' in value.value && this.reactFlowStore) { + const selectedNodeId = (value.value as any).selectedNodeId + if (selectedNodeId) { + const { setNodes } = this.reactFlowStore.getState() + const nodes = this.reactFlowStore.getState().getNodes() + const newNodes = nodes.map((n: Node) => ({ + ...n, + data: { + ...n.data, + selected: n.id === selectedNodeId, + }, + })) + setNodes(newNodes) + } + } + }, + }) + + this.provider = new CRDTProvider(socket, this.doc, this.handleSessionUnauthorized) + + this.setupSubscriptions() + + // Force user_connect if already connected + if (socket.connected) + emitWithAuthGuard(socket, 'user_connect', { workflow_id: appId }, { onUnauthorized: this.handleSessionUnauthorized }) + + return connectionId + } + + disconnect = (connectionId?: string): void => { + if (connectionId) + this.activeConnections.delete(connectionId) + + // Only disconnect when no more connections + if (this.activeConnections.size === 0) + this.forceDisconnect() + } + + private forceDisconnect = (): void => { + if (this.currentAppId) + webSocketClient.disconnect(this.currentAppId) + + this.provider?.destroy() + this.undoManager = null + this.doc = null + this.provider = null + this.nodesMap = null + this.edgesMap = null + this.currentAppId = null + this.reactFlowStore = null + this.cursors = {} + this.nodePanelPresence = {} + this.isUndoRedoInProgress = false + this.rejoinInProgress = false + + // Only reset leader status when actually disconnecting + const wasLeader = this.isLeader + this.isLeader = false + this.leaderId = null + + if (wasLeader) + this.eventEmitter.emit('leaderChange', false) + + this.activeConnections.clear() + this.eventEmitter.removeAllListeners() + } + + isConnected(): boolean { + return this.currentAppId ? webSocketClient.isConnected(this.currentAppId) : false + } + + getNodes(): Node[] { + if (!this.nodesMap) return [] + return Array.from(this.nodesMap.keys()).map(id => this.exportNode(id as string)) + } + + getEdges(): Edge[] { + return this.edgesMap ? Array.from(this.edgesMap.values()) : [] + } + + emitCursorMove(position: CursorPosition): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return + + const socket = this.getActiveSocket() + if (!socket) + return + + this.sendCollaborationEvent({ + type: 'mouse_move', + userId: socket.id, + data: { x: position.x, y: position.y }, + timestamp: Date.now(), + }) + } + + emitSyncRequest(): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return + + console.log('Emitting sync request to leader') + this.sendCollaborationEvent({ + type: 'sync_request', + data: { timestamp: Date.now() }, + timestamp: Date.now(), + }) + } + + emitWorkflowUpdate(appId: string): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return + + console.log('Emitting Workflow update event') + this.sendCollaborationEvent({ + type: 'workflow_update', + data: { appId, timestamp: Date.now() }, + timestamp: Date.now(), + }) + } + + emitNodePanelPresence(nodeId: string, isOpen: boolean, user: NodePanelPresenceUser): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return + + const socket = this.getActiveSocket() + if (!socket || !nodeId || !user?.userId) return + + const payload: NodePanelPresenceEventData = { + nodeId, + action: isOpen ? 'open' : 'close', + user, + clientId: socket.id as string, + timestamp: Date.now(), + } + + this.sendCollaborationEvent({ + type: 'node_panel_presence', + data: payload, + timestamp: payload.timestamp, + }) + + this.applyNodePanelPresenceUpdate(payload) + } + + onSyncRequest(callback: () => void): () => void { + return this.eventEmitter.on('syncRequest', callback) + } + + onStateChange(callback: (state: Partial) => void): () => void { + return this.eventEmitter.on('stateChange', callback) + } + + onCursorUpdate(callback: (cursors: Record) => void): () => void { + return this.eventEmitter.on('cursors', callback) + } + + onOnlineUsersUpdate(callback: (users: OnlineUser[]) => void): () => void { + return this.eventEmitter.on('onlineUsers', callback) + } + + onWorkflowUpdate(callback: (update: { appId: string; timestamp: number }) => void): () => void { + return this.eventEmitter.on('workflowUpdate', callback) + } + + onVarsAndFeaturesUpdate(callback: (update: any) => void): () => void { + return this.eventEmitter.on('varsAndFeaturesUpdate', callback) + } + + onAppStateUpdate(callback: (update: any) => void): () => void { + return this.eventEmitter.on('appStateUpdate', callback) + } + + onAppPublishUpdate(callback: (update: any) => void): () => void { + return this.eventEmitter.on('appPublishUpdate', callback) + } + + onAppMetaUpdate(callback: (update: any) => void): () => void { + return this.eventEmitter.on('appMetaUpdate', callback) + } + + onMcpServerUpdate(callback: (update: any) => void): () => void { + return this.eventEmitter.on('mcpServerUpdate', callback) + } + + onNodePanelPresenceUpdate(callback: (presence: NodePanelPresenceMap) => void): () => void { + const off = this.eventEmitter.on('nodePanelPresence', callback) + callback(this.getNodePanelPresenceSnapshot()) + return off + } + + onLeaderChange(callback: (isLeader: boolean) => void): () => void { + return this.eventEmitter.on('leaderChange', callback) + } + + onCommentsUpdate(callback: (update: { appId: string; timestamp: number }) => void): () => void { + return this.eventEmitter.on('commentsUpdate', callback) + } + + emitCommentsUpdate(appId: string): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return + + console.log('Emitting Comments update event') + this.sendCollaborationEvent({ + type: 'comments_update', + data: { appId, timestamp: Date.now() }, + timestamp: Date.now(), + }) + } + + onUndoRedoStateChange(callback: (state: { canUndo: boolean; canRedo: boolean }) => void): () => void { + return this.eventEmitter.on('undoRedoStateChange', callback) + } + + getLeaderId(): string | null { + return this.leaderId + } + + getIsLeader(): boolean { + return this.isLeader + } + + // Collaborative undo/redo methods + undo(): boolean { + if (!this.undoManager) { + console.log('UndoManager not initialized') + return false + } + + const canUndo = this.undoManager.canUndo() + console.log('Can undo:', canUndo) + + if (canUndo) { + this.isUndoRedoInProgress = true + const result = this.undoManager.undo() + + // After undo, manually update React state from CRDT without triggering collaboration + if (result && this.reactFlowStore) { + requestAnimationFrame(() => { + // Get ReactFlow's native setters, not the collaborative ones + const state = this.reactFlowStore.getState() + const updatedNodes = Array.from(this.nodesMap.values()) + const updatedEdges = Array.from(this.edgesMap.values()) + console.log('Manually updating React state after undo') + + // Call ReactFlow's native setters directly to avoid triggering collaboration + state.setNodes(updatedNodes) + state.setEdges(updatedEdges) + + this.isUndoRedoInProgress = false + + // Emit event to update UI button states + this.eventEmitter.emit('undoRedoStateChange', { + canUndo: this.undoManager?.canUndo() || false, + canRedo: this.undoManager?.canRedo() || false, + }) + }) + } + else { + this.isUndoRedoInProgress = false + } + + console.log('Undo result:', result) + return result + } + + return false + } + + redo(): boolean { + if (!this.undoManager) { + console.log('RedoManager not initialized') + return false + } + + const canRedo = this.undoManager.canRedo() + console.log('Can redo:', canRedo) + + if (canRedo) { + this.isUndoRedoInProgress = true + const result = this.undoManager.redo() + + // After redo, manually update React state from CRDT without triggering collaboration + if (result && this.reactFlowStore) { + requestAnimationFrame(() => { + // Get ReactFlow's native setters, not the collaborative ones + const state = this.reactFlowStore.getState() + const updatedNodes = Array.from(this.nodesMap.values()) + const updatedEdges = Array.from(this.edgesMap.values()) + console.log('Manually updating React state after redo') + + // Call ReactFlow's native setters directly to avoid triggering collaboration + state.setNodes(updatedNodes) + state.setEdges(updatedEdges) + + this.isUndoRedoInProgress = false + + // Emit event to update UI button states + this.eventEmitter.emit('undoRedoStateChange', { + canUndo: this.undoManager?.canUndo() || false, + canRedo: this.undoManager?.canRedo() || false, + }) + }) + } + else { + this.isUndoRedoInProgress = false + } + + console.log('Redo result:', result) + return result + } + + return false + } + + canUndo(): boolean { + if (!this.undoManager) return false + return this.undoManager.canUndo() + } + + canRedo(): boolean { + if (!this.undoManager) return false + return this.undoManager.canRedo() + } + + clearUndoStack(): void { + if (!this.undoManager) return + this.undoManager.clear() + } + + debugLeaderStatus(): void { + console.log('=== Leader Status Debug ===') + console.log('Current leader status:', this.isLeader) + console.log('Current leader ID:', this.leaderId) + console.log('Active connections:', this.activeConnections.size) + console.log('Connected:', this.isConnected()) + console.log('Current app ID:', this.currentAppId) + console.log('Has ReactFlow store:', !!this.reactFlowStore) + console.log('========================') + } + + private syncNodes(oldNodes: Node[], newNodes: Node[]): void { + if (!this.nodesMap || !this.doc) return + + const newIdSet = new Set(newNodes.map(node => node.id)) + + oldNodes.forEach((oldNode) => { + if (!newIdSet.has(oldNode.id)) + this.nodesMap.delete(oldNode.id) + }) + + newNodes.forEach((newNode) => { + const nodeContainer = this.getNodeContainer(newNode.id) + this.populateNodeContainer(nodeContainer, newNode) + }) + } + + private syncEdges(oldEdges: Edge[], newEdges: Edge[]): void { + if (!this.edgesMap) return + + const oldEdgesMap = new Map(oldEdges.map(edge => [edge.id, edge])) + const newEdgesMap = new Map(newEdges.map(edge => [edge.id, edge])) + + oldEdges.forEach((oldEdge) => { + if (!newEdgesMap.has(oldEdge.id)) + this.edgesMap.delete(oldEdge.id) + }) + + newEdges.forEach((newEdge) => { + const oldEdge = oldEdgesMap.get(newEdge.id) + if (!oldEdge) { + const clonedEdge = cloneDeep(newEdge) + this.edgesMap.set(newEdge.id, clonedEdge) + } + else if (!isEqual(oldEdge, newEdge)) { + const clonedEdge = cloneDeep(newEdge) + this.edgesMap.set(newEdge.id, clonedEdge) + } + }) + } + + private setupSubscriptions(): void { + this.nodesMap?.subscribe((event: any) => { + console.log('nodesMap subscription event:', event) + if (event.by === 'import' && this.reactFlowStore) { + // Don't update React nodes during undo/redo to prevent loops + if (this.isUndoRedoInProgress) { + console.log('Skipping nodes subscription update during undo/redo') + return + } + + requestAnimationFrame(() => { + const state = this.reactFlowStore.getState() + const previousNodes: Node[] = state.getNodes() + const previousNodeMap = new Map(previousNodes.map(node => [node.id, node])) + const selectedIds = new Set( + previousNodes + .filter(node => node.data?.selected) + .map(node => node.id), + ) + + this.pendingInitialSync = false + + const updatedNodes = Array + .from(this.nodesMap.keys()) + .map((nodeId) => { + const node = this.exportNode(nodeId as string) + const clonedNode: Node = { + ...node, + data: { + ...(node.data || {}), + }, + } + const clonedNodeData = clonedNode.data as (CommonNodeType & Record) + // Keep the previous node's private data properties (starting with _) + const previousNode = previousNodeMap.get(clonedNode.id) + if (previousNode?.data) { + const previousData = previousNode.data as Record + Object.entries(previousData) + .filter(([key]) => key.startsWith('_')) + .forEach(([key, value]) => { + if (!(key in clonedNodeData)) + clonedNodeData[key] = value + }) + } + + if (selectedIds.has(clonedNode.id)) + clonedNode.data.selected = true + + return clonedNode + }) + + console.log('Updating React nodes from subscription') + + // Call ReactFlow's native setter directly to avoid triggering collaboration + state.setNodes(updatedNodes) + }) + } + }) + + this.edgesMap?.subscribe((event: any) => { + console.log('edgesMap subscription event:', event) + if (event.by === 'import' && this.reactFlowStore) { + // Don't update React edges during undo/redo to prevent loops + if (this.isUndoRedoInProgress) { + console.log('Skipping edges subscription update during undo/redo') + return + } + + requestAnimationFrame(() => { + // Get ReactFlow's native setters, not the collaborative ones + const state = this.reactFlowStore.getState() + const updatedEdges = Array.from(this.edgesMap.values()) + console.log('Updating React edges from subscription') + + this.pendingInitialSync = false + + // Call ReactFlow's native setter directly to avoid triggering collaboration + state.setEdges(updatedEdges) + }) + } + }) + } + + private setupSocketEventListeners(socket: any): void { + console.log('Setting up socket event listeners for collaboration') + + socket.on('collaboration_update', (update: any) => { + if (update.type === 'mouse_move') { + // Update cursor state for this user + this.cursors[update.userId] = { + x: update.data.x, + y: update.data.y, + userId: update.userId, + timestamp: update.timestamp, + } + + this.eventEmitter.emit('cursors', { ...this.cursors }) + } + else if (update.type === 'vars_and_features_update') { + console.log('Processing vars_and_features_update event:', update) + this.eventEmitter.emit('varsAndFeaturesUpdate', update) + } + else if (update.type === 'app_state_update') { + console.log('Processing app_state_update event:', update) + this.eventEmitter.emit('appStateUpdate', update) + } + else if (update.type === 'app_meta_update') { + console.log('Processing app_meta_update event:', update) + this.eventEmitter.emit('appMetaUpdate', update) + } + else if (update.type === 'app_publish_update') { + console.log('Processing app_publish_update event:', update) + this.eventEmitter.emit('appPublishUpdate', update) + } + else if (update.type === 'mcp_server_update') { + console.log('Processing mcp_server_update event:', update) + this.eventEmitter.emit('mcpServerUpdate', update) + } + else if (update.type === 'workflow_update') { + console.log('Processing workflow_update event:', update) + this.eventEmitter.emit('workflowUpdate', update.data) + } + else if (update.type === 'comments_update') { + console.log('Processing comments_update event:', update) + this.eventEmitter.emit('commentsUpdate', update.data) + } + else if (update.type === 'node_panel_presence') { + console.log('Processing node_panel_presence event:', update) + this.applyNodePanelPresenceUpdate(update.data as NodePanelPresenceEventData) + } + else if (update.type === 'sync_request') { + console.log('Received sync request from another user') + // Only process if we are the leader + if (this.isLeader) { + console.log('Leader received sync request, triggering sync') + this.eventEmitter.emit('syncRequest', {}) + } + } + else if (update.type === 'graph_resync_request') { + console.log('Received graph resync request from collaborator') + if (this.isLeader) + this.broadcastCurrentGraph() + } + }) + + socket.on('online_users', (data: { users: OnlineUser[]; leader?: string }) => { + try { + if (!data || !Array.isArray(data.users)) { + console.warn('Invalid online_users data structure:', data) + return + } + + const onlineUserIds = new Set(data.users.map((user: OnlineUser) => user.user_id)) + const onlineClientIds = new Set( + data.users + .map((user: OnlineUser) => user.sid) + .filter((sid): sid is string => typeof sid === 'string' && sid.length > 0), + ) + + // Remove cursors for offline users + Object.keys(this.cursors).forEach((userId) => { + if (!onlineUserIds.has(userId)) + delete this.cursors[userId] + }) + + this.cleanupNodePanelPresence(onlineClientIds, onlineUserIds) + + // Update leader information + if (data.leader && typeof data.leader === 'string') + this.leaderId = data.leader + + this.eventEmitter.emit('onlineUsers', data.users) + this.eventEmitter.emit('cursors', { ...this.cursors }) + } + catch (error) { + console.error('Error processing online_users update:', error) + } + }) + + socket.on('status', (data: any) => { + try { + if (!data || typeof data.isLeader !== 'boolean') { + console.warn('Invalid status data:', data) + return + } + + const wasLeader = this.isLeader + this.isLeader = data.isLeader + + if (this.isLeader) + this.pendingInitialSync = false + else + this.requestInitialSyncIfNeeded() + + if (wasLeader !== this.isLeader) + this.eventEmitter.emit('leaderChange', this.isLeader) + } + catch (error) { + console.error('Error processing status update:', error) + } + }) + + socket.on('status', (data: { isLeader: boolean }) => { + if (this.isLeader !== data.isLeader) { + this.isLeader = data.isLeader + console.log(`Collaboration: I am now the ${this.isLeader ? 'Leader' : 'Follower'}.`) + this.eventEmitter.emit('leaderChange', this.isLeader) + } + if (this.isLeader) + this.pendingInitialSync = false + else + this.requestInitialSyncIfNeeded() + }) + + socket.on('status', (data: { isLeader: boolean }) => { + if (this.isLeader !== data.isLeader) { + this.isLeader = data.isLeader + console.log(`Collaboration: I am now the ${this.isLeader ? 'Leader' : 'Follower'}.`) + this.eventEmitter.emit('leaderChange', this.isLeader) + } + if (this.isLeader) + this.pendingInitialSync = false + else + this.requestInitialSyncIfNeeded() + }) + + socket.on('connect', () => { + console.log('WebSocket connected successfully') + this.eventEmitter.emit('stateChange', { isConnected: true }) + this.pendingInitialSync = true + }) + + socket.on('disconnect', (reason: string) => { + console.log('WebSocket disconnected:', reason) + this.cursors = {} + this.isLeader = false + this.leaderId = null + this.pendingInitialSync = false + this.eventEmitter.emit('stateChange', { isConnected: false }) + this.eventEmitter.emit('cursors', {}) + }) + + socket.on('connect_error', (error: any) => { + console.error('WebSocket connection error:', error) + this.eventEmitter.emit('stateChange', { isConnected: false, error: error.message }) + }) + + socket.on('error', (error: any) => { + console.error('WebSocket error:', error) + }) + } + + // We currently only relay CRDT updates; the server doesn't persist them. + // When a follower joins mid-session, it might miss earlier broadcasts and render stale data. + // This lightweight checkpoint asks the leader to rebroadcast the latest graph snapshot once. + private requestInitialSyncIfNeeded(): void { + if (!this.pendingInitialSync) return + if (this.isLeader) { + this.pendingInitialSync = false + return + } + + this.emitGraphResyncRequest() + this.pendingInitialSync = false + } + + private emitGraphResyncRequest(): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return + + this.sendCollaborationEvent({ + type: 'graph_resync_request', + data: { timestamp: Date.now() }, + timestamp: Date.now(), + }) + } + + private broadcastCurrentGraph(): void { + if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return + if (!this.doc) return + + const socket = webSocketClient.getSocket(this.currentAppId) + if (!socket) return + + try { + const snapshot = this.doc.export({ mode: 'snapshot' }) + this.sendGraphEvent(snapshot) + } + catch (error) { + console.error('Failed to broadcast graph snapshot:', error) + } + } +} + +export const collaborationManager = new CollaborationManager() diff --git a/web/app/components/workflow/collaboration/core/crdt-provider.ts b/web/app/components/workflow/collaboration/core/crdt-provider.ts new file mode 100644 index 0000000000..fbe4b13e02 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/crdt-provider.ts @@ -0,0 +1,39 @@ +import type { LoroDoc } from 'loro-crdt' +import type { Socket } from 'socket.io-client' +import { emitWithAuthGuard } from './websocket-manager' + +export class CRDTProvider { + private doc: LoroDoc + private socket: Socket + private onUnauthorized?: () => void + + constructor(socket: Socket, doc: LoroDoc, onUnauthorized?: () => void) { + this.socket = socket + this.doc = doc + this.onUnauthorized = onUnauthorized + this.setupEventListeners() + } + + private setupEventListeners(): void { + this.doc.subscribe((event: any) => { + if (event.by === 'local') { + const update = this.doc.export({ mode: 'update' }) + emitWithAuthGuard(this.socket, 'graph_event', update, { onUnauthorized: this.onUnauthorized }) + } + }) + + this.socket.on('graph_update', (updateData: Uint8Array) => { + try { + const data = new Uint8Array(updateData) + this.doc.import(data) + } + catch (error) { + console.error('Error importing graph update:', error) + } + }) + } + + destroy(): void { + this.socket.off('graph_update') + } +} diff --git a/web/app/components/workflow/collaboration/core/event-emitter.ts b/web/app/components/workflow/collaboration/core/event-emitter.ts new file mode 100644 index 0000000000..250b344b12 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/event-emitter.ts @@ -0,0 +1,49 @@ +export type EventHandler = (data: T) => void + +export class EventEmitter { + private events: Map> = new Map() + + on(event: string, handler: EventHandler): () => void { + if (!this.events.has(event)) + this.events.set(event, new Set()) + + this.events.get(event)!.add(handler) + + return () => this.off(event, handler) + } + + off(event: string, handler?: EventHandler): void { + if (!this.events.has(event)) return + + const handlers = this.events.get(event)! + if (handler) + handlers.delete(handler) + else + handlers.clear() + + if (handlers.size === 0) + this.events.delete(event) + } + + emit(event: string, data: T): void { + if (!this.events.has(event)) return + + const handlers = this.events.get(event)! + handlers.forEach((handler) => { + try { + handler(data) + } + catch (error) { + console.error(`Error in event handler for ${event}:`, error) + } + }) + } + + removeAllListeners(): void { + this.events.clear() + } + + getListenerCount(event: string): number { + return this.events.get(event)?.size || 0 + } +} diff --git a/web/app/components/workflow/collaboration/core/websocket-manager.ts b/web/app/components/workflow/collaboration/core/websocket-manager.ts new file mode 100644 index 0000000000..bef68e5269 --- /dev/null +++ b/web/app/components/workflow/collaboration/core/websocket-manager.ts @@ -0,0 +1,168 @@ +import type { Socket } from 'socket.io-client' +import { io } from 'socket.io-client' +import { ACCESS_TOKEN_LOCAL_STORAGE_NAME } from '@/config' +import type { DebugInfo, WebSocketConfig } from '../types/websocket' + +const isUnauthorizedAck = (...ackArgs: any[]): boolean => { + const [first, second] = ackArgs + + if (second === 401 || first === 401) + return true + + if (first && typeof first === 'object' && first.msg === 'unauthorized') + return true + + return false +} + +export type EmitAckOptions = { + onAck?: (...ackArgs: any[]) => void + onUnauthorized?: (...ackArgs: any[]) => void +} + +export const emitWithAuthGuard = ( + socket: Socket | null | undefined, + event: string, + payload: any, + options?: EmitAckOptions, +): void => { + if (!socket) + return + + socket.emit( + event, + payload, + (...ackArgs: any[]) => { + options?.onAck?.(...ackArgs) + if (isUnauthorizedAck(...ackArgs)) + options?.onUnauthorized?.(...ackArgs) + }, + ) +} + +export class WebSocketClient { + private connections: Map = new Map() + private connecting: Set = new Set() + private config: WebSocketConfig + + constructor(config: WebSocketConfig = {}) { + const inferUrl = () => { + if (typeof window === 'undefined') + return 'ws://localhost:5001' + const scheme = window.location.protocol === 'https:' ? 'wss:' : 'ws:' + return `${scheme}//${window.location.host}` + } + this.config = { + url: config.url || process.env.NEXT_PUBLIC_SOCKET_URL || inferUrl(), + transports: config.transports || ['websocket'], + withCredentials: config.withCredentials !== false, + ...config, + } + } + + connect(appId: string): Socket { + const existingSocket = this.connections.get(appId) + if (existingSocket?.connected) + return existingSocket + + if (this.connecting.has(appId)) { + const pendingSocket = this.connections.get(appId) + if (pendingSocket) + return pendingSocket + } + + if (existingSocket && !existingSocket.connected) { + existingSocket.disconnect() + this.connections.delete(appId) + } + + this.connecting.add(appId) + + const authToken = typeof window === 'undefined' + ? undefined + : window.localStorage.getItem(ACCESS_TOKEN_LOCAL_STORAGE_NAME) ?? undefined + + const socketOptions: { + path: string + transports: WebSocketConfig['transports'] + withCredentials?: boolean + auth?: { token: string } + } = { + path: '/socket.io', + transports: this.config.transports, + withCredentials: this.config.withCredentials, + } + + if (authToken) + socketOptions.auth = { token: authToken } + + const socket = io(this.config.url!, socketOptions) + + this.connections.set(appId, socket) + this.setupBaseEventListeners(socket, appId) + + return socket + } + + disconnect(appId?: string): void { + if (appId) { + const socket = this.connections.get(appId) + if (socket) { + socket.disconnect() + this.connections.delete(appId) + this.connecting.delete(appId) + } + } + else { + this.connections.forEach(socket => socket.disconnect()) + this.connections.clear() + this.connecting.clear() + } + } + + getSocket(appId: string): Socket | null { + return this.connections.get(appId) || null + } + + isConnected(appId: string): boolean { + return this.connections.get(appId)?.connected || false + } + + getConnectedApps(): string[] { + const connectedApps: string[] = [] + this.connections.forEach((socket, appId) => { + if (socket.connected) + connectedApps.push(appId) + }) + return connectedApps + } + + getDebugInfo(): DebugInfo { + const info: DebugInfo = {} + this.connections.forEach((socket, appId) => { + info[appId] = { + connected: socket.connected, + connecting: this.connecting.has(appId), + socketId: socket.id, + } + }) + return info + } + + private setupBaseEventListeners(socket: Socket, appId: string): void { + socket.on('connect', () => { + this.connecting.delete(appId) + emitWithAuthGuard(socket, 'user_connect', { workflow_id: appId }) + }) + + socket.on('disconnect', () => { + this.connecting.delete(appId) + }) + + socket.on('connect_error', () => { + this.connecting.delete(appId) + }) + } +} + +export const webSocketClient = new WebSocketClient() diff --git a/web/app/components/workflow/collaboration/hooks/use-collaboration.ts b/web/app/components/workflow/collaboration/hooks/use-collaboration.ts new file mode 100644 index 0000000000..3aec92a2e6 --- /dev/null +++ b/web/app/components/workflow/collaboration/hooks/use-collaboration.ts @@ -0,0 +1,119 @@ +import { useEffect, useRef, useState } from 'react' +import type { ReactFlowInstance } from 'reactflow' +import { collaborationManager } from '../core/collaboration-manager' +import { CursorService } from '../services/cursor-service' +import type { CollaborationState } from '../types/collaboration' +import { useGlobalPublicStore } from '@/context/global-public-context' + +export function useCollaboration(appId: string, reactFlowStore?: any) { + const [state, setState] = useState>({ + isConnected: false, + onlineUsers: [], + cursors: {}, + nodePanelPresence: {}, + isLeader: false, + }) + + const cursorServiceRef = useRef(null) + const isCollaborationEnabled = useGlobalPublicStore(s => s.systemFeatures.enable_collaboration_mode) + + useEffect(() => { + if (!appId || !isCollaborationEnabled) { + setState({ + isConnected: false, + onlineUsers: [], + cursors: {}, + nodePanelPresence: {}, + isLeader: false, + }) + return + } + + let connectionId: string | null = null + let isUnmounted = false + + if (!cursorServiceRef.current) + cursorServiceRef.current = new CursorService() + + const initCollaboration = async () => { + try { + const id = await collaborationManager.connect(appId, reactFlowStore) + if (isUnmounted) { + collaborationManager.disconnect(id) + return + } + connectionId = id + setState((prev: any) => ({ ...prev, appId, isConnected: collaborationManager.isConnected() })) + } + catch (error) { + console.error('Failed to initialize collaboration:', error) + } + } + + initCollaboration() + + const unsubscribeStateChange = collaborationManager.onStateChange((newState: any) => { + console.log('Collaboration state change:', newState) + setState((prev: any) => ({ ...prev, ...newState })) + }) + + const unsubscribeCursors = collaborationManager.onCursorUpdate((cursors: any) => { + setState((prev: any) => ({ ...prev, cursors })) + }) + + const unsubscribeUsers = collaborationManager.onOnlineUsersUpdate((users: any) => { + console.log('Online users update:', users) + setState((prev: any) => ({ ...prev, onlineUsers: users })) + }) + + const unsubscribeNodePanelPresence = collaborationManager.onNodePanelPresenceUpdate((presence) => { + setState((prev: any) => ({ ...prev, nodePanelPresence: presence })) + }) + + const unsubscribeLeaderChange = collaborationManager.onLeaderChange((isLeader: boolean) => { + console.log('Leader status changed:', isLeader) + setState((prev: any) => ({ ...prev, isLeader })) + }) + + return () => { + isUnmounted = true + unsubscribeStateChange() + unsubscribeCursors() + unsubscribeUsers() + unsubscribeNodePanelPresence() + unsubscribeLeaderChange() + cursorServiceRef.current?.stopTracking() + if (connectionId) + collaborationManager.disconnect(connectionId) + } + }, [appId, reactFlowStore, isCollaborationEnabled]) + + const startCursorTracking = (containerRef: React.RefObject, reactFlowInstance?: ReactFlowInstance) => { + if (!isCollaborationEnabled || !cursorServiceRef.current) + return + + if (cursorServiceRef.current) { + cursorServiceRef.current.startTracking(containerRef, (position) => { + collaborationManager.emitCursorMove(position) + }, reactFlowInstance) + } + } + + const stopCursorTracking = () => { + cursorServiceRef.current?.stopTracking() + } + + const result = { + isConnected: state.isConnected || false, + onlineUsers: state.onlineUsers || [], + cursors: state.cursors || {}, + nodePanelPresence: state.nodePanelPresence || {}, + isLeader: state.isLeader || false, + leaderId: collaborationManager.getLeaderId(), + isEnabled: isCollaborationEnabled, + startCursorTracking, + stopCursorTracking, + } + + return result +} diff --git a/web/app/components/workflow/collaboration/index.ts b/web/app/components/workflow/collaboration/index.ts new file mode 100644 index 0000000000..89cbb3aa43 --- /dev/null +++ b/web/app/components/workflow/collaboration/index.ts @@ -0,0 +1,5 @@ +export { collaborationManager } from './core/collaboration-manager' +export { webSocketClient } from './core/websocket-manager' +export { CursorService } from './services/cursor-service' +export { useCollaboration } from './hooks/use-collaboration' +export * from './types' diff --git a/web/app/components/workflow/collaboration/services/cursor-service.ts b/web/app/components/workflow/collaboration/services/cursor-service.ts new file mode 100644 index 0000000000..3fa976a90d --- /dev/null +++ b/web/app/components/workflow/collaboration/services/cursor-service.ts @@ -0,0 +1,88 @@ +import type { RefObject } from 'react' +import type { CursorPosition } from '../types/collaboration' +import type { ReactFlowInstance } from 'reactflow' + +const CURSOR_MIN_MOVE_DISTANCE = 10 +const CURSOR_THROTTLE_MS = 500 + +export class CursorService { + private containerRef: RefObject | null = null + private reactFlowInstance: ReactFlowInstance | null = null + private isTracking = false + private onCursorUpdate: ((cursors: Record) => void) | null = null + private onEmitPosition: ((position: CursorPosition) => void) | null = null + private lastEmitTime = 0 + private lastPosition: { x: number; y: number } | null = null + + startTracking( + containerRef: RefObject, + onEmitPosition: (position: CursorPosition) => void, + reactFlowInstance?: ReactFlowInstance, + ): void { + if (this.isTracking) this.stopTracking() + + this.containerRef = containerRef + this.onEmitPosition = onEmitPosition + this.reactFlowInstance = reactFlowInstance || null + this.isTracking = true + + if (containerRef.current) + containerRef.current.addEventListener('mousemove', this.handleMouseMove) + } + + stopTracking(): void { + if (this.containerRef?.current) + this.containerRef.current.removeEventListener('mousemove', this.handleMouseMove) + + this.containerRef = null + this.reactFlowInstance = null + this.onEmitPosition = null + this.isTracking = false + this.lastPosition = null + } + + setCursorUpdateHandler(handler: (cursors: Record) => void): void { + this.onCursorUpdate = handler + } + + updateCursors(cursors: Record): void { + if (this.onCursorUpdate) + this.onCursorUpdate(cursors) + } + + private handleMouseMove = (event: MouseEvent): void => { + if (!this.containerRef?.current || !this.onEmitPosition) return + + const rect = this.containerRef.current.getBoundingClientRect() + let x = event.clientX - rect.left + let y = event.clientY - rect.top + + // Transform coordinates to ReactFlow world coordinates if ReactFlow instance is available + if (this.reactFlowInstance) { + const viewport = this.reactFlowInstance.getViewport() + // Convert screen coordinates to world coordinates + // World coordinates = (screen coordinates - viewport translation) / zoom + x = (x - viewport.x) / viewport.zoom + y = (y - viewport.y) / viewport.zoom + } + + // Always emit cursor position (remove boundary check since world coordinates can be negative) + const now = Date.now() + const timeThrottled = now - this.lastEmitTime > CURSOR_THROTTLE_MS + const minDistance = CURSOR_MIN_MOVE_DISTANCE / (this.reactFlowInstance?.getZoom() || 1) + const distanceThrottled = !this.lastPosition + || (Math.abs(x - this.lastPosition.x) > minDistance) + || (Math.abs(y - this.lastPosition.y) > minDistance) + + if (timeThrottled && distanceThrottled) { + this.lastPosition = { x, y } + this.lastEmitTime = now + this.onEmitPosition({ + x, + y, + userId: '', + timestamp: now, + }) + } + } +} diff --git a/web/app/components/workflow/collaboration/types/collaboration.ts b/web/app/components/workflow/collaboration/types/collaboration.ts new file mode 100644 index 0000000000..4e74d8fbda --- /dev/null +++ b/web/app/components/workflow/collaboration/types/collaboration.ts @@ -0,0 +1,57 @@ +import type { Edge, Node } from '../../types' + +export type OnlineUser = { + user_id: string + username: string + avatar: string + sid: string +} + +export type WorkflowOnlineUsers = { + workflow_id: string + users: OnlineUser[] +} + +export type OnlineUserListResponse = { + data: WorkflowOnlineUsers[] +} + +export type CursorPosition = { + x: number + y: number + userId: string + timestamp: number +} + +export type NodePanelPresenceUser = { + userId: string + username: string + avatar?: string | null +} + +export type NodePanelPresenceInfo = NodePanelPresenceUser & { + clientId: string + timestamp: number +} + +export type NodePanelPresenceMap = Record> + +export type CollaborationState = { + appId: string + isConnected: boolean + onlineUsers: OnlineUser[] + cursors: Record + nodePanelPresence: NodePanelPresenceMap +} + +export type GraphSyncData = { + nodes: Node[] + edges: Edge[] +} + +export type CollaborationUpdate = { + type: 'mouse_move' | 'vars_and_features_update' | 'sync_request' | 'app_state_update' | 'app_meta_update' | 'mcp_server_update' | 'workflow_update' | 'comments_update' | 'node_panel_presence' | 'app_publish_update' + userId: string + data: any + timestamp: number +} diff --git a/web/app/components/workflow/collaboration/types/events.ts b/web/app/components/workflow/collaboration/types/events.ts new file mode 100644 index 0000000000..e995f9e876 --- /dev/null +++ b/web/app/components/workflow/collaboration/types/events.ts @@ -0,0 +1,38 @@ +export type CollaborationEvent = { + type: string + data: any + timestamp: number +} + +export type GraphUpdateEvent = { + type: 'graph_update' + data: Uint8Array +} & CollaborationEvent + +export type CursorMoveEvent = { + type: 'cursor_move' + data: { + x: number + y: number + userId: string + } +} & CollaborationEvent + +export type UserConnectEvent = { + type: 'user_connect' + data: { + workflow_id: string + } +} & CollaborationEvent + +export type OnlineUsersEvent = { + type: 'online_users' + data: { + users: Array<{ + user_id: string + username: string + avatar: string + sid: string + }> + } +} & CollaborationEvent diff --git a/web/app/components/workflow/collaboration/types/index.ts b/web/app/components/workflow/collaboration/types/index.ts new file mode 100644 index 0000000000..e79ed35da0 --- /dev/null +++ b/web/app/components/workflow/collaboration/types/index.ts @@ -0,0 +1,3 @@ +export * from './websocket' +export * from './collaboration' +export * from './events' diff --git a/web/app/components/workflow/collaboration/types/websocket.ts b/web/app/components/workflow/collaboration/types/websocket.ts new file mode 100644 index 0000000000..37b3a8ad17 --- /dev/null +++ b/web/app/components/workflow/collaboration/types/websocket.ts @@ -0,0 +1,16 @@ +export type WebSocketConfig = { + url?: string + token?: string + transports?: string[] + withCredentials?: boolean +} + +export type ConnectionInfo = { + connected: boolean + connecting: boolean + socketId?: string +} + +export type DebugInfo = { + [appId: string]: ConnectionInfo +} diff --git a/web/app/components/workflow/collaboration/utils/user-color.ts b/web/app/components/workflow/collaboration/utils/user-color.ts new file mode 100644 index 0000000000..51aee6a038 --- /dev/null +++ b/web/app/components/workflow/collaboration/utils/user-color.ts @@ -0,0 +1,12 @@ +/** + * Generate a consistent color for a user based on their ID + * Used for cursor colors and avatar backgrounds + */ +export const getUserColor = (id: string): string => { + const colors = ['#155AEF', '#0BA5EC', '#444CE7', '#7839EE', '#4CA30D', '#0E9384', '#DD2590', '#FF4405', '#D92D20', '#F79009', '#828DAD'] + const hash = id.split('').reduce((a, b) => { + a = ((a << 5) - a) + b.charCodeAt(0) + return a & a + }, 0) + return colors[Math.abs(hash) % colors.length] +} diff --git a/web/app/components/workflow/comment-manager.tsx b/web/app/components/workflow/comment-manager.tsx new file mode 100644 index 0000000000..e5817eb6b1 --- /dev/null +++ b/web/app/components/workflow/comment-manager.tsx @@ -0,0 +1,34 @@ +import { useEventListener } from 'ahooks' +import { useWorkflowStore } from './store' +import { useWorkflowComment } from './hooks/use-workflow-comment' + +const CommentManager = () => { + const workflowStore = useWorkflowStore() + const { handleCreateComment, handleCommentCancel } = useWorkflowComment() + + useEventListener('click', (e) => { + const { controlMode, mousePosition, pendingComment } = workflowStore.getState() + + if (controlMode === 'comment') { + const target = e.target as HTMLElement + const isInDropdown = target.closest('[data-mention-dropdown]') + const isInCommentInput = target.closest('[data-comment-input]') + const isOnCanvasPane = target.closest('.react-flow__pane') + + // Only when clicking on the React Flow canvas pane (background), + // and not inside comment input or its dropdown + if (!isInDropdown && !isInCommentInput && isOnCanvasPane) { + e.preventDefault() + e.stopPropagation() + if (pendingComment) + handleCommentCancel() + else + handleCreateComment(mousePosition) + } + } + }) + + return null +} + +export default CommentManager diff --git a/web/app/components/workflow/comment/comment-icon.tsx b/web/app/components/workflow/comment/comment-icon.tsx new file mode 100644 index 0000000000..f2b3a785b1 --- /dev/null +++ b/web/app/components/workflow/comment/comment-icon.tsx @@ -0,0 +1,265 @@ +'use client' + +import type { FC, PointerEvent as ReactPointerEvent } from 'react' +import { memo, useCallback, useMemo, useRef, useState } from 'react' +import { useReactFlow, useViewport } from 'reactflow' +import { UserAvatarList } from '@/app/components/base/user-avatar-list' +import CommentPreview from './comment-preview' +import type { WorkflowCommentList } from '@/service/workflow-comment' +import { useAppContext } from '@/context/app-context' + +type CommentIconProps = { + comment: WorkflowCommentList + onClick: () => void + isActive?: boolean + onPositionUpdate?: (position: { x: number; y: number }) => void +} + +export const CommentIcon: FC = memo(({ comment, onClick, isActive = false, onPositionUpdate }) => { + const { flowToScreenPosition, screenToFlowPosition } = useReactFlow() + const viewport = useViewport() + const { userProfile } = useAppContext() + const isAuthor = comment.created_by_account?.id === userProfile?.id + const [showPreview, setShowPreview] = useState(false) + const [dragPosition, setDragPosition] = useState<{ x: number; y: number } | null>(null) + const [isDragging, setIsDragging] = useState(false) + const dragStateRef = useRef<{ + offsetX: number + offsetY: number + startX: number + startY: number + hasMoved: boolean + } | null>(null) + + const workflowContainerRect = typeof document !== 'undefined' + ? document.getElementById('workflow-container')?.getBoundingClientRect() + : null + const containerLeft = workflowContainerRect?.left ?? 0 + const containerTop = workflowContainerRect?.top ?? 0 + + const screenPosition = useMemo(() => { + return flowToScreenPosition({ + x: comment.position_x, + y: comment.position_y, + }) + }, [comment.position_x, comment.position_y, viewport.x, viewport.y, viewport.zoom, flowToScreenPosition]) + + const effectiveScreenPosition = dragPosition ?? screenPosition + const canvasPosition = useMemo(() => ({ + x: effectiveScreenPosition.x - containerLeft, + y: effectiveScreenPosition.y - containerTop, + }), [effectiveScreenPosition.x, effectiveScreenPosition.y, containerLeft, containerTop]) + const cursorClass = useMemo(() => { + if (!isAuthor) + return 'cursor-pointer' + if (isActive) + return isDragging ? 'cursor-grabbing' : '' + return isDragging ? 'cursor-grabbing' : 'cursor-pointer' + }, [isActive, isAuthor, isDragging]) + + const handlePointerDown = useCallback((event: ReactPointerEvent) => { + if (event.button !== 0) + return + + event.stopPropagation() + event.preventDefault() + + if (!isAuthor) { + if (event.currentTarget.dataset.role !== 'comment-preview') + setShowPreview(false) + return + } + + dragStateRef.current = { + offsetX: event.clientX - screenPosition.x, + offsetY: event.clientY - screenPosition.y, + startX: event.clientX, + startY: event.clientY, + hasMoved: false, + } + + setDragPosition(screenPosition) + setIsDragging(false) + + if (event.currentTarget.dataset.role !== 'comment-preview') + setShowPreview(false) + + if (event.currentTarget.setPointerCapture) + event.currentTarget.setPointerCapture(event.pointerId) + }, [isAuthor, screenPosition]) + + const handlePointerMove = useCallback((event: ReactPointerEvent) => { + const dragState = dragStateRef.current + if (!dragState) + return + + event.stopPropagation() + event.preventDefault() + + const nextX = event.clientX - dragState.offsetX + const nextY = event.clientY - dragState.offsetY + + if (!dragState.hasMoved) { + const distance = Math.hypot(event.clientX - dragState.startX, event.clientY - dragState.startY) + if (distance > 4) { + dragState.hasMoved = true + setIsDragging(true) + } + } + + setDragPosition({ x: nextX, y: nextY }) + }, []) + + const finishDrag = useCallback((event: ReactPointerEvent) => { + const dragState = dragStateRef.current + if (!dragState) + return false + + if (event.currentTarget.hasPointerCapture?.(event.pointerId)) + event.currentTarget.releasePointerCapture(event.pointerId) + + dragStateRef.current = null + setDragPosition(null) + setIsDragging(false) + return dragState.hasMoved + }, []) + + const handlePointerUp = useCallback((event: ReactPointerEvent) => { + event.stopPropagation() + event.preventDefault() + + const finalScreenPosition = dragPosition ?? screenPosition + const didDrag = finishDrag(event) + + setShowPreview(false) + + if (didDrag) { + if (onPositionUpdate) { + const flowPosition = screenToFlowPosition({ + x: finalScreenPosition.x, + y: finalScreenPosition.y, + }) + onPositionUpdate(flowPosition) + } + } + else if (!isActive) { + onClick() + } + }, [dragPosition, finishDrag, isActive, onClick, onPositionUpdate, screenPosition, screenToFlowPosition]) + + const handlePointerCancel = useCallback((event: ReactPointerEvent) => { + event.stopPropagation() + event.preventDefault() + finishDrag(event) + }, [finishDrag]) + + const handleMouseEnter = useCallback(() => { + if (isActive || isDragging) + return + setShowPreview(true) + }, [isActive, isDragging]) + + const handleMouseLeave = useCallback(() => { + setShowPreview(false) + }, []) + + const participants = useMemo(() => { + const list = comment.participants ?? [] + const author = comment.created_by_account + if (!author) + return [...list] + const rest = list.filter(user => user.id !== author.id) + return [author, ...rest] + }, [comment.created_by_account, comment.participants]) + + // Calculate dynamic width based on number of participants + const participantCount = participants.length + const maxVisible = Math.min(3, participantCount) + const showCount = participantCount > 3 + const avatarSize = 24 + const avatarSpacing = 4 // -space-x-1 is about 4px overlap + + // Width calculation: first avatar + (additional avatars * (size - spacing)) + padding + const dynamicWidth = Math.max(40, // minimum width + 8 + avatarSize + Math.max(0, (showCount ? 2 : maxVisible - 1)) * (avatarSize - avatarSpacing) + 8, + ) + + const pointerEventHandlers = useMemo(() => ({ + onPointerDown: handlePointerDown, + onPointerMove: handlePointerMove, + onPointerUp: handlePointerUp, + onPointerCancel: handlePointerCancel, + }), [handlePointerCancel, handlePointerDown, handlePointerMove, handlePointerUp]) + + return ( + <> +
+
+
+
+
+ +
+
+
+
+
+ + {/* Preview panel */} + {showPreview && !isActive && ( +
setShowPreview(true)} + onMouseLeave={() => setShowPreview(false)} + > + { + setShowPreview(false) + onClick() + }} /> +
+ )} + + ) +}, (prevProps, nextProps) => { + return ( + prevProps.comment.id === nextProps.comment.id + && prevProps.comment.position_x === nextProps.comment.position_x + && prevProps.comment.position_y === nextProps.comment.position_y + && prevProps.onClick === nextProps.onClick + && prevProps.isActive === nextProps.isActive + && prevProps.onPositionUpdate === nextProps.onPositionUpdate + ) +}) + +CommentIcon.displayName = 'CommentIcon' diff --git a/web/app/components/workflow/comment/comment-input.tsx b/web/app/components/workflow/comment/comment-input.tsx new file mode 100644 index 0000000000..f8c96a0595 --- /dev/null +++ b/web/app/components/workflow/comment/comment-input.tsx @@ -0,0 +1,87 @@ +import type { FC } from 'react' +import { memo, useCallback, useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Avatar from '@/app/components/base/avatar' +import { useAppContext } from '@/context/app-context' +import { MentionInput } from './mention-input' +import cn from '@/utils/classnames' + +type CommentInputProps = { + position: { x: number; y: number } + onSubmit: (content: string, mentionedUserIds: string[]) => void + onCancel: () => void +} + +export const CommentInput: FC = memo(({ position, onSubmit, onCancel }) => { + const [content, setContent] = useState('') + const { t } = useTranslation() + const { userProfile } = useAppContext() + + useEffect(() => { + const handleGlobalKeyDown = (e: KeyboardEvent) => { + if (e.key === 'Escape') { + e.preventDefault() + e.stopPropagation() + onCancel() + } + } + + document.addEventListener('keydown', handleGlobalKeyDown, true) + return () => { + document.removeEventListener('keydown', handleGlobalKeyDown, true) + } + }, [onCancel]) + + const handleMentionSubmit = useCallback((content: string, mentionedUserIds: string[]) => { + onSubmit(content, mentionedUserIds) + setContent('') + }, [onSubmit]) + + return ( +
+
+
+
+
+
+
+ +
+
+
+
+
+
+
+ +
+
+
+
+ ) +}) + +CommentInput.displayName = 'CommentInput' diff --git a/web/app/components/workflow/comment/comment-preview.tsx b/web/app/components/workflow/comment/comment-preview.tsx new file mode 100644 index 0000000000..be43434dd5 --- /dev/null +++ b/web/app/components/workflow/comment/comment-preview.tsx @@ -0,0 +1,59 @@ +'use client' + +import type { FC } from 'react' +import { memo, useEffect, useMemo } from 'react' +import { UserAvatarList } from '@/app/components/base/user-avatar-list' +import type { WorkflowCommentList } from '@/service/workflow-comment' +import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now' +import { useStore } from '../store' + +type CommentPreviewProps = { + comment: WorkflowCommentList + onClick?: () => void +} + +const CommentPreview: FC = ({ comment, onClick }) => { + const { formatTimeFromNow } = useFormatTimeFromNow() + const setCommentPreviewHovering = useStore(s => s.setCommentPreviewHovering) + const participants = useMemo(() => { + const list = comment.participants ?? [] + const author = comment.created_by_account + if (!author) + return [...list] + const rest = list.filter(user => user.id !== author.id) + return [author, ...rest] + }, [comment.created_by_account, comment.participants]) + useEffect(() => () => { + setCommentPreviewHovering(false) + }, [setCommentPreviewHovering]) + + return ( +
setCommentPreviewHovering(true)} + onMouseLeave={() => setCommentPreviewHovering(false)} + > +
+ +
+ +
+
+
{comment.created_by_account.name}
+
+ {formatTimeFromNow(comment.updated_at * 1000)} +
+
+
+ +
{comment.content}
+
+ ) +} + +export default memo(CommentPreview) diff --git a/web/app/components/workflow/comment/cursor.tsx b/web/app/components/workflow/comment/cursor.tsx new file mode 100644 index 0000000000..56b5c24f16 --- /dev/null +++ b/web/app/components/workflow/comment/cursor.tsx @@ -0,0 +1,28 @@ +import type { FC } from 'react' +import { memo } from 'react' +import { useStore } from '../store' +import { ControlMode } from '../types' +import { Comment } from '@/app/components/base/icons/src/public/other' + +export const CommentCursor: FC = memo(() => { + const controlMode = useStore(s => s.controlMode) + const mousePosition = useStore(s => s.mousePosition) + + if (controlMode !== ControlMode.Comment) + return null + + return ( +
+ +
+ ) +}) + +CommentCursor.displayName = 'CommentCursor' diff --git a/web/app/components/workflow/comment/index.tsx b/web/app/components/workflow/comment/index.tsx new file mode 100644 index 0000000000..f80ddac5bf --- /dev/null +++ b/web/app/components/workflow/comment/index.tsx @@ -0,0 +1,5 @@ +export { CommentCursor } from './cursor' +export { CommentInput } from './comment-input' +export { CommentIcon } from './comment-icon' +export { CommentThread } from './thread' +export { MentionInput } from './mention-input' diff --git a/web/app/components/workflow/comment/mention-input.tsx b/web/app/components/workflow/comment/mention-input.tsx new file mode 100644 index 0000000000..18810e9ad1 --- /dev/null +++ b/web/app/components/workflow/comment/mention-input.tsx @@ -0,0 +1,649 @@ +'use client' + +import type { ReactNode } from 'react' +import { + forwardRef, + memo, + useCallback, + useEffect, + useImperativeHandle, + useLayoutEffect, + useMemo, + useRef, + useState, +} from 'react' +import { createPortal } from 'react-dom' +import { useParams } from 'next/navigation' +import { useTranslation } from 'react-i18next' +import { RiArrowUpLine, RiAtLine, RiLoader2Line } from '@remixicon/react' +import Textarea from 'react-textarea-autosize' +import Button from '@/app/components/base/button' +import Avatar from '@/app/components/base/avatar' +import cn from '@/utils/classnames' +import { type UserProfile, fetchMentionableUsers } from '@/service/workflow-comment' +import { useStore, useWorkflowStore } from '../store' +import { EnterKey } from '@/app/components/base/icons/src/public/common' + +type MentionInputProps = { + value: string + onChange: (value: string) => void + onSubmit: (content: string, mentionedUserIds: string[]) => void + onCancel?: () => void + placeholder?: string + disabled?: boolean + loading?: boolean + className?: string + isEditing?: boolean + autoFocus?: boolean +} + +const MentionInputInner = forwardRef(({ + value, + onChange, + onSubmit, + onCancel, + placeholder, + disabled = false, + loading = false, + className, + isEditing = false, + autoFocus = false, +}, forwardedRef) => { + const params = useParams() + const { t } = useTranslation() + const appId = params.appId as string + const textareaRef = useRef(null) + const highlightContentRef = useRef(null) + const actionContainerRef = useRef(null) + const actionRightRef = useRef(null) + const baseTextareaHeightRef = useRef(null) + + // Expose textarea ref to parent component + useImperativeHandle(forwardedRef, () => textareaRef.current!, []) + + const workflowStore = useWorkflowStore() + const mentionUsersFromStore = useStore(state => ( + appId ? state.mentionableUsersCache[appId] : undefined + )) + const mentionUsers = mentionUsersFromStore ?? [] + + const [showMentionDropdown, setShowMentionDropdown] = useState(false) + const [mentionQuery, setMentionQuery] = useState('') + const [mentionPosition, setMentionPosition] = useState(0) + const [selectedMentionIndex, setSelectedMentionIndex] = useState(0) + const [mentionedUserIds, setMentionedUserIds] = useState([]) + const resolvedPlaceholder = placeholder ?? t('workflow.comments.placeholder.add') + const BASE_PADDING = 4 + const [shouldReserveButtonGap, setShouldReserveButtonGap] = useState(isEditing) + const [shouldReserveHorizontalSpace, setShouldReserveHorizontalSpace] = useState(() => !isEditing) + const [paddingRight, setPaddingRight] = useState(() => BASE_PADDING + (isEditing ? 0 : 48)) + const [paddingBottom, setPaddingBottom] = useState(() => BASE_PADDING + (isEditing ? 32 : 0)) + + const mentionNameList = useMemo(() => { + const names = mentionUsers + .map(user => user.name?.trim()) + .filter((name): name is string => Boolean(name)) + + const uniqueNames = Array.from(new Set(names)) + uniqueNames.sort((a, b) => b.length - a.length) + return uniqueNames + }, [mentionUsers]) + + const highlightedValue = useMemo(() => { + if (!value) + return '' + + if (mentionNameList.length === 0) + return value + + const segments: ReactNode[] = [] + let cursor = 0 + let hasMention = false + + while (cursor < value.length) { + let nextMatchStart = -1 + let matchedName = '' + + for (const name of mentionNameList) { + const searchStart = value.indexOf(`@${name}`, cursor) + if (searchStart === -1) + continue + + const previousChar = searchStart > 0 ? value[searchStart - 1] : '' + if (searchStart > 0 && !/\s/.test(previousChar)) + continue + + if ( + nextMatchStart === -1 + || searchStart < nextMatchStart + || (searchStart === nextMatchStart && name.length > matchedName.length) + ) { + nextMatchStart = searchStart + matchedName = name + } + } + + if (nextMatchStart === -1) + break + + if (nextMatchStart > cursor) + segments.push({value.slice(cursor, nextMatchStart)}) + + const mentionEnd = nextMatchStart + matchedName.length + 1 + segments.push( + + {value.slice(nextMatchStart, mentionEnd)} + , + ) + + hasMention = true + cursor = mentionEnd + } + + if (!hasMention) + return value + + if (cursor < value.length) + segments.push({value.slice(cursor)}) + + return segments + }, [value, mentionNameList]) + + const loadMentionableUsers = useCallback(async () => { + if (!appId) + return + + const state = workflowStore.getState() + if (state.mentionableUsersCache[appId] !== undefined) + return + + if (state.mentionableUsersLoading[appId]) + return + + state.setMentionableUsersLoading(appId, true) + try { + const users = await fetchMentionableUsers(appId) + workflowStore.getState().setMentionableUsersCache(appId, users) + } + catch (error) { + console.error('Failed to load mentionable users:', error) + } + finally { + workflowStore.getState().setMentionableUsersLoading(appId, false) + } + }, [appId, workflowStore]) + + useEffect(() => { + loadMentionableUsers() + }, [loadMentionableUsers]) + const syncHighlightScroll = useCallback(() => { + const textarea = textareaRef.current + const highlightContent = highlightContentRef.current + if (!textarea || !highlightContent) + return + + const { scrollTop, scrollLeft } = textarea + highlightContent.style.transform = `translate(${-scrollLeft}px, ${-scrollTop}px)` + }, []) + + const evaluateContentLayout = useCallback(() => { + const textarea = textareaRef.current + if (!textarea) + return + + const extraBottom = Math.max(0, paddingBottom - BASE_PADDING) + const effectiveClientHeight = textarea.clientHeight - extraBottom + + if (baseTextareaHeightRef.current === null) + baseTextareaHeightRef.current = effectiveClientHeight + + const baseHeight = baseTextareaHeightRef.current ?? effectiveClientHeight + const hasMultiline = effectiveClientHeight > baseHeight + 1 + const shouldReserveVertical = isEditing ? true : hasMultiline + + setShouldReserveButtonGap(shouldReserveVertical) + setShouldReserveHorizontalSpace(!hasMultiline) + }, [isEditing, paddingBottom]) + + const updateLayoutPadding = useCallback(() => { + const actionEl = actionContainerRef.current + const rect = actionEl?.getBoundingClientRect() + const rightRect = actionRightRef.current?.getBoundingClientRect() + let actionWidth = 0 + if (rightRect) + actionWidth = Math.ceil(rightRect.width) + else if (rect) + actionWidth = Math.ceil(rect.width) + + const actionHeight = rect ? Math.ceil(rect.height) : 0 + const fallbackWidth = Math.max(0, paddingRight - BASE_PADDING) + const fallbackHeight = Math.max(0, paddingBottom - BASE_PADDING) + const effectiveWidth = actionWidth > 0 ? actionWidth : fallbackWidth + const effectiveHeight = actionHeight > 0 ? actionHeight : fallbackHeight + + const nextRight = BASE_PADDING + (shouldReserveHorizontalSpace ? effectiveWidth : 0) + const nextBottom = BASE_PADDING + (shouldReserveButtonGap ? effectiveHeight : 0) + + setPaddingRight(prev => (prev === nextRight ? prev : nextRight)) + setPaddingBottom(prev => (prev === nextBottom ? prev : nextBottom)) + }, [shouldReserveButtonGap, shouldReserveHorizontalSpace, paddingRight, paddingBottom]) + + const setActionContainerRef = useCallback((node: HTMLDivElement | null) => { + actionContainerRef.current = node + + if (!isEditing) + actionRightRef.current = node + else if (!node) + actionRightRef.current = null + + if (node && typeof window !== 'undefined') + window.requestAnimationFrame(() => updateLayoutPadding()) + }, [isEditing, updateLayoutPadding]) + + const setActionRightRef = useCallback((node: HTMLDivElement | null) => { + actionRightRef.current = node + + if (node && typeof window !== 'undefined') + window.requestAnimationFrame(() => updateLayoutPadding()) + }, [updateLayoutPadding]) + + useLayoutEffect(() => { + syncHighlightScroll() + }, [value, syncHighlightScroll]) + + useLayoutEffect(() => { + evaluateContentLayout() + }, [value, evaluateContentLayout]) + + useLayoutEffect(() => { + updateLayoutPadding() + }, [updateLayoutPadding, isEditing, shouldReserveButtonGap]) + + useEffect(() => { + const handleResize = () => { + evaluateContentLayout() + updateLayoutPadding() + } + + window.addEventListener('resize', handleResize) + return () => window.removeEventListener('resize', handleResize) + }, [evaluateContentLayout, updateLayoutPadding]) + + useEffect(() => { + baseTextareaHeightRef.current = null + evaluateContentLayout() + setShouldReserveHorizontalSpace(!isEditing) + }, [isEditing, evaluateContentLayout]) + + const filteredMentionUsers = useMemo(() => { + if (!mentionQuery) return mentionUsers + return mentionUsers.filter(user => + user.name.toLowerCase().includes(mentionQuery.toLowerCase()) + || user.email.toLowerCase().includes(mentionQuery.toLowerCase()), + ) + }, [mentionUsers, mentionQuery]) + + const shouldDisableMentionButton = useMemo(() => { + if (showMentionDropdown) + return true + + const textarea = textareaRef.current + if (!textarea) + return false + + const cursorPosition = textarea.selectionStart || 0 + const textBeforeCursor = value.slice(0, cursorPosition) + return /@\w*$/.test(textBeforeCursor) + }, [showMentionDropdown, value]) + + const dropdownPosition = useMemo(() => { + if (!showMentionDropdown || !textareaRef.current) + return { x: 0, y: 0, placement: 'bottom' as const } + + const textareaRect = textareaRef.current.getBoundingClientRect() + const dropdownHeight = 160 // max-h-40 = 10rem = 160px + const viewportHeight = window.innerHeight + const spaceBelow = viewportHeight - textareaRect.bottom + const spaceAbove = textareaRect.top + + const shouldPlaceAbove = spaceBelow < dropdownHeight && spaceAbove > spaceBelow + + return { + x: textareaRect.left, + y: shouldPlaceAbove ? textareaRect.top - 4 : textareaRect.bottom + 4, + placement: shouldPlaceAbove ? 'top' as const : 'bottom' as const, + } + }, [showMentionDropdown]) + + const handleContentChange = useCallback((newValue: string) => { + onChange(newValue) + + setTimeout(() => { + const cursorPosition = textareaRef.current?.selectionStart || 0 + const textBeforeCursor = newValue.slice(0, cursorPosition) + const mentionMatch = textBeforeCursor.match(/@(\w*)$/) + + if (mentionMatch) { + setMentionQuery(mentionMatch[1]) + setMentionPosition(cursorPosition - mentionMatch[0].length) + setShowMentionDropdown(true) + setSelectedMentionIndex(0) + } + else { + setShowMentionDropdown(false) + } + + if (typeof window !== 'undefined') { + window.requestAnimationFrame(() => { + evaluateContentLayout() + syncHighlightScroll() + }) + } + }, 0) + }, [onChange, evaluateContentLayout, syncHighlightScroll]) + + const handleMentionButtonClick = useCallback((e: React.MouseEvent) => { + e.preventDefault() + e.stopPropagation() + + const textarea = textareaRef.current + if (!textarea) + return + + const cursorPosition = textarea.selectionStart || 0 + const textBeforeCursor = value.slice(0, cursorPosition) + + if (showMentionDropdown) + return + + if (/@\w*$/.test(textBeforeCursor)) + return + + const newContent = `${value.slice(0, cursorPosition)}@${value.slice(cursorPosition)}` + + onChange(newContent) + + setTimeout(() => { + const newCursorPos = cursorPosition + 1 + textarea.setSelectionRange(newCursorPos, newCursorPos) + textarea.focus() + + setMentionQuery('') + setMentionPosition(cursorPosition) + setShowMentionDropdown(true) + setSelectedMentionIndex(0) + + if (typeof window !== 'undefined') { + window.requestAnimationFrame(() => { + evaluateContentLayout() + syncHighlightScroll() + }) + } + }, 0) + }, [value, onChange, evaluateContentLayout, syncHighlightScroll, showMentionDropdown]) + + const insertMention = useCallback((user: UserProfile) => { + const textarea = textareaRef.current + if (!textarea) return + + const beforeMention = value.slice(0, mentionPosition) + const afterMention = value.slice(textarea.selectionStart || 0) + + const needsSpaceBefore = mentionPosition > 0 && !/\s/.test(value[mentionPosition - 1]) + const prefix = needsSpaceBefore ? ' ' : '' + const newContent = `${beforeMention}${prefix}@${user.name} ${afterMention}` + + onChange(newContent) + setShowMentionDropdown(false) + + const newMentionedUserIds = [...mentionedUserIds, user.id] + setMentionedUserIds(newMentionedUserIds) + + setTimeout(() => { + const extraSpace = needsSpaceBefore ? 1 : 0 + const newCursorPos = mentionPosition + extraSpace + user.name.length + 2 // (space) + @ + name + space + textarea.setSelectionRange(newCursorPos, newCursorPos) + textarea.focus() + if (typeof window !== 'undefined') { + window.requestAnimationFrame(() => { + evaluateContentLayout() + syncHighlightScroll() + }) + } + }, 0) + }, [value, mentionPosition, onChange, mentionedUserIds, evaluateContentLayout, syncHighlightScroll]) + + const handleSubmit = useCallback(async (e?: React.MouseEvent) => { + if (e) { + e.preventDefault() + e.stopPropagation() + } + + if (value.trim()) { + try { + await onSubmit(value.trim(), mentionedUserIds) + setMentionedUserIds([]) + setShowMentionDropdown(false) + } + catch (error) { + console.error('Failed to submit', error) + } + } + }, [value, mentionedUserIds, onSubmit]) + + const handleKeyDown = useCallback((e: React.KeyboardEvent) => { + // Ignore key events during IME composition (e.g., Chinese, Japanese input) + if (e.nativeEvent.isComposing) + return + + if (showMentionDropdown) { + if (e.key === 'ArrowDown') { + e.preventDefault() + setSelectedMentionIndex(prev => + prev < filteredMentionUsers.length - 1 ? prev + 1 : 0, + ) + } + else if (e.key === 'ArrowUp') { + e.preventDefault() + setSelectedMentionIndex(prev => + prev > 0 ? prev - 1 : filteredMentionUsers.length - 1, + ) + } + else if (e.key === 'Enter') { + e.preventDefault() + if (filteredMentionUsers[selectedMentionIndex]) + insertMention(filteredMentionUsers[selectedMentionIndex]) + + return + } + else if (e.key === 'Escape') { + e.preventDefault() + setShowMentionDropdown(false) + return + } + } + + if (e.key === 'Enter' && !e.shiftKey && !showMentionDropdown) { + e.preventDefault() + handleSubmit() + } + }, [showMentionDropdown, filteredMentionUsers, selectedMentionIndex, insertMention, handleSubmit]) + + const resetMentionState = useCallback(() => { + setMentionedUserIds([]) + setShowMentionDropdown(false) + setMentionQuery('') + setMentionPosition(0) + setSelectedMentionIndex(0) + }, []) + + useEffect(() => { + if (!value) + resetMentionState() + }, [value, resetMentionState]) + + useEffect(() => { + if (autoFocus && textareaRef.current) { + const textarea = textareaRef.current + setTimeout(() => { + textarea.focus() + const length = textarea.value.length + textarea.setSelectionRange(length, length) + }, 0) + } + }, [autoFocus]) + + return ( + <> +
+
+
+ {highlightedValue} + {'​'} +
+
+