Merge remote-tracking branch 'origin/deploy/dev' into feat/memory-orchestration-be-dev-env

# Conflicts:
#	api/models/__init__.py
#	api/uv.lock
This commit is contained in:
Stream 2025-10-11 15:01:26 +08:00
commit 1a4600ce77
No known key found for this signature in database
GPG Key ID: 033728094B100D70
141 changed files with 11908 additions and 3514 deletions

View File

@ -1,3 +1,4 @@
import os
import sys
@ -8,10 +9,16 @@ def is_db_command():
# create app
celery = None
flask_app = None
socketio_app = None
if is_db_command():
from app_factory import create_migrations_app
app = create_migrations_app()
socketio_app = app
flask_app = app
else:
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
@ -33,8 +40,15 @@ else:
from app_factory import create_app
app = create_app()
celery = app.extensions["celery"]
socketio_app, flask_app = create_app()
app = flask_app
celery = flask_app.extensions["celery"]
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", 5001))
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
server.serve_forever()

View File

@ -31,14 +31,22 @@ def create_flask_app_with_configs() -> DifyApp:
return dify_app
def create_app() -> DifyApp:
def create_app() -> tuple[any, DifyApp]:
start_time = time.perf_counter()
app = create_flask_app_with_configs()
initialize_extensions(app)
import socketio
from extensions.ext_socketio import sio
sio.app = app
socketio_app = socketio.WSGIApp(sio, app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
return app
return socketio_app, app
def initialize_extensions(app: DifyApp):

View File

@ -836,6 +836,16 @@ class MailConfig(BaseSettings):
default=None,
)
ENABLE_TRIAL_APP: bool = Field(
description="Enable trial app",
default=False,
)
ENABLE_EXPLORE_BANNER: bool = Field(
description="Enable explore banner",
default=False,
)
class RagEtlConfig(BaseSettings):
"""

View File

@ -8,6 +8,11 @@ class HostedCreditConfig(BaseSettings):
default="",
)
HOSTED_POOL_CREDITS: int = Field(
description="Pool credits for hosted service",
default=200,
)
def get_model_credits(self, model_name: str) -> int:
"""
Get credit value for a specific model name.
@ -70,11 +75,6 @@ class HostedOpenAiConfig(BaseSettings):
"text-davinci-003",
)
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted OpenAI service usage",
default=200,
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted OpenAI service",
default=False,
@ -98,6 +98,129 @@ class HostedOpenAiConfig(BaseSettings):
)
class HostedGeminiConfig(BaseSettings):
"""
Configuration for fetching Gemini service
"""
HOSTED_GEMINI_API_KEY: str | None = Field(
description="API key for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_API_BASE: str | None = Field(
description="Base URL for hosted Gemini API",
default=None,
)
HOSTED_GEMINI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Gemini service",
default=False,
)
HOSTED_GEMINI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted gemini service",
default=False,
)
HOSTED_GEMINI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
class HostedXAIConfig(BaseSettings):
"""
Configuration for fetching XAI service
"""
HOSTED_XAI_API_KEY: str | None = Field(
description="API key for hosted XAI service",
default=None,
)
HOSTED_XAI_API_BASE: str | None = Field(
description="Base URL for hosted XAI API",
default=None,
)
HOSTED_XAI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted XAI service",
default=None,
)
HOSTED_XAI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted XAI service",
default=False,
)
HOSTED_XAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
HOSTED_XAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_XAI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedDeepseekConfig(BaseSettings):
"""
Configuration for fetching Deepseek service
"""
HOSTED_DEEPSEEK_API_KEY: str | None = Field(
description="API key for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_API_BASE: str | None = Field(
description="Base URL for hosted Deepseek API",
default=None,
)
HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Deepseek service",
default=False,
)
HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="deepseek-chat,deepseek-reasoner",
)
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedAzureOpenAiConfig(BaseSettings):
"""
Configuration for hosted Azure OpenAI service
@ -144,16 +267,32 @@ class HostedAnthropicConfig(BaseSettings):
default=False,
)
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted Anthropic service usage",
default=600000,
)
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
default=False,
)
HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
class HostedMinmaxConfig(BaseSettings):
"""
@ -250,5 +389,8 @@ class HostedServiceConfig(
HostedModerationConfig,
# credit config
HostedCreditConfig,
HostedGeminiConfig,
HostedXAIConfig,
HostedDeepseekConfig,
):
pass

View File

@ -58,11 +58,13 @@ from .app import (
mcp_server,
message,
model_config,
online_user,
ops_trace,
site,
statistic,
workflow,
workflow_app_log,
workflow_comment,
workflow_draft_variable,
workflow_run,
workflow_statistic,
@ -106,10 +108,12 @@ from .datasets.rag_pipeline import (
# Import explore controllers
from .explore import (
banner,
installed_app,
parameter,
recommended_app,
saved_message,
trial,
)
# Import tag controllers
@ -143,6 +147,7 @@ __all__ = [
"apikey",
"app",
"audio",
"banner",
"billing",
"bp",
"completion",
@ -196,6 +201,7 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
"trial",
"version",
"website",
"workflow",

View File

@ -15,7 +15,7 @@ from constants.languages import supported_language
from controllers.console import api, console_ns
from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
def admin_required(view: Callable[P, R]):
@ -61,6 +61,8 @@ class InsertExploreAppListApi(Resource):
"language": fields.String(required=True, description="Language code"),
"category": fields.String(required=True, description="App category"),
"position": fields.Integer(required=True, description="Display position"),
"can_trial": fields.Boolean(required=True, description="Can trial"),
"trial_limit": fields.Integer(required=True, description="Trial limit"),
},
)
)
@ -79,6 +81,8 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
parser.add_argument("can_trial", type=bool, required=True, nullable=False, location="json")
parser.add_argument("trial_limit", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
@ -115,6 +119,20 @@ class InsertExploreAppListApi(Resource):
)
db.session.add(recommended_app)
if args["can_trial"]:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == args["app_id"])
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=args["app_id"],
tenant_id=app.tenant_id,
trial_limit=args["trial_limit"],
)
)
else:
trial_app.trial_limit = args["trial_limit"]
app.is_public = True
db.session.commit()
@ -129,6 +147,20 @@ class InsertExploreAppListApi(Resource):
recommended_app.category = args["category"]
recommended_app.position = args["position"]
if args["can_trial"]:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == args["app_id"])
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=args["app_id"],
tenant_id=app.tenant_id,
trial_limit=args["trial_limit"],
)
)
else:
trial_app.trial_limit = args["trial_limit"]
app.is_public = True
db.session.commit()
@ -174,7 +206,67 @@ class InsertExploreAppApi(Resource):
for installed_app in installed_apps:
session.delete(installed_app)
trial_app = session.execute(
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
).scalar_one_or_none()
if trial_app:
session.delete(trial_app)
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
@console_ns.route("/admin/insert-explore-banner")
class InsertExploreBanner(Resource):
@api.doc("insert_explore_banner")
@api.doc(description="Insert an explore banner")
@api.expect(
api.model(
"InsertExploreBannerRequest",
{
"content": fields.String(required=True, description="Banner content"),
"link": fields.String(required=True, description="Banner link"),
"sort": fields.Integer(required=True, description="Banner sort"),
},
)
)
@api.response(200, "Banner inserted successfully")
@admin_required
@only_edition_cloud
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("link", type=str, required=True, nullable=False, location="json")
parser.add_argument("sort", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
banner = ExporleBanner(
content=args["content"],
link=args["link"],
sort=args["sort"],
)
db.session.add(banner)
db.session.commit()
return {"result": "success"}, 200
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
class DeleteExploreBanner(Resource):
@api.doc("delete_explore_banner")
@api.doc(description="Delete an explore banner")
@api.response(204, "Banner deleted successfully")
@admin_required
@only_edition_cloud
def delete(self, banner_id):
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
if not banner:
raise NotFound(f"Banner '{banner_id}' is not found")
db.session.delete(banner)
db.session.commit()
return {"result": "success"}, 204

View File

@ -0,0 +1,291 @@
import json
import time
from extensions.ext_redis import redis_client
from extensions.ext_socketio import sio
from libs.passport import PassportService
from services.account_service import AccountService
@sio.on("connect")
def socket_connect(sid, environ, auth):
"""
WebSocket connect event, do authentication here.
"""
token = None
if auth and isinstance(auth, dict):
token = auth.get("token")
if not token:
return False
try:
decoded = PassportService().verify(token)
user_id = decoded.get("user_id")
if not user_id:
return False
with sio.app.app_context():
user = AccountService.load_logged_in_account(account_id=user_id)
if not user:
return False
sio.save_session(sid, {"user_id": user.id, "username": user.name, "avatar": user.avatar})
return True
except Exception:
return False
@sio.on("user_connect")
def handle_user_connect(sid, data):
"""
Handle user connect event. Each session (tab) is treated as an independent collaborator.
"""
workflow_id = data.get("workflow_id")
if not workflow_id:
return {"msg": "workflow_id is required"}, 400
session = sio.get_session(sid)
user_id = session.get("user_id")
if not user_id:
return {"msg": "unauthorized"}, 401
# Each session is stored independently with sid as key
session_info = {
"user_id": user_id,
"username": session.get("username", "Unknown"),
"avatar": session.get("avatar", None),
"sid": sid,
"connected_at": int(time.time()), # Add timestamp to differentiate tabs
}
# Store session info with sid as key
redis_client.hset(f"workflow_online_users:{workflow_id}", sid, json.dumps(session_info))
redis_client.set(f"ws_sid_map:{sid}", json.dumps({"workflow_id": workflow_id, "user_id": user_id}))
# Leader election: first session becomes the leader
leader_sid = get_or_set_leader(workflow_id, sid)
is_leader = leader_sid == sid
sio.enter_room(sid, workflow_id)
broadcast_online_users(workflow_id)
# Notify this session of their leader status
sio.emit("status", {"isLeader": is_leader}, room=sid)
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
@sio.on("disconnect")
def handle_disconnect(sid):
"""
Handle session disconnect event. Remove the specific session from online users.
"""
mapping = redis_client.get(f"ws_sid_map:{sid}")
if mapping:
data = json.loads(mapping)
workflow_id = data["workflow_id"]
# Remove this specific session
redis_client.hdel(f"workflow_online_users:{workflow_id}", sid)
redis_client.delete(f"ws_sid_map:{sid}")
# Handle leader re-election if the leader session disconnected
handle_leader_disconnect(workflow_id, sid)
broadcast_online_users(workflow_id)
def _clear_session_state(workflow_id: str, sid: str) -> None:
redis_client.hdel(f"workflow_online_users:{workflow_id}", sid)
redis_client.delete(f"ws_sid_map:{sid}")
def _is_session_active(workflow_id: str, sid: str) -> bool:
if not sid:
return False
try:
if not sio.manager.is_connected(sid, "/"):
return False
except AttributeError:
return False
if not redis_client.hexists(f"workflow_online_users:{workflow_id}", sid):
return False
if not redis_client.exists(f"ws_sid_map:{sid}"):
return False
return True
def get_or_set_leader(workflow_id: str, sid: str) -> str:
"""
Get current leader session or set this session as leader if no valid leader exists.
Returns the leader session id (sid).
"""
leader_key = f"workflow_leader:{workflow_id}"
raw_leader = redis_client.get(leader_key)
current_leader = raw_leader.decode("utf-8") if isinstance(raw_leader, bytes) else raw_leader
leader_replaced = False
if current_leader and not _is_session_active(workflow_id, current_leader):
_clear_session_state(workflow_id, current_leader)
redis_client.delete(leader_key)
current_leader = None
leader_replaced = True
if not current_leader:
redis_client.set(leader_key, sid, ex=3600) # Expire in 1 hour
if leader_replaced:
broadcast_leader_change(workflow_id, sid)
return sid
return current_leader
def handle_leader_disconnect(workflow_id, disconnected_sid):
"""
Handle leader re-election when a session disconnects.
If the disconnected session was the leader, elect a new leader from remaining sessions.
"""
leader_key = f"workflow_leader:{workflow_id}"
current_leader = redis_client.get(leader_key)
if current_leader:
current_leader = current_leader.decode("utf-8") if isinstance(current_leader, bytes) else current_leader
if current_leader == disconnected_sid:
# Leader session disconnected, elect a new leader
sessions_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
if sessions_json:
# Get the first remaining session as new leader
new_leader_sid = list(sessions_json.keys())[0]
if isinstance(new_leader_sid, bytes):
new_leader_sid = new_leader_sid.decode("utf-8")
redis_client.set(leader_key, new_leader_sid, ex=3600)
# Notify all sessions about the new leader
broadcast_leader_change(workflow_id, new_leader_sid)
else:
# No sessions left, remove leader
redis_client.delete(leader_key)
def broadcast_leader_change(workflow_id, new_leader_sid):
"""
Broadcast leader change to all sessions in the workflow.
"""
sessions_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
for sid, session_info_json in sessions_json.items():
try:
sid_str = sid.decode("utf-8") if isinstance(sid, bytes) else sid
is_leader = sid_str == new_leader_sid
# Emit to each session whether they are the new leader
sio.emit("status", {"isLeader": is_leader}, room=sid_str)
except Exception:
continue
def get_current_leader(workflow_id):
"""
Get the current leader for a workflow.
"""
leader_key = f"workflow_leader:{workflow_id}"
leader = redis_client.get(leader_key)
return leader.decode("utf-8") if leader and isinstance(leader, bytes) else leader
def broadcast_online_users(workflow_id):
"""
Broadcast online users to the workflow room.
Each session is shown as a separate user (even if same person has multiple tabs).
"""
sessions_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
users = []
for sid, session_info_json in sessions_json.items():
try:
session_info = json.loads(session_info_json)
# Each session appears as a separate "user" in the UI
users.append(
{
"user_id": session_info["user_id"],
"username": session_info["username"],
"avatar": session_info.get("avatar"),
"sid": session_info["sid"],
"connected_at": session_info.get("connected_at"),
}
)
except Exception:
continue
# Sort by connection time to maintain consistent order
users.sort(key=lambda x: x.get("connected_at") or 0)
# Get current leader session
leader_sid = get_current_leader(workflow_id)
sio.emit("online_users", {"workflow_id": workflow_id, "users": users, "leader": leader_sid}, room=workflow_id)
@sio.on("collaboration_event")
def handle_collaboration_event(sid, data):
"""
Handle general collaboration events, include:
1. mouseMove
2. varsAndFeaturesUpdate
3. syncRequest(ask leader to update graph)
4. appStateUpdate
5. mcpServerUpdate
"""
mapping = redis_client.get(f"ws_sid_map:{sid}")
if not mapping:
return {"msg": "unauthorized"}, 401
mapping_data = json.loads(mapping)
workflow_id = mapping_data["workflow_id"]
user_id = mapping_data["user_id"]
event_type = data.get("type")
event_data = data.get("data")
timestamp = data.get("timestamp", int(time.time()))
if not event_type:
return {"msg": "invalid event type"}, 400
sio.emit(
"collaboration_update",
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
room=workflow_id,
skip_sid=sid,
)
return {"msg": "event_broadcasted"}
@sio.on("graph_event")
def handle_graph_event(sid, data):
"""
Handle graph events - simple broadcast relay.
"""
mapping = redis_client.get(f"ws_sid_map:{sid}")
if not mapping:
return {"msg": "unauthorized"}, 401
mapping_data = json.loads(mapping)
workflow_id = mapping_data["workflow_id"]
sio.emit("graph_update", data, room=workflow_id, skip_sid=sid)
return {"msg": "graph_update_broadcasted"}

View File

@ -21,7 +21,9 @@ from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import file_factory, variable_factory
from fields.online_user_fields import online_user_list_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
@ -127,6 +129,7 @@ class DraftWorkflowApi(Resource):
parser.add_argument("hash", type=str, required=False, location="json")
parser.add_argument("environment_variables", type=list, required=True, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json")
parser.add_argument("force_upload", type=bool, required=False, default=False, location="json")
args = parser.parse_args()
elif "text/plain" in content_type:
try:
@ -143,6 +146,7 @@ class DraftWorkflowApi(Resource):
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
"force_upload": data.get("force_upload", False),
}
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
@ -171,6 +175,7 @@ class DraftWorkflowApi(Resource):
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
force_upload=args.get("force_upload", False),
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
@ -796,6 +801,45 @@ class ConvertToWorkflowApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/config")
class WorkflowConfigApi(Resource):
"""Resource for workflow configuration."""
@api.doc("get_workflow_config")
@api.doc(description="Get workflow configuration")
@api.doc(params={"app_id": "Application ID"})
@api.response(200, "Workflow configuration retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}
class WorkflowFeaturesApi(Resource):
"""Update draft workflow features."""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
parser = reqparse.RequestParser()
parser.add_argument("features", type=dict, required=True, location="json")
args = parser.parse_args()
features = args.get("features")
# Update draft workflow features
workflow_service = WorkflowService()
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource):
@api.doc("get_all_published_workflows")
@ -985,3 +1029,105 @@ class DraftWorkflowNodeLastRunApi(Resource):
if node_exec is None:
raise NotFound("last run not found")
return node_exec
class WorkflowOnlineUsersApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(online_user_list_fields)
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("workflow_ids", type=str, required=True, location="args")
args = parser.parse_args()
workflow_ids = [id.strip() for id in args["workflow_ids"].split(",")]
results = []
for workflow_id in workflow_ids:
users_json = redis_client.hgetall(f"workflow_online_users:{workflow_id}")
users = []
for _, user_info_json in users_json.items():
try:
users.append(json.loads(user_info_json))
except Exception:
continue
results.append({"workflow_id": workflow_id, "users": users})
return {"data": results}
api.add_resource(
DraftWorkflowApi,
"/apps/<uuid:app_id>/workflows/draft",
)
api.add_resource(
WorkflowConfigApi,
"/apps/<uuid:app_id>/workflows/draft/config",
)
api.add_resource(
WorkflowFeaturesApi,
"/apps/<uuid:app_id>/workflows/draft/features",
)
api.add_resource(
AdvancedChatDraftWorkflowRunApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/run",
)
api.add_resource(
DraftWorkflowRunApi,
"/apps/<uuid:app_id>/workflows/draft/run",
)
api.add_resource(
WorkflowTaskStopApi,
"/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop",
)
api.add_resource(
DraftWorkflowNodeRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
)
api.add_resource(
AdvancedChatDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
WorkflowDraftRunIterationNodeApi,
"/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
)
api.add_resource(
AdvancedChatDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
WorkflowDraftRunLoopNodeApi,
"/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run",
)
api.add_resource(
PublishedWorkflowApi,
"/apps/<uuid:app_id>/workflows/publish",
)
api.add_resource(
PublishedAllWorkflowApi,
"/apps/<uuid:app_id>/workflows",
)
api.add_resource(
DefaultBlockConfigsApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs",
)
api.add_resource(
DefaultBlockConfigApi,
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>",
)
api.add_resource(
ConvertToWorkflowApi,
"/apps/<uuid:app_id>/convert-to-workflow",
)
api.add_resource(
WorkflowByIdApi,
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
)
api.add_resource(
DraftWorkflowNodeLastRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run",
)
api.add_resource(WorkflowOnlineUsersApi, "/apps/workflows/online-users")

View File

@ -0,0 +1,240 @@
import logging
from flask_restx import Resource, fields, marshal_with, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from fields.member_fields import account_with_role_fields
from fields.workflow_comment_fields import (
workflow_comment_basic_fields,
workflow_comment_create_fields,
workflow_comment_detail_fields,
workflow_comment_reply_create_fields,
workflow_comment_reply_update_fields,
workflow_comment_resolve_fields,
workflow_comment_update_fields,
)
from libs.login import current_user, login_required
from models import App
from services.account_service import TenantService
from services.workflow_comment_service import WorkflowCommentService
logger = logging.getLogger(__name__)
class WorkflowCommentListApi(Resource):
"""API for listing and creating workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_basic_fields, envelope="data")
def get(self, app_model: App):
"""Get all comments for a workflow."""
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
return comments
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_create_fields)
def post(self, app_model: App):
"""Create a new workflow comment."""
parser = reqparse.RequestParser()
parser.add_argument("position_x", type=float, required=True, location="json")
parser.add_argument("position_y", type=float, required=True, location="json")
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
result = WorkflowCommentService.create_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
created_by=current_user.id,
content=args.content,
position_x=args.position_x,
position_y=args.position_y,
mentioned_user_ids=args.mentioned_user_ids,
)
return result, 201
class WorkflowCommentDetailApi(Resource):
"""API for managing individual workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_detail_fields)
def get(self, app_model: App, comment_id: str):
"""Get a specific workflow comment."""
comment = WorkflowCommentService.get_comment(
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
)
return comment
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_update_fields)
def put(self, app_model: App, comment_id: str):
"""Update a workflow comment."""
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("position_x", type=float, required=False, location="json")
parser.add_argument("position_y", type=float, required=False, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
result = WorkflowCommentService.update_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
content=args.content,
position_x=args.position_x,
position_y=args.position_y,
mentioned_user_ids=args.mentioned_user_ids,
)
return result
@login_required
@setup_required
@account_initialization_required
@get_app_model
def delete(self, app_model: App, comment_id: str):
"""Delete a workflow comment."""
WorkflowCommentService.delete_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return {"result": "success"}, 204
class WorkflowCommentResolveApi(Resource):
"""API for resolving and reopening workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_resolve_fields)
def post(self, app_model: App, comment_id: str):
"""Resolve a workflow comment."""
comment = WorkflowCommentService.resolve_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return comment
class WorkflowCommentReplyApi(Resource):
"""API for managing comment replies."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_reply_create_fields)
def post(self, app_model: App, comment_id: str):
"""Add a reply to a workflow comment."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
result = WorkflowCommentService.create_reply(
comment_id=comment_id,
content=args.content,
created_by=current_user.id,
mentioned_user_ids=args.mentioned_user_ids,
)
return result, 201
class WorkflowCommentReplyDetailApi(Resource):
"""API for managing individual comment replies."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with(workflow_comment_reply_update_fields)
def put(self, app_model: App, comment_id: str, reply_id: str):
"""Update a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, location="json")
parser.add_argument("mentioned_user_ids", type=list, location="json", default=[])
args = parser.parse_args()
reply = WorkflowCommentService.update_reply(
reply_id=reply_id, user_id=current_user.id, content=args.content, mentioned_user_ids=args.mentioned_user_ids
)
return reply
@login_required
@setup_required
@account_initialization_required
@get_app_model
def delete(self, app_model: App, comment_id: str, reply_id: str):
"""Delete a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
return {"result": "success"}, 204
class WorkflowCommentMentionUsersApi(Resource):
"""API for getting mentionable users for workflow comments."""
@login_required
@setup_required
@account_initialization_required
@get_app_model
@marshal_with({"users": fields.List(fields.Nested(account_with_role_fields))})
def get(self, app_model: App):
"""Get all users in current tenant for mentions."""
members = TenantService.get_tenant_members(current_user.current_tenant)
return {"users": members}
# Register API routes
api.add_resource(WorkflowCommentListApi, "/apps/<uuid:app_id>/workflow/comments")
api.add_resource(WorkflowCommentDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
api.add_resource(WorkflowCommentResolveApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
api.add_resource(WorkflowCommentReplyApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
api.add_resource(
WorkflowCommentReplyDetailApi, "/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>"
)
api.add_resource(WorkflowCommentMentionUsersApi, "/apps/<uuid:app_id>/workflow/comments/mention-users")

View File

@ -19,8 +19,8 @@ from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories import variable_factory
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import App, AppMode
from models.account import Account
@ -353,7 +353,7 @@ class VariableApi(Resource):
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
new_value = build_segment_with_type(variable.value_type, raw_value)
new_value = variable_factory.build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@ -446,8 +446,35 @@ class ConversationVariableCollectionApi(Resource):
db.session.commit()
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
def post(self, app_model: App):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("conversation_variables", type=list, required=True, location="json")
args = parser.parse_args()
workflow_service = WorkflowService()
conversation_variables_list = args.get("conversation_variables") or []
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
workflow_service.update_draft_workflow_conversation_variables(
app_model=app_model,
account=current_user,
conversation_variables=conversation_variables,
)
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
class SystemVariableCollectionApi(Resource):
@api.doc("get_system_variables")
@api.doc(description="Get system variables for workflow")
@ -497,3 +524,44 @@ class EnvironmentVariableCollectionApi(Resource):
)
return {"items": env_vars_list}
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("environment_variables", type=list, required=True, location="json")
args = parser.parse_args()
workflow_service = WorkflowService()
environment_variables_list = args.get("environment_variables") or []
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
workflow_service.update_draft_workflow_environment_variables(
app_model=app_model,
account=current_user,
environment_variables=environment_variables,
)
return {"result": "success"}
api.add_resource(
WorkflowVariableCollectionApi,
"/apps/<uuid:app_id>/workflows/draft/variables",
)
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")

View File

@ -0,0 +1,34 @@
from flask_restx import Resource
from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled
from extensions.ext_database import db
from models.model import ExporleBanner
class BannerApi(Resource):
"""Resource for banner list."""
@explore_banner_enabled
def get(self):
"""Get banner list."""
banners = (
db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled").order_by(ExporleBanner.sort).all()
)
# Convert banners to serializable format
result = []
for banner in banners:
banner_data = {
"content": banner.content, # Already parsed as JSON by SQLAlchemy
"link": banner.link,
"sort": banner.sort,
"status": banner.status,
"created_at": banner.created_at.isoformat() if banner.created_at else None,
}
result.append(banner_data)
return result
api.add_resource(BannerApi, "/explore/banners")

View File

@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
error_code = "access_denied"
description = "App access denied."
code = 403
class TrialAppNotAllowed(BaseHTTPException):
"""*403* `Trial App Not Allowed`
Raise if the user has reached the trial app limit.
"""
error_code = "trial_app_not_allowed"
code = 403
description = "the app is not allowed to be trial."
class TrialAppLimitExceeded(BaseHTTPException):
"""*403* `Trial App Limit Exceeded`
Raise if the user has exceeded the trial app limit.
"""
error_code = "trial_app_limit_exceeded"
code = 403
description = "The user has exceeded the trial app limit."

View File

@ -27,6 +27,7 @@ recommended_app_fields = {
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
"can_trial": fields.Boolean,
}
recommended_app_list_fields = {

View File

@ -0,0 +1,375 @@
import logging
from flask import request
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common import fields
from controllers.common.fields import build_site_model
from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
ConversationCompletedError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.explore.error import (
AppSuggestedQuestionsAfterAnswerDisabledError,
NotChatAppError,
NotCompletionAppError,
)
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
from controllers.service_api import service_api_ns
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.app_fields import app_detail_fields_with_site
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.account import TenantStatus
from models.model import AppMode, Site
from services.app_generate_service import AppGenerateService
from services.app_service import AppService
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
from services.errors.conversation import ConversationNotExistsError
from services.errors.llm import InvokeRateLimitError
from services.errors.message import (
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.message_service import MessageService
from services.recommended_app_service import RecommendedAppService
logger = logging.getLogger(__name__)
class TrialChatApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialMessageSuggestedQuestionApi(TrialAppResource):
@trial_feature_enable
def get(self, trial_app, message_id):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id = str(message_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
class TrialChatAudioApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
file = request.files["file"]
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialChatTextApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialCompletionApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
)
RecommendedAppService.add_trial_app_record(app_model.id, current_user.id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialSitApi(Resource):
"""Resource for trial app sites."""
@trial_feature_enable
@get_app_model
@service_api_ns.marshal_with(build_site_model(service_api_ns))
def get(self, app_model):
"""Retrieve app site info.
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
assert app_model.tenant
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return site
class TrialAppParameterApi(Resource):
"""Resource for app variables."""
@trial_feature_enable
@get_app_model
@marshal_with(fields.parameters_fields)
def get(self, app_model):
"""Retrieve app parameters."""
if app_model is None:
raise AppUnavailableError()
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
class AppApi(Resource):
@trial_feature_enable
@get_app_model
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
app_model = app_service.get_app(app_model)
return app_model
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
api.add_resource(
TrialMessageSuggestedQuestionApi,
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="trial_app_suggested_question",
)
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")

View File

@ -2,15 +2,16 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_login import current_user
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import login_required
from models import InstalledApp
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -74,6 +75,59 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, P], R]):
@wraps(view)
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
if trial_app is None:
raise TrialAppNotAllowed()
app = trial_app.app
if app is None:
raise TrialAppNotAllowed()
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
.first()
)
if account_trial_app_record:
if account_trial_app_record.count >= trial_app.trial_limit:
raise TrialAppLimitExceeded()
return view(app, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_trial_app:
abort(403, "Trial app feature is not enabled.")
return view(*args, **kwargs)
return decorated
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_explore_banner:
abort(403, "Explore banner feature is not enabled.")
return view(*args, **kwargs)
return decorated
class InstalledAppResource(Resource):
# must be reversed if there are multiple decorators
@ -83,3 +137,13 @@ class InstalledAppResource(Resource):
account_initialization_required,
login_required,
]
class TrialAppResource(Resource):
# must be reversed if there are multiple decorators
method_decorators = [
trial_app_required,
account_initialization_required,
login_required,
]

View File

@ -33,6 +33,7 @@ from controllers.console.wraps import (
only_edition_cloud,
setup_required,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
@ -135,6 +136,17 @@ class AccountNameApi(Resource):
@console_ns.route("/account/avatar")
class AccountAvatarApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("avatar", type=str, required=True, location="args")
args = parser.parse_args()
avatar_url = file_helpers.get_signed_file_url(args["avatar"])
return {"avatar_url": avatar_url}
@setup_required
@login_required
@account_initialization_required

View File

@ -51,6 +51,8 @@ tenant_fields = {
"in_trial": fields.Boolean,
"trial_end_reason": fields.String,
"custom_config": fields.Raw(attribute="custom_config"),
"trial_credits": fields.Integer,
"trial_credits_used": fields.Integer,
}
tenants_fields = {

View File

@ -56,6 +56,9 @@ class HostingConfiguration:
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek()
self.moderation_config = self.init_moderation_config()
@ -128,7 +131,7 @@ class HostingConfiguration:
quotas: list[HostingQuota] = []
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
hosted_quota_limit = 0
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
@ -156,18 +159,49 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
@staticmethod
def init_anthropic() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
def init_gemini(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_GEMINI_TRIAL_ENABLED:
hosted_quota_limit = 0
trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
if dify_config.HOSTED_GEMINI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
}
if dify_config.HOSTED_GEMINI_API_BASE:
credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_anthropic(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
paid_quota = PaidHostingQuota()
paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
@ -185,6 +219,66 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
def init_xai(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_XAI_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_XAI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_XAI_API_KEY,
}
if dify_config.HOSTED_XAI_API_BASE:
credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_deepseek(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,
}
if dify_config.HOSTED_DEEPSEEK_API_BASE:
credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
@staticmethod
def init_minimax() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS

View File

@ -618,9 +618,9 @@ class ProviderManager:
)
for quota in configuration.quotas:
if quota.quota_type == ProviderQuotaType.TRIAL:
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
# Init trial provider records if not exists
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
if quota.quota_type not in provider_quota_to_provider_record_dict:
try:
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
new_provider_record = Provider(
@ -628,8 +628,8 @@ class ProviderManager:
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=quota.quota_limit, # type: ignore
quota_type=quota.quota_type,
quota_limit=0, # type: ignore
quota_used=0,
is_valid=True,
)
@ -642,7 +642,7 @@ class ProviderManager:
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
Provider.quota_type == quota.quota_type,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
@ -652,7 +652,7 @@ class ProviderManager:
existed_provider_record.is_valid = True
db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
return provider_name_to_provider_records_dict
@ -912,6 +912,22 @@ class ProviderManager:
provider_record
)
quota_configurations = []
if dify_config.EDITION == "CLOUD":
from services.credit_pool_service import CreditPoolService
trail_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.TRIAL.value,
)
paid_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.PAID.value,
)
else:
trail_pool = None
paid_pool = None
for provider_quota in provider_hosting_configuration.quotas:
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
if provider_quota.quota_type == ProviderQuotaType.FREE:
@ -932,16 +948,36 @@ class ProviderManager:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=trail_pool.quota_used,
quota_limit=trail_pool.quota_limit,
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
else:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configurations.append(quota_configuration)

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import QuotaUnit
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.file.models import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@ -136,21 +136,36 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
session.execute(stmt)
session.commit()
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@ -38,14 +38,16 @@ elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
else
if [[ "${DEBUG}" == "true" ]]; then
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug
export HOST=${DIFY_BIND_ADDRESS:-0.0.0.0}
export PORT=${DIFY_PORT:-5001}
exec python -m app
else
exec gunicorn \
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
--workers ${SERVER_WORKER_AMOUNT:-1} \
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
--worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
--timeout ${GUNICORN_TIMEOUT:-200} \
app:app
app:socketio_app
fi
fi

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, SystemConfiguration
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
@ -133,22 +133,38 @@ def handle(sender: Message, **kwargs):
system_configuration=system_configuration,
model_name=model_config.model,
)
if used_quota is not None:
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
if provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
),
description="quota_deduction_update",
)
updates_to_perform.append(quota_update)
credits_required=used_quota,
pool_type="trial",
)
elif provider_configuration.system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
),
description="quota_deduction_update",
)
updates_to_perform.append(quota_update)
# Execute all updates
start_time = time_module.perf_counter()

View File

@ -0,0 +1,3 @@
import socketio
sio = socketio.Server(async_mode="gevent", cors_allowed_origins="*")

View File

@ -0,0 +1,17 @@
from flask_restx import fields
online_user_partial_fields = {
"user_id": fields.String,
"username": fields.String,
"avatar": fields.String,
"sid": fields.String,
}
workflow_online_users_fields = {
"workflow_id": fields.String,
"users": fields.List(fields.Nested(online_user_partial_fields)),
}
online_user_list_fields = {
"data": fields.List(fields.Nested(workflow_online_users_fields)),
}

View File

@ -0,0 +1,96 @@
from flask_restx import fields
from libs.helper import AvatarUrlField, TimestampField
# basic account fields for comments
account_fields = {
"id": fields.String,
"name": fields.String,
"email": fields.String,
"avatar_url": AvatarUrlField,
}
# Comment mention fields
workflow_comment_mention_fields = {
"mentioned_user_id": fields.String,
"mentioned_user_account": fields.Nested(account_fields, allow_null=True),
"reply_id": fields.String,
}
# Comment reply fields
workflow_comment_reply_fields = {
"id": fields.String,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
}
# Basic comment fields (for list views)
workflow_comment_basic_fields = {
"id": fields.String,
"position_x": fields.Float,
"position_y": fields.Float,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
"updated_at": TimestampField,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
"reply_count": fields.Integer,
"mention_count": fields.Integer,
"participants": fields.List(fields.Nested(account_fields)),
}
# Detailed comment fields (for single comment view)
workflow_comment_detail_fields = {
"id": fields.String,
"position_x": fields.Float,
"position_y": fields.Float,
"content": fields.String,
"created_by": fields.String,
"created_by_account": fields.Nested(account_fields, allow_null=True),
"created_at": TimestampField,
"updated_at": TimestampField,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
"replies": fields.List(fields.Nested(workflow_comment_reply_fields)),
"mentions": fields.List(fields.Nested(workflow_comment_mention_fields)),
}
# Comment creation response fields (simplified)
workflow_comment_create_fields = {
"id": fields.String,
"created_at": TimestampField,
}
# Comment update response fields (simplified)
workflow_comment_update_fields = {
"id": fields.String,
"updated_at": TimestampField,
}
# Comment resolve response fields
workflow_comment_resolve_fields = {
"id": fields.String,
"resolved": fields.Boolean,
"resolved_at": TimestampField,
"resolved_by": fields.String,
}
# Reply creation response fields (simplified)
workflow_comment_reply_create_fields = {
"id": fields.String,
"created_at": TimestampField,
}
# Reply update response fields
workflow_comment_reply_update_fields = {
"id": fields.String,
"updated_at": TimestampField,
}

View File

@ -0,0 +1,90 @@
"""Add workflow comments table
Revision ID: 227822d22895
Revises: 68519ad5cd18
Create Date: 2025-08-22 17:26:15.255980
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '227822d22895'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('workflow_comments',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('position_x', sa.Float(), nullable=False),
sa.Column('position_y', sa.Float(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('resolved', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('resolved_at', sa.DateTime(), nullable=True),
sa.Column('resolved_by', models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint('id', name='workflow_comments_pkey')
)
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
batch_op.create_index('workflow_comments_app_idx', ['tenant_id', 'app_id'], unique=False)
batch_op.create_index('workflow_comments_created_at_idx', ['created_at'], unique=False)
op.create_table('workflow_comment_replies',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_replies_comment_id_fkey'), ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name='workflow_comment_replies_pkey')
)
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
batch_op.create_index('comment_replies_comment_idx', ['comment_id'], unique=False)
batch_op.create_index('comment_replies_created_at_idx', ['created_at'], unique=False)
op.create_table('workflow_comment_mentions',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
sa.Column('reply_id', models.types.StringUUID(), nullable=True),
sa.Column('mentioned_user_id', models.types.StringUUID(), nullable=False),
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_mentions_comment_id_fkey'), ondelete='CASCADE'),
sa.ForeignKeyConstraint(['reply_id'], ['workflow_comment_replies.id'], name=op.f('workflow_comment_mentions_reply_id_fkey'), ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id', name='workflow_comment_mentions_pkey')
)
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
batch_op.create_index('comment_mentions_comment_idx', ['comment_id'], unique=False)
batch_op.create_index('comment_mentions_reply_idx', ['reply_id'], unique=False)
batch_op.create_index('comment_mentions_user_idx', ['mentioned_user_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
batch_op.drop_index('comment_mentions_user_idx')
batch_op.drop_index('comment_mentions_reply_idx')
batch_op.drop_index('comment_mentions_comment_idx')
op.drop_table('workflow_comment_mentions')
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
batch_op.drop_index('comment_replies_created_at_idx')
batch_op.drop_index('comment_replies_comment_idx')
op.drop_table('workflow_comment_replies')
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
batch_op.drop_index('workflow_comments_created_at_idx')
batch_op.drop_index('workflow_comments_app_idx')
op.drop_table('workflow_comments')
# ### end Alembic commands ###

View File

@ -0,0 +1,79 @@
"""add table explore banner and trial
Revision ID: 1b435d90db42
Revises: cf7c38a32b2d
Create Date: 2025-09-19 14:42:58.416649
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '1b435d90db42'
down_revision = 'cf7c38a32b2d'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('account_trial_app_records',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('account_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('count', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'),
sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record')
)
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False)
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
op.create_table('exporle_banners',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('content', sa.JSON(), nullable=False),
sa.Column('link', sa.String(length=255), nullable=False),
sa.Column('sort', sa.Integer(), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
)
op.create_table('trial_apps',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('trial_limit', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id', name='trial_app_pkey'),
sa.UniqueConstraint('app_id', name='unique_trail_app_id')
)
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False)
batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False)
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.drop_column('credential_status')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('credential_status', sa.VARCHAR(length=20), server_default=sa.text("'active'::character varying"), autoincrement=False, nullable=True))
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
batch_op.drop_index('trial_app_tenant_id_idx')
batch_op.drop_index('trial_app_app_id_idx')
op.drop_table('trial_apps')
op.drop_table('exporle_banners')
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
batch_op.drop_index('account_trial_app_record_app_id_idx')
batch_op.drop_index('account_trial_app_record_account_id_idx')
op.drop_table('account_trial_app_records')
# ### end Alembic commands ###

View File

@ -0,0 +1,104 @@
"""add table credit pool
Revision ID: 58a70d22fdbd
Revises: 68519ad5cd18
Create Date: 2025-09-25 15:20:40.367078
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '58a70d22fdbd'
down_revision = '68519ad5cd18'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_credit_pools',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('pool_type', sa.String(length=40), nullable=False),
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
sa.Column('quota_used', sa.BigInteger(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
)
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
# Data migration: Move quota data from providers to tenant_credit_pools
migrate_quota_data()
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.drop_index('tenant_credit_pool_tenant_id_idx')
batch_op.drop_index('tenant_credit_pool_pool_type_idx')
op.drop_table('tenant_credit_pools')
# ### end Alembic commands ###
def migrate_quota_data():
"""
Migrate quota data from providers table to tenant_credit_pools table
for providers with quota_type='trial' or 'paid', provider_name='openai', provider_type='system'
"""
# Create connection
bind = op.get_bind()
# Define quota type mappings
quota_type_mappings = ['trial', 'paid']
for quota_type in quota_type_mappings:
# Query providers that match the criteria
select_sql = sa.text("""
SELECT tenant_id, quota_limit, quota_used
FROM providers
WHERE quota_type = :quota_type
AND provider_name = 'openai'
AND provider_type = 'system'
AND quota_limit IS NOT NULL
""")
result = bind.execute(select_sql, {"quota_type": quota_type})
providers_data = result.fetchall()
# Insert data into tenant_credit_pools
for provider_data in providers_data:
tenant_id, quota_limit, quota_used = provider_data
# Check if credit pool already exists for this tenant and pool type
check_sql = sa.text("""
SELECT COUNT(*)
FROM tenant_credit_pools
WHERE tenant_id = :tenant_id AND pool_type = :pool_type
""")
existing_count = bind.execute(check_sql, {
"tenant_id": tenant_id,
"pool_type": quota_type
}).scalar()
if existing_count == 0:
# Insert new credit pool record
insert_sql = sa.text("""
INSERT INTO tenant_credit_pools (tenant_id, pool_type, quota_limit, quota_used, created_at, updated_at)
VALUES (:tenant_id, :pool_type, :quota_limit, :quota_used, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
""")
bind.execute(insert_sql, {
"tenant_id": tenant_id,
"pool_type": quota_type,
"quota_limit": quota_limit or 0,
"quota_used": quota_used or 0
})

View File

@ -10,6 +10,11 @@ from .account import (
)
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from .chatflow_memory import ChatflowConversation, ChatflowMemoryVariable, ChatflowMessage
from .comment import (
WorkflowComment,
WorkflowCommentMention,
WorkflowCommentReply,
)
from .dataset import (
AppDatasetJoin,
Dataset,
@ -29,6 +34,7 @@ from .dataset import (
)
from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom
from .model import (
AccountTrialAppRecord,
ApiRequest,
ApiToken,
App,
@ -41,6 +47,7 @@ from .model import (
DatasetRetrieverResource,
DifySetup,
EndUser,
ExporleBanner,
IconType,
InstalledApp,
Message,
@ -54,7 +61,9 @@ from .model import (
Site,
Tag,
TagBinding,
TenantCreditPool,
TraceAppConfig,
TrialApp,
UploadFile,
)
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
@ -99,6 +108,7 @@ __all__ = [
"Account",
"AccountIntegrate",
"AccountStatus",
"AccountTrialAppRecord",
"ApiRequest",
"ApiToken",
"ApiToolProvider",
@ -135,6 +145,7 @@ __all__ = [
"DocumentSegment",
"Embedding",
"EndUser",
"ExporleBanner",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"IconType",
@ -163,6 +174,7 @@ __all__ = [
"Tenant",
"TenantAccountJoin",
"TenantAccountRole",
"TenantCreditPool",
"TenantDefaultModel",
"TenantPreferredModelProvider",
"TenantStatus",
@ -172,12 +184,16 @@ __all__ = [
"ToolLabelBinding",
"ToolModelInvoke",
"TraceAppConfig",
"TrialApp",
"UploadFile",
"UserFrom",
"Whitelist",
"Workflow",
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
"WorkflowComment",
"WorkflowCommentMention",
"WorkflowCommentReply",
"WorkflowNodeExecutionModel",
"WorkflowNodeExecutionOffload",
"WorkflowNodeExecutionTriggeredFrom",

189
api/models/comment.py Normal file
View File

@ -0,0 +1,189 @@
"""Workflow comment models."""
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from sqlalchemy import Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .account import Account
from .base import Base
from .engine import db
from .types import StringUUID
if TYPE_CHECKING:
pass
class WorkflowComment(Base):
"""Workflow comment model for canvas commenting functionality.
Comments are associated with apps rather than specific workflow versions,
since an app has only one draft workflow at a time and comments should persist
across workflow version changes.
Attributes:
id: Comment ID
tenant_id: Workspace ID
app_id: App ID (primary association, comments belong to apps)
position_x: X coordinate on canvas
position_y: Y coordinate on canvas
content: Comment content
created_by: Creator account ID
created_at: Creation time
updated_at: Last update time
resolved: Whether comment is resolved
resolved_at: Resolution time
resolved_by: Resolver account ID
"""
__tablename__ = "workflow_comments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
Index("workflow_comments_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position_x: Mapped[float] = mapped_column(db.Float)
position_y: Mapped[float] = mapped_column(db.Float)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
resolved_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
resolved_by: Mapped[Optional[str]] = mapped_column(StringUUID)
# Relationships
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
)
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
)
@property
def created_by_account(self):
"""Get creator account."""
return db.session.get(Account, self.created_by)
@property
def resolved_by_account(self):
"""Get resolver account."""
if self.resolved_by:
return db.session.get(Account, self.resolved_by)
return None
@property
def reply_count(self):
"""Get reply count."""
return len(self.replies)
@property
def mention_count(self):
"""Get mention count."""
return len(self.mentions)
@property
def participants(self):
"""Get all participants (creator + repliers + mentioned users)."""
participant_ids = set()
# Add comment creator
participant_ids.add(self.created_by)
# Add reply creators
participant_ids.update(reply.created_by for reply in self.replies)
# Add mentioned users
participant_ids.update(mention.mentioned_user_id for mention in self.mentions)
# Get account objects
participants = []
for user_id in participant_ids:
account = db.session.get(Account, user_id)
if account:
participants.append(account)
return participants
class WorkflowCommentReply(Base):
"""Workflow comment reply model.
Attributes:
id: Reply ID
comment_id: Parent comment ID
content: Reply content
created_by: Creator account ID
created_at: Creation time
"""
__tablename__ = "workflow_comment_replies"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
Index("comment_replies_comment_idx", "comment_id"),
Index("comment_replies_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
content: Mapped[str] = mapped_column(db.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
@property
def created_by_account(self):
"""Get creator account."""
return db.session.get(Account, self.created_by)
class WorkflowCommentMention(Base):
"""Workflow comment mention model.
Mentions are only for internal accounts since end users
cannot access workflow canvas and commenting features.
Attributes:
id: Mention ID
comment_id: Parent comment ID
mentioned_user_id: Mentioned account ID
"""
__tablename__ = "workflow_comment_mentions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
Index("comment_mentions_comment_idx", "comment_id"),
Index("comment_mentions_reply_idx", "reply_id"),
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
comment_id: Mapped[str] = mapped_column(
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
reply_id: Mapped[Optional[str]] = mapped_column(
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
)
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
@property
def mentioned_user_account(self):
"""Get mentioned account."""
return db.session.get(Account, self.mentioned_user_id)

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column
from configs import dify_config
@ -581,6 +581,63 @@ class InstalledApp(Base):
return tenant
class TrialApp(Base):
__tablename__ = "trial_apps"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
sa.Index("trial_app_app_id_idx", "app_id"),
sa.Index("trial_app_tenant_id_idx", "tenant_id"),
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
class AccountTrialAppRecord(Base):
__tablename__ = "account_trial_app_records"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
sa.Index("account_trial_app_record_account_id_idx", "account_id"),
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
account_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
count = mapped_column(sa.Integer, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def user(self) -> Account | None:
user = db.session.query(Account).where(Account.id == self.account_id).first()
return user
class ExporleBanner(Base):
__tablename__ = "exporle_banners"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
content = mapped_column(sa.JSON, nullable=False)
link = mapped_column(String(255), nullable=False)
sort = mapped_column(sa.Integer, nullable=False)
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class OAuthProviderApp(Base):
"""
Globally shared OAuth provider app information.
@ -1944,3 +2001,29 @@ class TraceAppConfig(Base):
"created_at": str(self.created_at) if self.created_at else None,
"updated_at": str(self.updated_at) if self.updated_at else None,
}
class TenantCreditPool(Base):
__tablename__ = "tenant_credit_pools"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="tenant_credit_pool_pkey"),
sa.Index("tenant_credit_pool_tenant_id_idx", "tenant_id"),
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
)
id = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"))
tenant_id = mapped_column(StringUUID, nullable=False)
pool_type = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
quota_limit = mapped_column(BigInteger, nullable=False, default=0)
quota_used = mapped_column(BigInteger, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
updated_at = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def remaining_credits(self) -> int:
return max(0, self.quota_limit - self.quota_used)
def has_sufficient_credits(self, required_credits: int) -> bool:
return self.remaining_credits >= required_credits

View File

@ -342,7 +342,7 @@ class Workflow(Base):
:return: hash
"""
entity = {"graph": self.graph_dict, "features": self.features_dict}
entity = {"graph": self.graph_dict}
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))

View File

@ -20,6 +20,7 @@ dependencies = [
"flask-orjson~=2.0.0",
"flask-sqlalchemy~=3.1.1",
"gevent~=25.9.1",
"gevent-websocket~=0.10.1",
"gmpy2~=2.2.1",
"google-api-core==2.18.0",
"google-api-python-client==2.90.0",
@ -68,6 +69,7 @@ dependencies = [
"pypdfium2==4.30.0",
"python-docx~=1.1.0",
"python-dotenv==1.0.1",
"python-socketio~=5.13.0",
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=6.1.0",
@ -86,6 +88,7 @@ dependencies = [
"sendgrid~=6.12.3",
"flask-restx~=1.3.0",
"packaging~=23.2",
"gevent-websocket>=0.10.1",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.

View File

@ -995,6 +995,11 @@ class TenantService:
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.commit()
from services.credit_pool_service import CreditPoolService
CreditPoolService.create_default_pool(tenant.id)
return tenant
@staticmethod

View File

@ -0,0 +1,68 @@
import logging
from typing import Optional
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.errors.error import QuotaExceededError
from extensions.ext_database import db
from models import TenantCreditPool
logger = logging.getLogger(__name__)
class CreditPoolService:
@classmethod
def create_default_pool(cls, tenant_id: str) -> TenantCreditPool:
"""create default credit pool for new tenant"""
credit_pool = TenantCreditPool(
tenant_id=tenant_id, quota_limit=dify_config.HOSTED_POOL_CREDITS, quota_used=0, pool_type="trial"
)
db.session.add(credit_pool)
db.session.commit()
return credit_pool
@classmethod
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> Optional[TenantCreditPool]:
"""get tenant credit pool"""
return (
db.session.query(TenantCreditPool)
.filter_by(
tenant_id=tenant_id,
pool_type=pool_type,
)
.first()
)
@classmethod
def check_and_deduct_credits(
cls,
tenant_id: str,
credits_required: int,
pool_type: str = "trial",
):
"""check and deduct credits"""
pool = cls.get_pool(tenant_id, pool_type)
if not pool:
raise QuotaExceededError("Credit pool not found")
if pool.remaining_credits < credits_required:
raise QuotaExceededError(
f"Insufficient credits. Required: {credits_required}, Available: {pool.remaining_credits}"
)
try:
with Session(db.engine) as session:
update_values = {"quota_used": pool.quota_used + credits_required}
where_conditions = [
TenantCreditPool.pool_type == pool_type,
TenantCreditPool.tenant_id == tenant_id,
TenantCreditPool.quota_used + credits_required <= TenantCreditPool.quota_limit,
]
stmt = update(TenantCreditPool).where(*where_conditions).values(**update_values)
session.execute(stmt)
session.commit()
except Exception:
raise QuotaExceededError("Failed to deduct credits")

View File

@ -160,6 +160,8 @@ class SystemFeatureModel(BaseModel):
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
enable_trial_app: bool = False
enable_explore_banner: bool = False
class FeatureService:
@ -214,6 +216,8 @@ class FeatureService:
system_features.is_allow_register = dify_config.ALLOW_REGISTER
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
@classmethod
def _fulfill_params_from_env(cls, features: FeatureModel):

View File

@ -1,4 +1,9 @@
from sqlalchemy.orm import Session
from configs import dify_config
from extensions.ext_database import db
from models.model import AccountTrialAppRecord, TrialApp
from services.feature_service import FeatureService
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
@ -20,6 +25,15 @@ class RecommendedAppService:
)
)
if FeatureService.get_system_features().enable_trial_app:
apps = result["recommended_apps"]
for app in apps:
app_id = app["app_id"]
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
if trial_app_model:
app["can_trial"] = True
else:
app["can_trial"] = False
return result
@classmethod
@ -32,4 +46,27 @@ class RecommendedAppService:
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
if FeatureService.get_system_features().enable_trial_app:
app_id = result["id"]
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
if trial_app_model:
result["can_trial"] = True
else:
result["can_trial"] = False
return result
@classmethod
def add_trial_app_record(cls, app_id: str, account_id: str):
"""
Add trial app record.
:param app_id: app id
:return:
"""
with Session(db.engine) as session:
account_trial_app_record = session.query(AccountTrialAppRecord).where(TrialApp.app_id == app_id).first()
if account_trial_app_record:
account_trial_app_record.count += 1
session.commit()
else:
session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
session.commit()

View File

@ -0,0 +1,311 @@
import logging
from typing import Optional
from sqlalchemy import desc, select
from sqlalchemy.orm import Session, selectinload
from werkzeug.exceptions import Forbidden, NotFound
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from models import WorkflowComment, WorkflowCommentMention, WorkflowCommentReply
logger = logging.getLogger(__name__)
class WorkflowCommentService:
"""Service for managing workflow comments."""
@staticmethod
def _validate_content(content: str) -> None:
if len(content.strip()) == 0:
raise ValueError("Comment content cannot be empty")
if len(content) > 1000:
raise ValueError("Comment content cannot exceed 1000 characters")
@staticmethod
def get_comments(tenant_id: str, app_id: str) -> list[WorkflowComment]:
"""Get all comments for a workflow."""
with Session(db.engine) as session:
# Get all comments with eager loading
stmt = (
select(WorkflowComment)
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
.where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id)
.order_by(desc(WorkflowComment.created_at))
)
comments = session.scalars(stmt).all()
return comments
@staticmethod
def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session = None) -> WorkflowComment:
"""Get a specific comment."""
def _get_comment(session: Session) -> WorkflowComment:
stmt = (
select(WorkflowComment)
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
.where(
WorkflowComment.id == comment_id,
WorkflowComment.tenant_id == tenant_id,
WorkflowComment.app_id == app_id,
)
)
comment = session.scalar(stmt)
if not comment:
raise NotFound("Comment not found")
return comment
if session is not None:
return _get_comment(session)
else:
with Session(db.engine, expire_on_commit=False) as session:
return _get_comment(session)
@staticmethod
def create_comment(
tenant_id: str,
app_id: str,
created_by: str,
content: str,
position_x: float,
position_y: float,
mentioned_user_ids: Optional[list[str]] = None,
) -> WorkflowComment:
"""Create a new workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine) as session:
comment = WorkflowComment(
tenant_id=tenant_id,
app_id=app_id,
position_x=position_x,
position_y=position_y,
content=content,
created_by=created_by,
)
session.add(comment)
session.flush() # Get the comment ID for mentions
# Create mentions if specified
mentioned_user_ids = mentioned_user_ids or []
for user_id in mentioned_user_ids:
if isinstance(user_id, str) and uuid_value(user_id):
mention = WorkflowCommentMention(
comment_id=comment.id,
reply_id=None, # This is a comment mention, not reply mention
mentioned_user_id=user_id,
)
session.add(mention)
session.commit()
# Return only what we need - id and created_at
return {"id": comment.id, "created_at": comment.created_at}
@staticmethod
def update_comment(
tenant_id: str,
app_id: str,
comment_id: str,
user_id: str,
content: str,
position_x: Optional[float] = None,
position_y: Optional[float] = None,
mentioned_user_ids: Optional[list[str]] = None,
) -> dict:
"""Update a workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
# Get comment with validation
stmt = select(WorkflowComment).where(
WorkflowComment.id == comment_id,
WorkflowComment.tenant_id == tenant_id,
WorkflowComment.app_id == app_id,
)
comment = session.scalar(stmt)
if not comment:
raise NotFound("Comment not found")
# Only the creator can update the comment
if comment.created_by != user_id:
raise Forbidden("Only the comment creator can update it")
# Update comment fields
comment.content = content
if position_x is not None:
comment.position_x = position_x
if position_y is not None:
comment.position_y = position_y
# Update mentions - first remove existing mentions for this comment only (not replies)
existing_mentions = session.scalars(
select(WorkflowCommentMention).where(
WorkflowCommentMention.comment_id == comment.id,
WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions
)
).all()
for mention in existing_mentions:
session.delete(mention)
# Add new mentions
mentioned_user_ids = mentioned_user_ids or []
for user_id_str in mentioned_user_ids:
if isinstance(user_id_str, str) and uuid_value(user_id_str):
mention = WorkflowCommentMention(
comment_id=comment.id,
reply_id=None, # This is a comment mention
mentioned_user_id=user_id_str,
)
session.add(mention)
session.commit()
return {"id": comment.id, "updated_at": comment.updated_at}
@staticmethod
def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None:
"""Delete a workflow comment."""
with Session(db.engine, expire_on_commit=False) as session:
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
# Only the creator can delete the comment
if comment.created_by != user_id:
raise Forbidden("Only the comment creator can delete it")
# Delete associated mentions (both comment and reply mentions)
mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id)
).all()
for mention in mentions:
session.delete(mention)
# Delete associated replies
replies = session.scalars(
select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id)
).all()
for reply in replies:
session.delete(reply)
session.delete(comment)
session.commit()
@staticmethod
def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment:
"""Resolve a workflow comment."""
with Session(db.engine, expire_on_commit=False) as session:
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
if comment.resolved:
return comment
comment.resolved = True
comment.resolved_at = naive_utc_now()
comment.resolved_by = user_id
session.commit()
return comment
@staticmethod
def create_reply(
comment_id: str, content: str, created_by: str, mentioned_user_ids: Optional[list[str]] = None
) -> dict:
"""Add a reply to a workflow comment."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
# Check if comment exists
comment = session.get(WorkflowComment, comment_id)
if not comment:
raise NotFound("Comment not found")
reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by)
session.add(reply)
session.flush() # Get the reply ID for mentions
# Create mentions if specified
mentioned_user_ids = mentioned_user_ids or []
for user_id in mentioned_user_ids:
if isinstance(user_id, str) and uuid_value(user_id):
# Create mention linking to specific reply
mention = WorkflowCommentMention(
comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id
)
session.add(mention)
session.commit()
return {"id": reply.id, "created_at": reply.created_at}
@staticmethod
def update_reply(
reply_id: str, user_id: str, content: str, mentioned_user_ids: Optional[list[str]] = None
) -> WorkflowCommentReply:
"""Update a comment reply."""
WorkflowCommentService._validate_content(content)
with Session(db.engine, expire_on_commit=False) as session:
reply = session.get(WorkflowCommentReply, reply_id)
if not reply:
raise NotFound("Reply not found")
# Only the creator can update the reply
if reply.created_by != user_id:
raise Forbidden("Only the reply creator can update it")
reply.content = content
# Update mentions - first remove existing mentions for this reply
existing_mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id)
).all()
for mention in existing_mentions:
session.delete(mention)
# Add mentions
mentioned_user_ids = mentioned_user_ids or []
for user_id_str in mentioned_user_ids:
if isinstance(user_id_str, str) and uuid_value(user_id_str):
mention = WorkflowCommentMention(
comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str
)
session.add(mention)
session.commit()
session.refresh(reply) # Refresh to get updated timestamp
return {"id": reply.id, "updated_at": reply.updated_at}
@staticmethod
def delete_reply(reply_id: str, user_id: str) -> None:
"""Delete a comment reply."""
with Session(db.engine, expire_on_commit=False) as session:
reply = session.get(WorkflowCommentReply, reply_id)
if not reply:
raise NotFound("Reply not found")
# Only the creator can delete the reply
if reply.created_by != user_id:
raise Forbidden("Only the reply creator can delete it")
# Delete associated mentions first
mentions = session.scalars(
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id)
).all()
for mention in mentions:
session.delete(mention)
session.delete(reply)
session.commit()
@staticmethod
def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment:
"""Validate that a comment belongs to the specified tenant and app."""
return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id)

View File

@ -198,15 +198,17 @@ class WorkflowService:
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
force_upload: bool = False,
) -> Workflow:
"""
Sync draft workflow
:param force_upload: Skip hash validation when True (for restore operations)
:raises WorkflowHashNotEqualError
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if workflow and workflow.unique_hash != unique_hash:
if workflow and workflow.unique_hash != unique_hash and not force_upload:
raise WorkflowHashNotEqualError()
# validate features structure
@ -244,6 +246,78 @@ class WorkflowService:
# return draft workflow
return workflow
def update_draft_workflow_environment_variables(
self,
*,
app_model: App,
environment_variables: Sequence[Variable],
account: Account,
):
"""
Update draft workflow environment variables
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
workflow.environment_variables = environment_variables
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def update_draft_workflow_conversation_variables(
self,
*,
app_model: App,
conversation_variables: Sequence[Variable],
account: Account,
):
"""
Update draft workflow conversation variables
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
workflow.conversation_variables = conversation_variables
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def update_draft_workflow_features(
self,
*,
app_model: App,
features: dict,
account: Account,
):
"""
Update draft workflow features
"""
# fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model)
if not workflow:
raise ValueError("No draft workflow found.")
# validate features structure
self.validate_features_structure(app_model=app_model, features=features)
workflow.features = json.dumps(features)
workflow.updated_by = account.id
workflow.updated_at = naive_utc_now()
# commit db session changes
db.session.commit()
def publish_workflow(
self,
*,

View File

@ -46,5 +46,17 @@ class WorkspaceService:
"remove_webapp_brand": remove_webapp_brand,
"replace_webapp_logo": replace_webapp_logo,
}
if dify_config.EDITION == "CLOUD":
from services.credit_pool_service import CreditPoolService
paid_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="paid")
if paid_pool:
tenant_info["trial_credits"] = paid_pool.quota_limit
tenant_info["trial_credits_used"] = paid_pool.quota_used
else:
trial_pool = CreditPoolService.get_pool(tenant_id=tenant.id, pool_type="trial")
if trial_pool:
tenant_info["trial_credits"] = trial_pool.quota_limit
tenant_info["trial_credits_used"] = trial_pool.quota_used
return tenant_info

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,14 @@ server {
include proxy.conf;
}
location /socket.io/ {
proxy_pass http://api:5001;
include proxy.conf;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_cache_bypass $http_upgrade;
}
location /v1 {
proxy_pass http://api:5001;
include proxy.conf;

View File

@ -5,7 +5,7 @@ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header X-Forwarded-Port $server_port;
proxy_http_version 1.1;
proxy_set_header Connection "";
# proxy_set_header Connection "";
proxy_buffering off;
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT};
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT};

View File

@ -1,6 +1,6 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import React, { useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import AppCard from '@/app/components/app/overview/app-card'
@ -19,6 +19,8 @@ import { asyncRunSafe } from '@/utils'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import type { IAppCardProps } from '@/app/components/app/overview/app-card'
import { useStore as useAppStore } from '@/app/components/app/store'
import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager'
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
export type ICardViewProps = {
appId: string
@ -47,15 +49,44 @@ const CardView: FC<ICardViewProps> = ({ appId, isInPanel, className }) => {
message ||= (type === 'success' ? 'modifiedSuccessfully' : 'modifiedUnsuccessfully')
if (type === 'success')
if (type === 'success') {
updateAppDetail()
// Emit collaboration event to notify other clients of app state changes
const socket = webSocketClient.getSocket(appId)
if (socket) {
socket.emit('collaboration_event', {
type: 'appStateUpdate',
data: { timestamp: Date.now() },
timestamp: Date.now(),
})
}
}
notify({
type,
message: t(`common.actionMsg.${message}`),
})
}
// Listen for collaborative app state updates from other clients
useEffect(() => {
if (!appId) return
const unsubscribe = collaborationManager.onAppStateUpdate(async (update: any) => {
try {
console.log('Received app state update from collaboration:', update)
// Update app detail when other clients modify app state
await updateAppDetail()
}
catch (error) {
console.error('app state update failed:', error)
}
})
return unsubscribe
}, [appId])
const onChangeSiteStatus = async (value: boolean) => {
const [err] = await asyncRunSafe<App>(
updateAppSiteStatus({

View File

@ -32,6 +32,8 @@ import { useGlobalPublicStore } from '@/context/global-public-context'
import { formatTime } from '@/utils/time'
import { useGetUserCanAccessApp } from '@/service/access-control'
import dynamic from 'next/dynamic'
import { UserAvatarList } from '@/app/components/base/user-avatar-list'
import type { WorkflowOnlineUser } from '@/models/app'
const EditAppModal = dynamic(() => import('@/app/components/explore/create-app-modal'), {
ssr: false,
@ -55,9 +57,10 @@ const AccessControl = dynamic(() => import('@/app/components/app/app-access-cont
export type AppCardProps = {
app: App
onRefresh?: () => void
onlineUsers?: WorkflowOnlineUser[]
}
const AppCard = ({ app, onRefresh }: AppCardProps) => {
const AppCard = ({ app, onRefresh, onlineUsers = [] }: AppCardProps) => {
const { t } = useTranslation()
const { notify } = useContext(ToastContext)
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
@ -333,6 +336,19 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
return `${t('datasetDocuments.segment.editedAt')} ${timeText}`
}, [app.updated_at, app.created_at])
const onlineUserAvatars = useMemo(() => {
if (!onlineUsers.length)
return []
return onlineUsers
.map(user => ({
id: user.user_id || user.sid || '',
name: user.username || 'User',
avatar_url: user.avatar || undefined,
}))
.filter(user => !!user.id)
}, [onlineUsers])
return (
<>
<div
@ -377,6 +393,11 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
<RiVerifiedBadgeLine className='h-4 w-4 text-text-quaternary' />
</Tooltip>}
</div>
<div>
{onlineUserAvatars.length > 0 && (
<UserAvatarList users={onlineUserAvatars} maxVisible={3} size={20} />
)}
</div>
</div>
<div className='title-wrapper h-[90px] px-[14px] text-xs leading-normal text-text-tertiary'>
<div

View File

@ -1,10 +1,11 @@
'use client'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import {
useRouter,
} from 'next/navigation'
import useSWRInfinite from 'swr/infinite'
import useSWR from 'swr'
import { useTranslation } from 'react-i18next'
import { useDebounceFn } from 'ahooks'
import {
@ -19,8 +20,8 @@ import AppCard from './app-card'
import NewAppCard from './new-app-card'
import useAppsQueryState from './hooks/use-apps-query-state'
import { useDSLDragDrop } from './hooks/use-dsl-drag-drop'
import type { AppListResponse } from '@/models/app'
import { fetchAppList } from '@/service/apps'
import type { AppListResponse, WorkflowOnlineUser } from '@/models/app'
import { fetchAppList, fetchWorkflowOnlineUsers } from '@/service/apps'
import { useAppContext } from '@/context/app-context'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
import { CheckModal } from '@/hooks/use-pay'
@ -112,6 +113,37 @@ const List = () => {
},
)
const apps = useMemo(() => data?.flatMap(page => page.data) ?? [], [data])
const workflowIds = useMemo(() => {
const ids = new Set<string>()
apps.forEach((appItem) => {
const workflowId = appItem.id
if (!workflowId)
return
if (appItem.mode === 'workflow' || appItem.mode === 'advanced-chat')
ids.add(workflowId)
})
return Array.from(ids)
}, [apps])
const { data: onlineUsersByWorkflow, mutate: refreshOnlineUsers } = useSWR<Record<string, WorkflowOnlineUser[]>>(
workflowIds.length ? { workflowIds } : null,
fetchWorkflowOnlineUsers,
)
useEffect(() => {
if (!workflowIds.length)
return
const timer = window.setInterval(() => {
refreshOnlineUsers()
}, 10000)
return () => window.clearInterval(timer)
}, [workflowIds.join(','), refreshOnlineUsers])
const anchorRef = useRef<HTMLDivElement>(null)
const options = [
{ value: 'all', text: t('app.types.all'), icon: <RiApps2Line className='mr-1 h-[14px] w-[14px]' /> },
@ -213,7 +245,12 @@ const List = () => {
{isCurrentWorkspaceEditor
&& <NewAppCard ref={newAppCardRef} onSuccess={mutate} selectedAppType={activeTab} />}
{data.map(({ data: apps }) => apps.map(app => (
<AppCard key={app.id} app={app} onRefresh={mutate} />
<AppCard
key={app.id}
app={app}
onRefresh={mutate}
onlineUsers={onlineUsersByWorkflow?.[app.id] ?? []}
/>
)))}
</div>
: <div className='relative grid grow grid-cols-1 content-start gap-4 overflow-hidden px-12 pt-2 sm:grid-cols-1 md:grid-cols-2 xl:grid-cols-4 2xl:grid-cols-5 2k:grid-cols-6'>

View File

@ -9,6 +9,7 @@ export type AvatarProps = {
className?: string
textClassName?: string
onError?: (x: boolean) => void
backgroundColor?: string
}
const Avatar = ({
name,
@ -17,9 +18,18 @@ const Avatar = ({
className,
textClassName,
onError,
backgroundColor,
}: AvatarProps) => {
const avatarClassName = 'shrink-0 flex items-center rounded-full bg-primary-600'
const style = { width: `${size}px`, height: `${size}px`, fontSize: `${size}px`, lineHeight: `${size}px` }
const avatarClassName = backgroundColor
? 'shrink-0 flex items-center rounded-full'
: 'shrink-0 flex items-center rounded-full bg-primary-600'
const style = {
width: `${size}px`,
height: `${size}px`,
fontSize: `${size}px`,
lineHeight: `${size}px`,
...(backgroundColor && !avatar ? { backgroundColor } : {}),
}
const [imgError, setImgError] = useState(false)
const handleError = () => {
@ -35,14 +45,18 @@ const Avatar = ({
if (avatar && !imgError) {
return (
<img
<span
className={cn(avatarClassName, className)}
style={style}
alt={name}
src={avatar}
onError={handleError}
onLoad={() => onError?.(false)}
/>
>
<img
className='h-full w-full rounded-full object-cover'
alt={name}
src={avatar}
onError={handleError}
onLoad={() => onError?.(false)}
/>
</span>
)
}

View File

@ -0,0 +1,3 @@
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="12" viewBox="0 0 14 12" fill="none">
<path d="M12.3334 4C12.3334 2.52725 11.1395 1.33333 9.66671 1.33333H4.33337C2.86062 1.33333 1.66671 2.52724 1.66671 4V10.6667H9.66671C11.1395 10.6667 12.3334 9.47274 12.3334 8V4ZM7.66671 6.66667V8H4.33337V6.66667H7.66671ZM9.66671 4V5.33333H4.33337V4H9.66671ZM13.6667 8C13.6667 10.2091 11.8758 12 9.66671 12H0.333374V4C0.333374 1.79086 2.12424 0 4.33337 0H9.66671C11.8758 0 13.6667 1.79086 13.6667 4V8Z" fill="currentColor"/>
</svg>

After

Width:  |  Height:  |  Size: 527 B

View File

@ -0,0 +1,26 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"xmlns": "http://www.w3.org/2000/svg",
"width": "14",
"height": "12",
"viewBox": "0 0 14 12",
"fill": "none"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"d": "M12.3334 4C12.3334 2.52725 11.1395 1.33333 9.66671 1.33333H4.33337C2.86062 1.33333 1.66671 2.52724 1.66671 4V10.6667H9.66671C11.1395 10.6667 12.3334 9.47274 12.3334 8V4ZM7.66671 6.66667V8H4.33337V6.66667H7.66671ZM9.66671 4V5.33333H4.33337V4H9.66671ZM13.6667 8C13.6667 10.2091 11.8758 12 9.66671 12H0.333374V4C0.333374 1.79086 2.12424 0 4.33337 0H9.66671C11.8758 0 13.6667 1.79086 13.6667 4V8Z",
"fill": "currentColor"
},
"children": []
}
]
},
"name": "Comment"
}

View File

@ -0,0 +1,20 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './Comment.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconData } from '@/app/components/base/icons/IconBase'
const Icon = (
{
ref,
...props
}: React.SVGProps<SVGSVGElement> & {
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>;
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'Comment'
export default Icon

View File

@ -1,4 +1,5 @@
export { default as Icon3Dots } from './Icon3Dots'
export { default as Comment } from './Comment'
export { default as DefaultToolIcon } from './DefaultToolIcon'
export { default as Message3Fill } from './Message3Fill'
export { default as RowStruct } from './RowStruct'

View File

@ -2,6 +2,7 @@
import type { FC } from 'react'
import React, { useEffect } from 'react'
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import type {
EditorState,
} from 'lexical'
@ -80,6 +81,29 @@ import {
import { useEventEmitterContextContext } from '@/context/event-emitter'
import cn from '@/utils/classnames'
const ValueSyncPlugin: FC<{ value?: string }> = ({ value }) => {
const [editor] = useLexicalComposerContext()
useEffect(() => {
if (value === undefined)
return
const incomingValue = value ?? ''
const shouldUpdate = editor.getEditorState().read(() => {
const currentText = $getRoot().getChildren().map(node => node.getTextContent()).join('\n')
return currentText !== incomingValue
})
if (!shouldUpdate)
return
const editorState = editor.parseEditorState(textToEditorState(incomingValue))
editor.setEditorState(editorState)
}, [editor, value])
return null
}
export type PromptEditorProps = {
instanceId?: string
compact?: boolean
@ -293,6 +317,7 @@ const PromptEditor: FC<PromptEditorProps> = ({
<VariableValueBlock />
)
}
<ValueSyncPlugin value={value} />
<OnChangePlugin onChange={handleEditorChange} />
<OnBlurBlock onBlur={onBlur} onFocus={onFocus} />
<UpdateBlock instanceId={instanceId} />

View File

@ -0,0 +1,77 @@
import type { FC } from 'react'
import { memo } from 'react'
import { getUserColor } from '@/app/components/workflow/collaboration/utils/user-color'
import { useAppContext } from '@/context/app-context'
import Avatar from '@/app/components/base/avatar'
type User = {
id: string
name: string
avatar_url?: string | null
}
type UserAvatarListProps = {
users: User[]
maxVisible?: number
size?: number
className?: string
showCount?: boolean
}
export const UserAvatarList: FC<UserAvatarListProps> = memo(({
users,
maxVisible = 3,
size = 24,
className = '',
showCount = true,
}) => {
const { userProfile } = useAppContext()
if (!users.length) return null
const shouldShowCount = showCount && users.length > maxVisible
const actualMaxVisible = shouldShowCount ? Math.max(1, maxVisible - 1) : maxVisible
const visibleUsers = users.slice(0, actualMaxVisible)
const remainingCount = users.length - actualMaxVisible
const currentUserId = userProfile?.id
return (
<div className={`flex items-center -space-x-1 ${className}`}>
{visibleUsers.map((user, index) => {
const isCurrentUser = user.id === currentUserId
const userColor = isCurrentUser ? undefined : getUserColor(user.id)
return (
<div
key={`${user.id}-${index}`}
className='relative'
style={{ zIndex: visibleUsers.length - index }}
>
<Avatar
name={user.name}
avatar={user.avatar_url || null}
size={size}
className='ring-2 ring-components-panel-bg'
backgroundColor={userColor}
/>
</div>
)
},
)}
{shouldShowCount && remainingCount > 0 && (
<div
className={'flex items-center justify-center rounded-full bg-gray-500 text-[10px] leading-none text-white ring-2 ring-components-panel-bg'}
style={{
zIndex: 0,
width: size,
height: size,
}}
>
+{remainingCount}
</div>
)}
</div>
)
})
UserAvatarList.displayName = 'UserAvatarList'

View File

@ -26,6 +26,8 @@ import {
import { BlockEnum } from '@/app/components/workflow/types'
import cn from '@/utils/classnames'
import { fetchAppDetail } from '@/service/apps'
import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager'
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
export type IAppCardProps = {
appInfo: AppDetailResponse & Partial<AppSSO>
@ -90,6 +92,19 @@ function MCPServiceCard({
const onGenCode = async () => {
await refreshMCPServerCode(detail?.id || '')
invalidateMCPServerDetail(appId)
// Emit collaboration event to notify other clients of MCP server changes
const socket = webSocketClient.getSocket(appId)
if (socket) {
socket.emit('collaboration_event', {
type: 'mcpServerUpdate',
data: {
action: 'codeRegenerated',
timestamp: Date.now(),
},
timestamp: Date.now(),
})
}
}
const onChangeStatus = async (state: boolean) => {
@ -119,6 +134,20 @@ function MCPServiceCard({
})
invalidateMCPServerDetail(appId)
}
// Emit collaboration event to notify other clients of MCP server status change
const socket = webSocketClient.getSocket(appId)
if (socket) {
socket.emit('collaboration_event', {
type: 'mcpServerUpdate',
data: {
action: 'statusChanged',
status: state ? 'active' : 'inactive',
timestamp: Date.now(),
},
timestamp: Date.now(),
})
}
}
const handleServerModalHide = () => {
@ -131,6 +160,23 @@ function MCPServiceCard({
setActivated(serverActivated)
}, [serverActivated])
// Listen for collaborative MCP server updates from other clients
useEffect(() => {
if (!appId) return
const unsubscribe = collaborationManager.onMcpServerUpdate(async (update: any) => {
try {
console.log('Received MCP server update from collaboration:', update)
invalidateMCPServerDetail(appId)
}
catch (error) {
console.error('MCP server update failed:', error)
}
})
return unsubscribe
}, [appId, invalidateMCPServerDetail])
if (!currentWorkflow && isAdvancedApp)
return null

View File

@ -1,11 +1,18 @@
import {
useCallback,
useEffect,
useMemo,
useRef,
useState,
} from 'react'
import { useFeaturesStore } from '@/app/components/base/features/hooks'
import type { Features as FeaturesData } from '@/app/components/base/features/types'
import { SupportUploadFileTypes } from '@/app/components/workflow/types'
import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants'
import { WorkflowWithInnerContext } from '@/app/components/workflow'
import type { WorkflowProps } from '@/app/components/workflow'
import WorkflowChildren from './workflow-children'
import {
useAvailableNodesMetaData,
useConfigsMap,
@ -18,7 +25,12 @@ import {
useWorkflowRun,
useWorkflowStartRun,
} from '../hooks'
import { useWorkflowStore } from '@/app/components/workflow/store'
import { useWorkflowUpdate } from '@/app/components/workflow/hooks/use-workflow-interactions'
import { useStore, useWorkflowStore } from '@/app/components/workflow/store'
import { useCollaboration } from '@/app/components/workflow/collaboration'
import { collaborationManager } from '@/app/components/workflow/collaboration'
import { fetchWorkflowDraft } from '@/service/workflow'
import { useReactFlow, useStoreApi } from 'reactflow'
type WorkflowMainProps = Pick<WorkflowProps, 'nodes' | 'edges' | 'viewport'>
const WorkflowMain = ({
@ -28,6 +40,31 @@ const WorkflowMain = ({
}: WorkflowMainProps) => {
const featuresStore = useFeaturesStore()
const workflowStore = useWorkflowStore()
const appId = useStore(s => s.appId)
const containerRef = useRef<HTMLDivElement>(null)
const reactFlow = useReactFlow()
const store = useStoreApi()
const { startCursorTracking, stopCursorTracking, onlineUsers, cursors, isConnected } = useCollaboration(appId || '', store)
const [myUserId, setMyUserId] = useState<string | null>(null)
useEffect(() => {
if (isConnected)
setMyUserId('current-user')
}, [isConnected])
const filteredCursors = Object.fromEntries(
Object.entries(cursors).filter(([userId]) => userId !== myUserId),
)
useEffect(() => {
if (containerRef.current)
startCursorTracking(containerRef as React.RefObject<HTMLElement>, reactFlow)
return () => {
stopCursorTracking()
}
}, [startCursorTracking, stopCursorTracking, reactFlow])
const handleWorkflowDataUpdate = useCallback((payload: any) => {
const {
@ -38,7 +75,33 @@ const WorkflowMain = ({
if (features && featuresStore) {
const { setFeatures } = featuresStore.getState()
setFeatures(features)
const transformedFeatures: FeaturesData = {
file: {
image: {
enabled: !!features.file_upload?.image?.enabled,
number_limits: features.file_upload?.image?.number_limits || 3,
transfer_methods: features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
},
enabled: !!(features.file_upload?.enabled || features.file_upload?.image?.enabled),
allowed_file_types: features.file_upload?.allowed_file_types || [SupportUploadFileTypes.image],
allowed_file_extensions: features.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`),
allowed_file_upload_methods: features.file_upload?.allowed_file_upload_methods || features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
number_limits: features.file_upload?.number_limits || features.file_upload?.image?.number_limits || 3,
},
opening: {
enabled: !!features.opening_statement,
opening_statement: features.opening_statement,
suggested_questions: features.suggested_questions,
},
suggested: features.suggested_questions_after_answer || { enabled: false },
speech2text: features.speech_to_text || { enabled: false },
text2speech: features.text_to_speech || { enabled: false },
citation: features.retriever_resource || { enabled: false },
moderation: features.sensitive_word_avoidance || { enabled: false },
annotationReply: features.annotation_reply || { enabled: false },
}
setFeatures(transformedFeatures)
}
if (conversation_variables) {
const { setConversationVariables } = workflowStore.getState()
@ -55,6 +118,7 @@ const WorkflowMain = ({
syncWorkflowDraftWhenPageClose,
} = useNodesSyncDraft()
const { handleRefreshWorkflowDraft } = useWorkflowRefreshDraft()
const { handleUpdateWorkflowCanvas } = useWorkflowUpdate()
const {
handleBackupDraft,
handleLoadBackupDraft,
@ -62,6 +126,63 @@ const WorkflowMain = ({
handleRun,
handleStopRun,
} = useWorkflowRun()
useEffect(() => {
if (!appId) return
const unsubscribe = collaborationManager.onVarsAndFeaturesUpdate(async (update: any) => {
try {
const response = await fetchWorkflowDraft(`/apps/${appId}/workflows/draft`)
handleWorkflowDataUpdate(response)
}
catch (error) {
console.error('workflow vars and features update failed:', error)
}
})
return unsubscribe
}, [appId, handleWorkflowDataUpdate])
// Listen for workflow updates from other users
useEffect(() => {
if (!appId) return
const unsubscribe = collaborationManager.onWorkflowUpdate(async () => {
console.log('Received workflow update from collaborator, fetching latest workflow data')
try {
const response = await fetchWorkflowDraft(`/apps/${appId}/workflows/draft`)
// Handle features, variables etc.
handleWorkflowDataUpdate(response)
// Update workflow canvas (nodes, edges, viewport)
if (response.graph) {
handleUpdateWorkflowCanvas({
nodes: response.graph.nodes || [],
edges: response.graph.edges || [],
viewport: response.graph.viewport || { x: 0, y: 0, zoom: 1 },
})
}
}
catch (error) {
console.error('Failed to fetch updated workflow:', error)
}
})
return unsubscribe
}, [appId, handleWorkflowDataUpdate, handleUpdateWorkflowCanvas])
// Listen for sync requests from other users (only processed by leader)
useEffect(() => {
if (!appId) return
const unsubscribe = collaborationManager.onSyncRequest(() => {
console.log('Leader received sync request, performing sync')
doSyncWorkflowDraft()
})
return unsubscribe
}, [appId, doSyncWorkflowDraft])
const {
handleStartWorkflowRun,
handleWorkflowStartRunInChatflow,
@ -75,6 +196,7 @@ const WorkflowMain = ({
} = useDSL()
const configsMap = useConfigsMap()
const { fetchInspectVars } = useSetWorkflowVarsWithValue({
...configsMap,
})
@ -164,15 +286,23 @@ const WorkflowMain = ({
])
return (
<WorkflowWithInnerContext
nodes={nodes}
edges={edges}
viewport={viewport}
onWorkflowDataUpdate={handleWorkflowDataUpdate}
hooksStore={hooksStore as any}
<div
ref={containerRef}
className="relative h-full w-full"
>
<WorkflowChildren />
</WorkflowWithInnerContext>
<WorkflowWithInnerContext
nodes={nodes}
edges={edges}
viewport={viewport}
onWorkflowDataUpdate={handleWorkflowDataUpdate}
hooksStore={hooksStore as any}
cursors={filteredCursors}
myUserId={myUserId}
onlineUsers={onlineUsers}
>
<WorkflowChildren />
</WorkflowWithInnerContext>
</div>
)
}

View File

@ -7,6 +7,7 @@ import { useStore } from '@/app/components/workflow/store'
import {
useIsChatMode,
} from '../hooks'
import CommentsPanel from '@/app/components/workflow/panel/comments-panel'
import { useStore as useAppStore } from '@/app/components/app/store'
import type { PanelProps } from '@/app/components/workflow/panel'
import Panel from '@/app/components/workflow/panel'
@ -67,6 +68,7 @@ const WorkflowPanelOnRight = () => {
const showDebugAndPreviewPanel = useStore(s => s.showDebugAndPreviewPanel)
const showChatVariablePanel = useStore(s => s.showChatVariablePanel)
const showGlobalVariablePanel = useStore(s => s.showGlobalVariablePanel)
const controlMode = useStore(s => s.controlMode)
return (
<>
@ -100,6 +102,7 @@ const WorkflowPanelOnRight = () => {
<GlobalVariablePanel />
)
}
{controlMode === 'comment' && <CommentsPanel />}
</>
)
}

View File

@ -13,6 +13,7 @@ import { syncWorkflowDraft } from '@/service/workflow'
import { useFeaturesStore } from '@/app/components/base/features/hooks'
import { API_PREFIX } from '@/config'
import { useWorkflowRefreshDraft } from '.'
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
export const useNodesSyncDraft = () => {
const store = useStoreApi()
@ -85,6 +86,7 @@ export const useNodesSyncDraft = () => {
environment_variables: environmentVariables,
conversation_variables: conversationVariables,
hash: syncWorkflowDraftHash,
_is_collaborative: true,
},
}
}
@ -93,9 +95,20 @@ export const useNodesSyncDraft = () => {
const syncWorkflowDraftWhenPageClose = useCallback(() => {
if (getNodesReadOnly())
return
// Check leader status at sync time
const currentIsLeader = collaborationManager.getIsLeader()
// Only allow leader to sync data
if (!currentIsLeader) {
console.log('Not leader, skipping sync on page close')
return
}
const postParams = getPostParams()
if (postParams) {
console.log('Leader syncing workflow draft on page close')
navigator.sendBeacon(
`${API_PREFIX}/apps/${params.appId}/workflows/draft?_token=${localStorage.getItem('console_token')}`,
JSON.stringify(postParams.params),
@ -110,9 +123,23 @@ export const useNodesSyncDraft = () => {
onError?: () => void
onSettled?: () => void
},
forceUpload?: boolean,
) => {
if (getNodesReadOnly())
return
// Check leader status at sync time
const currentIsLeader = collaborationManager.getIsLeader()
// If not leader and not forcing upload, request the leader to sync
if (!currentIsLeader && !forceUpload) {
console.log('Not leader, requesting leader to sync workflow draft')
collaborationManager.emitSyncRequest()
callback?.onSettled?.()
return
}
console.log(forceUpload ? 'Force uploading workflow draft' : 'Leader performing workflow draft sync')
const postParams = getPostParams()
if (postParams) {
@ -120,17 +147,30 @@ export const useNodesSyncDraft = () => {
setSyncWorkflowDraftHash,
setDraftUpdatedAt,
} = workflowStore.getState()
// Add force_upload parameter if needed
const finalParams = {
...postParams.params,
...(forceUpload && { force_upload: true }),
}
try {
const res = await syncWorkflowDraft(postParams)
const res = await syncWorkflowDraft({
url: postParams.url,
params: finalParams,
})
setSyncWorkflowDraftHash(res.hash)
setDraftUpdatedAt(res.updated_at)
callback?.onSuccess?.()
}
catch (error: any) {
console.error('Leader failed to sync workflow draft:', error)
if (error && error.json && !error.bodyUsed) {
error.json().then((err: any) => {
if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError)
if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError) {
console.error('draft_workflow_not_sync', err)
handleRefreshWorkflowDraft()
}
})
}
callback?.onError?.()

View File

@ -25,6 +25,7 @@ import {
import type { InjectWorkflowStoreSliceFn } from '@/app/components/workflow/store'
import { createWorkflowSlice } from './store/workflow/workflow-slice'
import WorkflowAppMain from './components/workflow-main'
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
const WorkflowAppWithAdditionalContext = () => {
const {
@ -35,15 +36,20 @@ const WorkflowAppWithAdditionalContext = () => {
const { isLoadingCurrentWorkspace, currentWorkspace } = useAppContext()
const nodesData = useMemo(() => {
if (data)
return initialNodes(data.graph.nodes, data.graph.edges)
if (data) {
const processedNodes = initialNodes(data.graph.nodes, data.graph.edges)
collaborationManager.setNodes([], processedNodes)
return processedNodes
}
return []
}, [data])
const edgesData = useMemo(() => {
if (data)
return initialEdges(data.graph.edges, data.graph.nodes)
const edgesData = useMemo(() => {
if (data) {
const processedEdges = initialEdges(data.graph.edges, data.graph.nodes)
collaborationManager.setEdges([], processedEdges)
return processedEdges
}
return []
}, [data])

View File

@ -4,7 +4,6 @@ import {
import produce from 'immer'
import {
useReactFlow,
useStoreApi,
useViewport,
} from 'reactflow'
import { useEventListener } from 'ahooks'
@ -19,9 +18,9 @@ import CustomNode from './nodes'
import CustomNoteNode from './note-node'
import { CUSTOM_NOTE_NODE } from './note-node/constants'
import { BlockEnum } from './types'
import { useCollaborativeWorkflow } from '@/app/components/workflow/hooks/use-collaborative-workflow'
const CandidateNode = () => {
const store = useStoreApi()
const reactflow = useReactFlow()
const workflowStore = useWorkflowStore()
const candidateNode = useStore(s => s.candidateNode)
@ -29,18 +28,15 @@ const CandidateNode = () => {
const { zoom } = useViewport()
const { handleNodeSelect } = useNodesInteractions()
const { saveStateToHistory } = useWorkflowHistory()
const collaborativeWorkflow = useCollaborativeWorkflow()
useEventListener('click', (e) => {
const { candidateNode, mousePosition } = workflowStore.getState()
if (candidateNode) {
e.preventDefault()
const {
getNodes,
setNodes,
} = store.getState()
const { nodes, setNodes } = collaborativeWorkflow.getState()
const { screenToFlowPosition } = reactflow
const nodes = getNodes()
const { x, y } = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY })
const newNodes = produce(nodes, (draft) => {
draft.push({

View File

@ -0,0 +1,78 @@
import type { FC } from 'react'
import { useViewport } from 'reactflow'
import type { CursorPosition, OnlineUser } from '@/app/components/workflow/collaboration/types'
import { getUserColor } from '../utils/user-color'
type UserCursorsProps = {
cursors: Record<string, CursorPosition>
myUserId: string | null
onlineUsers: OnlineUser[]
}
const UserCursors: FC<UserCursorsProps> = ({
cursors,
myUserId,
onlineUsers,
}) => {
const viewport = useViewport()
const convertToScreenCoordinates = (cursor: CursorPosition) => {
// Convert world coordinates to screen coordinates using current viewport
const screenX = cursor.x * viewport.zoom + viewport.x
const screenY = cursor.y * viewport.zoom + viewport.y
return { x: screenX, y: screenY }
}
return (
<>
{Object.entries(cursors || {}).map(([userId, cursor]) => {
if (userId === myUserId)
return null
const userInfo = onlineUsers.find(user => user.user_id === userId)
const userName = userInfo?.username || `User ${userId.slice(-4)}`
const userColor = getUserColor(userId)
const screenPos = convertToScreenCoordinates(cursor)
return (
<div
key={userId}
className="pointer-events-none absolute z-[8] transition-all duration-150 ease-out"
style={{
left: screenPos.x,
top: screenPos.y,
}}
>
<svg
width="20"
height="20"
viewBox="0 0 20 20"
fill="none"
xmlns="http://www.w3.org/2000/svg"
className="drop-shadow-md"
>
<path
d="M5 3L5 15L8 11.5L11 16L13 15L10 10.5L14 10.5L5 3Z"
fill={userColor}
stroke="white"
strokeWidth="1.5"
strokeLinejoin="round"
/>
</svg>
<div
className="absolute left-4 top-4 max-w-[120px] overflow-hidden text-ellipsis whitespace-nowrap rounded px-1.5 py-0.5 text-[11px] font-medium text-white shadow-sm"
style={{
backgroundColor: userColor,
}}
>
{userName}
</div>
</div>
)
})}
</>
)
}
export default UserCursors

View File

@ -0,0 +1,924 @@
import { LoroDoc, UndoManager } from 'loro-crdt'
import { isEqual } from 'lodash-es'
import { webSocketClient } from './websocket-manager'
import { CRDTProvider } from './crdt-provider'
import { EventEmitter } from './event-emitter'
import type { Edge, Node } from '../../types'
import type {
CollaborationState,
CursorPosition,
NodePanelPresenceMap,
NodePanelPresenceUser,
OnlineUser,
} from '../types/collaboration'
type NodePanelPresenceEventData = {
nodeId: string
action: 'open' | 'close'
user: NodePanelPresenceUser
clientId: string
timestamp?: number
}
export class CollaborationManager {
private doc: LoroDoc | null = null
private undoManager: UndoManager | null = null
private provider: CRDTProvider | null = null
private nodesMap: any = null
private edgesMap: any = null
private eventEmitter = new EventEmitter()
private currentAppId: string | null = null
private reactFlowStore: any = null
private isLeader = false
private leaderId: string | null = null
private cursors: Record<string, CursorPosition> = {}
private nodePanelPresence: NodePanelPresenceMap = {}
private activeConnections = new Set<string>()
private isUndoRedoInProgress = false
private getNodePanelPresenceSnapshot(): NodePanelPresenceMap {
const snapshot: NodePanelPresenceMap = {}
Object.entries(this.nodePanelPresence).forEach(([nodeId, viewers]) => {
snapshot[nodeId] = { ...viewers }
})
return snapshot
}
private applyNodePanelPresenceUpdate(update: NodePanelPresenceEventData): void {
const { nodeId, action, clientId, user, timestamp } = update
if (action === 'open') {
// ensure a client only appears on a single node at a time
Object.entries(this.nodePanelPresence).forEach(([id, viewers]) => {
if (viewers[clientId]) {
delete viewers[clientId]
if (Object.keys(viewers).length === 0)
delete this.nodePanelPresence[id]
}
})
if (!this.nodePanelPresence[nodeId])
this.nodePanelPresence[nodeId] = {}
this.nodePanelPresence[nodeId][clientId] = {
...user,
clientId,
timestamp: timestamp || Date.now(),
}
}
else {
const viewers = this.nodePanelPresence[nodeId]
if (viewers) {
delete viewers[clientId]
if (Object.keys(viewers).length === 0)
delete this.nodePanelPresence[nodeId]
}
}
this.eventEmitter.emit('nodePanelPresence', this.getNodePanelPresenceSnapshot())
}
private cleanupNodePanelPresence(activeClientIds: Set<string>, activeUserIds: Set<string>): void {
let hasChanges = false
Object.entries(this.nodePanelPresence).forEach(([nodeId, viewers]) => {
Object.keys(viewers).forEach((clientId) => {
const viewer = viewers[clientId]
const clientActive = activeClientIds.has(clientId)
const userActive = viewer?.userId ? activeUserIds.has(viewer.userId) : false
if (!clientActive && !userActive) {
delete viewers[clientId]
hasChanges = true
}
})
if (Object.keys(viewers).length === 0)
delete this.nodePanelPresence[nodeId]
})
if (hasChanges)
this.eventEmitter.emit('nodePanelPresence', this.getNodePanelPresenceSnapshot())
}
init = (appId: string, reactFlowStore: any): void => {
if (!reactFlowStore) {
console.warn('CollaborationManager.init called without reactFlowStore, deferring to connect()')
return
}
this.connect(appId, reactFlowStore)
}
setNodes = (oldNodes: Node[], newNodes: Node[]): void => {
if (!this.doc) return
// Don't track operations during undo/redo to prevent loops
if (this.isUndoRedoInProgress) {
console.log('Skipping setNodes during undo/redo')
return
}
console.log('Setting nodes with tracking')
this.syncNodes(oldNodes, newNodes)
this.doc.commit()
}
setEdges = (oldEdges: Edge[], newEdges: Edge[]): void => {
if (!this.doc) return
// Don't track operations during undo/redo to prevent loops
if (this.isUndoRedoInProgress) {
console.log('Skipping setEdges during undo/redo')
return
}
console.log('Setting edges with tracking')
this.syncEdges(oldEdges, newEdges)
this.doc.commit()
}
destroy = (): void => {
this.disconnect()
}
async connect(appId: string, reactFlowStore?: any): Promise<string> {
const connectionId = Math.random().toString(36).substring(2, 11)
this.activeConnections.add(connectionId)
if (this.currentAppId === appId && this.doc) {
// Already connected to the same app, only update store if provided and we don't have one
if (reactFlowStore && !this.reactFlowStore)
this.reactFlowStore = reactFlowStore
return connectionId
}
// Only disconnect if switching to a different app
if (this.currentAppId && this.currentAppId !== appId)
this.forceDisconnect()
this.currentAppId = appId
// Only set store if provided
if (reactFlowStore)
this.reactFlowStore = reactFlowStore
const socket = webSocketClient.connect(appId)
// Setup event listeners BEFORE any other operations
this.setupSocketEventListeners(socket)
this.doc = new LoroDoc()
this.nodesMap = this.doc.getMap('nodes')
this.edgesMap = this.doc.getMap('edges')
// Initialize UndoManager for collaborative undo/redo
this.undoManager = new UndoManager(this.doc, {
maxUndoSteps: 100,
mergeInterval: 500, // Merge operations within 500ms
excludeOriginPrefixes: [], // Don't exclude anything - let UndoManager track all local operations
onPush: (isUndo, range, event) => {
console.log('UndoManager onPush:', { isUndo, range, event })
// Store current selection state when an operation is pushed
const selectedNode = this.reactFlowStore?.getState().getNodes().find((n: Node) => n.data?.selected)
// Emit event to update UI button states when new operation is pushed
setTimeout(() => {
this.eventEmitter.emit('undoRedoStateChange', {
canUndo: this.undoManager?.canUndo() || false,
canRedo: this.undoManager?.canRedo() || false,
})
}, 0)
return {
value: {
selectedNodeId: selectedNode?.id || null,
timestamp: Date.now(),
},
cursors: [],
}
},
onPop: (isUndo, value, counterRange) => {
console.log('UndoManager onPop:', { isUndo, value, counterRange })
// Restore selection state when undoing/redoing
if (value?.value && typeof value.value === 'object' && 'selectedNodeId' in value.value && this.reactFlowStore) {
const selectedNodeId = (value.value as any).selectedNodeId
if (selectedNodeId) {
const { setNodes } = this.reactFlowStore.getState()
const nodes = this.reactFlowStore.getState().getNodes()
const newNodes = nodes.map((n: Node) => ({
...n,
data: {
...n.data,
selected: n.id === selectedNodeId,
},
}))
setNodes(newNodes)
}
}
},
})
this.provider = new CRDTProvider(socket, this.doc)
this.setupSubscriptions()
// Force user_connect if already connected
if (socket.connected)
socket.emit('user_connect', { workflow_id: appId })
return connectionId
}
disconnect = (connectionId?: string): void => {
if (connectionId)
this.activeConnections.delete(connectionId)
// Only disconnect when no more connections
if (this.activeConnections.size === 0)
this.forceDisconnect()
}
private forceDisconnect = (): void => {
if (this.currentAppId)
webSocketClient.disconnect(this.currentAppId)
this.provider?.destroy()
this.undoManager = null
this.doc = null
this.provider = null
this.nodesMap = null
this.edgesMap = null
this.currentAppId = null
this.reactFlowStore = null
this.cursors = {}
this.nodePanelPresence = {}
this.isUndoRedoInProgress = false
// Only reset leader status when actually disconnecting
const wasLeader = this.isLeader
this.isLeader = false
this.leaderId = null
if (wasLeader)
this.eventEmitter.emit('leaderChange', false)
this.activeConnections.clear()
this.eventEmitter.removeAllListeners()
}
isConnected(): boolean {
return this.currentAppId ? webSocketClient.isConnected(this.currentAppId) : false
}
getNodes(): Node[] {
return this.nodesMap ? Array.from(this.nodesMap.values()) : []
}
getEdges(): Edge[] {
return this.edgesMap ? Array.from(this.edgesMap.values()) : []
}
emitCursorMove(position: CursorPosition): void {
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
const socket = webSocketClient.getSocket(this.currentAppId)
if (socket) {
socket.emit('collaboration_event', {
type: 'mouseMove',
userId: socket.id,
data: { x: position.x, y: position.y },
timestamp: Date.now(),
})
}
}
emitSyncRequest(): void {
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
const socket = webSocketClient.getSocket(this.currentAppId)
if (socket) {
console.log('Emitting sync request to leader')
socket.emit('collaboration_event', {
type: 'syncRequest',
data: { timestamp: Date.now() },
timestamp: Date.now(),
})
}
}
emitWorkflowUpdate(appId: string): void {
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
const socket = webSocketClient.getSocket(this.currentAppId)
if (socket) {
console.log('Emitting Workflow update event')
socket.emit('collaboration_event', {
type: 'workflowUpdate',
data: { appId, timestamp: Date.now() },
timestamp: Date.now(),
})
}
}
emitNodePanelPresence(nodeId: string, isOpen: boolean, user: NodePanelPresenceUser): void {
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
const socket = webSocketClient.getSocket(this.currentAppId)
if (!socket || !nodeId || !user?.userId) return
const payload: NodePanelPresenceEventData = {
nodeId,
action: isOpen ? 'open' : 'close',
user,
clientId: socket.id as string,
timestamp: Date.now(),
}
socket.emit('collaboration_event', {
type: 'nodePanelPresence',
data: payload,
timestamp: payload.timestamp,
})
this.applyNodePanelPresenceUpdate(payload)
}
onSyncRequest(callback: () => void): () => void {
return this.eventEmitter.on('syncRequest', callback)
}
onStateChange(callback: (state: Partial<CollaborationState>) => void): () => void {
return this.eventEmitter.on('stateChange', callback)
}
onCursorUpdate(callback: (cursors: Record<string, CursorPosition>) => void): () => void {
return this.eventEmitter.on('cursors', callback)
}
onOnlineUsersUpdate(callback: (users: OnlineUser[]) => void): () => void {
return this.eventEmitter.on('onlineUsers', callback)
}
onWorkflowUpdate(callback: (update: { appId: string; timestamp: number }) => void): () => void {
return this.eventEmitter.on('workflowUpdate', callback)
}
onVarsAndFeaturesUpdate(callback: (update: any) => void): () => void {
return this.eventEmitter.on('varsAndFeaturesUpdate', callback)
}
onAppStateUpdate(callback: (update: any) => void): () => void {
return this.eventEmitter.on('appStateUpdate', callback)
}
onMcpServerUpdate(callback: (update: any) => void): () => void {
return this.eventEmitter.on('mcpServerUpdate', callback)
}
onNodePanelPresenceUpdate(callback: (presence: NodePanelPresenceMap) => void): () => void {
const off = this.eventEmitter.on('nodePanelPresence', callback)
callback(this.getNodePanelPresenceSnapshot())
return off
}
onLeaderChange(callback: (isLeader: boolean) => void): () => void {
return this.eventEmitter.on('leaderChange', callback)
}
onCommentsUpdate(callback: (update: { appId: string; timestamp: number }) => void): () => void {
return this.eventEmitter.on('commentsUpdate', callback)
}
emitCommentsUpdate(appId: string): void {
if (!this.currentAppId || !webSocketClient.isConnected(this.currentAppId)) return
const socket = webSocketClient.getSocket(this.currentAppId)
if (socket) {
console.log('Emitting Comments update event')
socket.emit('collaboration_event', {
type: 'commentsUpdate',
data: { appId, timestamp: Date.now() },
timestamp: Date.now(),
})
}
}
onUndoRedoStateChange(callback: (state: { canUndo: boolean; canRedo: boolean }) => void): () => void {
return this.eventEmitter.on('undoRedoStateChange', callback)
}
getLeaderId(): string | null {
return this.leaderId
}
getIsLeader(): boolean {
return this.isLeader
}
// Collaborative undo/redo methods
undo(): boolean {
if (!this.undoManager) {
console.log('UndoManager not initialized')
return false
}
const canUndo = this.undoManager.canUndo()
console.log('Can undo:', canUndo)
if (canUndo) {
this.isUndoRedoInProgress = true
const result = this.undoManager.undo()
// After undo, manually update React state from CRDT without triggering collaboration
if (result && this.reactFlowStore) {
requestAnimationFrame(() => {
// Get ReactFlow's native setters, not the collaborative ones
const state = this.reactFlowStore.getState()
const updatedNodes = Array.from(this.nodesMap.values())
const updatedEdges = Array.from(this.edgesMap.values())
console.log('Manually updating React state after undo')
// Call ReactFlow's native setters directly to avoid triggering collaboration
state.setNodes(updatedNodes)
state.setEdges(updatedEdges)
this.isUndoRedoInProgress = false
// Emit event to update UI button states
this.eventEmitter.emit('undoRedoStateChange', {
canUndo: this.undoManager?.canUndo() || false,
canRedo: this.undoManager?.canRedo() || false,
})
})
}
else {
this.isUndoRedoInProgress = false
}
console.log('Undo result:', result)
return result
}
return false
}
redo(): boolean {
if (!this.undoManager) {
console.log('RedoManager not initialized')
return false
}
const canRedo = this.undoManager.canRedo()
console.log('Can redo:', canRedo)
if (canRedo) {
this.isUndoRedoInProgress = true
const result = this.undoManager.redo()
// After redo, manually update React state from CRDT without triggering collaboration
if (result && this.reactFlowStore) {
requestAnimationFrame(() => {
// Get ReactFlow's native setters, not the collaborative ones
const state = this.reactFlowStore.getState()
const updatedNodes = Array.from(this.nodesMap.values())
const updatedEdges = Array.from(this.edgesMap.values())
console.log('Manually updating React state after redo')
// Call ReactFlow's native setters directly to avoid triggering collaboration
state.setNodes(updatedNodes)
state.setEdges(updatedEdges)
this.isUndoRedoInProgress = false
// Emit event to update UI button states
this.eventEmitter.emit('undoRedoStateChange', {
canUndo: this.undoManager?.canUndo() || false,
canRedo: this.undoManager?.canRedo() || false,
})
})
}
else {
this.isUndoRedoInProgress = false
}
console.log('Redo result:', result)
return result
}
return false
}
canUndo(): boolean {
if (!this.undoManager) return false
return this.undoManager.canUndo()
}
canRedo(): boolean {
if (!this.undoManager) return false
return this.undoManager.canRedo()
}
clearUndoStack(): void {
if (!this.undoManager) return
this.undoManager.clear()
}
debugLeaderStatus(): void {
console.log('=== Leader Status Debug ===')
console.log('Current leader status:', this.isLeader)
console.log('Current leader ID:', this.leaderId)
console.log('Active connections:', this.activeConnections.size)
console.log('Connected:', this.isConnected())
console.log('Current app ID:', this.currentAppId)
console.log('Has ReactFlow store:', !!this.reactFlowStore)
console.log('========================')
}
private syncNodes(oldNodes: Node[], newNodes: Node[]): void {
if (!this.nodesMap || !this.doc) return
const oldNodesMap = new Map(oldNodes.map(node => [node.id, node]))
const newNodesMap = new Map(newNodes.map(node => [node.id, node]))
const syncDataAllowList = new Set(['_children'])
const shouldSyncDataKey = (key: string) => (syncDataAllowList.has(key) || !key.startsWith('_')) && key !== 'selected'
// Delete removed nodes
oldNodes.forEach((oldNode) => {
if (!newNodesMap.has(oldNode.id))
this.nodesMap.delete(oldNode.id)
})
// Add or update nodes with fine-grained sync for data properties
const copyOptionalNodeProps = (source: Node, target: any) => {
const optionalProps: Array<keyof Node | keyof any> = [
'parentId',
'positionAbsolute',
'extent',
'zIndex',
'draggable',
'selectable',
'dragHandle',
'dragging',
'connectable',
'expandParent',
'focusable',
'hidden',
'style',
'className',
'ariaLabel',
'markerStart',
'markerEnd',
'resizing',
'deletable',
]
optionalProps.forEach((prop) => {
const value = (source as any)[prop]
if (value === undefined) {
if (prop in target)
delete target[prop]
return
}
if (value !== null && typeof value === 'object')
target[prop as string] = JSON.parse(JSON.stringify(value))
else
target[prop as string] = value
})
}
newNodes.forEach((newNode) => {
const oldNode = oldNodesMap.get(newNode.id)
if (!oldNode) {
// New node - create as nested structure
const nodeData: any = {
id: newNode.id,
type: newNode.type,
position: { ...newNode.position },
width: newNode.width,
height: newNode.height,
sourcePosition: newNode.sourcePosition,
targetPosition: newNode.targetPosition,
data: {},
}
copyOptionalNodeProps(newNode, nodeData)
// Clone data properties, excluding private ones
Object.entries(newNode.data).forEach(([key, value]) => {
if (shouldSyncDataKey(key) && value !== undefined)
nodeData.data[key] = JSON.parse(JSON.stringify(value))
})
this.nodesMap.set(newNode.id, nodeData)
}
else {
// Get existing node from CRDT
const existingNode = this.nodesMap.get(newNode.id)
if (existingNode) {
// Create a deep copy to modify
const updatedNode = JSON.parse(JSON.stringify(existingNode))
// Update position only if changed
if (oldNode.position.x !== newNode.position.x || oldNode.position.y !== newNode.position.y)
updatedNode.position = { ...newNode.position }
// Update dimensions only if changed
if (oldNode.width !== newNode.width)
updatedNode.width = newNode.width
if (oldNode.height !== newNode.height)
updatedNode.height = newNode.height
// Ensure optional node props stay in sync
copyOptionalNodeProps(newNode, updatedNode)
// Ensure data object exists
if (!updatedNode.data)
updatedNode.data = {}
// Fine-grained update of data properties
const oldData = oldNode.data || {}
const newData = newNode.data || {}
// Only update changed properties in data
Object.entries(newData).forEach(([key, value]) => {
if (shouldSyncDataKey(key)) {
const oldValue = (oldData as any)[key]
if (!isEqual(oldValue, value))
updatedNode.data[key] = JSON.parse(JSON.stringify(value))
}
})
// Remove deleted properties from data
Object.keys(oldData).forEach((key) => {
if (shouldSyncDataKey(key) && !(key in newData))
delete updatedNode.data[key]
})
// Only update in CRDT if something actually changed
if (!isEqual(existingNode, updatedNode))
this.nodesMap.set(newNode.id, updatedNode)
}
else {
// Node exists locally but not in CRDT yet
const nodeData: any = {
id: newNode.id,
type: newNode.type,
position: { ...newNode.position },
width: newNode.width,
height: newNode.height,
sourcePosition: newNode.sourcePosition,
targetPosition: newNode.targetPosition,
data: {},
}
copyOptionalNodeProps(newNode, nodeData)
Object.entries(newNode.data).forEach(([key, value]) => {
if (shouldSyncDataKey(key) && value !== undefined)
nodeData.data[key] = JSON.parse(JSON.stringify(value))
})
this.nodesMap.set(newNode.id, nodeData)
}
}
})
}
private syncEdges(oldEdges: Edge[], newEdges: Edge[]): void {
if (!this.edgesMap) return
const oldEdgesMap = new Map(oldEdges.map(edge => [edge.id, edge]))
const newEdgesMap = new Map(newEdges.map(edge => [edge.id, edge]))
oldEdges.forEach((oldEdge) => {
if (!newEdgesMap.has(oldEdge.id))
this.edgesMap.delete(oldEdge.id)
})
newEdges.forEach((newEdge) => {
const oldEdge = oldEdgesMap.get(newEdge.id)
if (!oldEdge) {
const clonedEdge = JSON.parse(JSON.stringify(newEdge))
this.edgesMap.set(newEdge.id, clonedEdge)
}
else if (!isEqual(oldEdge, newEdge)) {
const clonedEdge = JSON.parse(JSON.stringify(newEdge))
this.edgesMap.set(newEdge.id, clonedEdge)
}
})
}
private setupSubscriptions(): void {
this.nodesMap?.subscribe((event: any) => {
console.log('nodesMap subscription event:', event)
if (event.by === 'import' && this.reactFlowStore) {
// Don't update React nodes during undo/redo to prevent loops
if (this.isUndoRedoInProgress) {
console.log('Skipping nodes subscription update during undo/redo')
return
}
requestAnimationFrame(() => {
// Get ReactFlow's native setters, not the collaborative ones
const state = this.reactFlowStore.getState()
const previousNodes: Node[] = state.getNodes()
const selectedIds = new Set(
previousNodes
.filter(node => node.data?.selected)
.map(node => node.id),
)
const updatedNodes = Array
.from(this.nodesMap.values())
.map((node: Node) => {
const clonedNode: Node = {
...node,
data: {
...(node.data || {}),
},
}
if (selectedIds.has(clonedNode.id))
clonedNode.data.selected = true
return clonedNode
})
console.log('Updating React nodes from subscription')
// Call ReactFlow's native setter directly to avoid triggering collaboration
state.setNodes(updatedNodes)
})
}
})
this.edgesMap?.subscribe((event: any) => {
console.log('edgesMap subscription event:', event)
if (event.by === 'import' && this.reactFlowStore) {
// Don't update React edges during undo/redo to prevent loops
if (this.isUndoRedoInProgress) {
console.log('Skipping edges subscription update during undo/redo')
return
}
requestAnimationFrame(() => {
// Get ReactFlow's native setters, not the collaborative ones
const state = this.reactFlowStore.getState()
const updatedEdges = Array.from(this.edgesMap.values())
console.log('Updating React edges from subscription')
// Call ReactFlow's native setter directly to avoid triggering collaboration
state.setEdges(updatedEdges)
})
}
})
}
private setupSocketEventListeners(socket: any): void {
console.log('Setting up socket event listeners for collaboration')
socket.on('collaboration_update', (update: any) => {
if (update.type === 'mouseMove') {
// Update cursor state for this user
this.cursors[update.userId] = {
x: update.data.x,
y: update.data.y,
userId: update.userId,
timestamp: update.timestamp,
}
this.eventEmitter.emit('cursors', { ...this.cursors })
}
else if (update.type === 'varsAndFeaturesUpdate') {
console.log('Processing varsAndFeaturesUpdate event:', update)
this.eventEmitter.emit('varsAndFeaturesUpdate', update)
}
else if (update.type === 'appStateUpdate') {
console.log('Processing appStateUpdate event:', update)
this.eventEmitter.emit('appStateUpdate', update)
}
else if (update.type === 'mcpServerUpdate') {
console.log('Processing mcpServerUpdate event:', update)
this.eventEmitter.emit('mcpServerUpdate', update)
}
else if (update.type === 'workflowUpdate') {
console.log('Processing workflowUpdate event:', update)
this.eventEmitter.emit('workflowUpdate', update.data)
}
else if (update.type === 'commentsUpdate') {
console.log('Processing commentsUpdate event:', update)
this.eventEmitter.emit('commentsUpdate', update.data)
}
else if (update.type === 'nodePanelPresence') {
console.log('Processing nodePanelPresence event:', update)
this.applyNodePanelPresenceUpdate(update.data as NodePanelPresenceEventData)
}
else if (update.type === 'syncRequest') {
console.log('Received sync request from another user')
// Only process if we are the leader
if (this.isLeader) {
console.log('Leader received sync request, triggering sync')
this.eventEmitter.emit('syncRequest', {})
}
}
})
socket.on('online_users', (data: { users: OnlineUser[]; leader?: string }) => {
try {
if (!data || !Array.isArray(data.users)) {
console.warn('Invalid online_users data structure:', data)
return
}
const onlineUserIds = new Set(data.users.map((user: OnlineUser) => user.user_id))
const onlineClientIds = new Set(
data.users
.map((user: OnlineUser) => user.sid)
.filter((sid): sid is string => typeof sid === 'string' && sid.length > 0),
)
// Remove cursors for offline users
Object.keys(this.cursors).forEach((userId) => {
if (!onlineUserIds.has(userId))
delete this.cursors[userId]
})
this.cleanupNodePanelPresence(onlineClientIds, onlineUserIds)
// Update leader information
if (data.leader && typeof data.leader === 'string')
this.leaderId = data.leader
this.eventEmitter.emit('onlineUsers', data.users)
this.eventEmitter.emit('cursors', { ...this.cursors })
}
catch (error) {
console.error('Error processing online_users update:', error)
}
})
socket.on('status', (data: any) => {
try {
if (!data || typeof data.isLeader !== 'boolean') {
console.warn('Invalid status data:', data)
return
}
const wasLeader = this.isLeader
this.isLeader = data.isLeader
if (wasLeader !== this.isLeader)
this.eventEmitter.emit('leaderChange', this.isLeader)
}
catch (error) {
console.error('Error processing status update:', error)
}
})
socket.on('status', (data: { isLeader: boolean }) => {
if (this.isLeader !== data.isLeader) {
this.isLeader = data.isLeader
console.log(`Collaboration: I am now the ${this.isLeader ? 'Leader' : 'Follower'}.`)
this.eventEmitter.emit('leaderChange', this.isLeader)
}
})
socket.on('status', (data: { isLeader: boolean }) => {
if (this.isLeader !== data.isLeader) {
this.isLeader = data.isLeader
console.log(`Collaboration: I am now the ${this.isLeader ? 'Leader' : 'Follower'}.`)
this.eventEmitter.emit('leaderChange', this.isLeader)
}
})
socket.on('connect', () => {
console.log('WebSocket connected successfully')
this.eventEmitter.emit('stateChange', { isConnected: true })
})
socket.on('disconnect', (reason: string) => {
console.log('WebSocket disconnected:', reason)
this.cursors = {}
this.isLeader = false
this.leaderId = null
this.eventEmitter.emit('stateChange', { isConnected: false })
this.eventEmitter.emit('cursors', {})
})
socket.on('connect_error', (error: any) => {
console.error('WebSocket connection error:', error)
this.eventEmitter.emit('stateChange', { isConnected: false, error: error.message })
})
socket.on('error', (error: any) => {
console.error('WebSocket error:', error)
})
}
}
export const collaborationManager = new CollaborationManager()

View File

@ -0,0 +1,36 @@
import type { LoroDoc } from 'loro-crdt'
import type { Socket } from 'socket.io-client'
export class CRDTProvider {
private doc: LoroDoc
private socket: Socket
constructor(socket: Socket, doc: LoroDoc) {
this.socket = socket
this.doc = doc
this.setupEventListeners()
}
private setupEventListeners(): void {
this.doc.subscribe((event: any) => {
if (event.by === 'local') {
const update = this.doc.export({ mode: 'update' })
this.socket.emit('graph_event', update)
}
})
this.socket.on('graph_update', (updateData: Uint8Array) => {
try {
const data = new Uint8Array(updateData)
this.doc.import(data)
}
catch (error) {
console.error('Error importing graph update:', error)
}
})
}
destroy(): void {
this.socket.off('graph_update')
}
}

View File

@ -0,0 +1,49 @@
export type EventHandler<T = any> = (data: T) => void
export class EventEmitter {
private events: Map<string, Set<EventHandler>> = new Map()
on<T = any>(event: string, handler: EventHandler<T>): () => void {
if (!this.events.has(event))
this.events.set(event, new Set())
this.events.get(event)!.add(handler)
return () => this.off(event, handler)
}
off<T = any>(event: string, handler?: EventHandler<T>): void {
if (!this.events.has(event)) return
const handlers = this.events.get(event)!
if (handler)
handlers.delete(handler)
else
handlers.clear()
if (handlers.size === 0)
this.events.delete(event)
}
emit<T = any>(event: string, data: T): void {
if (!this.events.has(event)) return
const handlers = this.events.get(event)!
handlers.forEach((handler) => {
try {
handler(data)
}
catch (error) {
console.error(`Error in event handler for ${event}:`, error)
}
})
}
removeAllListeners(): void {
this.events.clear()
}
getListenerCount(event: string): number {
return this.events.get(event)?.size || 0
}
}

View File

@ -0,0 +1,125 @@
import type { Socket } from 'socket.io-client'
import { io } from 'socket.io-client'
import type { DebugInfo, WebSocketConfig } from '../types/websocket'
export class WebSocketClient {
private connections: Map<string, Socket> = new Map()
private connecting: Set<string> = new Set()
private config: WebSocketConfig
constructor(config: WebSocketConfig = {}) {
const inferUrl = () => {
if (typeof window === 'undefined')
return 'ws://localhost:5001'
const scheme = window.location.protocol === 'https:' ? 'wss:' : 'ws:'
return `${scheme}//${window.location.host}`
}
this.config = {
url: config.url || process.env.NEXT_PUBLIC_SOCKET_URL || inferUrl(),
transports: config.transports || ['websocket'],
withCredentials: config.withCredentials !== false,
...config,
}
}
connect(appId: string): Socket {
const existingSocket = this.connections.get(appId)
if (existingSocket?.connected)
return existingSocket
if (this.connecting.has(appId)) {
const pendingSocket = this.connections.get(appId)
if (pendingSocket)
return pendingSocket
}
if (existingSocket && !existingSocket.connected) {
existingSocket.disconnect()
this.connections.delete(appId)
}
this.connecting.add(appId)
const authToken = localStorage.getItem('console_token')
const socket = io(this.config.url!, {
path: '/socket.io',
transports: this.config.transports,
auth: { token: authToken },
withCredentials: this.config.withCredentials,
})
this.connections.set(appId, socket)
this.setupBaseEventListeners(socket, appId)
return socket
}
disconnect(appId?: string): void {
if (appId) {
const socket = this.connections.get(appId)
if (socket) {
socket.disconnect()
this.connections.delete(appId)
this.connecting.delete(appId)
}
}
else {
this.connections.forEach(socket => socket.disconnect())
this.connections.clear()
this.connecting.clear()
}
}
getSocket(appId: string): Socket | null {
return this.connections.get(appId) || null
}
isConnected(appId: string): boolean {
return this.connections.get(appId)?.connected || false
}
getConnectedApps(): string[] {
const connectedApps: string[] = []
this.connections.forEach((socket, appId) => {
if (socket.connected)
connectedApps.push(appId)
})
return connectedApps
}
getDebugInfo(): DebugInfo {
const info: DebugInfo = {}
this.connections.forEach((socket, appId) => {
info[appId] = {
connected: socket.connected,
connecting: this.connecting.has(appId),
socketId: socket.id,
}
})
return info
}
private setupBaseEventListeners(socket: Socket, appId: string): void {
socket.on('connect', () => {
this.connecting.delete(appId)
socket.emit('user_connect', { workflow_id: appId })
})
socket.on('disconnect', () => {
this.connecting.delete(appId)
})
socket.on('connect_error', () => {
this.connecting.delete(appId)
})
}
}
export const webSocketClient = new WebSocketClient()
export const fetchAppsOnlineUsers = async (appIds: string[]) => {
const response = await fetch(`/api/online-users?${new URLSearchParams({
app_ids: appIds.join(','),
})}`)
return response.json()
}

View File

@ -0,0 +1,92 @@
import { useEffect, useRef, useState } from 'react'
import type { ReactFlowInstance } from 'reactflow'
import { collaborationManager } from '../core/collaboration-manager'
import { CursorService } from '../services/cursor-service'
import type { CollaborationState } from '../types/collaboration'
export function useCollaboration(appId: string, reactFlowStore?: any) {
const [state, setState] = useState<Partial<CollaborationState & { isLeader: boolean }>>({
isConnected: false,
onlineUsers: [],
cursors: {},
nodePanelPresence: {},
isLeader: false,
})
const cursorServiceRef = useRef<CursorService | null>(null)
useEffect(() => {
if (!appId) return
let connectionId: string | null = null
if (!cursorServiceRef.current)
cursorServiceRef.current = new CursorService()
const initCollaboration = async () => {
connectionId = await collaborationManager.connect(appId, reactFlowStore)
setState((prev: any) => ({ ...prev, appId, isConnected: collaborationManager.isConnected() }))
}
initCollaboration()
const unsubscribeStateChange = collaborationManager.onStateChange((newState: any) => {
console.log('Collaboration state change:', newState)
setState((prev: any) => ({ ...prev, ...newState }))
})
const unsubscribeCursors = collaborationManager.onCursorUpdate((cursors: any) => {
setState((prev: any) => ({ ...prev, cursors }))
})
const unsubscribeUsers = collaborationManager.onOnlineUsersUpdate((users: any) => {
console.log('Online users update:', users)
setState((prev: any) => ({ ...prev, onlineUsers: users }))
})
const unsubscribeNodePanelPresence = collaborationManager.onNodePanelPresenceUpdate((presence) => {
setState((prev: any) => ({ ...prev, nodePanelPresence: presence }))
})
const unsubscribeLeaderChange = collaborationManager.onLeaderChange((isLeader: boolean) => {
console.log('Leader status changed:', isLeader)
setState((prev: any) => ({ ...prev, isLeader }))
})
return () => {
unsubscribeStateChange()
unsubscribeCursors()
unsubscribeUsers()
unsubscribeNodePanelPresence()
unsubscribeLeaderChange()
cursorServiceRef.current?.stopTracking()
if (connectionId)
collaborationManager.disconnect(connectionId)
}
}, [appId, reactFlowStore])
const startCursorTracking = (containerRef: React.RefObject<HTMLElement>, reactFlowInstance?: ReactFlowInstance) => {
if (cursorServiceRef.current) {
cursorServiceRef.current.startTracking(containerRef, (position) => {
collaborationManager.emitCursorMove(position)
}, reactFlowInstance)
}
}
const stopCursorTracking = () => {
cursorServiceRef.current?.stopTracking()
}
const result = {
isConnected: state.isConnected || false,
onlineUsers: state.onlineUsers || [],
cursors: state.cursors || {},
nodePanelPresence: state.nodePanelPresence || {},
isLeader: state.isLeader || false,
leaderId: collaborationManager.getLeaderId(),
startCursorTracking,
stopCursorTracking,
}
return result
}

View File

@ -0,0 +1,5 @@
export { collaborationManager } from './core/collaboration-manager'
export { webSocketClient, fetchAppsOnlineUsers } from './core/websocket-manager'
export { CursorService } from './services/cursor-service'
export { useCollaboration } from './hooks/use-collaboration'
export * from './types'

View File

@ -0,0 +1,88 @@
import type { RefObject } from 'react'
import type { CursorPosition } from '../types/collaboration'
import type { ReactFlowInstance } from 'reactflow'
const CURSOR_MIN_MOVE_DISTANCE = 10
const CURSOR_THROTTLE_MS = 500
export class CursorService {
private containerRef: RefObject<HTMLElement> | null = null
private reactFlowInstance: ReactFlowInstance | null = null
private isTracking = false
private onCursorUpdate: ((cursors: Record<string, CursorPosition>) => void) | null = null
private onEmitPosition: ((position: CursorPosition) => void) | null = null
private lastEmitTime = 0
private lastPosition: { x: number; y: number } | null = null
startTracking(
containerRef: RefObject<HTMLElement>,
onEmitPosition: (position: CursorPosition) => void,
reactFlowInstance?: ReactFlowInstance,
): void {
if (this.isTracking) this.stopTracking()
this.containerRef = containerRef
this.onEmitPosition = onEmitPosition
this.reactFlowInstance = reactFlowInstance || null
this.isTracking = true
if (containerRef.current)
containerRef.current.addEventListener('mousemove', this.handleMouseMove)
}
stopTracking(): void {
if (this.containerRef?.current)
this.containerRef.current.removeEventListener('mousemove', this.handleMouseMove)
this.containerRef = null
this.reactFlowInstance = null
this.onEmitPosition = null
this.isTracking = false
this.lastPosition = null
}
setCursorUpdateHandler(handler: (cursors: Record<string, CursorPosition>) => void): void {
this.onCursorUpdate = handler
}
updateCursors(cursors: Record<string, CursorPosition>): void {
if (this.onCursorUpdate)
this.onCursorUpdate(cursors)
}
private handleMouseMove = (event: MouseEvent): void => {
if (!this.containerRef?.current || !this.onEmitPosition) return
const rect = this.containerRef.current.getBoundingClientRect()
let x = event.clientX - rect.left
let y = event.clientY - rect.top
// Transform coordinates to ReactFlow world coordinates if ReactFlow instance is available
if (this.reactFlowInstance) {
const viewport = this.reactFlowInstance.getViewport()
// Convert screen coordinates to world coordinates
// World coordinates = (screen coordinates - viewport translation) / zoom
x = (x - viewport.x) / viewport.zoom
y = (y - viewport.y) / viewport.zoom
}
// Always emit cursor position (remove boundary check since world coordinates can be negative)
const now = Date.now()
const timeThrottled = now - this.lastEmitTime > CURSOR_THROTTLE_MS
const minDistance = CURSOR_MIN_MOVE_DISTANCE / (this.reactFlowInstance?.getZoom() || 1)
const distanceThrottled = !this.lastPosition
|| (Math.abs(x - this.lastPosition.x) > minDistance)
|| (Math.abs(y - this.lastPosition.y) > minDistance)
if (timeThrottled && distanceThrottled) {
this.lastPosition = { x, y }
this.lastEmitTime = now
this.onEmitPosition({
x,
y,
userId: '',
timestamp: now,
})
}
}
}

View File

@ -0,0 +1,57 @@
import type { Edge, Node } from '../../types'
export type OnlineUser = {
user_id: string
username: string
avatar: string
sid: string
}
export type WorkflowOnlineUsers = {
workflow_id: string
users: OnlineUser[]
}
export type OnlineUserListResponse = {
data: WorkflowOnlineUsers[]
}
export type CursorPosition = {
x: number
y: number
userId: string
timestamp: number
}
export type NodePanelPresenceUser = {
userId: string
username: string
avatar?: string | null
}
export type NodePanelPresenceInfo = NodePanelPresenceUser & {
clientId: string
timestamp: number
}
export type NodePanelPresenceMap = Record<string, Record<string, NodePanelPresenceInfo>>
export type CollaborationState = {
appId: string
isConnected: boolean
onlineUsers: OnlineUser[]
cursors: Record<string, CursorPosition>
nodePanelPresence: NodePanelPresenceMap
}
export type GraphSyncData = {
nodes: Node[]
edges: Edge[]
}
export type CollaborationUpdate = {
type: 'mouseMove' | 'graphUpdate' | 'userJoin' | 'userLeave'
userId: string
data: any
timestamp: number
}

View File

@ -0,0 +1,38 @@
export type CollaborationEvent = {
type: string
data: any
timestamp: number
}
export type GraphUpdateEvent = {
type: 'graph_update'
data: Uint8Array
} & CollaborationEvent
export type CursorMoveEvent = {
type: 'cursor_move'
data: {
x: number
y: number
userId: string
}
} & CollaborationEvent
export type UserConnectEvent = {
type: 'user_connect'
data: {
workflow_id: string
}
} & CollaborationEvent
export type OnlineUsersEvent = {
type: 'online_users'
data: {
users: Array<{
user_id: string
username: string
avatar: string
sid: string
}>
}
} & CollaborationEvent

View File

@ -0,0 +1,3 @@
export * from './websocket'
export * from './collaboration'
export * from './events'

View File

@ -0,0 +1,16 @@
export type WebSocketConfig = {
url?: string
token?: string
transports?: string[]
withCredentials?: boolean
}
export type ConnectionInfo = {
connected: boolean
connecting: boolean
socketId?: string
}
export type DebugInfo = {
[appId: string]: ConnectionInfo
}

View File

@ -0,0 +1,12 @@
/**
* Generate a consistent color for a user based on their ID
* Used for cursor colors and avatar backgrounds
*/
export const getUserColor = (id: string): string => {
const colors = ['#155AEF', '#0BA5EC', '#444CE7', '#7839EE', '#4CA30D', '#0E9384', '#DD2590', '#FF4405', '#D92D20', '#F79009', '#828DAD']
const hash = id.split('').reduce((a, b) => {
a = ((a << 5) - a) + b.charCodeAt(0)
return a & a
}, 0)
return colors[Math.abs(hash) % colors.length]
}

View File

@ -0,0 +1,31 @@
import { useEventListener } from 'ahooks'
import { useWorkflowStore } from './store'
import { useWorkflowComment } from './hooks/use-workflow-comment'
const CommentManager = () => {
const workflowStore = useWorkflowStore()
const { handleCreateComment } = useWorkflowComment()
useEventListener('click', (e) => {
const { controlMode, mousePosition } = workflowStore.getState()
if (controlMode === 'comment') {
const target = e.target as HTMLElement
const isInDropdown = target.closest('[data-mention-dropdown]')
const isInCommentInput = target.closest('[data-comment-input]')
const isOnCanvasPane = target.closest('.react-flow__pane')
// Only when clicking on the React Flow canvas pane (background),
// and not inside comment input or its dropdown
if (!isInDropdown && !isInCommentInput && isOnCanvasPane) {
e.preventDefault()
e.stopPropagation()
handleCreateComment(mousePosition)
}
}
})
return null
}
export default CommentManager

View File

@ -0,0 +1,239 @@
'use client'
import type { FC, PointerEvent as ReactPointerEvent } from 'react'
import { memo, useCallback, useMemo, useRef, useState } from 'react'
import { useReactFlow, useViewport } from 'reactflow'
import { UserAvatarList } from '@/app/components/base/user-avatar-list'
import CommentPreview from './comment-preview'
import type { WorkflowCommentList } from '@/service/workflow-comment'
type CommentIconProps = {
comment: WorkflowCommentList
onClick: () => void
isActive?: boolean
onPositionUpdate?: (position: { x: number; y: number }) => void
}
export const CommentIcon: FC<CommentIconProps> = memo(({ comment, onClick, isActive = false, onPositionUpdate }) => {
const { flowToScreenPosition, screenToFlowPosition } = useReactFlow()
const viewport = useViewport()
const [showPreview, setShowPreview] = useState(false)
const [dragPosition, setDragPosition] = useState<{ x: number; y: number } | null>(null)
const [isDragging, setIsDragging] = useState(false)
const dragStateRef = useRef<{
offsetX: number
offsetY: number
startX: number
startY: number
hasMoved: boolean
} | null>(null)
const screenPosition = useMemo(() => {
return flowToScreenPosition({
x: comment.position_x,
y: comment.position_y,
})
}, [comment.position_x, comment.position_y, viewport.x, viewport.y, viewport.zoom, flowToScreenPosition])
const effectivePosition = dragPosition ?? screenPosition
const handlePointerDown = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
if (event.button !== 0)
return
event.stopPropagation()
event.preventDefault()
dragStateRef.current = {
offsetX: event.clientX - screenPosition.x,
offsetY: event.clientY - screenPosition.y,
startX: event.clientX,
startY: event.clientY,
hasMoved: false,
}
setDragPosition(screenPosition)
setIsDragging(false)
if (event.currentTarget.dataset.role !== 'comment-preview')
setShowPreview(false)
if (event.currentTarget.setPointerCapture)
event.currentTarget.setPointerCapture(event.pointerId)
}, [screenPosition])
const handlePointerMove = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
const dragState = dragStateRef.current
if (!dragState)
return
event.stopPropagation()
event.preventDefault()
const nextX = event.clientX - dragState.offsetX
const nextY = event.clientY - dragState.offsetY
if (!dragState.hasMoved) {
const distance = Math.hypot(event.clientX - dragState.startX, event.clientY - dragState.startY)
if (distance > 4) {
dragState.hasMoved = true
setIsDragging(true)
}
}
setDragPosition({ x: nextX, y: nextY })
}, [])
const finishDrag = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
const dragState = dragStateRef.current
if (!dragState)
return false
if (event.currentTarget.hasPointerCapture?.(event.pointerId))
event.currentTarget.releasePointerCapture(event.pointerId)
dragStateRef.current = null
setDragPosition(null)
setIsDragging(false)
return dragState.hasMoved
}, [])
const handlePointerUp = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
event.stopPropagation()
event.preventDefault()
const finalScreenPosition = dragPosition ?? screenPosition
const didDrag = finishDrag(event)
setShowPreview(false)
if (didDrag) {
if (onPositionUpdate) {
const flowPosition = screenToFlowPosition({
x: finalScreenPosition.x,
y: finalScreenPosition.y,
})
onPositionUpdate(flowPosition)
}
}
else if (!isActive) {
onClick()
}
}, [dragPosition, finishDrag, isActive, onClick, onPositionUpdate, screenPosition, screenToFlowPosition])
const handlePointerCancel = useCallback((event: ReactPointerEvent<HTMLDivElement>) => {
event.stopPropagation()
event.preventDefault()
finishDrag(event)
}, [finishDrag])
const handleMouseEnter = useCallback(() => {
if (isActive || isDragging)
return
setShowPreview(true)
}, [isActive, isDragging])
const handleMouseLeave = useCallback(() => {
setShowPreview(false)
}, [])
const participants = useMemo(() => {
const list = comment.participants ?? []
const author = comment.created_by_account
if (!author)
return [...list]
const rest = list.filter(user => user.id !== author.id)
return [author, ...rest]
}, [comment.created_by_account, comment.participants])
// Calculate dynamic width based on number of participants
const participantCount = participants.length
const maxVisible = Math.min(3, participantCount)
const showCount = participantCount > 3
const avatarSize = 24
const avatarSpacing = 4 // -space-x-1 is about 4px overlap
// Width calculation: first avatar + (additional avatars * (size - spacing)) + padding
const dynamicWidth = Math.max(40, // minimum width
8 + avatarSize + Math.max(0, (showCount ? 2 : maxVisible - 1)) * (avatarSize - avatarSpacing) + 8,
)
const pointerEventHandlers = useMemo(() => ({
onPointerDown: handlePointerDown,
onPointerMove: handlePointerMove,
onPointerUp: handlePointerUp,
onPointerCancel: handlePointerCancel,
}), [handlePointerCancel, handlePointerDown, handlePointerMove, handlePointerUp])
return (
<>
<div
className="absolute z-10"
style={{
left: effectivePosition.x,
top: effectivePosition.y,
transform: 'translate(-50%, -50%)',
}}
data-role='comment-marker'
{...pointerEventHandlers}
>
<div
className={isActive ? (isDragging ? 'cursor-grabbing' : '') : isDragging ? 'cursor-grabbing' : 'cursor-pointer'}
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
>
<div
className={'relative h-10 overflow-hidden rounded-br-full rounded-tl-full rounded-tr-full'}
style={{ width: dynamicWidth }}
>
<div className={`absolute inset-[6px] overflow-hidden rounded-br-full rounded-tl-full rounded-tr-full border ${
isActive
? 'border-2 border-primary-500 bg-components-panel-bg'
: 'border-components-panel-border bg-components-panel-bg'
}`}>
<div className="flex h-full w-full items-center justify-center px-1">
<UserAvatarList
users={participants}
maxVisible={3}
size={24}
/>
</div>
</div>
</div>
</div>
</div>
{/* Preview panel */}
{showPreview && !isActive && (
<div
className="absolute z-20"
style={{
left: (dragPosition ?? screenPosition).x - dynamicWidth / 2,
top: (dragPosition ?? screenPosition).y + 20,
transform: 'translateY(-100%)',
}}
data-role='comment-preview'
{...pointerEventHandlers}
onMouseEnter={() => setShowPreview(true)}
onMouseLeave={() => setShowPreview(false)}
>
<CommentPreview comment={comment} onClick={() => {
setShowPreview(false)
onClick()
}} />
</div>
)}
</>
)
}, (prevProps, nextProps) => {
return (
prevProps.comment.id === nextProps.comment.id
&& prevProps.comment.position_x === nextProps.comment.position_x
&& prevProps.comment.position_y === nextProps.comment.position_y
&& prevProps.onClick === nextProps.onClick
&& prevProps.isActive === nextProps.isActive
&& prevProps.onPositionUpdate === nextProps.onPositionUpdate
)
})
CommentIcon.displayName = 'CommentIcon'

View File

@ -0,0 +1,87 @@
import type { FC } from 'react'
import { memo, useCallback, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Avatar from '@/app/components/base/avatar'
import { useAppContext } from '@/context/app-context'
import { MentionInput } from './mention-input'
import cn from '@/utils/classnames'
type CommentInputProps = {
position: { x: number; y: number }
onSubmit: (content: string, mentionedUserIds: string[]) => void
onCancel: () => void
}
export const CommentInput: FC<CommentInputProps> = memo(({ position, onSubmit, onCancel }) => {
const [content, setContent] = useState('')
const { t } = useTranslation()
const { userProfile } = useAppContext()
useEffect(() => {
const handleGlobalKeyDown = (e: KeyboardEvent) => {
if (e.key === 'Escape') {
e.preventDefault()
e.stopPropagation()
onCancel()
}
}
document.addEventListener('keydown', handleGlobalKeyDown, true)
return () => {
document.removeEventListener('keydown', handleGlobalKeyDown, true)
}
}, [onCancel])
const handleMentionSubmit = useCallback((content: string, mentionedUserIds: string[]) => {
onSubmit(content, mentionedUserIds)
setContent('')
}, [onSubmit])
return (
<div
className="absolute z-50 w-96"
style={{
left: position.x,
top: position.y,
}}
data-comment-input
>
<div className="flex items-center gap-3">
<div className="relative shrink-0">
<div className="relative h-8 w-8 overflow-hidden rounded-br-full rounded-tl-full rounded-tr-full bg-primary-500">
<div className="absolute inset-[2px] overflow-hidden rounded-br-full rounded-tl-full rounded-tr-full bg-white">
<div className="flex h-full w-full items-center justify-center">
<div className="h-6 w-6 overflow-hidden rounded-full">
<Avatar
avatar={userProfile.avatar_url}
name={userProfile.name}
size={24}
className="h-full w-full"
/>
</div>
</div>
</div>
</div>
</div>
<div
className={cn(
'relative z-10 flex-1 rounded-xl border border-components-chat-input-border bg-components-panel-bg-blur pb-[4px] shadow-md',
)}
>
<div className='relative px-[9px] pt-[4px]'>
<MentionInput
value={content}
onChange={setContent}
onSubmit={handleMentionSubmit}
placeholder={t('workflow.comments.placeholder.add')}
autoFocus
className="relative"
/>
</div>
</div>
</div>
</div>
)
})
CommentInput.displayName = 'CommentInput'

View File

@ -0,0 +1,52 @@
'use client'
import type { FC } from 'react'
import { memo, useMemo } from 'react'
import { UserAvatarList } from '@/app/components/base/user-avatar-list'
import type { WorkflowCommentList } from '@/service/workflow-comment'
import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now'
type CommentPreviewProps = {
comment: WorkflowCommentList
onClick?: () => void
}
const CommentPreview: FC<CommentPreviewProps> = ({ comment, onClick }) => {
const { formatTimeFromNow } = useFormatTimeFromNow()
const participants = useMemo(() => {
const list = comment.participants ?? []
const author = comment.created_by_account
if (!author)
return [...list]
const rest = list.filter(user => user.id !== author.id)
return [author, ...rest]
}, [comment.created_by_account, comment.participants])
return (
<div
className="w-80 cursor-pointer rounded-br-xl rounded-tl-xl rounded-tr-xl border border-components-panel-border bg-components-panel-bg p-4 shadow-lg transition-colors hover:bg-components-panel-on-panel-item-bg-hover"
onClick={onClick}
>
<div className="mb-3 flex items-center justify-between">
<UserAvatarList
users={participants}
maxVisible={3}
size={24}
/>
</div>
<div className="mb-2 flex items-start">
<div className="flex min-w-0 items-center gap-2">
<div className="system-sm-medium truncate text-text-primary">{comment.created_by_account.name}</div>
<div className="system-2xs-regular shrink-0 text-text-tertiary">
{formatTimeFromNow(comment.updated_at * 1000)}
</div>
</div>
</div>
<div className="system-sm-regular break-words text-text-secondary">{comment.content}</div>
</div>
)
}
export default memo(CommentPreview)

View File

@ -0,0 +1,28 @@
import type { FC } from 'react'
import { memo } from 'react'
import { useStore } from '../store'
import { ControlMode } from '../types'
import { Comment } from '@/app/components/base/icons/src/public/other'
export const CommentCursor: FC = memo(() => {
const controlMode = useStore(s => s.controlMode)
const mousePosition = useStore(s => s.mousePosition)
if (controlMode !== ControlMode.Comment)
return null
return (
<div
className="pointer-events-none absolute z-50 flex h-6 w-6 items-center justify-center"
style={{
left: mousePosition.elementX,
top: mousePosition.elementY,
transform: 'translate(-50%, -50%)',
}}
>
<Comment className="text-text-primary" />
</div>
)
})
CommentCursor.displayName = 'CommentCursor'

View File

@ -0,0 +1,5 @@
export { CommentCursor } from './cursor'
export { CommentInput } from './comment-input'
export { CommentIcon } from './comment-icon'
export { CommentThread } from './thread'
export { MentionInput } from './mention-input'

View File

@ -0,0 +1,440 @@
'use client'
import type { FC, ReactNode } from 'react'
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { createPortal } from 'react-dom'
import { useParams } from 'next/navigation'
import { useTranslation } from 'react-i18next'
import { RiArrowUpLine, RiAtLine } from '@remixicon/react'
import Textarea from 'react-textarea-autosize'
import Button from '@/app/components/base/button'
import Avatar from '@/app/components/base/avatar'
import cn from '@/utils/classnames'
import { type UserProfile, fetchMentionableUsers } from '@/service/workflow-comment'
import { useStore, useWorkflowStore } from '../store'
type MentionInputProps = {
value: string
onChange: (value: string) => void
onSubmit: (content: string, mentionedUserIds: string[]) => void
onCancel?: () => void
placeholder?: string
disabled?: boolean
loading?: boolean
className?: string
isEditing?: boolean
autoFocus?: boolean
}
export const MentionInput: FC<MentionInputProps> = memo(({
value,
onChange,
onSubmit,
onCancel,
placeholder,
disabled = false,
loading = false,
className,
isEditing = false,
autoFocus = false,
}) => {
const params = useParams()
const { t } = useTranslation()
const appId = params.appId as string
const textareaRef = useRef<HTMLTextAreaElement>(null)
const workflowStore = useWorkflowStore()
const mentionUsersFromStore = useStore(state => (
appId ? state.mentionableUsersCache[appId] : undefined
))
const mentionUsers = mentionUsersFromStore ?? []
const [showMentionDropdown, setShowMentionDropdown] = useState(false)
const [mentionQuery, setMentionQuery] = useState('')
const [mentionPosition, setMentionPosition] = useState(0)
const [selectedMentionIndex, setSelectedMentionIndex] = useState(0)
const [mentionedUserIds, setMentionedUserIds] = useState<string[]>([])
const resolvedPlaceholder = placeholder ?? t('workflow.comments.placeholder.add')
const mentionNameList = useMemo(() => {
const names = mentionUsers
.map(user => user.name?.trim())
.filter((name): name is string => Boolean(name))
const uniqueNames = Array.from(new Set(names))
uniqueNames.sort((a, b) => b.length - a.length)
return uniqueNames
}, [mentionUsers])
const highlightedValue = useMemo<ReactNode>(() => {
if (!value)
return ''
if (mentionNameList.length === 0)
return value
const segments: ReactNode[] = []
let cursor = 0
let hasMention = false
while (cursor < value.length) {
let nextMatchStart = -1
let matchedName = ''
for (const name of mentionNameList) {
const searchStart = value.indexOf(`@${name}`, cursor)
if (searchStart === -1)
continue
const previousChar = searchStart > 0 ? value[searchStart - 1] : ''
if (searchStart > 0 && !/\s/.test(previousChar))
continue
if (
nextMatchStart === -1
|| searchStart < nextMatchStart
|| (searchStart === nextMatchStart && name.length > matchedName.length)
) {
nextMatchStart = searchStart
matchedName = name
}
}
if (nextMatchStart === -1)
break
if (nextMatchStart > cursor)
segments.push(<span key={`text-${cursor}`}>{value.slice(cursor, nextMatchStart)}</span>)
const mentionEnd = nextMatchStart + matchedName.length + 1
segments.push(
<span key={`mention-${nextMatchStart}`} className='text-primary-600'>
{value.slice(nextMatchStart, mentionEnd)}
</span>,
)
hasMention = true
cursor = mentionEnd
}
if (!hasMention)
return value
if (cursor < value.length)
segments.push(<span key={`text-${cursor}`}>{value.slice(cursor)}</span>)
return segments
}, [value, mentionNameList])
const loadMentionableUsers = useCallback(async () => {
if (!appId)
return
const state = workflowStore.getState()
if (state.mentionableUsersCache[appId] !== undefined)
return
if (state.mentionableUsersLoading[appId])
return
state.setMentionableUsersLoading(appId, true)
try {
const users = await fetchMentionableUsers(appId)
workflowStore.getState().setMentionableUsersCache(appId, users)
}
catch (error) {
console.error('Failed to load mentionable users:', error)
}
finally {
workflowStore.getState().setMentionableUsersLoading(appId, false)
}
}, [appId, workflowStore])
useEffect(() => {
loadMentionableUsers()
}, [loadMentionableUsers])
const filteredMentionUsers = useMemo(() => {
if (!mentionQuery) return mentionUsers
return mentionUsers.filter(user =>
user.name.toLowerCase().includes(mentionQuery.toLowerCase())
|| user.email.toLowerCase().includes(mentionQuery.toLowerCase()),
)
}, [mentionUsers, mentionQuery])
const dropdownPosition = useMemo(() => {
if (!showMentionDropdown || !textareaRef.current)
return { x: 0, y: 0, placement: 'bottom' as const }
const textareaRect = textareaRef.current.getBoundingClientRect()
const dropdownHeight = 160 // max-h-40 = 10rem = 160px
const viewportHeight = window.innerHeight
const spaceBelow = viewportHeight - textareaRect.bottom
const spaceAbove = textareaRect.top
const shouldPlaceAbove = spaceBelow < dropdownHeight && spaceAbove > spaceBelow
return {
x: textareaRect.left,
y: shouldPlaceAbove ? textareaRect.top - 4 : textareaRect.bottom + 4,
placement: shouldPlaceAbove ? 'top' as const : 'bottom' as const,
}
}, [showMentionDropdown])
const handleContentChange = useCallback((newValue: string) => {
onChange(newValue)
setTimeout(() => {
const cursorPosition = textareaRef.current?.selectionStart || 0
const textBeforeCursor = newValue.slice(0, cursorPosition)
const mentionMatch = textBeforeCursor.match(/@(\w*)$/)
if (mentionMatch) {
setMentionQuery(mentionMatch[1])
setMentionPosition(cursorPosition - mentionMatch[0].length)
setShowMentionDropdown(true)
setSelectedMentionIndex(0)
}
else {
setShowMentionDropdown(false)
}
}, 0)
}, [onChange])
const handleMentionButtonClick = useCallback((e: React.MouseEvent) => {
e.preventDefault()
e.stopPropagation()
const textarea = textareaRef.current
if (!textarea) return
const cursorPosition = textarea.selectionStart || 0
const newContent = `${value.slice(0, cursorPosition)}@${value.slice(cursorPosition)}`
onChange(newContent)
setTimeout(() => {
const newCursorPos = cursorPosition + 1
textarea.setSelectionRange(newCursorPos, newCursorPos)
textarea.focus()
setMentionQuery('')
setMentionPosition(cursorPosition)
setShowMentionDropdown(true)
setSelectedMentionIndex(0)
}, 0)
}, [value, onChange])
const insertMention = useCallback((user: UserProfile) => {
const textarea = textareaRef.current
if (!textarea) return
const beforeMention = value.slice(0, mentionPosition)
const afterMention = value.slice(textarea.selectionStart || 0)
const needsSpaceBefore = mentionPosition > 0 && !/\s/.test(value[mentionPosition - 1])
const prefix = needsSpaceBefore ? ' ' : ''
const newContent = `${beforeMention}${prefix}@${user.name} ${afterMention}`
onChange(newContent)
setShowMentionDropdown(false)
const newMentionedUserIds = [...mentionedUserIds, user.id]
setMentionedUserIds(newMentionedUserIds)
setTimeout(() => {
const extraSpace = needsSpaceBefore ? 1 : 0
const newCursorPos = mentionPosition + extraSpace + user.name.length + 2 // (space) + @ + name + space
textarea.setSelectionRange(newCursorPos, newCursorPos)
textarea.focus()
}, 0)
}, [value, mentionPosition, onChange, mentionedUserIds])
const handleSubmit = useCallback((e?: React.MouseEvent) => {
if (e) {
e.preventDefault()
e.stopPropagation()
}
if (value.trim()) {
onSubmit(value.trim(), mentionedUserIds)
setMentionedUserIds([])
setShowMentionDropdown(false)
}
}, [value, mentionedUserIds, onSubmit])
const handleKeyDown = useCallback((e: React.KeyboardEvent) => {
if (showMentionDropdown) {
if (e.key === 'ArrowDown') {
e.preventDefault()
setSelectedMentionIndex(prev =>
prev < filteredMentionUsers.length - 1 ? prev + 1 : 0,
)
}
else if (e.key === 'ArrowUp') {
e.preventDefault()
setSelectedMentionIndex(prev =>
prev > 0 ? prev - 1 : filteredMentionUsers.length - 1,
)
}
else if (e.key === 'Enter') {
e.preventDefault()
if (filteredMentionUsers[selectedMentionIndex])
insertMention(filteredMentionUsers[selectedMentionIndex])
return
}
else if (e.key === 'Escape') {
e.preventDefault()
setShowMentionDropdown(false)
return
}
}
if (e.key === 'Enter' && !e.shiftKey && !showMentionDropdown) {
e.preventDefault()
handleSubmit()
}
}, [showMentionDropdown, filteredMentionUsers, selectedMentionIndex, insertMention, handleSubmit])
const resetMentionState = useCallback(() => {
setMentionedUserIds([])
setShowMentionDropdown(false)
setMentionQuery('')
setMentionPosition(0)
setSelectedMentionIndex(0)
}, [])
useEffect(() => {
if (!value)
resetMentionState()
}, [value, resetMentionState])
useEffect(() => {
if (autoFocus && textareaRef.current) {
const textarea = textareaRef.current
setTimeout(() => {
textarea.focus()
const length = textarea.value.length
textarea.setSelectionRange(length, length)
}, 0)
}
}, [autoFocus])
return (
<>
<div className={cn('relative flex items-center', className)}>
<div
aria-hidden
className={cn(
'pointer-events-none absolute inset-0 z-0 overflow-hidden whitespace-pre-wrap break-words p-1 leading-6',
'body-lg-regular text-text-primary',
)}
>
{highlightedValue}
{''}
</div>
<Textarea
ref={textareaRef}
className={cn(
'body-lg-regular relative z-10 w-full resize-none bg-transparent p-1 leading-6 text-transparent caret-primary-500 outline-none',
'placeholder:text-text-tertiary',
)}
placeholder={resolvedPlaceholder}
autoFocus={autoFocus}
minRows={isEditing ? 4 : 1}
maxRows={4}
value={value}
disabled={disabled || loading}
onChange={e => handleContentChange(e.target.value)}
onKeyDown={handleKeyDown}
/>
{!isEditing && (
<div className="absolute bottom-0 right-1 z-20 flex items-end gap-1">
<div
className="z-20 flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg hover:bg-state-base-hover"
onClick={handleMentionButtonClick}
>
<RiAtLine className="h-4 w-4 text-components-button-primary-text" />
</div>
<Button
className='z-20 ml-2 w-8 px-0'
variant='primary'
disabled={!value.trim() || disabled || loading}
onClick={handleSubmit}
>
<RiArrowUpLine className='h-4 w-4 text-components-button-primary-text' />
</Button>
</div>
)}
{isEditing && (
<div className="absolute bottom-0 left-1 right-1 z-20 flex items-end justify-between">
<div
className="z-20 flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg hover:bg-state-base-hover"
onClick={handleMentionButtonClick}
>
<RiAtLine className="h-4 w-4 text-components-button-primary-text" />
</div>
<div className='flex items-center gap-2'>
<Button variant='secondary' size='small' onClick={onCancel} disabled={loading}>
{t('common.operation.cancel')}
</Button>
<Button
variant='primary'
size='small'
disabled={loading || !value.trim()}
onClick={() => handleSubmit()}
>
{t('common.operation.save')}
</Button>
</div>
</div>
)}
</div>
{showMentionDropdown && filteredMentionUsers.length > 0 && typeof document !== 'undefined' && createPortal(
<div
className="fixed z-[9999] max-h-40 w-64 overflow-y-auto rounded-lg border border-components-panel-border bg-components-panel-bg shadow-lg"
style={{
left: dropdownPosition.x,
[dropdownPosition.placement === 'top' ? 'bottom' : 'top']: dropdownPosition.placement === 'top'
? window.innerHeight - dropdownPosition.y
: dropdownPosition.y,
}}
data-mention-dropdown
>
{filteredMentionUsers.map((user, index) => (
<div
key={user.id}
className={cn(
'flex cursor-pointer items-center gap-2 p-2 hover:bg-state-base-hover',
index === selectedMentionIndex && 'bg-state-base-hover',
)}
onClick={() => insertMention(user)}
>
<Avatar
avatar={user.avatar_url || null}
name={user.name}
size={24}
className="shrink-0"
/>
<div className="min-w-0 flex-1">
<div className="truncate text-sm font-medium text-text-primary">
{user.name}
</div>
<div className="truncate text-xs text-text-tertiary">
{user.email}
</div>
</div>
</div>
))}
</div>,
document.body,
)}
</>
)
})
MentionInput.displayName = 'MentionInput'

View File

@ -0,0 +1,413 @@
'use client'
import type { FC, ReactNode } from 'react'
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useReactFlow, useViewport } from 'reactflow'
import { useTranslation } from 'react-i18next'
import { RiArrowDownSLine, RiArrowUpSLine, RiCheckboxCircleFill, RiCheckboxCircleLine, RiCloseLine, RiDeleteBinLine, RiMoreFill } from '@remixicon/react'
import Avatar from '@/app/components/base/avatar'
import Divider from '@/app/components/base/divider'
import cn from '@/utils/classnames'
import { useFormatTimeFromNow } from '@/hooks/use-format-time-from-now'
import type { WorkflowCommentDetail, WorkflowCommentDetailReply } from '@/service/workflow-comment'
import { useAppContext } from '@/context/app-context'
import { MentionInput } from './mention-input'
import { getUserColor } from '@/app/components/workflow/collaboration/utils/user-color'
type CommentThreadProps = {
comment: WorkflowCommentDetail
loading?: boolean
onClose: () => void
onDelete?: () => void
onResolve?: () => void
onPrev?: () => void
onNext?: () => void
canGoPrev?: boolean
canGoNext?: boolean
onReply?: (content: string, mentionedUserIds?: string[]) => Promise<void> | void
onReplyEdit?: (replyId: string, content: string, mentionedUserIds?: string[]) => Promise<void> | void
onReplyDelete?: (replyId: string) => void
}
const ThreadMessage: FC<{
authorId: string
authorName: string
avatarUrl?: string | null
createdAt: number
content: string
mentionedNames?: string[]
}> = ({ authorId, authorName, avatarUrl, createdAt, content, mentionedNames }) => {
const { formatTimeFromNow } = useFormatTimeFromNow()
const { userProfile } = useAppContext()
const currentUserId = userProfile?.id
const isCurrentUser = authorId === currentUserId
const userColor = isCurrentUser ? undefined : getUserColor(authorId)
const highlightedContent = useMemo<ReactNode>(() => {
if (!content)
return ''
const normalizedNames = Array.from(new Set((mentionedNames || [])
.map(name => name.trim())
.filter(Boolean)))
if (normalizedNames.length === 0)
return content
const segments: ReactNode[] = []
let hasMention = false
let cursor = 0
while (cursor < content.length) {
let nextMatchStart = -1
let matchedName = ''
for (const name of normalizedNames) {
const searchStart = content.indexOf(`@${name}`, cursor)
if (searchStart === -1)
continue
const previousChar = searchStart > 0 ? content[searchStart - 1] : ''
if (searchStart > 0 && !/\s/.test(previousChar))
continue
if (
nextMatchStart === -1
|| searchStart < nextMatchStart
|| (searchStart === nextMatchStart && name.length > matchedName.length)
) {
nextMatchStart = searchStart
matchedName = name
}
}
if (nextMatchStart === -1)
break
if (nextMatchStart > cursor)
segments.push(<span key={`text-${cursor}`}>{content.slice(cursor, nextMatchStart)}</span>)
const mentionEnd = nextMatchStart + matchedName.length + 1
segments.push(
<span key={`mention-${nextMatchStart}`} className='text-primary-600'>
{content.slice(nextMatchStart, mentionEnd)}
</span>,
)
hasMention = true
cursor = mentionEnd
}
if (!hasMention)
return content
if (cursor < content.length)
segments.push(<span key={`text-${cursor}`}>{content.slice(cursor)}</span>)
return segments
}, [content, mentionedNames])
return (
<div className={cn('flex gap-3 pt-1')}>
<div className='shrink-0'>
<Avatar
name={authorName}
avatar={avatarUrl || null}
size={24}
className={cn('h-8 w-8 rounded-full')}
backgroundColor={userColor}
/>
</div>
<div className='min-w-0 flex-1 pb-4 text-text-primary last:pb-0'>
<div className='flex flex-wrap items-center gap-x-2 gap-y-1'>
<span className='system-sm-medium text-text-primary'>{authorName}</span>
<span className='system-2xs-regular text-text-tertiary'>{formatTimeFromNow(createdAt * 1000)}</span>
</div>
<div className='system-sm-regular mt-1 whitespace-pre-wrap break-words text-text-secondary'>
{highlightedContent}
</div>
</div>
</div>
)
}
export const CommentThread: FC<CommentThreadProps> = memo(({
comment,
loading = false,
onClose,
onDelete,
onResolve,
onPrev,
onNext,
canGoPrev,
canGoNext,
onReply,
onReplyEdit,
onReplyDelete,
}) => {
const { flowToScreenPosition } = useReactFlow()
const viewport = useViewport()
const { userProfile } = useAppContext()
const { t } = useTranslation()
const [replyContent, setReplyContent] = useState('')
const [activeReplyMenuId, setActiveReplyMenuId] = useState<string | null>(null)
const [editingReply, setEditingReply] = useState<{ id: string; content: string }>({ id: '', content: '' })
useEffect(() => {
setReplyContent('')
}, [comment.id])
const handleReplySubmit = useCallback(async (content: string, mentionedUserIds: string[]) => {
if (!onReply || loading) return
try {
await onReply(content, mentionedUserIds)
setReplyContent('')
}
catch (error) {
console.error('Failed to send reply', error)
}
}, [onReply, loading])
const screenPosition = useMemo(() => {
return flowToScreenPosition({
x: comment.position_x,
y: comment.position_y,
})
}, [comment.position_x, comment.position_y, viewport.x, viewport.y, viewport.zoom, flowToScreenPosition])
const handleStartEdit = useCallback((reply: WorkflowCommentDetailReply) => {
setEditingReply({ id: reply.id, content: reply.content })
setActiveReplyMenuId(null)
}, [])
const handleCancelEdit = useCallback(() => {
setEditingReply({ id: '', content: '' })
}, [])
const handleEditSubmit = useCallback(async (content: string, mentionedUserIds: string[]) => {
if (!onReplyEdit || !editingReply) return
const trimmed = content.trim()
if (!trimmed) return
await onReplyEdit(editingReply.id, trimmed, mentionedUserIds)
setEditingReply({ id: '', content: '' })
}, [editingReply, onReplyEdit])
const replies = comment.replies || []
const messageListRef = useRef<HTMLDivElement>(null)
const previousReplyCountRef = useRef(replies.length)
const previousCommentIdRef = useRef(comment.id)
useEffect(() => {
const container = messageListRef.current
if (!container)
return
const isNewComment = comment.id !== previousCommentIdRef.current
const hasNewReply = replies.length > previousReplyCountRef.current
if (isNewComment || hasNewReply)
container.scrollTop = container.scrollHeight
previousCommentIdRef.current = comment.id
previousReplyCountRef.current = replies.length
}, [comment.id, replies.length])
const mentionsByTarget = useMemo(() => {
const map = new Map<string, string[]>()
for (const mention of comment.mentions || []) {
const name = mention.mentioned_user_account?.name?.trim()
if (!name)
continue
const key = mention.reply_id ?? 'root'
const existing = map.get(key)
if (existing) {
if (!existing.includes(name))
existing.push(name)
}
else {
map.set(key, [name])
}
}
return map
}, [comment.mentions])
return (
<div
className='absolute z-50 w-[360px] max-w-[360px]'
style={{
left: screenPosition.x + 40,
top: screenPosition.y,
transform: 'translateY(-20%)',
}}
>
<div className='relative flex h-[360px] flex-col overflow-hidden rounded-2xl border border-components-panel-border bg-components-panel-bg shadow-xl'>
<div className='flex items-center justify-between rounded-t-2xl border-b border-components-panel-border bg-components-panel-bg-blur px-4 py-3'>
<div className='font-semibold uppercase text-text-primary'>{t('workflow.comments.panelTitle')}</div>
<div className='flex items-center gap-1'>
<button
type='button'
disabled={loading}
className={cn('flex h-6 w-6 items-center justify-center rounded-lg text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary disabled:cursor-not-allowed disabled:text-text-disabled disabled:hover:bg-transparent disabled:hover:text-text-disabled')}
onClick={onDelete}
aria-label={t('workflow.comments.aria.deleteComment')}
>
<RiDeleteBinLine className='h-4 w-4' />
</button>
<button
type='button'
disabled={comment.resolved || loading}
className={cn('flex h-6 w-6 items-center justify-center rounded-lg text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary disabled:cursor-not-allowed disabled:text-text-disabled disabled:hover:bg-transparent disabled:hover:text-text-disabled')}
onClick={onResolve}
aria-label={t('workflow.comments.aria.resolveComment')}
>
{comment.resolved ? <RiCheckboxCircleFill className='h-4 w-4' /> : <RiCheckboxCircleLine className='h-4 w-4' />}
</button>
<Divider type='vertical' className='h-3.5' />
<button
type='button'
disabled={!canGoPrev || loading}
className={cn('flex h-6 w-6 items-center justify-center rounded-lg text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary disabled:cursor-not-allowed disabled:text-text-disabled disabled:hover:bg-transparent disabled:hover:text-text-disabled')}
onClick={onPrev}
aria-label={t('workflow.comments.aria.previousComment')}
>
<RiArrowUpSLine className='h-4 w-4' />
</button>
<button
type='button'
disabled={!canGoNext || loading}
className={cn('flex h-6 w-6 items-center justify-center rounded-lg text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary disabled:cursor-not-allowed disabled:text-text-disabled disabled:hover:bg-transparent disabled:hover:text-text-disabled')}
onClick={onNext}
aria-label={t('workflow.comments.aria.nextComment')}
>
<RiArrowDownSLine className='h-4 w-4' />
</button>
<button
type='button'
className='flex h-6 w-6 items-center justify-center rounded-lg text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary'
onClick={onClose}
aria-label={t('workflow.comments.aria.closeComment')}
>
<RiCloseLine className='h-4 w-4' />
</button>
</div>
</div>
<div
ref={messageListRef}
className='relative mt-2 flex-1 overflow-y-auto px-4'
>
<ThreadMessage
authorId={comment.created_by_account?.id || ''}
authorName={comment.created_by_account?.name || t('workflow.comments.fallback.user')}
avatarUrl={comment.created_by_account?.avatar_url || null}
createdAt={comment.created_at}
content={comment.content}
mentionedNames={mentionsByTarget.get('root')}
/>
{replies.length > 0 && (
<div className='mt-2 space-y-3 pt-3'>
{replies.map((reply) => {
const isReplyEditing = editingReply?.id === reply.id
const isOwnReply = reply.created_by_account?.id === userProfile?.id
return (
<div
key={reply.id}
className='group relative rounded-lg py-2 transition-colors hover:bg-components-panel-on-panel-item-bg'
>
{isOwnReply && !isReplyEditing && (
<div className='absolute right-1 top-1 hidden gap-1 group-hover:flex'>
<button
type='button'
className='flex h-6 w-6 items-center justify-center rounded-md text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary'
onClick={(e) => {
e.stopPropagation()
setActiveReplyMenuId(prev => prev === reply.id ? null : reply.id)
}}
aria-label={t('workflow.comments.aria.replyActions')}
>
<RiMoreFill className='h-4 w-4' />
</button>
{activeReplyMenuId === reply.id && (
<div className='absolute right-0 top-7 z-40 w-36 rounded-lg border border-components-panel-border bg-components-panel-bg shadow-lg'>
<button
className='flex w-full items-center justify-start px-3 py-2 text-left text-sm text-text-secondary hover:bg-state-base-hover'
onClick={() => handleStartEdit(reply)}
>
{t('workflow.comments.actions.editReply')}
</button>
<button
className='text-negative flex w-full items-center justify-start px-3 py-2 text-left text-sm text-text-secondary hover:bg-state-base-hover'
onClick={() => {
setActiveReplyMenuId(null)
onReplyDelete?.(reply.id)
}}
>
{t('workflow.comments.actions.deleteReply')}
</button>
</div>
)}
</div>
)}
{isReplyEditing ? (
<div className='rounded-lg border border-components-chat-input-border bg-components-panel-bg-blur px-3 py-2 shadow-sm'>
<MentionInput
value={editingReply?.content ?? ''}
onChange={newContent => setEditingReply(prev => prev ? { ...prev, content: newContent } : prev)}
onSubmit={handleEditSubmit}
onCancel={handleCancelEdit}
placeholder={t('workflow.comments.placeholder.editReply')}
disabled={loading}
loading={loading}
isEditing={true}
className="system-sm-regular"
autoFocus
/>
</div>
) : (
<ThreadMessage
authorId={reply.created_by_account?.id || ''}
authorName={reply.created_by_account?.name || t('workflow.comments.fallback.user')}
avatarUrl={reply.created_by_account?.avatar_url || null}
createdAt={reply.created_at}
content={reply.content}
mentionedNames={mentionsByTarget.get(reply.id)}
/>
)}
</div>
)
})}
</div>
)}
</div>
{loading && (
<div className='bg-components-panel-bg/70 absolute inset-0 z-30 flex items-center justify-center text-sm text-text-tertiary'>
{t('workflow.comments.loading')}
</div>
)}
{onReply && (
<div className='border-t border-components-panel-border px-4 py-3'>
<div className='flex items-center gap-3'>
<Avatar
avatar={userProfile?.avatar_url || null}
name={userProfile?.name || t('common.you')}
size={24}
className='h-8 w-8'
/>
<div className='flex-1 rounded-xl border border-components-chat-input-border bg-components-panel-bg-blur p-[2px] shadow-sm'>
<MentionInput
value={replyContent}
onChange={setReplyContent}
onSubmit={handleReplySubmit}
placeholder={t('workflow.comments.placeholder.reply')}
disabled={loading}
loading={loading}
/>
</div>
</div>
</div>
)}
</div>
</div>
)
})
CommentThread.displayName = 'CommentThread'

View File

@ -7,21 +7,23 @@ import { useStore } from './store'
import {
useIsChatMode,
useNodesReadOnly,
useNodesSyncDraft,
} from './hooks'
import { type CommonNodeType, type InputVar, InputVarType, type Node } from './types'
import useConfig from './nodes/start/use-config'
import type { StartNodeType } from './nodes/start/types'
import type { PromptVariable } from '@/models/debug'
import NewFeaturePanel from '@/app/components/base/features/new-feature-panel'
import { webSocketClient } from '@/app/components/workflow/collaboration/core/websocket-manager'
import { useFeaturesStore } from '@/app/components/base/features/hooks'
import { updateFeatures } from '@/service/workflow'
const Features = () => {
const setShowFeaturesPanel = useStore(s => s.setShowFeaturesPanel)
const appId = useStore(s => s.appId)
const isChatMode = useIsChatMode()
const { nodesReadOnly } = useNodesReadOnly()
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
const featuresStore = useFeaturesStore()
const nodes = useNodes<CommonNodeType>()
const startNode = nodes.find(node => node.data.type === 'start')
const { id, data } = startNode as Node<StartNodeType>
const { handleAddVariable } = useConfig(id, data)
@ -39,10 +41,45 @@ const Features = () => {
handleAddVariable(startNodeVariable)
}
const handleFeaturesChange = useCallback(() => {
handleSyncWorkflowDraft()
const handleFeaturesChange = useCallback(async () => {
if (!appId || !featuresStore) return
try {
const currentFeatures = featuresStore.getState().features
// Transform features to match the expected server format (same as doSyncWorkflowDraft)
const transformedFeatures = {
opening_statement: currentFeatures.opening?.enabled ? (currentFeatures.opening?.opening_statement || '') : '',
suggested_questions: currentFeatures.opening?.enabled ? (currentFeatures.opening?.suggested_questions || []) : [],
suggested_questions_after_answer: currentFeatures.suggested,
text_to_speech: currentFeatures.text2speech,
speech_to_text: currentFeatures.speech2text,
retriever_resource: currentFeatures.citation,
sensitive_word_avoidance: currentFeatures.moderation,
file_upload: currentFeatures.file,
}
console.log('Sending features to server:', transformedFeatures)
await updateFeatures({
appId,
features: transformedFeatures,
})
// Emit update event to other connected clients
const socket = webSocketClient.getSocket(appId)
if (socket) {
socket.emit('collaboration_event', {
type: 'varsAndFeaturesUpdate',
})
}
}
catch (error) {
console.error('Failed to update features:', error)
}
setShowFeaturesPanel(true)
}, [handleSyncWorkflowDraft, setShowFeaturesPanel])
}, [appId, featuresStore, setShowFeaturesPanel])
return (
<NewFeaturePanel

View File

@ -18,6 +18,7 @@ import RunAndHistory from './run-and-history'
import EditingTitle from './editing-title'
import EnvButton from './env-button'
import VersionHistoryButton from './version-history-button'
import OnlineUsers from './online-users'
import { useInputFieldPanel } from '@/app/components/rag-pipeline/hooks'
export type HeaderInNormalProps = {
@ -64,7 +65,9 @@ const HeaderInNormal = ({
<EditingTitle />
</div>
<div className='flex items-center gap-2'>
<OnlineUsers />
{components?.left}
<Divider type='vertical' className='mx-auto h-3.5' />
<EnvButton disabled={nodesReadOnly} />
<Divider type='vertical' className='mx-auto h-3.5' />
<RunAndHistory {...runAndHistoryProps} />

View File

@ -19,8 +19,10 @@ import RestoringTitle from './restoring-title'
import Button from '@/app/components/base/button'
import { useInvalidAllLastRun } from '@/service/use-workflow'
import { useHooksStore } from '../hooks-store'
import { useStore as useAppStore } from '@/app/components/app/store'
import useTheme from '@/hooks/use-theme'
import cn from '@/utils/classnames'
import { collaborationManager } from '../collaboration/core/collaboration-manager'
export type HeaderInRestoringProps = {
onRestoreSettled?: () => void
@ -31,6 +33,7 @@ const HeaderInRestoring = ({
const { t } = useTranslation()
const { theme } = useTheme()
const workflowStore = useWorkflowStore()
const appDetail = useAppStore.getState().appDetail
const configsMap = useHooksStore(s => s.configsMap)
const invalidAllLastRun = useInvalidAllLastRun(configsMap?.flowType, configsMap?.flowId)
const {
@ -60,6 +63,9 @@ const HeaderInRestoring = ({
type: 'success',
message: t('workflow.versionHistory.action.restoreSuccess'),
})
// Notify other collaboration clients about the workflow restore
if (appDetail)
collaborationManager.emitWorkflowUpdate(appDetail.id)
},
onError: () => {
Toast.notify({
@ -70,10 +76,10 @@ const HeaderInRestoring = ({
onSettled: () => {
onRestoreSettled?.()
},
})
}, true) // Enable forceUpload for restore operation
deleteAllInspectVars()
invalidAllLastRun()
}, [setShowWorkflowVersionHistoryPanel, workflowStore, handleSyncWorkflowDraft, deleteAllInspectVars, invalidAllLastRun, t, onRestoreSettled])
}, [setShowWorkflowVersionHistoryPanel, workflowStore, handleSyncWorkflowDraft, deleteAllInspectVars, invalidAllLastRun, t, onRestoreSettled, appDetail])
return (
<>

View File

@ -0,0 +1,195 @@
'use client'
import { useEffect, useState } from 'react'
import { useReactFlow } from 'reactflow'
import Avatar from '@/app/components/base/avatar'
import { useCollaboration } from '../collaboration/hooks/use-collaboration'
import { useStore } from '../store'
import cn from '@/utils/classnames'
import { ChevronDown } from '@/app/components/base/icons/src/vender/solid/arrows'
import { getUserColor } from '../collaboration/utils/user-color'
import Tooltip from '@/app/components/base/tooltip'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import { useAppContext } from '@/context/app-context'
import { getAvatar } from '@/service/common'
const useAvatarUrls = (users: any[]) => {
const [avatarUrls, setAvatarUrls] = useState<Record<string, string>>({})
useEffect(() => {
const fetchAvatars = async () => {
const newAvatarUrls: Record<string, string> = {}
await Promise.all(
users.map(async (user) => {
if (user.avatar) {
try {
const response = await getAvatar({ avatar: user.avatar })
newAvatarUrls[user.sid] = response.avatar_url
}
catch (error) {
console.error('Failed to fetch avatar:', error)
newAvatarUrls[user.sid] = user.avatar
}
}
}),
)
setAvatarUrls(newAvatarUrls)
}
if (users.length > 0)
fetchAvatars()
}, [users])
return avatarUrls
}
const OnlineUsers = () => {
const appId = useStore(s => s.appId)
const { onlineUsers, cursors } = useCollaboration(appId as string)
const { userProfile } = useAppContext()
const reactFlow = useReactFlow()
const [dropdownOpen, setDropdownOpen] = useState(false)
const avatarUrls = useAvatarUrls(onlineUsers || [])
const currentUserId = userProfile?.id
// Function to jump to user's cursor position
const jumpToUserCursor = (userId: string) => {
const cursor = cursors[userId]
if (!cursor) return
// Convert world coordinates to center the view on the cursor
reactFlow.setCenter(cursor.x, cursor.y, { zoom: 1, duration: 800 })
}
if (!onlineUsers || onlineUsers.length === 0)
return null
// Display logic:
// 1-3 users: show all avatars
// 4+ users: show 2 avatars + count + arrow
const shouldShowCount = onlineUsers.length >= 4
const maxVisible = shouldShowCount ? 2 : 3
const visibleUsers = onlineUsers.slice(0, maxVisible)
const remainingCount = onlineUsers.length - maxVisible
const getAvatarUrl = (user: any) => {
return avatarUrls[user.sid] || user.avatar
}
return (
<div className="flex items-center rounded-full border border-components-panel-border bg-components-panel-bg px-1 py-1">
<div className="flex items-center">
<div className="flex items-center -space-x-2">
{visibleUsers.map((user, index) => {
const isCurrentUser = user.user_id === currentUserId
const userColor = isCurrentUser ? undefined : getUserColor(user.user_id)
const displayName = isCurrentUser
? `${user.username || 'User'} (You)`
: (user.username || 'User')
return (
<Tooltip
key={`${user.sid}-${index}`}
popupContent={displayName}
position="bottom"
triggerMethod="hover"
needsDelay={false}
asChild
>
<div
className={cn(
'relative',
!isCurrentUser && 'cursor-pointer transition-transform hover:scale-110',
)}
style={{ zIndex: visibleUsers.length - index }}
onClick={() => !isCurrentUser && jumpToUserCursor(user.user_id)}
>
<Avatar
name={user.username || 'User'}
avatar={getAvatarUrl(user)}
size={28}
className="ring-2 ring-components-panel-bg"
backgroundColor={userColor}
/>
</div>
</Tooltip>
)
})}
{remainingCount > 0 && (
<PortalToFollowElem
open={dropdownOpen}
onOpenChange={setDropdownOpen}
placement="bottom-start"
>
<PortalToFollowElemTrigger
onMouseEnter={() => setDropdownOpen(true)}
onMouseLeave={() => setDropdownOpen(false)}
asChild
>
<div className="flex items-center">
<div
className={cn(
'flex h-7 w-7 items-center justify-center',
'cursor-pointer rounded-full bg-gray-300',
'text-xs font-medium text-gray-700',
'ring-2 ring-components-panel-bg',
)}
>
+{remainingCount}
</div>
<ChevronDown className="ml-1 h-3 w-3 text-gray-500" />
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent
onMouseEnter={() => setDropdownOpen(true)}
onMouseLeave={() => setDropdownOpen(false)}
className="z-[9999]"
>
<div className="mt-2 min-w-[200px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg p-1 shadow-lg">
{onlineUsers.map((user) => {
const isCurrentUser = user.user_id === currentUserId
const userColor = isCurrentUser ? undefined : getUserColor(user.user_id)
const displayName = isCurrentUser
? `${user.username || 'User'} (You)`
: (user.username || 'User')
return (
<div
key={user.sid}
className={cn(
'flex items-center gap-2 rounded-lg px-3 py-2',
!isCurrentUser && 'cursor-pointer hover:bg-components-panel-on-panel-item-bg-hover',
)}
onClick={() => !isCurrentUser && jumpToUserCursor(user.user_id)}
>
<div className="relative">
<Avatar
name={user.username || 'User'}
avatar={getAvatarUrl(user)}
size={24}
backgroundColor={userColor}
/>
</div>
<span className="text-sm text-text-secondary">
{displayName}
</span>
</div>
)
})}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
)}
</div>
</div>
</div>
)
}
export default OnlineUsers

View File

@ -6,27 +6,39 @@ import {
RiArrowGoForwardFill,
} from '@remixicon/react'
import TipPopup from '../operator/tip-popup'
import { useWorkflowHistoryStore } from '../workflow-history-store'
import Divider from '../../base/divider'
import { useNodesReadOnly } from '@/app/components/workflow/hooks'
import ViewWorkflowHistory from '@/app/components/workflow/header/view-workflow-history'
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
import classNames from '@/utils/classnames'
export type UndoRedoProps = { handleUndo: () => void; handleRedo: () => void }
const UndoRedo: FC<UndoRedoProps> = ({ handleUndo, handleRedo }) => {
const { t } = useTranslation()
const { store } = useWorkflowHistoryStore()
const [buttonsDisabled, setButtonsDisabled] = useState({ undo: true, redo: true })
useEffect(() => {
const unsubscribe = store.temporal.subscribe((state) => {
// Update button states based on Loro's UndoManager
const updateButtonStates = () => {
setButtonsDisabled({
undo: state.pastStates.length === 0,
redo: state.futureStates.length === 0,
undo: !collaborationManager.canUndo(),
redo: !collaborationManager.canRedo(),
})
}
// Initial state
updateButtonStates()
// Listen for undo/redo state changes
const unsubscribe = collaborationManager.onUndoRedoStateChange((state) => {
setButtonsDisabled({
undo: !state.canUndo,
redo: !state.canRedo,
})
})
return () => unsubscribe()
}, [store])
}, [])
const { nodesReadOnly } = useNodesReadOnly()

View File

@ -33,7 +33,8 @@ export type CommonHooksFnMap = {
onSuccess?: () => void
onError?: () => void
onSettled?: () => void
}
},
forceUpload?: boolean
) => Promise<void>
syncWorkflowDraftWhenPageClose: () => void
handleRefreshWorkflowDraft: () => void

View File

@ -22,3 +22,4 @@ export * from './use-DSL'
export * from './use-inspect-vars-crud'
export * from './use-set-workflow-vars-with-value'
export * from './use-workflow-search'
export * from './use-workflow-comment'

View File

@ -116,7 +116,9 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
if (node.type === CUSTOM_NODE) {
const checkData = getCheckData(node.data)
let { errorMessage } = nodesExtraData![node.data.type].checkValid(checkData, t, moreDataForCheckValid)
// temp fix nodeMetaData is undefined
const nodeMetaData = nodesExtraData?.[node.data.type]
let { errorMessage } = nodeMetaData?.checkValid ? nodeMetaData.checkValid(checkData, t, moreDataForCheckValid) : { errorMessage: undefined }
if (!errorMessage) {
const availableVars = map[node.id].availableVars

View File

@ -0,0 +1,84 @@
import { useCallback } from 'react'
import { useStoreApi } from 'reactflow'
import type { Edge, Node } from '../types'
import { collaborationManager } from '../collaboration/core/collaboration-manager'
const sanitizeNodeForBroadcast = (node: Node): Node => {
if (!node.data)
return node
if (!Object.prototype.hasOwnProperty.call(node.data, 'selected'))
return node
const sanitizedData = { ...node.data }
delete (sanitizedData as Record<string, unknown>).selected
return {
...node,
data: sanitizedData,
}
}
const sanitizeEdgeForBroadcast = (edge: Edge): Edge => {
if (!edge.data)
return edge
if (!Object.prototype.hasOwnProperty.call(edge.data, '_connectedNodeIsSelected'))
return edge
const sanitizedData = { ...edge.data }
delete (sanitizedData as Record<string, unknown>)._connectedNodeIsSelected
return {
...edge,
data: sanitizedData,
}
}
export const useCollaborativeWorkflow = () => {
const store = useStoreApi()
const { setNodes: collabSetNodes, setEdges: collabSetEdges } = collaborationManager
const setNodes = useCallback((newNodes: Node[], shouldBroadcast: boolean = true) => {
const { getNodes, setNodes: reactFlowSetNodes } = store.getState()
if (shouldBroadcast) {
const oldNodes = getNodes()
collabSetNodes(
oldNodes.map(sanitizeNodeForBroadcast),
newNodes.map(sanitizeNodeForBroadcast),
)
}
reactFlowSetNodes(newNodes)
}, [store, collabSetNodes])
const setEdges = useCallback((newEdges: Edge[], shouldBroadcast: boolean = true) => {
const { edges, setEdges: reactFlowSetEdges } = store.getState()
if (shouldBroadcast) {
collabSetEdges(
edges.map(sanitizeEdgeForBroadcast),
newEdges.map(sanitizeEdgeForBroadcast),
)
}
reactFlowSetEdges(newEdges)
}, [store, collabSetEdges])
const collaborativeStore = useCallback(() => {
const state = store.getState()
return {
nodes: state.getNodes(),
edges: state.edges,
setNodes,
setEdges,
}
}, [store, setNodes, setEdges])
return {
getState: collaborativeStore,
setNodes,
setEdges,
}
}

View File

@ -4,9 +4,7 @@ import type {
EdgeMouseHandler,
OnEdgesChange,
} from 'reactflow'
import {
useStoreApi,
} from 'reactflow'
import type {
Node,
} from '../types'
@ -14,61 +12,55 @@ import { getNodesConnectedSourceOrTargetHandleIdsMap } from '../utils'
import { useNodesSyncDraft } from './use-nodes-sync-draft'
import { useNodesReadOnly } from './use-workflow'
import { WorkflowHistoryEvent, useWorkflowHistory } from './use-workflow-history'
import { useCollaborativeWorkflow } from './use-collaborative-workflow'
export const useEdgesInteractions = () => {
const store = useStoreApi()
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
const { getNodesReadOnly } = useNodesReadOnly()
const { saveStateToHistory } = useWorkflowHistory()
const collaborativeWorkflow = useCollaborativeWorkflow()
const handleEdgeEnter = useCallback<EdgeMouseHandler>((_, edge) => {
if (getNodesReadOnly())
return
const {
edges,
setEdges,
} = store.getState()
const { edges, setEdges } = collaborativeWorkflow.getState()
const newEdges = produce(edges, (draft) => {
const currentEdge = draft.find(e => e.id === edge.id)!
currentEdge.data._hovering = true
})
setEdges(newEdges)
}, [store, getNodesReadOnly])
setEdges(newEdges, false)
}, [collaborativeWorkflow, getNodesReadOnly])
const handleEdgeLeave = useCallback<EdgeMouseHandler>((_, edge) => {
if (getNodesReadOnly())
return
const {
edges,
setEdges,
} = store.getState()
const { edges, setEdges } = collaborativeWorkflow.getState()
const newEdges = produce(edges, (draft) => {
const currentEdge = draft.find(e => e.id === edge.id)!
currentEdge.data._hovering = false
})
setEdges(newEdges)
}, [store, getNodesReadOnly])
setEdges(newEdges, false)
}, [collaborativeWorkflow, getNodesReadOnly])
const handleEdgeDeleteByDeleteBranch = useCallback((nodeId: string, branchId: string) => {
if (getNodesReadOnly())
return
const {
getNodes,
nodes,
setNodes,
edges,
setEdges,
} = store.getState()
} = collaborativeWorkflow.getState()
const edgeWillBeDeleted = edges.filter(edge => edge.source === nodeId && edge.sourceHandle === branchId)
if (!edgeWillBeDeleted.length)
return
const nodes = getNodes()
const nodesConnectedSourceOrTargetHandleIdsMap = getNodesConnectedSourceOrTargetHandleIdsMap(
edgeWillBeDeleted.map(edge => ({ type: 'remove', edge })),
nodes,
@ -90,24 +82,23 @@ export const useEdgesInteractions = () => {
setEdges(newEdges)
handleSyncWorkflowDraft()
saveStateToHistory(WorkflowHistoryEvent.EdgeDeleteByDeleteBranch)
}, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory])
}, [getNodesReadOnly, collaborativeWorkflow, handleSyncWorkflowDraft, saveStateToHistory])
const handleEdgeDelete = useCallback(() => {
if (getNodesReadOnly())
return
const {
getNodes,
nodes,
setNodes,
edges,
setEdges,
} = store.getState()
} = collaborativeWorkflow.getState()
const currentEdgeIndex = edges.findIndex(edge => edge.selected)
if (currentEdgeIndex < 0)
return
const currentEdge = edges[currentEdgeIndex]
const nodes = getNodes()
const nodesConnectedSourceOrTargetHandleIdsMap = getNodesConnectedSourceOrTargetHandleIdsMap(
[
{ type: 'remove', edge: currentEdge },
@ -131,7 +122,7 @@ export const useEdgesInteractions = () => {
setEdges(newEdges)
handleSyncWorkflowDraft()
saveStateToHistory(WorkflowHistoryEvent.EdgeDelete)
}, [getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory])
}, [getNodesReadOnly, collaborativeWorkflow, handleSyncWorkflowDraft, saveStateToHistory])
const handleEdgesChange = useCallback<OnEdgesChange>((changes) => {
if (getNodesReadOnly())
@ -140,7 +131,7 @@ export const useEdgesInteractions = () => {
const {
edges,
setEdges,
} = store.getState()
} = collaborativeWorkflow.getState()
const newEdges = produce(edges, (draft) => {
changes.forEach((change) => {
@ -149,7 +140,7 @@ export const useEdgesInteractions = () => {
})
})
setEdges(newEdges)
}, [store, getNodesReadOnly])
}, [collaborativeWorkflow, getNodesReadOnly])
return {
handleEdgeEnter,

View File

@ -3,6 +3,7 @@ import produce from 'immer'
import { useStoreApi } from 'reactflow'
import { useNodesSyncDraft } from './use-nodes-sync-draft'
import { useNodesReadOnly } from './use-workflow'
import { useCollaborativeWorkflow } from './use-collaborative-workflow'
type NodeDataUpdatePayload = {
id: string
@ -13,13 +14,11 @@ export const useNodeDataUpdate = () => {
const store = useStoreApi()
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
const { getNodesReadOnly } = useNodesReadOnly()
const collaborativeWorkflow = useCollaborativeWorkflow()
const handleNodeDataUpdate = useCallback(({ id, data }: NodeDataUpdatePayload) => {
const {
getNodes,
setNodes,
} = store.getState()
const newNodes = produce(getNodes(), (draft) => {
const { nodes, setNodes } = collaborativeWorkflow.getState()
const newNodes = produce(nodes, (draft) => {
const currentNode = draft.find(node => node.id === id)!
if (currentNode)

View File

@ -14,7 +14,6 @@ import {
getConnectedEdges,
getOutgoers,
useReactFlow,
useStoreApi,
} from 'reactflow'
import type { ToolDefaultValue } from '../block-selector/types'
import type { Edge, Node, OnNodeAdd } from '../types'
@ -46,7 +45,7 @@ import { CUSTOM_LOOP_START_NODE } from '../nodes/loop-start/constants'
import type { VariableAssignerNodeType } from '../nodes/variable-assigner/types'
import { useNodeIterationInteractions } from '../nodes/iteration/use-interactions'
import { useNodeLoopInteractions } from '../nodes/loop/use-interactions'
import { useWorkflowHistoryStore } from '../workflow-history-store'
import { collaborationManager } from '../collaboration/core/collaboration-manager'
import { useNodesSyncDraft } from './use-nodes-sync-draft'
import { useHelpline } from './use-helpline'
import {
@ -62,13 +61,13 @@ import { useNodesMetaData } from './use-nodes-meta-data'
import type { RAGPipelineVariables } from '@/models/pipeline'
import useInspectVarsCrud from './use-inspect-vars-crud'
import { getNodeUsedVars } from '../nodes/_base/components/variable/utils'
import { useCollaborativeWorkflow } from './use-collaborative-workflow'
export const useNodesInteractions = () => {
const { t } = useTranslation()
const store = useStoreApi()
const collaborativeWorkflow = useCollaborativeWorkflow()
const workflowStore = useWorkflowStore()
const reactflow = useReactFlow()
const { store: workflowHistoryStore } = useWorkflowHistoryStore()
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
const { getAfterNodesInSameBranch } = useWorkflow()
const { getNodesReadOnly } = useNodesReadOnly()
@ -84,7 +83,7 @@ export const useNodesInteractions = () => {
})
const { nodesMap: nodesMetaDataMap } = useNodesMetaData()
const { saveStateToHistory, undo, redo } = useWorkflowHistory()
const { saveStateToHistory } = useWorkflowHistory()
const handleNodeDragStart = useCallback<NodeDragHandler>(
(_, node) => {
@ -120,19 +119,18 @@ export const useNodesInteractions = () => {
if (node.type === CUSTOM_LOOP_START_NODE) return
const { getNodes, setNodes } = store.getState()
e.stopPropagation()
const nodes = getNodes()
const { nodes, setNodes } = collaborativeWorkflow.getState()
const { restrictPosition } = handleNodeIterationChildDrag(node)
const { restrictPosition: restrictLoopPosition }
= handleNodeLoopChildDrag(node)
= handleNodeLoopChildDrag(node)
const { showHorizontalHelpLineNodes, showVerticalHelpLineNodes }
= handleSetHelpline(node)
= handleSetHelpline(node)
const showHorizontalHelpLineNodesLength
= showHorizontalHelpLineNodes.length
= showHorizontalHelpLineNodes.length
const showVerticalHelpLineNodesLength = showVerticalHelpLineNodes.length
const newNodes = produce(nodes, (draft) => {
@ -152,18 +150,11 @@ export const useNodesInteractions = () => {
currentNode.position.y = restrictPosition.y
else if (restrictLoopPosition.y !== undefined)
currentNode.position.y = restrictLoopPosition.y
else currentNode.position.y = node.position.y
else
currentNode.position.y = node.position.y
})
setNodes(newNodes)
},
[
getNodesReadOnly,
store,
handleNodeIterationChildDrag,
handleNodeLoopChildDrag,
handleSetHelpline,
],
)
}, [getNodesReadOnly, collaborativeWorkflow, handleNodeIterationChildDrag, handleNodeLoopChildDrag, handleSetHelpline])
const handleNodeDragStop = useCallback<NodeDragHandler>(
(_, node) => {
@ -210,11 +201,11 @@ export const useNodesInteractions = () => {
)
return
const { getNodes, setNodes, edges, setEdges } = store.getState()
const nodes = getNodes()
const { connectingNodePayload, setEnteringNodePayload }
= workflowStore.getState()
const { nodes, edges, setNodes, setEdges } = collaborativeWorkflow.getState()
const {
connectingNodePayload,
setEnteringNodePayload,
} = workflowStore.getState()
if (connectingNodePayload) {
if (connectingNodePayload.nodeId === node.id) return
const connectingNode: Node = nodes.find(
@ -233,25 +224,25 @@ export const useNodesInteractions = () => {
draft.forEach((n) => {
if (
n.id === node.id
&& fromType === 'source'
&& (node.data.type === BlockEnum.VariableAssigner
|| node.data.type === BlockEnum.VariableAggregator)
&& fromType === 'source'
&& (node.data.type === BlockEnum.VariableAssigner
|| node.data.type === BlockEnum.VariableAggregator)
) {
if (!node.data.advanced_settings?.group_enabled)
n.data._isEntering = true
}
if (
n.id === node.id
&& fromType === 'target'
&& (connectingNode.data.type === BlockEnum.VariableAssigner
|| connectingNode.data.type === BlockEnum.VariableAggregator)
&& node.data.type !== BlockEnum.IfElse
&& node.data.type !== BlockEnum.QuestionClassifier
&& fromType === 'target'
&& (connectingNode.data.type === BlockEnum.VariableAssigner
|| connectingNode.data.type === BlockEnum.VariableAggregator)
&& node.data.type !== BlockEnum.IfElse
&& node.data.type !== BlockEnum.QuestionClassifier
)
n.data._isEntering = true
})
})
setNodes(newNodes)
setNodes(newNodes, false)
}
}
const newEdges = produce(edges, (draft) => {
@ -262,9 +253,9 @@ export const useNodesInteractions = () => {
if (currentEdge) currentEdge.data._connectedNodeIsHovering = true
})
})
setEdges(newEdges)
setEdges(newEdges, false)
},
[store, workflowStore, getNodesReadOnly],
[collaborativeWorkflow, workflowStore, getNodesReadOnly],
)
const handleNodeLeave = useCallback<NodeMouseHandler>(
@ -285,21 +276,21 @@ export const useNodesInteractions = () => {
const { setEnteringNodePayload } = workflowStore.getState()
setEnteringNodePayload(undefined)
const { getNodes, setNodes, edges, setEdges } = store.getState()
const newNodes = produce(getNodes(), (draft) => {
const { nodes, setNodes, edges, setEdges } = collaborativeWorkflow.getState()
const newNodes = produce(nodes, (draft) => {
draft.forEach((node) => {
node.data._isEntering = false
})
})
setNodes(newNodes)
setNodes(newNodes, false)
const newEdges = produce(edges, (draft) => {
draft.forEach((edge) => {
edge.data._connectedNodeIsHovering = false
})
})
setEdges(newEdges)
setEdges(newEdges, false)
},
[store, workflowStore, getNodesReadOnly],
[collaborativeWorkflow, workflowStore, getNodesReadOnly],
)
const handleNodeSelect = useCallback(
@ -310,9 +301,7 @@ export const useNodesInteractions = () => {
) => {
if (initShowLastRunTab)
workflowStore.setState({ initShowLastRunTab: true })
const { getNodes, setNodes, edges, setEdges } = store.getState()
const nodes = getNodes()
const { nodes, setNodes, edges, setEdges } = collaborativeWorkflow.getState()
const selectedNode = nodes.find(node => node.data.selected)
if (!cancelSelection && selectedNode?.id === nodeId) return
@ -323,7 +312,7 @@ export const useNodesInteractions = () => {
else node.data.selected = false
})
})
setNodes(newNodes)
setNodes(newNodes, false)
const connectedEdges = getConnectedEdges(
[{ id: nodeId } as Node],
@ -345,12 +334,8 @@ export const useNodesInteractions = () => {
}
})
})
setEdges(newEdges)
handleSyncWorkflowDraft()
},
[store, handleSyncWorkflowDraft],
)
setEdges(newEdges, false)
}, [collaborativeWorkflow])
const handleNodeClick = useCallback<NodeMouseHandler>(
(_, node) => {
@ -367,8 +352,7 @@ export const useNodesInteractions = () => {
if (source === target) return
if (getNodesReadOnly()) return
const { getNodes, setNodes, edges, setEdges } = store.getState()
const nodes = getNodes()
const { nodes, edges, setNodes, setEdges } = collaborativeWorkflow.getState()
const targetNode = nodes.find(node => node.id === target!)
const sourceNode = nodes.find(node => node.id === source!)
@ -446,7 +430,7 @@ export const useNodesInteractions = () => {
},
[
getNodesReadOnly,
store,
collaborativeWorkflow,
workflowStore,
handleSyncWorkflowDraft,
saveStateToHistory,
@ -459,8 +443,8 @@ export const useNodesInteractions = () => {
if (nodeId && handleType) {
const { setConnectingNodePayload } = workflowStore.getState()
const { getNodes } = store.getState()
const node = getNodes().find(n => n.id === nodeId)!
const { nodes } = collaborativeWorkflow.getState()
const node = nodes.find(n => n.id === nodeId)!
if (node.type === CUSTOM_NOTE_NODE) return
@ -477,9 +461,7 @@ export const useNodesInteractions = () => {
handleId,
})
}
},
[store, workflowStore, getNodesReadOnly],
)
}, [collaborativeWorkflow, workflowStore, getNodesReadOnly])
const handleNodeConnectEnd = useCallback<OnConnectEnd>(
(e: any) => {
@ -495,8 +477,7 @@ export const useNodesInteractions = () => {
const { setShowAssignVariablePopup, hoveringAssignVariableGroupId }
= workflowStore.getState()
const { screenToFlowPosition } = reactflow
const { getNodes, setNodes } = store.getState()
const nodes = getNodes()
const { nodes, setNodes } = collaborativeWorkflow.getState()
const fromHandleType = connectingNodePayload.handleType
const fromHandleId = connectingNodePayload.handleId
const fromNode = nodes.find(
@ -553,7 +534,7 @@ export const useNodesInteractions = () => {
setConnectingNodePayload(undefined)
setEnteringNodePayload(undefined)
},
[store, handleNodeConnect, getNodesReadOnly, workflowStore, reactflow],
[collaborativeWorkflow, handleNodeConnect, getNodesReadOnly, workflowStore, reactflow],
)
const { deleteNodeInspectorVars } = useInspectVarsCrud()
@ -562,9 +543,7 @@ export const useNodesInteractions = () => {
(nodeId: string) => {
if (getNodesReadOnly()) return
const { getNodes, setNodes, edges, setEdges } = store.getState()
const nodes = getNodes()
const { nodes, setNodes, edges, setEdges } = collaborativeWorkflow.getState()
const currentNodeIndex = nodes.findIndex(node => node.id === nodeId)
const currentNode = nodes[currentNodeIndex]
@ -719,7 +698,7 @@ export const useNodesInteractions = () => {
},
[
getNodesReadOnly,
store,
collaborativeWorkflow,
handleSyncWorkflowDraft,
saveStateToHistory,
workflowStore,
@ -741,8 +720,7 @@ export const useNodesInteractions = () => {
) => {
if (getNodesReadOnly()) return
const { getNodes, setNodes, edges, setEdges } = store.getState()
const nodes = getNodes()
const { nodes, setNodes, edges, setEdges } = collaborativeWorkflow.getState()
const nodesWithSameType = nodes.filter(
node => node.data.type === nodeType,
)
@ -1272,7 +1250,7 @@ export const useNodesInteractions = () => {
},
[
getNodesReadOnly,
store,
collaborativeWorkflow,
handleSyncWorkflowDraft,
saveStateToHistory,
workflowStore,
@ -1290,8 +1268,7 @@ export const useNodesInteractions = () => {
) => {
if (getNodesReadOnly()) return
const { getNodes, setNodes, edges, setEdges } = store.getState()
const nodes = getNodes()
const { nodes, setNodes, edges, setEdges } = collaborativeWorkflow.getState()
const currentNode = nodes.find(node => node.id === currentNodeId)!
const connectedEdges = getConnectedEdges([currentNode], edges)
const nodesWithSameType = nodes.filter(
@ -1369,7 +1346,7 @@ export const useNodesInteractions = () => {
},
[
getNodesReadOnly,
store,
collaborativeWorkflow,
handleSyncWorkflowDraft,
saveStateToHistory,
nodesMetaDataMap,
@ -1377,16 +1354,14 @@ export const useNodesInteractions = () => {
)
const handleNodesCancelSelected = useCallback(() => {
const { getNodes, setNodes } = store.getState()
const nodes = getNodes()
const { nodes, setNodes } = collaborativeWorkflow.getState()
const newNodes = produce(nodes, (draft) => {
draft.forEach((node) => {
node.data.selected = false
})
})
setNodes(newNodes)
}, [store])
}, [collaborativeWorkflow])
const handleNodeContextMenu = useCallback(
(e: MouseEvent, node: Node) => {
@ -1423,9 +1398,7 @@ export const useNodesInteractions = () => {
const { setClipboardElements } = workflowStore.getState()
const { getNodes } = store.getState()
const nodes = getNodes()
const { nodes } = collaborativeWorkflow.getState()
if (nodeId) {
// If nodeId is provided, copy that specific node
@ -1464,7 +1437,7 @@ export const useNodesInteractions = () => {
if (selectedNode) setClipboardElements([selectedNode])
}
},
[getNodesReadOnly, store, workflowStore],
[getNodesReadOnly, collaborativeWorkflow, workflowStore],
)
const handleNodesPaste = useCallback(() => {
@ -1472,11 +1445,10 @@ export const useNodesInteractions = () => {
const { clipboardElements, mousePosition } = workflowStore.getState()
const { getNodes, setNodes, edges, setEdges } = store.getState()
const { nodes, setNodes, edges, setEdges } = collaborativeWorkflow.getState()
const nodesToPaste: Node[] = []
const edgesToPaste: Edge[] = []
const nodes = getNodes()
if (clipboardElements.length) {
const { x, y } = getTopLeftNodePosition(clipboardElements)
@ -1628,7 +1600,7 @@ export const useNodesInteractions = () => {
}, [
getNodesReadOnly,
workflowStore,
store,
collaborativeWorkflow,
reactflow,
saveStateToHistory,
handleSyncWorkflowDraft,
@ -1650,9 +1622,8 @@ export const useNodesInteractions = () => {
const handleNodesDelete = useCallback(() => {
if (getNodesReadOnly()) return
const { getNodes, edges } = store.getState()
const { nodes, edges } = collaborativeWorkflow.getState()
const nodes = getNodes()
const bundledNodes = nodes.filter(
node => node.data._isBundled && node.data.type !== BlockEnum.Start,
)
@ -1671,16 +1642,15 @@ export const useNodesInteractions = () => {
)
if (selectedNode) handleNodeDelete(selectedNode.id)
}, [store, getNodesReadOnly, handleNodeDelete])
}, [collaborativeWorkflow, getNodesReadOnly, handleNodeDelete])
const handleNodeResize = useCallback(
(nodeId: string, params: ResizeParamsWithDirection) => {
if (getNodesReadOnly()) return
const { getNodes, setNodes } = store.getState()
const { nodes, setNodes } = collaborativeWorkflow.getState()
const { x, y, width, height } = params
const nodes = getNodes()
const currentNode = nodes.find(n => n.id === nodeId)!
const childrenNodes = nodes.filter(n =>
currentNode.data._children?.find((c: any) => c.nodeId === n.id),
@ -1739,15 +1709,14 @@ export const useNodesInteractions = () => {
handleSyncWorkflowDraft()
saveStateToHistory(WorkflowHistoryEvent.NodeResize, { nodeId })
},
[getNodesReadOnly, store, handleSyncWorkflowDraft, saveStateToHistory],
[getNodesReadOnly, collaborativeWorkflow, handleSyncWorkflowDraft, saveStateToHistory],
)
const handleNodeDisconnect = useCallback(
(nodeId: string) => {
if (getNodesReadOnly()) return
const { getNodes, setNodes, edges, setEdges } = store.getState()
const nodes = getNodes()
const { nodes, setNodes, edges, setEdges } = collaborativeWorkflow.getState()
const currentNode = nodes.find(node => node.id === nodeId)!
const connectedEdges = getConnectedEdges([currentNode], edges)
const nodesConnectedSourceOrTargetHandleIdsMap
@ -1778,24 +1747,24 @@ export const useNodesInteractions = () => {
handleSyncWorkflowDraft()
saveStateToHistory(WorkflowHistoryEvent.EdgeDelete)
},
[store, getNodesReadOnly, handleSyncWorkflowDraft, saveStateToHistory],
[collaborativeWorkflow, getNodesReadOnly, handleSyncWorkflowDraft, saveStateToHistory],
)
const handleHistoryBack = useCallback(() => {
if (getNodesReadOnly() || getWorkflowReadOnly()) return
const { setEdges, setNodes } = store.getState()
undo()
// Use collaborative undo from Loro
const undoResult = collaborationManager.undo()
const { edges, nodes } = workflowHistoryStore.getState()
if (edges.length === 0 && nodes.length === 0) return
setEdges(edges)
setNodes(nodes)
if (undoResult) {
// The undo operation will automatically trigger subscriptions
// which will update the nodes and edges through setupSubscriptions
console.log('Collaborative undo performed')
}
else {
console.log('Nothing to undo')
}
}, [
store,
undo,
workflowHistoryStore,
getNodesReadOnly,
getWorkflowReadOnly,
])
@ -1803,18 +1772,18 @@ export const useNodesInteractions = () => {
const handleHistoryForward = useCallback(() => {
if (getNodesReadOnly() || getWorkflowReadOnly()) return
const { setEdges, setNodes } = store.getState()
redo()
// Use collaborative redo from Loro
const redoResult = collaborationManager.redo()
const { edges, nodes } = workflowHistoryStore.getState()
if (edges.length === 0 && nodes.length === 0) return
setEdges(edges)
setNodes(nodes)
if (redoResult) {
// The redo operation will automatically trigger subscriptions
// which will update the nodes and edges through setupSubscriptions
console.log('Collaborative redo performed')
}
else {
console.log('Nothing to redo')
}
}, [
redo,
store,
workflowHistoryStore,
getNodesReadOnly,
getWorkflowReadOnly,
])
@ -1823,8 +1792,7 @@ export const useNodesInteractions = () => {
/** Add opacity-30 to all nodes except the nodeId */
const dimOtherNodes = useCallback(() => {
if (isDimming) return
const { getNodes, setNodes, edges, setEdges } = store.getState()
const nodes = getNodes()
const { nodes, setNodes, edges, setEdges } = collaborativeWorkflow.getState()
const selectedNode = nodes.find(n => n.data.selected)
if (!selectedNode) return
@ -1917,12 +1885,11 @@ export const useNodesInteractions = () => {
draft.push(...tempEdges)
})
setEdges(newEdges)
}, [isDimming, store])
}, [isDimming, collaborativeWorkflow])
/** Restore all nodes to full opacity */
const undimAllNodes = useCallback(() => {
const { getNodes, setNodes, edges, setEdges } = store.getState()
const nodes = getNodes()
const { nodes, setNodes, edges, setEdges } = collaborativeWorkflow.getState()
setIsDimming(false)
const newNodes = produce(nodes, (draft) => {
@ -1942,7 +1909,7 @@ export const useNodesInteractions = () => {
},
)
setEdges(newEdges)
}, [store])
}, [collaborativeWorkflow])
return {
handleNodeDragStart,

View File

@ -21,12 +21,13 @@ export const useNodesSyncDraft = () => {
onError?: () => void
onSettled?: () => void
},
forceUpload?: boolean,
) => {
if (getNodesReadOnly())
return
if (sync)
doSyncWorkflowDraft(notRefreshWhenSyncError, callback)
doSyncWorkflowDraft(notRefreshWhenSyncError, callback, forceUpload)
else
debouncedSyncWorkflowDraft(doSyncWorkflowDraft)
}, [debouncedSyncWorkflowDraft, doSyncWorkflowDraft, getNodesReadOnly])

View File

@ -36,6 +36,7 @@ export const useShortcuts = (): void => {
const {
handleModeHand,
handleModePointer,
handleModeComment,
} = useWorkflowMoveMode()
const { handleLayout } = useWorkflowOrganize()
const { handleToggleMaximizeCanvas } = useWorkflowCanvasMaximize()
@ -144,6 +145,16 @@ export const useShortcuts = (): void => {
useCapture: true,
})
useKeyPress('c', (e) => {
if (shouldHandleShortcut(e)) {
e.preventDefault()
handleModeComment()
}
}, {
exactMatch: true,
useCapture: true,
})
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.o`, (e) => {
if (shouldHandleShortcut(e)) {
e.preventDefault()

View File

@ -0,0 +1,412 @@
import { useCallback, useEffect, useRef } from 'react'
import { useParams } from 'next/navigation'
import { useReactFlow } from 'reactflow'
import { useStore } from '../store'
import { ControlMode } from '../types'
import type { WorkflowCommentDetail, WorkflowCommentList } from '@/service/workflow-comment'
import { createWorkflowComment, createWorkflowCommentReply, deleteWorkflowComment, deleteWorkflowCommentReply, fetchWorkflowComment, fetchWorkflowComments, resolveWorkflowComment, updateWorkflowComment, updateWorkflowCommentReply } from '@/service/workflow-comment'
import { collaborationManager } from '@/app/components/workflow/collaboration'
export const useWorkflowComment = () => {
const params = useParams()
const appId = params.appId as string
const reactflow = useReactFlow()
const controlMode = useStore(s => s.controlMode)
const setControlMode = useStore(s => s.setControlMode)
const pendingComment = useStore(s => s.pendingComment)
const setPendingComment = useStore(s => s.setPendingComment)
const setActiveCommentId = useStore(s => s.setActiveCommentId)
const activeCommentId = useStore(s => s.activeCommentId)
const comments = useStore(s => s.comments)
const setComments = useStore(s => s.setComments)
const loading = useStore(s => s.commentsLoading)
const setCommentsLoading = useStore(s => s.setCommentsLoading)
const activeComment = useStore(s => s.activeCommentDetail)
const setActiveComment = useStore(s => s.setActiveCommentDetail)
const activeCommentLoading = useStore(s => s.activeCommentDetailLoading)
const setActiveCommentLoading = useStore(s => s.setActiveCommentDetailLoading)
const commentDetailCache = useStore(s => s.commentDetailCache)
const setCommentDetailCache = useStore(s => s.setCommentDetailCache)
const commentDetailCacheRef = useRef<Record<string, WorkflowCommentDetail>>(commentDetailCache)
const activeCommentIdRef = useRef<string | null>(null)
useEffect(() => {
activeCommentIdRef.current = activeCommentId ?? null
}, [activeCommentId])
useEffect(() => {
commentDetailCacheRef.current = commentDetailCache
}, [commentDetailCache])
const refreshActiveComment = useCallback(async (commentId: string) => {
if (!appId) return
const detailResponse = await fetchWorkflowComment(appId, commentId)
const detail = (detailResponse as any)?.data ?? detailResponse
commentDetailCacheRef.current = {
...commentDetailCacheRef.current,
[commentId]: detail,
}
setCommentDetailCache(commentDetailCacheRef.current)
setActiveComment(detail)
}, [appId, setActiveComment, setCommentDetailCache])
const loadComments = useCallback(async () => {
if (!appId) return
setCommentsLoading(true)
try {
const commentsData = await fetchWorkflowComments(appId)
setComments(commentsData)
}
catch (error) {
console.error('Failed to fetch comments:', error)
}
finally {
setCommentsLoading(false)
}
}, [appId, setComments, setCommentsLoading])
// Setup collaboration
useEffect(() => {
if (!appId) return
const unsubscribe = collaborationManager.onCommentsUpdate(() => {
loadComments()
if (activeCommentIdRef.current)
refreshActiveComment(activeCommentIdRef.current)
})
return unsubscribe
}, [appId, loadComments, refreshActiveComment])
useEffect(() => {
loadComments()
}, [loadComments])
const handleCommentSubmit = useCallback(async (content: string, mentionedUserIds: string[] = []) => {
if (!pendingComment) return
console.log('Submitting comment:', { appId, pendingComment, content, mentionedUserIds })
if (!appId) {
console.error('AppId is missing')
return
}
try {
// Convert screen position to flow position when submitting
const { screenToFlowPosition } = reactflow
const flowPosition = screenToFlowPosition({ x: pendingComment.x, y: pendingComment.y })
const newComment = await createWorkflowComment(appId, {
position_x: flowPosition.x,
position_y: flowPosition.y,
content,
mentioned_user_ids: mentionedUserIds,
})
console.log('Comment created successfully:', newComment)
collaborationManager.emitCommentsUpdate(appId)
await loadComments()
setPendingComment(null)
}
catch (error) {
console.error('Failed to create comment:', error)
setPendingComment(null)
}
}, [appId, pendingComment, setPendingComment, loadComments, reactflow])
const handleCommentCancel = useCallback(() => {
setPendingComment(null)
}, [setPendingComment])
useEffect(() => {
if (controlMode !== ControlMode.Comment)
setPendingComment(null)
}, [controlMode, setPendingComment])
const handleCommentIconClick = useCallback(async (comment: WorkflowCommentList) => {
setPendingComment(null)
activeCommentIdRef.current = comment.id
setActiveCommentId(comment.id)
const cachedDetail = commentDetailCacheRef.current[comment.id]
setActiveComment(cachedDetail || comment)
let horizontalOffsetPx = 220
const maxOffset = Math.max(0, (window.innerWidth / 2) - 60)
horizontalOffsetPx = Math.min(horizontalOffsetPx, maxOffset)
reactflow.setCenter(
comment.position_x + horizontalOffsetPx,
comment.position_y,
{ zoom: 1, duration: 600 },
)
if (!appId) return
setActiveCommentLoading(!cachedDetail)
try {
const detailResponse = await fetchWorkflowComment(appId, comment.id)
const detail = (detailResponse as any)?.data ?? detailResponse
commentDetailCacheRef.current = {
...commentDetailCacheRef.current,
[comment.id]: detail,
}
setCommentDetailCache(commentDetailCacheRef.current)
if (activeCommentIdRef.current === comment.id)
setActiveComment(detail)
}
catch (e) {
console.warn('Failed to load workflow comment detail', e)
}
finally {
setActiveCommentLoading(false)
}
}, [appId, reactflow, setActiveComment, setActiveCommentId, setActiveCommentLoading, setCommentDetailCache, setControlMode, setPendingComment])
const handleCommentResolve = useCallback(async (commentId: string) => {
if (!appId) return
setActiveCommentLoading(true)
try {
await resolveWorkflowComment(appId, commentId)
collaborationManager.emitCommentsUpdate(appId)
await refreshActiveComment(commentId)
await loadComments()
}
catch (error) {
console.error('Failed to resolve comment:', error)
}
finally {
setActiveCommentLoading(false)
}
}, [appId, loadComments, refreshActiveComment, setActiveCommentLoading])
const handleCommentDelete = useCallback(async (commentId: string) => {
if (!appId) return
setActiveCommentLoading(true)
try {
await deleteWorkflowComment(appId, commentId)
collaborationManager.emitCommentsUpdate(appId)
const updatedCache = { ...commentDetailCacheRef.current }
delete updatedCache[commentId]
commentDetailCacheRef.current = updatedCache
setCommentDetailCache(updatedCache)
const currentComments = comments.filter(c => c.id !== commentId)
const commentIndex = comments.findIndex(c => c.id === commentId)
const fallbackTarget = commentIndex >= 0 ? comments[commentIndex + 1] ?? comments[commentIndex - 1] : undefined
await loadComments()
if (fallbackTarget) {
handleCommentIconClick(fallbackTarget)
}
else if (currentComments.length > 0) {
const nextComment = currentComments[0]
handleCommentIconClick(nextComment)
}
else {
setActiveComment(null)
setActiveCommentId(null)
activeCommentIdRef.current = null
}
}
catch (error) {
console.error('Failed to delete comment:', error)
}
finally {
setActiveCommentLoading(false)
}
}, [appId, comments, handleCommentIconClick, loadComments, setActiveComment, setActiveCommentId, setActiveCommentLoading, setCommentDetailCache])
const handleCommentPositionUpdate = useCallback(async (commentId: string, position: { x: number; y: number }) => {
if (!appId) return
const targetComment = comments.find(c => c.id === commentId)
if (!targetComment) return
const nextPosition = {
position_x: position.x,
position_y: position.y,
}
const previousComments = comments
const updatedComments = comments.map(c =>
c.id === commentId
? { ...c, ...nextPosition }
: c,
)
setComments(updatedComments)
const cachedDetail = commentDetailCacheRef.current[commentId]
const updatedDetail = cachedDetail ? { ...cachedDetail, ...nextPosition } : null
if (updatedDetail) {
commentDetailCacheRef.current = {
...commentDetailCacheRef.current,
[commentId]: updatedDetail,
}
setCommentDetailCache(commentDetailCacheRef.current)
if (activeCommentIdRef.current === commentId)
setActiveComment(updatedDetail)
}
else if (activeComment?.id === commentId) {
setActiveComment({ ...activeComment, ...nextPosition })
}
try {
await updateWorkflowComment(appId, commentId, {
content: targetComment.content,
position_x: nextPosition.position_x,
position_y: nextPosition.position_y,
})
collaborationManager.emitCommentsUpdate(appId)
}
catch (error) {
console.error('Failed to update comment position:', error)
setComments(previousComments)
if (cachedDetail) {
commentDetailCacheRef.current = {
...commentDetailCacheRef.current,
[commentId]: cachedDetail,
}
setCommentDetailCache(commentDetailCacheRef.current)
if (activeCommentIdRef.current === commentId)
setActiveComment(cachedDetail)
}
else if (activeComment?.id === commentId) {
setActiveComment(activeComment)
}
}
}, [activeComment, appId, comments, setComments, setCommentDetailCache, setActiveComment])
const handleCommentReply = useCallback(async (commentId: string, content: string, mentionedUserIds: string[] = []) => {
if (!appId) return
const trimmed = content.trim()
if (!trimmed) return
setActiveCommentLoading(true)
try {
await createWorkflowCommentReply(appId, commentId, { content: trimmed, mentioned_user_ids: mentionedUserIds })
collaborationManager.emitCommentsUpdate(appId)
await refreshActiveComment(commentId)
await loadComments()
}
catch (error) {
console.error('Failed to create reply:', error)
}
finally {
setActiveCommentLoading(false)
}
}, [appId, loadComments, refreshActiveComment, setActiveCommentLoading])
const handleCommentReplyUpdate = useCallback(async (commentId: string, replyId: string, content: string, mentionedUserIds: string[] = []) => {
if (!appId) return
const trimmed = content.trim()
if (!trimmed) return
setActiveCommentLoading(true)
try {
await updateWorkflowCommentReply(appId, commentId, replyId, { content: trimmed, mentioned_user_ids: mentionedUserIds })
collaborationManager.emitCommentsUpdate(appId)
await refreshActiveComment(commentId)
await loadComments()
}
catch (error) {
console.error('Failed to update reply:', error)
}
finally {
setActiveCommentLoading(false)
}
}, [appId, loadComments, refreshActiveComment, setActiveCommentLoading])
const handleCommentReplyDelete = useCallback(async (commentId: string, replyId: string) => {
if (!appId) return
setActiveCommentLoading(true)
try {
await deleteWorkflowCommentReply(appId, commentId, replyId)
collaborationManager.emitCommentsUpdate(appId)
await refreshActiveComment(commentId)
await loadComments()
}
catch (error) {
console.error('Failed to delete reply:', error)
}
finally {
setActiveCommentLoading(false)
}
}, [appId, loadComments, refreshActiveComment, setActiveCommentLoading])
const handleCommentNavigate = useCallback((direction: 'prev' | 'next') => {
const currentId = activeCommentIdRef.current
if (!currentId) return
const idx = comments.findIndex(c => c.id === currentId)
if (idx === -1) return
const target = direction === 'prev' ? comments[idx - 1] : comments[idx + 1]
if (target)
handleCommentIconClick(target)
}, [comments, handleCommentIconClick])
const handleActiveCommentClose = useCallback(() => {
setActiveComment(null)
setActiveCommentLoading(false)
setActiveCommentId(null)
activeCommentIdRef.current = null
}, [setActiveComment, setActiveCommentId, setActiveCommentLoading])
const handleCreateComment = useCallback((mousePosition: { elementX: number; elementY: number }) => {
if (controlMode === ControlMode.Comment) {
console.log('Setting pending comment at screen position:', mousePosition)
setPendingComment({ x: mousePosition.elementX, y: mousePosition.elementY })
}
else {
console.log('Control mode is not Comment:', controlMode)
}
}, [controlMode, setPendingComment])
return {
comments,
loading,
pendingComment,
activeComment,
activeCommentLoading,
handleCommentSubmit,
handleCommentCancel,
handleCommentIconClick,
handleActiveCommentClose,
handleCommentResolve,
handleCommentDelete,
handleCommentNavigate,
handleCommentReply,
handleCommentReplyUpdate,
handleCommentReplyDelete,
handleCommentPositionUpdate,
refreshActiveComment,
handleCreateComment,
loadComments,
}
}

View File

@ -1,7 +1,7 @@
import {
useCallback,
} from 'react'
import { useReactFlow, useStoreApi } from 'reactflow'
import { useReactFlow } from 'reactflow'
import produce from 'immer'
import { useStore, useWorkflowStore } from '../store'
import {
@ -29,6 +29,7 @@ import { useNodesInteractionsWithoutSync } from './use-nodes-interactions-withou
import { useNodesSyncDraft } from './use-nodes-sync-draft'
import { WorkflowHistoryEvent, useWorkflowHistory } from './use-workflow-history'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { useCollaborativeWorkflow } from '@/app/components/workflow/hooks/use-collaborative-workflow'
export const useWorkflowInteractions = () => {
const workflowStore = useWorkflowStore()
@ -71,31 +72,39 @@ export const useWorkflowMoveMode = () => {
handleSelectionCancel()
}, [getNodesReadOnly, setControlMode, handleSelectionCancel])
const handleModeComment = useCallback(() => {
if (getNodesReadOnly())
return
setControlMode(ControlMode.Comment)
handleSelectionCancel()
}, [getNodesReadOnly, setControlMode, handleSelectionCancel])
return {
handleModePointer,
handleModeHand,
handleModeComment,
}
}
export const useWorkflowOrganize = () => {
const workflowStore = useWorkflowStore()
const store = useStoreApi()
const reactflow = useReactFlow()
const { getNodesReadOnly } = useNodesReadOnly()
const { saveStateToHistory } = useWorkflowHistory()
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
const collaborativeWorkflow = useCollaborativeWorkflow()
const handleLayout = useCallback(async () => {
if (getNodesReadOnly())
return
workflowStore.setState({ nodeAnimation: true })
const {
getNodes,
nodes,
edges,
setNodes,
} = store.getState()
} = collaborativeWorkflow.getState()
const { setViewport } = reactflow
const nodes = getNodes()
const loopAndIterationNodes = nodes.filter(
node => (node.data.type === BlockEnum.Loop || node.data.type === BlockEnum.Iteration)
@ -232,7 +241,7 @@ export const useWorkflowOrganize = () => {
setTimeout(() => {
handleSyncWorkflowDraft()
})
}, [getNodesReadOnly, store, reactflow, workflowStore, handleSyncWorkflowDraft, saveStateToHistory])
}, [getNodesReadOnly, collaborativeWorkflow, reactflow, workflowStore, handleSyncWorkflowDraft, saveStateToHistory])
return {
handleLayout,

View File

@ -32,6 +32,7 @@ import { CUSTOM_NOTE_NODE } from '../note-node/constants'
import { findUsedVarNodes, getNodeOutputVars, updateNodeVars } from '../nodes/_base/components/variable/utils'
import { useAvailableBlocks } from './use-available-blocks'
import { useStore as useAppStore } from '@/app/components/app/store'
import { useCollaborativeWorkflow } from '@/app/components/workflow/hooks/use-collaborative-workflow'
import {
fetchAllBuiltInTools,
fetchAllCustomTools,
@ -50,25 +51,19 @@ export const useIsChatMode = () => {
}
export const useWorkflow = () => {
const store = useStoreApi()
const collaborativeWorkflow = useCollaborativeWorkflow()
const { getAvailableBlocks } = useAvailableBlocks()
const { nodesMap } = useNodesMetaData()
const store = useStoreApi()
const getNodeById = useCallback((nodeId: string) => {
const {
getNodes,
} = store.getState()
const nodes = getNodes()
const { nodes } = collaborativeWorkflow.getState()
const currentNode = nodes.find(node => node.id === nodeId)
return currentNode
}, [store])
}, [collaborativeWorkflow])
const getTreeLeafNodes = useCallback((nodeId: string) => {
const {
getNodes,
edges,
} = store.getState()
const nodes = getNodes()
const { nodes, edges } = collaborativeWorkflow.getState()
const currentNode = nodes.find(node => node.id === nodeId)
let startNodes = nodes.filter(node => nodesMap?.[node.data.type as BlockEnum]?.metaData.isStart) || []
@ -111,14 +106,11 @@ export const useWorkflow = () => {
return uniqBy(list, 'id').filter((item: Node) => {
return SUPPORT_OUTPUT_VARS_NODE.includes(item.data.type)
})
}, [store, nodesMap])
}, [collaborativeWorkflow, nodesMap])
const getBeforeNodesInSameBranch = useCallback((nodeId: string, newNodes?: Node[], newEdges?: Edge[]) => {
const {
getNodes,
edges,
} = store.getState()
const nodes = newNodes || getNodes()
const { nodes: oldNodes, edges } = collaborativeWorkflow.getState()
const nodes = newNodes || oldNodes
const currentNode = nodes.find(node => node.id === nodeId)
const list: Node[] = []
@ -161,14 +153,11 @@ export const useWorkflow = () => {
}
return []
}, [store])
}, [collaborativeWorkflow])
const getBeforeNodesInSameBranchIncludeParent = useCallback((nodeId: string, newNodes?: Node[], newEdges?: Edge[]) => {
const nodes = getBeforeNodesInSameBranch(nodeId, newNodes, newEdges)
const {
getNodes,
} = store.getState()
const allNodes = getNodes()
const { nodes: allNodes } = collaborativeWorkflow.getState()
const node = allNodes.find(n => n.id === nodeId)
const parentNodeId = node?.parentId
const parentNode = allNodes.find(n => n.id === parentNodeId)
@ -176,14 +165,10 @@ export const useWorkflow = () => {
nodes.push(parentNode)
return nodes
}, [getBeforeNodesInSameBranch, store])
}, [getBeforeNodesInSameBranch, collaborativeWorkflow])
const getAfterNodesInSameBranch = useCallback((nodeId: string) => {
const {
getNodes,
edges,
} = store.getState()
const nodes = getNodes()
const { nodes, edges } = collaborativeWorkflow.getState()
const currentNode = nodes.find(node => node.id === nodeId)!
if (!currentNode)
@ -207,40 +192,29 @@ export const useWorkflow = () => {
})
return uniqBy(list, 'id')
}, [store])
}, [collaborativeWorkflow])
const getBeforeNodeById = useCallback((nodeId: string) => {
const {
getNodes,
edges,
} = store.getState()
const nodes = getNodes()
const { nodes, edges } = collaborativeWorkflow.getState()
const node = nodes.find(node => node.id === nodeId)!
return getIncomers(node, nodes, edges)
}, [store])
}, [collaborativeWorkflow])
const getIterationNodeChildren = useCallback((nodeId: string) => {
const {
getNodes,
} = store.getState()
const nodes = getNodes()
const { nodes } = collaborativeWorkflow.getState()
return nodes.filter(node => node.parentId === nodeId)
}, [store])
}, [collaborativeWorkflow])
const getLoopNodeChildren = useCallback((nodeId: string) => {
const {
getNodes,
} = store.getState()
const nodes = getNodes()
const { nodes } = collaborativeWorkflow.getState()
return nodes.filter(node => node.parentId === nodeId)
}, [store])
}, [collaborativeWorkflow])
const handleOutVarRenameChange = useCallback((nodeId: string, oldValeSelector: ValueSelector, newVarSelector: ValueSelector) => {
const { getNodes, setNodes } = store.getState()
const allNodes = getNodes()
const { nodes: allNodes, setNodes } = collaborativeWorkflow.getState()
const affectedNodes = findUsedVarNodes(oldValeSelector, allNodes)
if (affectedNodes.length > 0) {
const newNodes = allNodes.map((node) => {
@ -251,7 +225,7 @@ export const useWorkflow = () => {
})
setNodes(newNodes)
}
}, [store])
}, [collaborativeWorkflow])
const isVarUsedInNodes = useCallback((varSelector: ValueSelector) => {
const nodeId = varSelector[0]
@ -262,11 +236,11 @@ export const useWorkflow = () => {
const removeUsedVarInNodes = useCallback((varSelector: ValueSelector) => {
const nodeId = varSelector[0]
const { getNodes, setNodes } = store.getState()
const { nodes, setNodes } = collaborativeWorkflow.getState()
const afterNodes = getAfterNodesInSameBranch(nodeId)
const effectNodes = findUsedVarNodes(varSelector, afterNodes)
if (effectNodes.length > 0) {
const newNodes = getNodes().map((node) => {
const newNodes = nodes.map((node) => {
if (effectNodes.find(n => n.id === node.id))
return updateNodeVars(node, varSelector, [])
@ -274,7 +248,7 @@ export const useWorkflow = () => {
})
setNodes(newNodes)
}
}, [getAfterNodesInSameBranch, store])
}, [getAfterNodesInSameBranch, collaborativeWorkflow])
const isNodeVarsUsedInNodes = useCallback((node: Node, isChatMode: boolean) => {
const outputVars = getNodeOutputVars(node, isChatMode)
@ -285,11 +259,7 @@ export const useWorkflow = () => {
}, [isVarUsedInNodes])
const getRootNodesById = useCallback((nodeId: string) => {
const {
getNodes,
edges,
} = store.getState()
const nodes = getNodes()
const { nodes, edges } = collaborativeWorkflow.getState()
const currentNode = nodes.find(node => node.id === nodeId)
const rootNodes: Node[] = []
@ -329,7 +299,7 @@ export const useWorkflow = () => {
return uniqBy(rootNodes, 'id')
return []
}, [store])
}, [collaborativeWorkflow])
const getStartNodes = useCallback((nodes: Node[], currentNode?: Node) => {
const { id, parentId } = currentNode || {}
@ -395,7 +365,7 @@ export const useWorkflow = () => {
}
return !hasCycle(targetNode)
}, [store, getAvailableBlocks])
}, [collaborativeWorkflow, getAvailableBlocks])
return {
getNodeById,
@ -498,13 +468,10 @@ export const useNodesReadOnly = () => {
}
export const useIsNodeInIteration = (iterationId: string) => {
const store = useStoreApi()
const collaborativeWorkflow = useCollaborativeWorkflow()
const isNodeInIteration = useCallback((nodeId: string) => {
const {
getNodes,
} = store.getState()
const nodes = getNodes()
const { nodes } = collaborativeWorkflow.getState()
const node = nodes.find(node => node.id === nodeId)
if (!node)
@ -514,20 +481,17 @@ export const useIsNodeInIteration = (iterationId: string) => {
return true
return false
}, [iterationId, store])
}, [iterationId, collaborativeWorkflow])
return {
isNodeInIteration,
}
}
export const useIsNodeInLoop = (loopId: string) => {
const store = useStoreApi()
const collaborativeWorkflow = useCollaborativeWorkflow()
const isNodeInLoop = useCallback((nodeId: string) => {
const {
getNodes,
} = store.getState()
const nodes = getNodes()
const { nodes } = collaborativeWorkflow.getState()
const node = nodes.find(node => node.id === nodeId)
if (!node)
@ -537,7 +501,7 @@ export const useIsNodeInLoop = (loopId: string) => {
return true
return false
}, [loopId, store])
}, [loopId, collaborativeWorkflow])
return {
isNodeInLoop,
}

View File

@ -2,6 +2,7 @@
import type { FC } from 'react'
import {
Fragment,
memo,
useCallback,
useEffect,
@ -9,6 +10,7 @@ import {
useRef,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import { setAutoFreeze } from 'immer'
import {
useEventListener,
@ -67,11 +69,15 @@ import CustomEdge from './custom-edge'
import CustomConnectionLine from './custom-connection-line'
import HelpLine from './help-line'
import CandidateNode from './candidate-node'
import CommentManager from './comment-manager'
import PanelContextmenu from './panel-contextmenu'
import NodeContextmenu from './node-contextmenu'
import SelectionContextmenu from './selection-contextmenu'
import SyncingDataModal from './syncing-data-modal'
import { setupScrollToNodeListener } from './utils/node-navigation'
import { CommentCursor, CommentIcon, CommentInput, CommentThread } from './comment'
import { useWorkflowComment } from './hooks/use-workflow-comment'
import UserCursors from './collaboration/components/user-cursors'
import {
useStore,
useWorkflowStore,
@ -115,6 +121,9 @@ export type WorkflowProps = {
viewport?: Viewport
children?: React.ReactNode
onWorkflowDataUpdate?: (v: any) => void
cursors?: Record<string, any>
myUserId?: string | null
onlineUsers?: any[]
}
export const Workflow: FC<WorkflowProps> = memo(({
nodes: originalNodes,
@ -122,10 +131,14 @@ export const Workflow: FC<WorkflowProps> = memo(({
viewport,
children,
onWorkflowDataUpdate,
cursors,
myUserId,
onlineUsers,
}) => {
const workflowContainerRef = useRef<HTMLDivElement>(null)
const workflowStore = useWorkflowStore()
const reactflow = useReactFlow()
const [isMouseOverCanvas, setIsMouseOverCanvas] = useState(false)
const [nodes, setNodes] = useNodesState(originalNodes)
const [edges, setEdges] = useEdgesState(originalEdges)
const controlMode = useStore(s => s.controlMode)
@ -170,6 +183,26 @@ export const Workflow: FC<WorkflowProps> = memo(({
const { workflowReadOnly } = useWorkflowReadOnly()
const { nodesReadOnly } = useNodesReadOnly()
const { eventEmitter } = useEventEmitterContextContext()
const {
comments,
pendingComment,
activeComment,
activeCommentLoading,
handleCommentSubmit,
handleCommentCancel,
handleCommentIconClick,
handleActiveCommentClose,
handleCommentResolve,
handleCommentDelete,
handleCommentNavigate,
handleCommentReply,
handleCommentReplyUpdate,
handleCommentReplyDelete,
handleCommentPositionUpdate,
} = useWorkflowComment()
const showUserComments = useStore(s => s.showUserComments)
const showUserCursors = useStore(s => s.showUserCursors)
const { t } = useTranslation()
eventEmitter?.useSubscription((v: any) => {
if (v.type === WORKFLOW_DATA_UPDATE) {
@ -210,6 +243,33 @@ export const Workflow: FC<WorkflowProps> = memo(({
setTimeout(() => handleRefreshWorkflowDraft(), 500)
}, [syncWorkflowDraftWhenPageClose, handleRefreshWorkflowDraft])
// Optimized comment deletion using showConfirm
const handleCommentDeleteClick = useCallback((commentId: string) => {
if (!showConfirm) {
setShowConfirm({
title: t('workflow.comments.confirm.deleteThreadTitle'),
desc: t('workflow.comments.confirm.deleteThreadDesc'),
onConfirm: async () => {
await handleCommentDelete(commentId)
setShowConfirm(undefined)
},
})
}
}, [showConfirm, setShowConfirm, handleCommentDelete, t])
const handleCommentReplyDeleteClick = useCallback((commentId: string, replyId: string) => {
if (!showConfirm) {
setShowConfirm({
title: t('workflow.comments.confirm.deleteReplyTitle'),
desc: t('workflow.comments.confirm.deleteReplyDesc'),
onConfirm: async () => {
await handleCommentReplyDelete(commentId, replyId)
setShowConfirm(undefined)
},
})
}
}, [showConfirm, setShowConfirm, handleCommentReplyDelete, t])
useEffect(() => {
document.addEventListener('visibilitychange', handleSyncWorkflowDraftWhenPageClose)
@ -240,6 +300,9 @@ export const Workflow: FC<WorkflowProps> = memo(({
elementY: e.clientY - containerClientRect.top,
},
})
const target = e.target as HTMLElement
const onPane = !!target?.closest('.react-flow__pane')
setIsMouseOverCanvas(onPane)
}
})
const { handleFetchAllTools } = useFetchToolsData()
@ -355,6 +418,7 @@ export const Workflow: FC<WorkflowProps> = memo(({
>
<SyncingDataModal />
<CandidateNode />
<CommentManager />
<div
className='pointer-events-none absolute left-0 top-0 z-10 flex w-12 items-center justify-center p-1 pl-2'
style={{ height: controlHeight }}
@ -366,23 +430,75 @@ export const Workflow: FC<WorkflowProps> = memo(({
<NodeContextmenu />
<SelectionContextmenu />
<HelpLine />
{
!!showConfirm && (
<Confirm
isShow
onCancel={() => setShowConfirm(undefined)}
onConfirm={showConfirm.onConfirm}
title={showConfirm.title}
content={showConfirm.desc}
{!!showConfirm && (
<Confirm
isShow
onCancel={() => setShowConfirm(undefined)}
onConfirm={showConfirm.onConfirm}
title={showConfirm.title}
content={showConfirm.desc}
/>
)}
{controlMode === ControlMode.Comment && isMouseOverCanvas && (
<CommentCursor />
)}
{pendingComment && (
<CommentInput
position={pendingComment}
onSubmit={handleCommentSubmit}
onCancel={handleCommentCancel}
/>
)}
{comments.map((comment, index) => {
const isActive = activeComment?.id === comment.id
if (isActive && activeComment) {
const canGoPrev = index > 0
const canGoNext = index < comments.length - 1
return (
<Fragment key={comment.id}>
<CommentIcon
key={`${comment.id}-icon`}
comment={comment}
onClick={() => handleCommentIconClick(comment)}
isActive={true}
onPositionUpdate={position => handleCommentPositionUpdate(comment.id, position)}
/>
<CommentThread
key={`${comment.id}-thread`}
comment={activeComment}
loading={activeCommentLoading}
onClose={handleActiveCommentClose}
onResolve={() => handleCommentResolve(comment.id)}
onDelete={() => handleCommentDeleteClick(comment.id)}
onPrev={canGoPrev ? () => handleCommentNavigate('prev') : undefined}
onNext={canGoNext ? () => handleCommentNavigate('next') : undefined}
onReply={(content, ids) => handleCommentReply(comment.id, content, ids ?? [])}
onReplyEdit={(replyId, content, ids) => handleCommentReplyUpdate(comment.id, replyId, content, ids ?? [])}
onReplyDelete={replyId => handleCommentReplyDeleteClick(comment.id, replyId)}
canGoPrev={canGoPrev}
canGoNext={canGoNext}
/>
</Fragment>
)
}
return (showUserComments || controlMode === ControlMode.Comment) ? (
<CommentIcon
key={comment.id}
comment={comment}
onClick={() => handleCommentIconClick(comment)}
onPositionUpdate={position => handleCommentPositionUpdate(comment.id, position)}
/>
)
}
) : null
})}
{children}
<ReactFlow
nodeTypes={nodeTypes}
edgeTypes={edgeTypes}
nodes={nodes}
edges={edges}
className={controlMode === ControlMode.Comment ? 'comment-mode-flow' : ''}
onNodeDragStart={handleNodeDragStart}
onNodeDrag={handleNodeDrag}
onNodeDragStop={handleNodeDragStop}
@ -428,6 +544,13 @@ export const Workflow: FC<WorkflowProps> = memo(({
className="bg-workflow-canvas-workflow-bg"
color='var(--color-workflow-canvas-workflow-dot-color)'
/>
{showUserCursors && cursors && (
<UserCursors
cursors={cursors}
myUserId={myUserId || null}
onlineUsers={onlineUsers || []}
/>
)}
</ReactFlow>
</div>
)
@ -435,14 +558,25 @@ export const Workflow: FC<WorkflowProps> = memo(({
type WorkflowWithInnerContextProps = WorkflowProps & {
hooksStore?: Partial<HooksStoreShape>
cursors?: Record<string, any>
myUserId?: string | null
onlineUsers?: any[]
}
export const WorkflowWithInnerContext = memo(({
hooksStore,
cursors,
myUserId,
onlineUsers,
...restProps
}: WorkflowWithInnerContextProps) => {
return (
<HooksStoreContextProvider {...hooksStore}>
<Workflow {...restProps} />
<Workflow
{...restProps}
cursors={cursors}
myUserId={myUserId}
onlineUsers={onlineUsers}
/>
</HooksStoreContextProvider>
)
})

View File

@ -75,6 +75,10 @@ import { DataSourceClassification } from '@/app/components/workflow/nodes/data-s
import { useModalContext } from '@/context/modal-context'
import DataSourceBeforeRunForm from '@/app/components/workflow/nodes/data-source/before-run-form'
import useInspectVarsCrud from '@/app/components/workflow/hooks/use-inspect-vars-crud'
import { useCollaboration } from '@/app/components/workflow/collaboration/hooks/use-collaboration'
import { collaborationManager } from '@/app/components/workflow/collaboration/core/collaboration-manager'
import { useAppContext } from '@/context/app-context'
import { UserAvatarList } from '@/app/components/base/user-avatar-list'
const getCustomRunForm = (params: CustomRunFormProps): React.JSX.Element => {
const nodeType = params.payload.type
@ -97,11 +101,51 @@ const BasePanel: FC<BasePanelProps> = ({
children,
}) => {
const { t } = useTranslation()
const appId = useStore(s => s.appId)
const { userProfile } = useAppContext()
const { isConnected, nodePanelPresence } = useCollaboration(appId as string)
const { showMessageLogModal } = useAppStore(useShallow(state => ({
showMessageLogModal: state.showMessageLogModal,
})))
const isSingleRunning = data._singleRunningStatus === NodeRunningStatus.Running
const currentUserPresence = useMemo(() => {
const userId = userProfile?.id || ''
const username = userProfile?.name || userProfile?.email || 'User'
const avatar = userProfile?.avatar_url || userProfile?.avatar || null
return {
userId,
username,
avatar,
}
}, [userProfile?.avatar, userProfile?.avatar_url, userProfile?.email, userProfile?.id, userProfile?.name])
useEffect(() => {
if (!isConnected || !currentUserPresence.userId)
return
collaborationManager.emitNodePanelPresence(id, true, currentUserPresence)
return () => {
collaborationManager.emitNodePanelPresence(id, false, currentUserPresence)
}
}, [id, isConnected, currentUserPresence])
const viewingUsers = useMemo(() => {
const presence = nodePanelPresence?.[id]
if (!presence)
return []
return Object.values(presence)
.filter(viewer => viewer.userId && viewer.userId !== currentUserPresence.userId)
.map(viewer => ({
id: viewer.userId,
name: viewer.username,
avatar_url: viewer.avatar || null,
}))
}, [currentUserPresence.userId, id, nodePanelPresence])
const showSingleRunPanel = useStore(s => s.showSingleRunPanel)
const workflowCanvasWidth = useStore(s => s.workflowCanvasWidth)
const nodePanelWidth = useStore(s => s.nodePanelWidth)
@ -393,6 +437,15 @@ const BasePanel: FC<BasePanelProps> = ({
value={data.title || ''}
onBlur={handleTitleBlur}
/>
{viewingUsers.length > 0 && (
<div className='ml-3 shrink-0'>
<UserAvatarList
users={viewingUsers}
maxVisible={3}
size={24}
/>
</div>
)}
<div className='flex shrink-0 items-center text-text-tertiary'>
{
isSupportSingleRun && !nodesReadOnly && (

View File

@ -47,6 +47,10 @@ import BlockIcon from '@/app/components/workflow/block-icon'
import Tooltip from '@/app/components/base/tooltip'
import useInspectVarsCrud from '../../hooks/use-inspect-vars-crud'
import { ToolTypeEnum } from '../../block-selector/types'
import { UserAvatarList } from '@/app/components/base/user-avatar-list'
import { useAppContext } from '@/context/app-context'
import { useStore } from '@/app/components/workflow/store'
import { useCollaboration } from '@/app/components/workflow/collaboration/hooks/use-collaboration'
type BaseNodeProps = {
children: ReactElement
@ -65,6 +69,35 @@ const BaseNode: FC<BaseNodeProps> = ({
const { handleNodeIterationChildSizeChange } = useNodeIterationInteractions()
const { handleNodeLoopChildSizeChange } = useNodeLoopInteractions()
const toolIcon = useToolIcon(data)
const { userProfile } = useAppContext()
const appId = useStore(s => s.appId)
const { nodePanelPresence } = useCollaboration(appId as string)
const currentUserPresence = useMemo(() => {
const userId = userProfile?.id || ''
const username = userProfile?.name || userProfile?.email || 'User'
const avatar = userProfile?.avatar_url || userProfile?.avatar || null
return {
userId,
username,
avatar,
}
}, [userProfile?.avatar, userProfile?.avatar_url, userProfile?.email, userProfile?.id, userProfile?.name])
const viewingUsers = useMemo(() => {
const presence = nodePanelPresence?.[id]
if (!presence)
return []
return Object.values(presence)
.filter(viewer => viewer.userId && viewer.userId !== currentUserPresence.userId)
.map(viewer => ({
id: viewer.userId,
name: viewer.username,
avatar_url: viewer.avatar || null,
}))
}, [currentUserPresence.userId, id, nodePanelPresence])
useEffect(() => {
if (nodeRef.current && data.selected && data.isInIteration) {
@ -237,7 +270,7 @@ const BaseNode: FC<BaseNodeProps> = ({
/>
<div
title={data.title}
className='system-sm-semibold-uppercase mr-1 flex grow items-center truncate text-text-primary'
className='system-sm-semibold-uppercase mr-1 flex grow items-center justify-between truncate text-text-primary'
>
<div>
{data.title}
@ -258,6 +291,15 @@ const BaseNode: FC<BaseNodeProps> = ({
</Tooltip>
)
}
{viewingUsers.length > 0 && (
<div className='ml-3 shrink-0'>
<UserAvatarList
users={viewingUsers}
maxVisible={3}
size={24}
/>
</div>
)}
</div>
{
data._iterationLength && data._iterationIndex && data._runningStatus === NodeRunningStatus.Running && (

Some files were not shown because too many files have changed in this diff Show More