Merge branch 'feat/iteration-node' into deploy/dev

This commit is contained in:
CodingOnStar 2025-10-24 15:31:04 +08:00
commit 3d08c79c3e
20 changed files with 524 additions and 285 deletions

View File

@ -32,6 +32,7 @@ from libs.token import (
clear_csrf_token_from_cookie, clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie, clear_refresh_token_from_cookie,
extract_access_token, extract_access_token,
extract_refresh_token,
set_access_token_to_cookie, set_access_token_to_cookie,
set_csrf_token_to_cookie, set_csrf_token_to_cookie,
set_refresh_token_to_cookie, set_refresh_token_to_cookie,
@ -273,7 +274,7 @@ class EmailCodeLoginApi(Resource):
class RefreshTokenApi(Resource): class RefreshTokenApi(Resource):
def post(self): def post(self):
# Get refresh token from cookie instead of request body # Get refresh token from cookie instead of request body
refresh_token = request.cookies.get("refresh_token") refresh_token = extract_refresh_token(request)
if not refresh_token: if not refresh_token:
return {"result": "fail", "message": "No refresh token provided"}, 401 return {"result": "fail", "message": "No refresh token provided"}, 401

View File

@ -193,15 +193,19 @@ class QuestionClassifierNode(Node):
finish_reason = event.finish_reason finish_reason = event.finish_reason
break break
category_name = node_data.classes[0].name rendered_classes = [
category_id = node_data.classes[0].id c.model_copy(update={"name": variable_pool.convert_template(c.name).text}) for c in node_data.classes
]
category_name = rendered_classes[0].name
category_id = rendered_classes[0].id
if "<think>" in result_text: if "<think>" in result_text:
result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE) result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE)
result_text_json = parse_and_check_json_markdown(result_text, []) result_text_json = parse_and_check_json_markdown(result_text, [])
# result_text_json = json.loads(result_text.strip('```JSON\n')) # result_text_json = json.loads(result_text.strip('```JSON\n'))
if "category_name" in result_text_json and "category_id" in result_text_json: if "category_name" in result_text_json and "category_id" in result_text_json:
category_id_result = result_text_json["category_id"] category_id_result = result_text_json["category_id"]
classes = node_data.classes classes = rendered_classes
classes_map = {class_.id: class_.name for class_ in classes} classes_map = {class_.id: class_.name for class_ in classes}
category_ids = [_class.id for _class in classes] category_ids = [_class.id for _class in classes]
if category_id_result in category_ids: if category_id_result in category_ids:

View File

@ -5,6 +5,7 @@ import json
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from collections.abc import Mapping as TypingMapping from collections.abc import Mapping as TypingMapping
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Protocol from typing import Any, Protocol
from pydantic.json import pydantic_encoder from pydantic.json import pydantic_encoder
@ -106,6 +107,23 @@ class GraphProtocol(Protocol):
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ... def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
@dataclass(slots=True)
class _GraphRuntimeStateSnapshot:
"""Immutable view of a serialized runtime state snapshot."""
start_at: float
total_tokens: int
node_run_steps: int
llm_usage: LLMUsage
outputs: dict[str, Any]
variable_pool: VariablePool
has_variable_pool: bool
ready_queue_dump: str | None
graph_execution_dump: str | None
response_coordinator_dump: str | None
paused_nodes: tuple[str, ...]
class GraphRuntimeState: class GraphRuntimeState:
"""Mutable runtime state shared across graph execution components.""" """Mutable runtime state shared across graph execution components."""
@ -293,69 +311,28 @@ class GraphRuntimeState:
return json.dumps(snapshot, default=pydantic_encoder) return json.dumps(snapshot, default=pydantic_encoder)
def loads(self, data: str | Mapping[str, Any]) -> None: @classmethod
def from_snapshot(cls, data: str | Mapping[str, Any]) -> GraphRuntimeState:
"""Restore runtime state from a serialized snapshot.""" """Restore runtime state from a serialized snapshot."""
payload: dict[str, Any] snapshot = cls._parse_snapshot_payload(data)
if isinstance(data, str):
payload = json.loads(data)
else:
payload = dict(data)
version = payload.get("version") state = cls(
if version != "1.0": variable_pool=snapshot.variable_pool,
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}") start_at=snapshot.start_at,
total_tokens=snapshot.total_tokens,
llm_usage=snapshot.llm_usage,
outputs=snapshot.outputs,
node_run_steps=snapshot.node_run_steps,
)
state._apply_snapshot(snapshot)
return state
self._start_at = float(payload.get("start_at", 0.0)) def loads(self, data: str | Mapping[str, Any]) -> None:
total_tokens = int(payload.get("total_tokens", 0)) """Restore runtime state from a serialized snapshot (legacy API)."""
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
node_run_steps = int(payload.get("node_run_steps", 0)) snapshot = self._parse_snapshot_payload(data)
if node_run_steps < 0: self._apply_snapshot(snapshot)
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
llm_usage_payload = payload.get("llm_usage", {})
self._llm_usage = LLMUsage.model_validate(llm_usage_payload)
self._outputs = deepcopy(payload.get("outputs", {}))
variable_pool_payload = payload.get("variable_pool")
if variable_pool_payload is not None:
self._variable_pool = VariablePool.model_validate(variable_pool_payload)
ready_queue_payload = payload.get("ready_queue")
if ready_queue_payload is not None:
self._ready_queue = self._build_ready_queue()
self._ready_queue.loads(ready_queue_payload)
else:
self._ready_queue = None
graph_execution_payload = payload.get("graph_execution")
self._graph_execution = None
self._pending_graph_execution_workflow_id = None
if graph_execution_payload is not None:
try:
execution_payload = json.loads(graph_execution_payload)
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
except (json.JSONDecodeError, TypeError, AttributeError):
self._pending_graph_execution_workflow_id = None
self.graph_execution.loads(graph_execution_payload)
response_payload = payload.get("response_coordinator")
if response_payload is not None:
if self._graph is not None:
self.response_coordinator.loads(response_payload)
else:
self._pending_response_coordinator_dump = response_payload
else:
self._pending_response_coordinator_dump = None
self._response_coordinator = None
paused_nodes_payload = payload.get("paused_nodes", [])
self._paused_nodes = set(map(str, paused_nodes_payload))
def register_paused_node(self, node_id: str) -> None: def register_paused_node(self, node_id: str) -> None:
"""Record a node that should resume when execution is continued.""" """Record a node that should resume when execution is continued."""
@ -391,3 +368,106 @@ class GraphRuntimeState:
module = importlib.import_module("core.workflow.graph_engine.response_coordinator") module = importlib.import_module("core.workflow.graph_engine.response_coordinator")
coordinator_cls = module.ResponseStreamCoordinator coordinator_cls = module.ResponseStreamCoordinator
return coordinator_cls(variable_pool=self.variable_pool, graph=graph) return coordinator_cls(variable_pool=self.variable_pool, graph=graph)
# ------------------------------------------------------------------
# Snapshot helpers
# ------------------------------------------------------------------
@classmethod
def _parse_snapshot_payload(cls, data: str | Mapping[str, Any]) -> _GraphRuntimeStateSnapshot:
payload: dict[str, Any]
if isinstance(data, str):
payload = json.loads(data)
else:
payload = dict(data)
version = payload.get("version")
if version != "1.0":
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
start_at = float(payload.get("start_at", 0.0))
total_tokens = int(payload.get("total_tokens", 0))
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
node_run_steps = int(payload.get("node_run_steps", 0))
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
llm_usage_payload = payload.get("llm_usage", {})
llm_usage = LLMUsage.model_validate(llm_usage_payload)
outputs_payload = deepcopy(payload.get("outputs", {}))
variable_pool_payload = payload.get("variable_pool")
has_variable_pool = variable_pool_payload is not None
variable_pool = VariablePool.model_validate(variable_pool_payload) if has_variable_pool else VariablePool()
ready_queue_payload = payload.get("ready_queue")
graph_execution_payload = payload.get("graph_execution")
response_payload = payload.get("response_coordinator")
paused_nodes_payload = payload.get("paused_nodes", [])
return _GraphRuntimeStateSnapshot(
start_at=start_at,
total_tokens=total_tokens,
node_run_steps=node_run_steps,
llm_usage=llm_usage,
outputs=outputs_payload,
variable_pool=variable_pool,
has_variable_pool=has_variable_pool,
ready_queue_dump=ready_queue_payload,
graph_execution_dump=graph_execution_payload,
response_coordinator_dump=response_payload,
paused_nodes=tuple(map(str, paused_nodes_payload)),
)
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
self._start_at = snapshot.start_at
self._total_tokens = snapshot.total_tokens
self._node_run_steps = snapshot.node_run_steps
self._llm_usage = snapshot.llm_usage.model_copy()
self._outputs = deepcopy(snapshot.outputs)
if snapshot.has_variable_pool or self._variable_pool is None:
self._variable_pool = snapshot.variable_pool
self._restore_ready_queue(snapshot.ready_queue_dump)
self._restore_graph_execution(snapshot.graph_execution_dump)
self._restore_response_coordinator(snapshot.response_coordinator_dump)
self._paused_nodes = set(snapshot.paused_nodes)
def _restore_ready_queue(self, payload: str | None) -> None:
if payload is not None:
self._ready_queue = self._build_ready_queue()
self._ready_queue.loads(payload)
else:
self._ready_queue = None
def _restore_graph_execution(self, payload: str | None) -> None:
self._graph_execution = None
self._pending_graph_execution_workflow_id = None
if payload is None:
return
try:
execution_payload = json.loads(payload)
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
except (json.JSONDecodeError, TypeError, AttributeError):
self._pending_graph_execution_workflow_id = None
self.graph_execution.loads(payload)
def _restore_response_coordinator(self, payload: str | None) -> None:
if payload is None:
self._pending_response_coordinator_dump = None
self._response_coordinator = None
return
if self._graph is not None:
self.response_coordinator.loads(payload)
self._pending_response_coordinator_dump = None
return
self._pending_response_coordinator_dump = payload
self._response_coordinator = None

View File

@ -6,10 +6,11 @@ from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config from configs import dify_config
from constants import HEADER_NAME_APP_CODE
from dify_app import DifyApp from dify_app import DifyApp
from extensions.ext_database import db from extensions.ext_database import db
from libs.passport import PassportService from libs.passport import PassportService
from libs.token import extract_access_token from libs.token import extract_access_token, extract_webapp_passport
from models import Account, Tenant, TenantAccountJoin from models import Account, Tenant, TenantAccountJoin
from models.model import AppMCPServer, EndUser from models.model import AppMCPServer, EndUser
from services.account_service import AccountService from services.account_service import AccountService
@ -61,14 +62,30 @@ def load_user_from_request(request_from_flask_login):
logged_in_account = AccountService.load_logged_in_account(account_id=user_id) logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account return logged_in_account
elif request.blueprint == "web": elif request.blueprint == "web":
decoded = PassportService().verify(auth_token) app_code = request.headers.get(HEADER_NAME_APP_CODE)
end_user_id = decoded.get("end_user_id") webapp_token = extract_webapp_passport(app_code, request) if app_code else None
if not end_user_id:
raise Unauthorized("Invalid Authorization token.") if webapp_token:
end_user = db.session.query(EndUser).where(EndUser.id == decoded["end_user_id"]).first() decoded = PassportService().verify(webapp_token)
if not end_user: end_user_id = decoded.get("end_user_id")
raise NotFound("End user not found.") if not end_user_id:
return end_user raise Unauthorized("Invalid Authorization token.")
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
if not end_user:
raise NotFound("End user not found.")
return end_user
else:
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
decoded = PassportService().verify(auth_token)
end_user_id = decoded.get("end_user_id")
if end_user_id:
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
if not end_user:
raise NotFound("End user not found.")
return end_user
else:
raise Unauthorized("Invalid Authorization token for web API.")
elif request.blueprint == "mcp": elif request.blueprint == "mcp":
server_code = request.view_args.get("server_code") if request.view_args else None server_code = request.view_args.get("server_code") if request.view_args else None
if not server_code: if not server_code:

View File

@ -38,9 +38,6 @@ def _real_cookie_name(cookie_name: str) -> str:
def _try_extract_from_header(request: Request) -> str | None: def _try_extract_from_header(request: Request) -> str | None:
"""
Try to extract access token from header
"""
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
if auth_header: if auth_header:
if " " not in auth_header: if " " not in auth_header:
@ -55,27 +52,19 @@ def _try_extract_from_header(request: Request) -> str | None:
return None return None
def extract_refresh_token(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_REFRESH_TOKEN))
def extract_csrf_token(request: Request) -> str | None: def extract_csrf_token(request: Request) -> str | None:
"""
Try to extract CSRF token from header or cookie.
"""
return request.headers.get(HEADER_NAME_CSRF_TOKEN) return request.headers.get(HEADER_NAME_CSRF_TOKEN)
def extract_csrf_token_from_cookie(request: Request) -> str | None: def extract_csrf_token_from_cookie(request: Request) -> str | None:
"""
Try to extract CSRF token from cookie.
"""
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN)) return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
def extract_access_token(request: Request) -> str | None: def extract_access_token(request: Request) -> str | None:
"""
Try to extract access token from cookie, header or params.
Access token is either for console session or webapp passport exchange.
"""
def _try_extract_from_cookie(request: Request) -> str | None: def _try_extract_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN)) return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
@ -83,20 +72,10 @@ def extract_access_token(request: Request) -> str | None:
def extract_webapp_access_token(request: Request) -> str | None: def extract_webapp_access_token(request: Request) -> str | None:
"""
Try to extract webapp access token from cookie, then header.
"""
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request) return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
def extract_webapp_passport(app_code: str, request: Request) -> str | None: def extract_webapp_passport(app_code: str, request: Request) -> str | None:
"""
Try to extract app token from header or params.
Webapp access token (part of passport) is only used for webapp session.
"""
def _try_extract_passport_token_from_cookie(request: Request) -> str | None: def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))

View File

@ -82,54 +82,51 @@ class AudioService:
message_id: str | None = None, message_id: str | None = None,
is_draft: bool = False, is_draft: bool = False,
): ):
from app import app
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False): def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
with app.app_context(): if voice is None:
if voice is None: if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if is_draft:
if is_draft: workflow = WorkflowService().get_draft_workflow(app_model=app_model)
workflow = WorkflowService().get_draft_workflow(app_model=app_model) else:
else: workflow = app_model.workflow
workflow = app_model.workflow if (
if ( workflow is None
workflow is None or "text_to_speech" not in workflow.features_dict
or "text_to_speech" not in workflow.features_dict or not workflow.features_dict["text_to_speech"].get("enabled")
or not workflow.features_dict["text_to_speech"].get("enabled") ):
): raise ValueError("TTS is not enabled")
voice = workflow.features_dict["text_to_speech"].get("voice")
else:
if not is_draft:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
if not text_to_speech_dict.get("enabled"):
raise ValueError("TTS is not enabled") raise ValueError("TTS is not enabled")
voice = workflow.features_dict["text_to_speech"].get("voice") voice = text_to_speech_dict.get("voice")
else:
if not is_draft:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
if not text_to_speech_dict.get("enabled"): model_manager = ModelManager()
raise ValueError("TTS is not enabled") model_instance = model_manager.get_default_model_instance(
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
voice = text_to_speech_dict.get("voice") )
try:
model_manager = ModelManager() if not voice:
model_instance = model_manager.get_default_model_instance( voices = model_instance.get_tts_voices()
tenant_id=app_model.tenant_id, model_type=ModelType.TTS if voices:
) voice = voices[0].get("value")
try: if not voice:
if not voice:
voices = model_instance.get_tts_voices()
if voices:
voice = voices[0].get("value")
if not voice:
raise ValueError("Sorry, no voice available.")
else:
raise ValueError("Sorry, no voice available.") raise ValueError("Sorry, no voice available.")
else:
raise ValueError("Sorry, no voice available.")
return model_instance.invoke_tts( return model_instance.invoke_tts(
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
) )
except Exception as e: except Exception as e:
raise e raise e
if message_id: if message_id:
try: try:

View File

@ -283,7 +283,7 @@ class VariableTruncator:
break break
remaining_budget = target_size - used_size remaining_budget = target_size - used_size
if item is None or isinstance(item, (str, list, dict, bool, int, float)): if item is None or isinstance(item, (str, list, dict, bool, int, float, UpdatedVariable)):
part_result = self._truncate_json_primitives(item, remaining_budget) part_result = self._truncate_json_primitives(item, remaining_budget)
else: else:
raise UnknownTypeError(f"got unknown type {type(item)} in array truncation") raise UnknownTypeError(f"got unknown type {type(item)} in array truncation")
@ -373,6 +373,11 @@ class VariableTruncator:
return _PartResult(truncated_obj, used_size, truncated) return _PartResult(truncated_obj, used_size, truncated)
@overload
def _truncate_json_primitives(
self, val: UpdatedVariable, target_size: int
) -> _PartResult[Mapping[str, object]]: ...
@overload @overload
def _truncate_json_primitives(self, val: str, target_size: int) -> _PartResult[str]: ... def _truncate_json_primitives(self, val: str, target_size: int) -> _PartResult[str]: ...

View File

@ -8,6 +8,18 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
class StubCoordinator:
def __init__(self) -> None:
self.state = "initial"
def dumps(self) -> str:
return json.dumps({"state": self.state})
def loads(self, data: str) -> None:
payload = json.loads(data)
self.state = payload["state"]
class TestGraphRuntimeState: class TestGraphRuntimeState:
def test_property_getters_and_setters(self): def test_property_getters_and_setters(self):
# FIXME(-LAN-): Mock VariablePool if needed # FIXME(-LAN-): Mock VariablePool if needed
@ -191,17 +203,6 @@ class TestGraphRuntimeState:
graph_execution.exceptions_count = 4 graph_execution.exceptions_count = 4
graph_execution.started = True graph_execution.started = True
class StubCoordinator:
def __init__(self) -> None:
self.state = "initial"
def dumps(self) -> str:
return json.dumps({"state": self.state})
def loads(self, data: str) -> None:
payload = json.loads(data)
self.state = payload["state"]
mock_graph = MagicMock() mock_graph = MagicMock()
stub = StubCoordinator() stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub): with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub):
@ -211,8 +212,7 @@ class TestGraphRuntimeState:
snapshot = state.dumps() snapshot = state.dumps()
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) restored = GraphRuntimeState.from_snapshot(snapshot)
restored.loads(snapshot)
assert restored.total_tokens == 10 assert restored.total_tokens == 10
assert restored.node_run_steps == 3 assert restored.node_run_steps == 3
@ -235,3 +235,47 @@ class TestGraphRuntimeState:
restored.attach_graph(mock_graph) restored.attach_graph(mock_graph)
assert new_stub.state == "configured" assert new_stub.state == "configured"
def test_loads_rehydrates_existing_instance(self):
variable_pool = VariablePool()
variable_pool.add(("node", "key"), "value")
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
state.total_tokens = 7
state.node_run_steps = 2
state.set_output("foo", "bar")
state.ready_queue.put("node-1")
execution = state.graph_execution
execution.workflow_id = "wf-456"
execution.started = True
mock_graph = MagicMock()
original_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub):
state.attach_graph(mock_graph)
original_stub.state = "configured"
snapshot = state.dumps()
new_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
restored.attach_graph(mock_graph)
restored.loads(snapshot)
assert restored.total_tokens == 7
assert restored.node_run_steps == 2
assert restored.get_output("foo") == "bar"
assert restored.ready_queue.qsize() == 1
assert restored.ready_queue.get(timeout=0.01) == "node-1"
restored_segment = restored.variable_pool.get(("node", "key"))
assert restored_segment is not None
assert restored_segment.value == "value"
restored_execution = restored.graph_execution
assert restored_execution.workflow_id == "wf-456"
assert restored_execution.started is True
assert new_stub.state == "configured"

View File

@ -265,16 +265,18 @@ POSTGRES_MAINTENANCE_WORK_MEM=64MB
POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
# Sets the maximum allowed duration of any statement before termination. # Sets the maximum allowed duration of any statement before termination.
# Default is 60000 milliseconds. # Default is 0 (no timeout).
# #
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT # Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
POSTGRES_STATEMENT_TIMEOUT=60000 # A value of 0 prevents the server from timing out statements.
POSTGRES_STATEMENT_TIMEOUT=0
# Sets the maximum allowed duration of any idle in-transaction session before termination. # Sets the maximum allowed duration of any idle in-transaction session before termination.
# Default is 60000 milliseconds. # Default is 0 (no timeout).
# #
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT # Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=60000 # A value of 0 prevents the server from terminating idle sessions.
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=0
# ------------------------------ # ------------------------------
# Redis Configuration # Redis Configuration

View File

@ -115,8 +115,8 @@ services:
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}' -c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}' -c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-0}'
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}' -c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}'
volumes: volumes:
- ./volumes/db/data:/var/lib/postgresql/data - ./volumes/db/data:/var/lib/postgresql/data
healthcheck: healthcheck:

View File

@ -15,8 +15,8 @@ services:
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}' -c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}' -c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-0}'
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}' -c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}'
volumes: volumes:
- ${PGDATA_HOST_VOLUME:-./volumes/db/data}:/var/lib/postgresql/data - ${PGDATA_HOST_VOLUME:-./volumes/db/data}:/var/lib/postgresql/data
ports: ports:

View File

@ -68,8 +68,8 @@ x-shared-env: &shared-api-worker-env
POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB} POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB}
POSTGRES_MAINTENANCE_WORK_MEM: ${POSTGRES_MAINTENANCE_WORK_MEM:-64MB} POSTGRES_MAINTENANCE_WORK_MEM: ${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}
POSTGRES_EFFECTIVE_CACHE_SIZE: ${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB} POSTGRES_EFFECTIVE_CACHE_SIZE: ${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}
POSTGRES_STATEMENT_TIMEOUT: ${POSTGRES_STATEMENT_TIMEOUT:-60000} POSTGRES_STATEMENT_TIMEOUT: ${POSTGRES_STATEMENT_TIMEOUT:-0}
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT: ${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000} POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT: ${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}
REDIS_HOST: ${REDIS_HOST:-redis} REDIS_HOST: ${REDIS_HOST:-redis}
REDIS_PORT: ${REDIS_PORT:-6379} REDIS_PORT: ${REDIS_PORT:-6379}
REDIS_USERNAME: ${REDIS_USERNAME:-} REDIS_USERNAME: ${REDIS_USERNAME:-}
@ -724,8 +724,8 @@ services:
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}' -c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}' -c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}' -c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}' -c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-0}'
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}' -c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}'
volumes: volumes:
- ./volumes/db/data:/var/lib/postgresql/data - ./volumes/db/data:/var/lib/postgresql/data
healthcheck: healthcheck:

View File

@ -41,16 +41,18 @@ POSTGRES_MAINTENANCE_WORK_MEM=64MB
POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
# Sets the maximum allowed duration of any statement before termination. # Sets the maximum allowed duration of any statement before termination.
# Default is 60000 milliseconds. # Default is 0 (no timeout).
# #
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT # Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
POSTGRES_STATEMENT_TIMEOUT=60000 # A value of 0 prevents the server from timing out statements.
POSTGRES_STATEMENT_TIMEOUT=0
# Sets the maximum allowed duration of any idle in-transaction session before termination. # Sets the maximum allowed duration of any idle in-transaction session before termination.
# Default is 60000 milliseconds. # Default is 0 (no timeout).
# #
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT # Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=60000 # A value of 0 prevents the server from terminating idle sessions.
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=0
# ----------------------------- # -----------------------------
# Environment Variables for redis Service # Environment Variables for redis Service

View File

@ -132,8 +132,6 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
importedVersion: imported_dsl_version ?? '', importedVersion: imported_dsl_version ?? '',
systemVersion: current_dsl_version ?? '', systemVersion: current_dsl_version ?? '',
}) })
if (onClose)
onClose()
setTimeout(() => { setTimeout(() => {
setShowErrorModal(true) setShowErrorModal(true)
}, 300) }, 300)

View File

@ -14,7 +14,6 @@ import timezone from 'dayjs/plugin/timezone'
import { createContext, useContext } from 'use-context-selector' import { createContext, useContext } from 'use-context-selector'
import { useShallow } from 'zustand/react/shallow' import { useShallow } from 'zustand/react/shallow'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
import type { ChatItemInTree } from '../../base/chat/types' import type { ChatItemInTree } from '../../base/chat/types'
import Indicator from '../../header/indicator' import Indicator from '../../header/indicator'
import VarPanel from './var-panel' import VarPanel from './var-panel'
@ -43,10 +42,6 @@ import cn from '@/utils/classnames'
import { noop } from 'lodash-es' import { noop } from 'lodash-es'
import PromptLogModal from '../../base/prompt-log-modal' import PromptLogModal from '../../base/prompt-log-modal'
type AppStoreState = ReturnType<typeof useAppStore.getState>
type ConversationListItem = ChatConversationGeneralDetail | CompletionConversationGeneralDetail
type ConversationSelection = ConversationListItem | { id: string; isPlaceholder?: true }
dayjs.extend(utc) dayjs.extend(utc)
dayjs.extend(timezone) dayjs.extend(timezone)
@ -206,7 +201,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
const { formatTime } = useTimestamp() const { formatTime } = useTimestamp()
const { onClose, appDetail } = useContext(DrawerContext) const { onClose, appDetail } = useContext(DrawerContext)
const { notify } = useContext(ToastContext) const { notify } = useContext(ToastContext)
const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow((state: AppStoreState) => ({ const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({
currentLogItem: state.currentLogItem, currentLogItem: state.currentLogItem,
setCurrentLogItem: state.setCurrentLogItem, setCurrentLogItem: state.setCurrentLogItem,
showMessageLogModal: state.showMessageLogModal, showMessageLogModal: state.showMessageLogModal,
@ -898,113 +893,20 @@ const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string }
const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh }) => { const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh }) => {
const { t } = useTranslation() const { t } = useTranslation()
const { formatTime } = useTimestamp() const { formatTime } = useTimestamp()
const router = useRouter()
const pathname = usePathname()
const searchParams = useSearchParams()
const conversationIdInUrl = searchParams.get('conversation_id') ?? undefined
const media = useBreakpoints() const media = useBreakpoints()
const isMobile = media === MediaType.mobile const isMobile = media === MediaType.mobile
const [showDrawer, setShowDrawer] = useState<boolean>(false) // Whether to display the chat details drawer const [showDrawer, setShowDrawer] = useState<boolean>(false) // Whether to display the chat details drawer
const [currentConversation, setCurrentConversation] = useState<ConversationSelection | undefined>() // Currently selected conversation const [currentConversation, setCurrentConversation] = useState<ChatConversationGeneralDetail | CompletionConversationGeneralDetail | undefined>() // Currently selected conversation
const closingConversationIdRef = useRef<string | null>(null)
const pendingConversationIdRef = useRef<string | null>(null)
const pendingConversationCacheRef = useRef<ConversationSelection | undefined>(undefined)
const isChatMode = appDetail.mode !== 'completion' // Whether the app is a chat app const isChatMode = appDetail.mode !== 'completion' // Whether the app is a chat app
const isChatflow = appDetail.mode === 'advanced-chat' // Whether the app is a chatflow app const isChatflow = appDetail.mode === 'advanced-chat' // Whether the app is a chatflow app
const { setShowPromptLogModal, setShowAgentLogModal, setShowMessageLogModal } = useAppStore(useShallow((state: AppStoreState) => ({ const { setShowPromptLogModal, setShowAgentLogModal, setShowMessageLogModal } = useAppStore(useShallow(state => ({
setShowPromptLogModal: state.setShowPromptLogModal, setShowPromptLogModal: state.setShowPromptLogModal,
setShowAgentLogModal: state.setShowAgentLogModal, setShowAgentLogModal: state.setShowAgentLogModal,
setShowMessageLogModal: state.setShowMessageLogModal, setShowMessageLogModal: state.setShowMessageLogModal,
}))) })))
const activeConversationId = conversationIdInUrl ?? pendingConversationIdRef.current ?? currentConversation?.id
const buildUrlWithConversation = useCallback((conversationId?: string) => {
const params = new URLSearchParams(searchParams.toString())
if (conversationId)
params.set('conversation_id', conversationId)
else
params.delete('conversation_id')
const queryString = params.toString()
return queryString ? `${pathname}?${queryString}` : pathname
}, [pathname, searchParams])
const handleRowClick = useCallback((log: ConversationListItem) => {
if (conversationIdInUrl === log.id) {
if (!showDrawer)
setShowDrawer(true)
if (!currentConversation || currentConversation.id !== log.id)
setCurrentConversation(log)
return
}
pendingConversationIdRef.current = log.id
pendingConversationCacheRef.current = log
if (!showDrawer)
setShowDrawer(true)
if (currentConversation?.id !== log.id)
setCurrentConversation(undefined)
router.push(buildUrlWithConversation(log.id), { scroll: false })
}, [buildUrlWithConversation, conversationIdInUrl, currentConversation, router, showDrawer])
const currentConversationId = currentConversation?.id
useEffect(() => {
if (!conversationIdInUrl) {
if (pendingConversationIdRef.current)
return
if (showDrawer || currentConversationId) {
setShowDrawer(false)
setCurrentConversation(undefined)
}
closingConversationIdRef.current = null
pendingConversationCacheRef.current = undefined
return
}
if (closingConversationIdRef.current === conversationIdInUrl)
return
if (pendingConversationIdRef.current === conversationIdInUrl)
pendingConversationIdRef.current = null
const matchedConversation = logs?.data?.find((item: ConversationListItem) => item.id === conversationIdInUrl)
const nextConversation: ConversationSelection = matchedConversation
?? pendingConversationCacheRef.current
?? { id: conversationIdInUrl, isPlaceholder: true }
if (!showDrawer)
setShowDrawer(true)
if (!currentConversation || currentConversation.id !== conversationIdInUrl || (matchedConversation && currentConversation !== matchedConversation))
setCurrentConversation(nextConversation)
if (pendingConversationCacheRef.current?.id === conversationIdInUrl || matchedConversation)
pendingConversationCacheRef.current = undefined
}, [conversationIdInUrl, currentConversation, isChatMode, logs?.data, showDrawer])
const onCloseDrawer = useCallback(() => {
onRefresh()
setShowDrawer(false)
setCurrentConversation(undefined)
setShowPromptLogModal(false)
setShowAgentLogModal(false)
setShowMessageLogModal(false)
pendingConversationIdRef.current = null
pendingConversationCacheRef.current = undefined
closingConversationIdRef.current = conversationIdInUrl ?? null
if (conversationIdInUrl)
router.replace(buildUrlWithConversation(), { scroll: false })
}, [buildUrlWithConversation, conversationIdInUrl, onRefresh, router, setShowAgentLogModal, setShowMessageLogModal, setShowPromptLogModal])
// Annotated data needs to be highlighted // Annotated data needs to be highlighted
const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => { const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => {
return ( return (
@ -1023,6 +925,15 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
) )
} }
const onCloseDrawer = () => {
onRefresh()
setShowDrawer(false)
setCurrentConversation(undefined)
setShowPromptLogModal(false)
setShowAgentLogModal(false)
setShowMessageLogModal(false)
}
if (!logs) if (!logs)
return <Loading /> return <Loading />
@ -1049,8 +960,11 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
const rightValue = get(log, isChatMode ? 'message_count' : 'message.answer') const rightValue = get(log, isChatMode ? 'message_count' : 'message.answer')
return <tr return <tr
key={log.id} key={log.id}
className={cn('cursor-pointer border-b border-divider-subtle hover:bg-background-default-hover', activeConversationId !== log.id ? '' : 'bg-background-default-hover')} className={cn('cursor-pointer border-b border-divider-subtle hover:bg-background-default-hover', currentConversation?.id !== log.id ? '' : 'bg-background-default-hover')}
onClick={() => handleRowClick(log)}> onClick={() => {
setShowDrawer(true)
setCurrentConversation(log)
}}>
<td className='h-4'> <td className='h-4'>
{!log.read_at && ( {!log.read_at && (
<div className='flex items-center p-3 pr-0.5'> <div className='flex items-center p-3 pr-0.5'>

View File

@ -4,6 +4,7 @@ import {
import { produce } from 'immer' import { produce } from 'immer'
import { import {
useReactFlow, useReactFlow,
useStoreApi,
useViewport, useViewport,
} from 'reactflow' } from 'reactflow'
import { useEventListener } from 'ahooks' import { useEventListener } from 'ahooks'
@ -12,15 +13,15 @@ import {
useWorkflowStore, useWorkflowStore,
} from './store' } from './store'
import { WorkflowHistoryEvent, useNodesInteractions, useWorkflowHistory } from './hooks' import { WorkflowHistoryEvent, useNodesInteractions, useWorkflowHistory } from './hooks'
import { CUSTOM_NODE } from './constants' import { CUSTOM_NODE, ITERATION_PADDING } from './constants'
import { getIterationStartNode, getLoopStartNode } from './utils' import { getIterationStartNode, getLoopStartNode } from './utils'
import CustomNode from './nodes' import CustomNode from './nodes'
import CustomNoteNode from './note-node' import CustomNoteNode from './note-node'
import { CUSTOM_NOTE_NODE } from './note-node/constants' import { CUSTOM_NOTE_NODE } from './note-node/constants'
import { BlockEnum } from './types' import { BlockEnum } from './types'
import { useCollaborativeWorkflow } from '@/app/components/workflow/hooks/use-collaborative-workflow'
const CandidateNode = () => { const CandidateNode = () => {
const store = useStoreApi()
const reactflow = useReactFlow() const reactflow = useReactFlow()
const workflowStore = useWorkflowStore() const workflowStore = useWorkflowStore()
const candidateNode = useStore(s => s.candidateNode) const candidateNode = useStore(s => s.candidateNode)
@ -28,16 +29,45 @@ const CandidateNode = () => {
const { zoom } = useViewport() const { zoom } = useViewport()
const { handleNodeSelect } = useNodesInteractions() const { handleNodeSelect } = useNodesInteractions()
const { saveStateToHistory } = useWorkflowHistory() const { saveStateToHistory } = useWorkflowHistory()
const collaborativeWorkflow = useCollaborativeWorkflow()
useEventListener('click', (e) => { useEventListener('click', (e) => {
const { candidateNode, mousePosition } = workflowStore.getState() const { candidateNode, mousePosition } = workflowStore.getState()
if (candidateNode) { if (candidateNode) {
e.preventDefault() e.preventDefault()
const { nodes, setNodes } = collaborativeWorkflow.getState() const {
getNodes,
setNodes,
} = store.getState()
const { screenToFlowPosition } = reactflow const { screenToFlowPosition } = reactflow
const { x, y } = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY }) const nodes = getNodes()
// Get mouse position in flow coordinates (this is where the top-left corner should be)
let { x, y } = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY })
// If the node has a parent (e.g., inside iteration), apply constraints and convert to relative position
if (candidateNode.parentId) {
const parentNode = nodes.find(node => node.id === candidateNode.parentId)
if (parentNode && parentNode.position) {
// Apply boundary constraints for iteration nodes
if (candidateNode.data.isInIteration) {
const nodeWidth = candidateNode.width || 0
const nodeHeight = candidateNode.height || 0
const minX = parentNode.position.x + ITERATION_PADDING.left
const maxX = parentNode.position.x + (parentNode.width || 0) - ITERATION_PADDING.right - nodeWidth
const minY = parentNode.position.y + ITERATION_PADDING.top
const maxY = parentNode.position.y + (parentNode.height || 0) - ITERATION_PADDING.bottom - nodeHeight
// Constrain position
x = Math.max(minX, Math.min(maxX, x))
y = Math.max(minY, Math.min(maxY, y))
}
// Convert to relative position
x = x - parentNode.position.x
y = y - parentNode.position.y
}
}
const newNodes = produce(nodes, (draft) => { const newNodes = produce(nodes, (draft) => {
draft.push({ draft.push({
...candidateNode, ...candidateNode,
@ -55,6 +85,20 @@ const CandidateNode = () => {
if (candidateNode.data.type === BlockEnum.Loop) if (candidateNode.data.type === BlockEnum.Loop)
draft.push(getLoopStartNode(candidateNode.id)) draft.push(getLoopStartNode(candidateNode.id))
// Update parent iteration node's _children array
if (candidateNode.parentId && candidateNode.data.isInIteration) {
const parentNode = draft.find(node => node.id === candidateNode.parentId)
if (parentNode && parentNode.data.type === BlockEnum.Iteration) {
if (!parentNode.data._children)
parentNode.data._children = []
parentNode.data._children.push({
nodeId: candidateNode.id,
nodeType: candidateNode.data.type,
})
}
}
}) })
setNodes(newNodes) setNodes(newNodes)
if (candidateNode.type === CUSTOM_NOTE_NODE) if (candidateNode.type === CUSTOM_NOTE_NODE)
@ -80,6 +124,34 @@ const CandidateNode = () => {
if (!candidateNode) if (!candidateNode)
return null return null
// Apply boundary constraints if node is inside iteration
if (candidateNode.parentId && candidateNode.data.isInIteration) {
const { getNodes } = store.getState()
const nodes = getNodes()
const parentNode = nodes.find(node => node.id === candidateNode.parentId)
if (parentNode && parentNode.position) {
const { screenToFlowPosition, flowToScreenPosition } = reactflow
// Get mouse position in flow coordinates
const flowPosition = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY })
// Calculate boundaries in flow coordinates
const nodeWidth = candidateNode.width || 0
const nodeHeight = candidateNode.height || 0
const minX = parentNode.position.x + ITERATION_PADDING.left
const maxX = parentNode.position.x + (parentNode.width || 0) - ITERATION_PADDING.right - nodeWidth
const minY = parentNode.position.y + ITERATION_PADDING.top
const maxY = parentNode.position.y + (parentNode.height || 0) - ITERATION_PADDING.bottom - nodeHeight
// Constrain position
const constrainedX = Math.max(minX, Math.min(maxX, flowPosition.x))
const constrainedY = Math.max(minY, Math.min(maxY, flowPosition.y))
// Convert back to screen coordinates
flowToScreenPosition({ x: constrainedX, y: constrainedY })
}
}
return ( return (
<div <div
className='absolute z-10' className='absolute z-10'

View File

@ -125,12 +125,12 @@ export const useNodesInteractions = () => {
const { restrictPosition } = handleNodeIterationChildDrag(node) const { restrictPosition } = handleNodeIterationChildDrag(node)
const { restrictPosition: restrictLoopPosition } const { restrictPosition: restrictLoopPosition }
= handleNodeLoopChildDrag(node) = handleNodeLoopChildDrag(node)
const { showHorizontalHelpLineNodes, showVerticalHelpLineNodes } const { showHorizontalHelpLineNodes, showVerticalHelpLineNodes }
= handleSetHelpline(node) = handleSetHelpline(node)
const showHorizontalHelpLineNodesLength const showHorizontalHelpLineNodesLength
= showHorizontalHelpLineNodes.length = showHorizontalHelpLineNodes.length
const showVerticalHelpLineNodesLength = showVerticalHelpLineNodes.length const showVerticalHelpLineNodesLength = showVerticalHelpLineNodes.length
const newNodes = produce(nodes, (draft) => { const newNodes = produce(nodes, (draft) => {
@ -716,7 +716,7 @@ export const useNodesInteractions = () => {
targetHandle = 'target', targetHandle = 'target',
toolDefaultValue, toolDefaultValue,
}, },
{ prevNodeId, prevNodeSourceHandle, nextNodeId, nextNodeTargetHandle }, { prevNodeId, prevNodeSourceHandle, nextNodeId, nextNodeTargetHandle, skipAutoConnect },
) => { ) => {
if (getNodesReadOnly()) return if (getNodesReadOnly()) return
@ -808,7 +808,7 @@ export const useNodesInteractions = () => {
} }
let newEdge = null let newEdge = null
if (nodeType !== BlockEnum.DataSource) { if (nodeType !== BlockEnum.DataSource && !skipAutoConnect) {
newEdge = { newEdge = {
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`, id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
type: CUSTOM_EDGE, type: CUSTOM_EDGE,
@ -948,6 +948,7 @@ export const useNodesInteractions = () => {
nodeType !== BlockEnum.IfElse nodeType !== BlockEnum.IfElse
&& nodeType !== BlockEnum.QuestionClassifier && nodeType !== BlockEnum.QuestionClassifier
&& nodeType !== BlockEnum.LoopEnd && nodeType !== BlockEnum.LoopEnd
&& !skipAutoConnect
) { ) {
newEdge = { newEdge = {
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`, id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`,
@ -1097,7 +1098,7 @@ export const useNodesInteractions = () => {
) )
let newPrevEdge = null let newPrevEdge = null
if (nodeType !== BlockEnum.DataSource) { if (nodeType !== BlockEnum.DataSource && !skipAutoConnect) {
newPrevEdge = { newPrevEdge = {
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`, id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
type: CUSTOM_EDGE, type: CUSTOM_EDGE,
@ -1137,6 +1138,7 @@ export const useNodesInteractions = () => {
nodeType !== BlockEnum.IfElse nodeType !== BlockEnum.IfElse
&& nodeType !== BlockEnum.QuestionClassifier && nodeType !== BlockEnum.QuestionClassifier
&& nodeType !== BlockEnum.LoopEnd && nodeType !== BlockEnum.LoopEnd
&& !skipAutoConnect
) { ) {
newNextEdge = { newNextEdge = {
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`, id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`,

View File

@ -15,7 +15,9 @@ import {
useNodesSyncDraft, useNodesSyncDraft,
} from '@/app/components/workflow/hooks' } from '@/app/components/workflow/hooks'
import ShortcutsName from '@/app/components/workflow/shortcuts-name' import ShortcutsName from '@/app/components/workflow/shortcuts-name'
import type { Node } from '@/app/components/workflow/types' import { BlockEnum, type Node } from '@/app/components/workflow/types'
import PanelAddBlock from '@/app/components/workflow/nodes/iteration/panel-add-block'
import type { IterationNodeType } from '@/app/components/workflow/nodes/iteration/types'
type PanelOperatorPopupProps = { type PanelOperatorPopupProps = {
id: string id: string
@ -51,6 +53,9 @@ const PanelOperatorPopup = ({
(showChangeBlock || canRunBySingle(data.type, isChildNode)) && ( (showChangeBlock || canRunBySingle(data.type, isChildNode)) && (
<> <>
<div className='p-1'> <div className='p-1'>
{data.type === BlockEnum.Iteration && (
<PanelAddBlock iterationNodeData={data as IterationNodeType} onClosePopup={onClosePopup}/>
)}
{ {
canRunBySingle(data.type, isChildNode) && ( canRunBySingle(data.type, isChildNode) && (
<div <div

View File

@ -0,0 +1,116 @@
import {
memo,
useCallback,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import type { OffsetOptions } from '@floating-ui/react'
import { useStoreApi } from 'reactflow'
import BlockSelector from '@/app/components/workflow/block-selector'
import type {
OnSelectBlock,
} from '@/app/components/workflow/types'
import {
BlockEnum,
} from '@/app/components/workflow/types'
import { useAvailableBlocks, useNodesMetaData, useNodesReadOnly, usePanelInteractions } from '../../hooks'
import type { IterationNodeType } from './types'
import { useWorkflowStore } from '../../store'
import { generateNewNode, getNodeCustomTypeByNodeDataType } from '../../utils'
import { ITERATION_CHILDREN_Z_INDEX } from '../../constants'
type AddBlockProps = {
renderTrigger?: (open: boolean) => React.ReactNode
offset?: OffsetOptions
iterationNodeData: IterationNodeType
onClosePopup: () => void
}
const AddBlock = ({
offset,
iterationNodeData,
onClosePopup,
}: AddBlockProps) => {
const { t } = useTranslation()
const store = useStoreApi()
const workflowStore = useWorkflowStore()
const { nodesReadOnly } = useNodesReadOnly()
const { handlePaneContextmenuCancel } = usePanelInteractions()
const [open, setOpen] = useState(false)
const { availableNextBlocks } = useAvailableBlocks(BlockEnum.Start, false)
const { nodesMap: nodesMetaDataMap } = useNodesMetaData()
const handleOpenChange = useCallback((open: boolean) => {
setOpen(open)
if (!open)
handlePaneContextmenuCancel()
}, [handlePaneContextmenuCancel])
const handleSelect = useCallback<OnSelectBlock>((type, toolDefaultValue) => {
const { getNodes } = store.getState()
const nodes = getNodes()
const nodesWithSameType = nodes.filter(node => node.data.type === type)
const { defaultValue } = nodesMetaDataMap![type]
// Find the parent iteration node
const parentIterationNode = nodes.find(node => node.data.start_node_id === iterationNodeData.start_node_id)
const { newNode } = generateNewNode({
type: getNodeCustomTypeByNodeDataType(type),
data: {
...(defaultValue as any),
title: nodesWithSameType.length > 0 ? `${defaultValue.title} ${nodesWithSameType.length + 1}` : defaultValue.title,
...toolDefaultValue,
_isCandidate: true,
// Set iteration-specific properties
isInIteration: true,
iteration_id: parentIterationNode?.id,
},
position: {
x: 0,
y: 0,
},
})
// Set parent and z-index for iteration child
if (parentIterationNode) {
newNode.parentId = parentIterationNode.id
newNode.extent = 'parent' as any
newNode.zIndex = ITERATION_CHILDREN_Z_INDEX
}
workflowStore.setState({
candidateNode: newNode,
})
onClosePopup()
}, [store, workflowStore, nodesMetaDataMap, iterationNodeData.start_node_id, onClosePopup])
const renderTrigger = () => {
return (
<div
className='flex h-8 cursor-pointer items-center justify-between rounded-lg px-3 text-sm text-text-secondary hover:bg-state-base-hover'
>
{t('workflow.common.addBlock')}
</div>
)
}
return (
<BlockSelector
open={open}
onOpenChange={handleOpenChange}
disabled={nodesReadOnly}
onSelect={handleSelect}
placement='right-start'
offset={offset ?? {
mainAxis: 4,
crossAxis: -8,
}}
trigger={renderTrigger}
popupClassName='!min-w-[256px]'
availableBlocksTypes={availableNextBlocks}
/>
)
}
export default memo(AddBlock)

View File

@ -380,6 +380,7 @@ export type OnNodeAdd = (
prevNodeSourceHandle?: string prevNodeSourceHandle?: string
nextNodeId?: string nextNodeId?: string
nextNodeTargetHandle?: string nextNodeTargetHandle?: string
skipAutoConnect?: boolean
}, },
) => void ) => void