mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 12:37:20 +08:00
Merge branch 'feat/iteration-node' into deploy/dev
This commit is contained in:
commit
3d08c79c3e
@ -32,6 +32,7 @@ from libs.token import (
|
|||||||
clear_csrf_token_from_cookie,
|
clear_csrf_token_from_cookie,
|
||||||
clear_refresh_token_from_cookie,
|
clear_refresh_token_from_cookie,
|
||||||
extract_access_token,
|
extract_access_token,
|
||||||
|
extract_refresh_token,
|
||||||
set_access_token_to_cookie,
|
set_access_token_to_cookie,
|
||||||
set_csrf_token_to_cookie,
|
set_csrf_token_to_cookie,
|
||||||
set_refresh_token_to_cookie,
|
set_refresh_token_to_cookie,
|
||||||
@ -273,7 +274,7 @@ class EmailCodeLoginApi(Resource):
|
|||||||
class RefreshTokenApi(Resource):
|
class RefreshTokenApi(Resource):
|
||||||
def post(self):
|
def post(self):
|
||||||
# Get refresh token from cookie instead of request body
|
# Get refresh token from cookie instead of request body
|
||||||
refresh_token = request.cookies.get("refresh_token")
|
refresh_token = extract_refresh_token(request)
|
||||||
|
|
||||||
if not refresh_token:
|
if not refresh_token:
|
||||||
return {"result": "fail", "message": "No refresh token provided"}, 401
|
return {"result": "fail", "message": "No refresh token provided"}, 401
|
||||||
|
|||||||
@ -193,15 +193,19 @@ class QuestionClassifierNode(Node):
|
|||||||
finish_reason = event.finish_reason
|
finish_reason = event.finish_reason
|
||||||
break
|
break
|
||||||
|
|
||||||
category_name = node_data.classes[0].name
|
rendered_classes = [
|
||||||
category_id = node_data.classes[0].id
|
c.model_copy(update={"name": variable_pool.convert_template(c.name).text}) for c in node_data.classes
|
||||||
|
]
|
||||||
|
|
||||||
|
category_name = rendered_classes[0].name
|
||||||
|
category_id = rendered_classes[0].id
|
||||||
if "<think>" in result_text:
|
if "<think>" in result_text:
|
||||||
result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE)
|
result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE)
|
||||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||||
# result_text_json = json.loads(result_text.strip('```JSON\n'))
|
# result_text_json = json.loads(result_text.strip('```JSON\n'))
|
||||||
if "category_name" in result_text_json and "category_id" in result_text_json:
|
if "category_name" in result_text_json and "category_id" in result_text_json:
|
||||||
category_id_result = result_text_json["category_id"]
|
category_id_result = result_text_json["category_id"]
|
||||||
classes = node_data.classes
|
classes = rendered_classes
|
||||||
classes_map = {class_.id: class_.name for class_ in classes}
|
classes_map = {class_.id: class_.name for class_ in classes}
|
||||||
category_ids = [_class.id for _class in classes]
|
category_ids = [_class.id for _class in classes]
|
||||||
if category_id_result in category_ids:
|
if category_id_result in category_ids:
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import json
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from collections.abc import Mapping as TypingMapping
|
from collections.abc import Mapping as TypingMapping
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from pydantic.json import pydantic_encoder
|
from pydantic.json import pydantic_encoder
|
||||||
@ -106,6 +107,23 @@ class GraphProtocol(Protocol):
|
|||||||
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
|
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _GraphRuntimeStateSnapshot:
|
||||||
|
"""Immutable view of a serialized runtime state snapshot."""
|
||||||
|
|
||||||
|
start_at: float
|
||||||
|
total_tokens: int
|
||||||
|
node_run_steps: int
|
||||||
|
llm_usage: LLMUsage
|
||||||
|
outputs: dict[str, Any]
|
||||||
|
variable_pool: VariablePool
|
||||||
|
has_variable_pool: bool
|
||||||
|
ready_queue_dump: str | None
|
||||||
|
graph_execution_dump: str | None
|
||||||
|
response_coordinator_dump: str | None
|
||||||
|
paused_nodes: tuple[str, ...]
|
||||||
|
|
||||||
|
|
||||||
class GraphRuntimeState:
|
class GraphRuntimeState:
|
||||||
"""Mutable runtime state shared across graph execution components."""
|
"""Mutable runtime state shared across graph execution components."""
|
||||||
|
|
||||||
@ -293,69 +311,28 @@ class GraphRuntimeState:
|
|||||||
|
|
||||||
return json.dumps(snapshot, default=pydantic_encoder)
|
return json.dumps(snapshot, default=pydantic_encoder)
|
||||||
|
|
||||||
def loads(self, data: str | Mapping[str, Any]) -> None:
|
@classmethod
|
||||||
|
def from_snapshot(cls, data: str | Mapping[str, Any]) -> GraphRuntimeState:
|
||||||
"""Restore runtime state from a serialized snapshot."""
|
"""Restore runtime state from a serialized snapshot."""
|
||||||
|
|
||||||
payload: dict[str, Any]
|
snapshot = cls._parse_snapshot_payload(data)
|
||||||
if isinstance(data, str):
|
|
||||||
payload = json.loads(data)
|
|
||||||
else:
|
|
||||||
payload = dict(data)
|
|
||||||
|
|
||||||
version = payload.get("version")
|
state = cls(
|
||||||
if version != "1.0":
|
variable_pool=snapshot.variable_pool,
|
||||||
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
|
start_at=snapshot.start_at,
|
||||||
|
total_tokens=snapshot.total_tokens,
|
||||||
|
llm_usage=snapshot.llm_usage,
|
||||||
|
outputs=snapshot.outputs,
|
||||||
|
node_run_steps=snapshot.node_run_steps,
|
||||||
|
)
|
||||||
|
state._apply_snapshot(snapshot)
|
||||||
|
return state
|
||||||
|
|
||||||
self._start_at = float(payload.get("start_at", 0.0))
|
def loads(self, data: str | Mapping[str, Any]) -> None:
|
||||||
total_tokens = int(payload.get("total_tokens", 0))
|
"""Restore runtime state from a serialized snapshot (legacy API)."""
|
||||||
if total_tokens < 0:
|
|
||||||
raise ValueError("total_tokens must be non-negative")
|
|
||||||
self._total_tokens = total_tokens
|
|
||||||
|
|
||||||
node_run_steps = int(payload.get("node_run_steps", 0))
|
snapshot = self._parse_snapshot_payload(data)
|
||||||
if node_run_steps < 0:
|
self._apply_snapshot(snapshot)
|
||||||
raise ValueError("node_run_steps must be non-negative")
|
|
||||||
self._node_run_steps = node_run_steps
|
|
||||||
|
|
||||||
llm_usage_payload = payload.get("llm_usage", {})
|
|
||||||
self._llm_usage = LLMUsage.model_validate(llm_usage_payload)
|
|
||||||
|
|
||||||
self._outputs = deepcopy(payload.get("outputs", {}))
|
|
||||||
|
|
||||||
variable_pool_payload = payload.get("variable_pool")
|
|
||||||
if variable_pool_payload is not None:
|
|
||||||
self._variable_pool = VariablePool.model_validate(variable_pool_payload)
|
|
||||||
|
|
||||||
ready_queue_payload = payload.get("ready_queue")
|
|
||||||
if ready_queue_payload is not None:
|
|
||||||
self._ready_queue = self._build_ready_queue()
|
|
||||||
self._ready_queue.loads(ready_queue_payload)
|
|
||||||
else:
|
|
||||||
self._ready_queue = None
|
|
||||||
|
|
||||||
graph_execution_payload = payload.get("graph_execution")
|
|
||||||
self._graph_execution = None
|
|
||||||
self._pending_graph_execution_workflow_id = None
|
|
||||||
if graph_execution_payload is not None:
|
|
||||||
try:
|
|
||||||
execution_payload = json.loads(graph_execution_payload)
|
|
||||||
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
|
|
||||||
except (json.JSONDecodeError, TypeError, AttributeError):
|
|
||||||
self._pending_graph_execution_workflow_id = None
|
|
||||||
self.graph_execution.loads(graph_execution_payload)
|
|
||||||
|
|
||||||
response_payload = payload.get("response_coordinator")
|
|
||||||
if response_payload is not None:
|
|
||||||
if self._graph is not None:
|
|
||||||
self.response_coordinator.loads(response_payload)
|
|
||||||
else:
|
|
||||||
self._pending_response_coordinator_dump = response_payload
|
|
||||||
else:
|
|
||||||
self._pending_response_coordinator_dump = None
|
|
||||||
self._response_coordinator = None
|
|
||||||
|
|
||||||
paused_nodes_payload = payload.get("paused_nodes", [])
|
|
||||||
self._paused_nodes = set(map(str, paused_nodes_payload))
|
|
||||||
|
|
||||||
def register_paused_node(self, node_id: str) -> None:
|
def register_paused_node(self, node_id: str) -> None:
|
||||||
"""Record a node that should resume when execution is continued."""
|
"""Record a node that should resume when execution is continued."""
|
||||||
@ -391,3 +368,106 @@ class GraphRuntimeState:
|
|||||||
module = importlib.import_module("core.workflow.graph_engine.response_coordinator")
|
module = importlib.import_module("core.workflow.graph_engine.response_coordinator")
|
||||||
coordinator_cls = module.ResponseStreamCoordinator
|
coordinator_cls = module.ResponseStreamCoordinator
|
||||||
return coordinator_cls(variable_pool=self.variable_pool, graph=graph)
|
return coordinator_cls(variable_pool=self.variable_pool, graph=graph)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Snapshot helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
@classmethod
|
||||||
|
def _parse_snapshot_payload(cls, data: str | Mapping[str, Any]) -> _GraphRuntimeStateSnapshot:
|
||||||
|
payload: dict[str, Any]
|
||||||
|
if isinstance(data, str):
|
||||||
|
payload = json.loads(data)
|
||||||
|
else:
|
||||||
|
payload = dict(data)
|
||||||
|
|
||||||
|
version = payload.get("version")
|
||||||
|
if version != "1.0":
|
||||||
|
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
|
||||||
|
|
||||||
|
start_at = float(payload.get("start_at", 0.0))
|
||||||
|
|
||||||
|
total_tokens = int(payload.get("total_tokens", 0))
|
||||||
|
if total_tokens < 0:
|
||||||
|
raise ValueError("total_tokens must be non-negative")
|
||||||
|
|
||||||
|
node_run_steps = int(payload.get("node_run_steps", 0))
|
||||||
|
if node_run_steps < 0:
|
||||||
|
raise ValueError("node_run_steps must be non-negative")
|
||||||
|
|
||||||
|
llm_usage_payload = payload.get("llm_usage", {})
|
||||||
|
llm_usage = LLMUsage.model_validate(llm_usage_payload)
|
||||||
|
|
||||||
|
outputs_payload = deepcopy(payload.get("outputs", {}))
|
||||||
|
|
||||||
|
variable_pool_payload = payload.get("variable_pool")
|
||||||
|
has_variable_pool = variable_pool_payload is not None
|
||||||
|
variable_pool = VariablePool.model_validate(variable_pool_payload) if has_variable_pool else VariablePool()
|
||||||
|
|
||||||
|
ready_queue_payload = payload.get("ready_queue")
|
||||||
|
graph_execution_payload = payload.get("graph_execution")
|
||||||
|
response_payload = payload.get("response_coordinator")
|
||||||
|
paused_nodes_payload = payload.get("paused_nodes", [])
|
||||||
|
|
||||||
|
return _GraphRuntimeStateSnapshot(
|
||||||
|
start_at=start_at,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
node_run_steps=node_run_steps,
|
||||||
|
llm_usage=llm_usage,
|
||||||
|
outputs=outputs_payload,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
has_variable_pool=has_variable_pool,
|
||||||
|
ready_queue_dump=ready_queue_payload,
|
||||||
|
graph_execution_dump=graph_execution_payload,
|
||||||
|
response_coordinator_dump=response_payload,
|
||||||
|
paused_nodes=tuple(map(str, paused_nodes_payload)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
|
||||||
|
self._start_at = snapshot.start_at
|
||||||
|
self._total_tokens = snapshot.total_tokens
|
||||||
|
self._node_run_steps = snapshot.node_run_steps
|
||||||
|
self._llm_usage = snapshot.llm_usage.model_copy()
|
||||||
|
self._outputs = deepcopy(snapshot.outputs)
|
||||||
|
if snapshot.has_variable_pool or self._variable_pool is None:
|
||||||
|
self._variable_pool = snapshot.variable_pool
|
||||||
|
|
||||||
|
self._restore_ready_queue(snapshot.ready_queue_dump)
|
||||||
|
self._restore_graph_execution(snapshot.graph_execution_dump)
|
||||||
|
self._restore_response_coordinator(snapshot.response_coordinator_dump)
|
||||||
|
self._paused_nodes = set(snapshot.paused_nodes)
|
||||||
|
|
||||||
|
def _restore_ready_queue(self, payload: str | None) -> None:
|
||||||
|
if payload is not None:
|
||||||
|
self._ready_queue = self._build_ready_queue()
|
||||||
|
self._ready_queue.loads(payload)
|
||||||
|
else:
|
||||||
|
self._ready_queue = None
|
||||||
|
|
||||||
|
def _restore_graph_execution(self, payload: str | None) -> None:
|
||||||
|
self._graph_execution = None
|
||||||
|
self._pending_graph_execution_workflow_id = None
|
||||||
|
|
||||||
|
if payload is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
execution_payload = json.loads(payload)
|
||||||
|
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
|
||||||
|
except (json.JSONDecodeError, TypeError, AttributeError):
|
||||||
|
self._pending_graph_execution_workflow_id = None
|
||||||
|
|
||||||
|
self.graph_execution.loads(payload)
|
||||||
|
|
||||||
|
def _restore_response_coordinator(self, payload: str | None) -> None:
|
||||||
|
if payload is None:
|
||||||
|
self._pending_response_coordinator_dump = None
|
||||||
|
self._response_coordinator = None
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._graph is not None:
|
||||||
|
self.response_coordinator.loads(payload)
|
||||||
|
self._pending_response_coordinator_dump = None
|
||||||
|
return
|
||||||
|
|
||||||
|
self._pending_response_coordinator_dump = payload
|
||||||
|
self._response_coordinator = None
|
||||||
|
|||||||
@ -6,10 +6,11 @@ from flask_login import user_loaded_from_request, user_logged_in
|
|||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from constants import HEADER_NAME_APP_CODE
|
||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
from libs.token import extract_access_token
|
from libs.token import extract_access_token, extract_webapp_passport
|
||||||
from models import Account, Tenant, TenantAccountJoin
|
from models import Account, Tenant, TenantAccountJoin
|
||||||
from models.model import AppMCPServer, EndUser
|
from models.model import AppMCPServer, EndUser
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
@ -61,14 +62,30 @@ def load_user_from_request(request_from_flask_login):
|
|||||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||||
return logged_in_account
|
return logged_in_account
|
||||||
elif request.blueprint == "web":
|
elif request.blueprint == "web":
|
||||||
decoded = PassportService().verify(auth_token)
|
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||||
end_user_id = decoded.get("end_user_id")
|
webapp_token = extract_webapp_passport(app_code, request) if app_code else None
|
||||||
if not end_user_id:
|
|
||||||
raise Unauthorized("Invalid Authorization token.")
|
if webapp_token:
|
||||||
end_user = db.session.query(EndUser).where(EndUser.id == decoded["end_user_id"]).first()
|
decoded = PassportService().verify(webapp_token)
|
||||||
if not end_user:
|
end_user_id = decoded.get("end_user_id")
|
||||||
raise NotFound("End user not found.")
|
if not end_user_id:
|
||||||
return end_user
|
raise Unauthorized("Invalid Authorization token.")
|
||||||
|
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||||
|
if not end_user:
|
||||||
|
raise NotFound("End user not found.")
|
||||||
|
return end_user
|
||||||
|
else:
|
||||||
|
if not auth_token:
|
||||||
|
raise Unauthorized("Invalid Authorization token.")
|
||||||
|
decoded = PassportService().verify(auth_token)
|
||||||
|
end_user_id = decoded.get("end_user_id")
|
||||||
|
if end_user_id:
|
||||||
|
end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first()
|
||||||
|
if not end_user:
|
||||||
|
raise NotFound("End user not found.")
|
||||||
|
return end_user
|
||||||
|
else:
|
||||||
|
raise Unauthorized("Invalid Authorization token for web API.")
|
||||||
elif request.blueprint == "mcp":
|
elif request.blueprint == "mcp":
|
||||||
server_code = request.view_args.get("server_code") if request.view_args else None
|
server_code = request.view_args.get("server_code") if request.view_args else None
|
||||||
if not server_code:
|
if not server_code:
|
||||||
|
|||||||
@ -38,9 +38,6 @@ def _real_cookie_name(cookie_name: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _try_extract_from_header(request: Request) -> str | None:
|
def _try_extract_from_header(request: Request) -> str | None:
|
||||||
"""
|
|
||||||
Try to extract access token from header
|
|
||||||
"""
|
|
||||||
auth_header = request.headers.get("Authorization")
|
auth_header = request.headers.get("Authorization")
|
||||||
if auth_header:
|
if auth_header:
|
||||||
if " " not in auth_header:
|
if " " not in auth_header:
|
||||||
@ -55,27 +52,19 @@ def _try_extract_from_header(request: Request) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_refresh_token(request: Request) -> str | None:
|
||||||
|
return request.cookies.get(_real_cookie_name(COOKIE_NAME_REFRESH_TOKEN))
|
||||||
|
|
||||||
|
|
||||||
def extract_csrf_token(request: Request) -> str | None:
|
def extract_csrf_token(request: Request) -> str | None:
|
||||||
"""
|
|
||||||
Try to extract CSRF token from header or cookie.
|
|
||||||
"""
|
|
||||||
return request.headers.get(HEADER_NAME_CSRF_TOKEN)
|
return request.headers.get(HEADER_NAME_CSRF_TOKEN)
|
||||||
|
|
||||||
|
|
||||||
def extract_csrf_token_from_cookie(request: Request) -> str | None:
|
def extract_csrf_token_from_cookie(request: Request) -> str | None:
|
||||||
"""
|
|
||||||
Try to extract CSRF token from cookie.
|
|
||||||
"""
|
|
||||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
|
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
|
||||||
|
|
||||||
|
|
||||||
def extract_access_token(request: Request) -> str | None:
|
def extract_access_token(request: Request) -> str | None:
|
||||||
"""
|
|
||||||
Try to extract access token from cookie, header or params.
|
|
||||||
|
|
||||||
Access token is either for console session or webapp passport exchange.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _try_extract_from_cookie(request: Request) -> str | None:
|
def _try_extract_from_cookie(request: Request) -> str | None:
|
||||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
||||||
|
|
||||||
@ -83,20 +72,10 @@ def extract_access_token(request: Request) -> str | None:
|
|||||||
|
|
||||||
|
|
||||||
def extract_webapp_access_token(request: Request) -> str | None:
|
def extract_webapp_access_token(request: Request) -> str | None:
|
||||||
"""
|
|
||||||
Try to extract webapp access token from cookie, then header.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
|
return request.cookies.get(_real_cookie_name(COOKIE_NAME_WEBAPP_ACCESS_TOKEN)) or _try_extract_from_header(request)
|
||||||
|
|
||||||
|
|
||||||
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
|
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
|
||||||
"""
|
|
||||||
Try to extract app token from header or params.
|
|
||||||
|
|
||||||
Webapp access token (part of passport) is only used for webapp session.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
|
def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
|
||||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))
|
return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))
|
||||||
|
|
||||||
|
|||||||
@ -82,54 +82,51 @@ class AudioService:
|
|||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
is_draft: bool = False,
|
is_draft: bool = False,
|
||||||
):
|
):
|
||||||
from app import app
|
|
||||||
|
|
||||||
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
|
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
|
||||||
with app.app_context():
|
if voice is None:
|
||||||
if voice is None:
|
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
if is_draft:
|
||||||
if is_draft:
|
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
|
||||||
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
|
else:
|
||||||
else:
|
workflow = app_model.workflow
|
||||||
workflow = app_model.workflow
|
if (
|
||||||
if (
|
workflow is None
|
||||||
workflow is None
|
or "text_to_speech" not in workflow.features_dict
|
||||||
or "text_to_speech" not in workflow.features_dict
|
or not workflow.features_dict["text_to_speech"].get("enabled")
|
||||||
or not workflow.features_dict["text_to_speech"].get("enabled")
|
):
|
||||||
):
|
raise ValueError("TTS is not enabled")
|
||||||
|
|
||||||
|
voice = workflow.features_dict["text_to_speech"].get("voice")
|
||||||
|
else:
|
||||||
|
if not is_draft:
|
||||||
|
if app_model.app_model_config is None:
|
||||||
|
raise ValueError("AppModelConfig not found")
|
||||||
|
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
|
||||||
|
|
||||||
|
if not text_to_speech_dict.get("enabled"):
|
||||||
raise ValueError("TTS is not enabled")
|
raise ValueError("TTS is not enabled")
|
||||||
|
|
||||||
voice = workflow.features_dict["text_to_speech"].get("voice")
|
voice = text_to_speech_dict.get("voice")
|
||||||
else:
|
|
||||||
if not is_draft:
|
|
||||||
if app_model.app_model_config is None:
|
|
||||||
raise ValueError("AppModelConfig not found")
|
|
||||||
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
|
|
||||||
|
|
||||||
if not text_to_speech_dict.get("enabled"):
|
model_manager = ModelManager()
|
||||||
raise ValueError("TTS is not enabled")
|
model_instance = model_manager.get_default_model_instance(
|
||||||
|
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
|
||||||
voice = text_to_speech_dict.get("voice")
|
)
|
||||||
|
try:
|
||||||
model_manager = ModelManager()
|
if not voice:
|
||||||
model_instance = model_manager.get_default_model_instance(
|
voices = model_instance.get_tts_voices()
|
||||||
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
|
if voices:
|
||||||
)
|
voice = voices[0].get("value")
|
||||||
try:
|
if not voice:
|
||||||
if not voice:
|
|
||||||
voices = model_instance.get_tts_voices()
|
|
||||||
if voices:
|
|
||||||
voice = voices[0].get("value")
|
|
||||||
if not voice:
|
|
||||||
raise ValueError("Sorry, no voice available.")
|
|
||||||
else:
|
|
||||||
raise ValueError("Sorry, no voice available.")
|
raise ValueError("Sorry, no voice available.")
|
||||||
|
else:
|
||||||
|
raise ValueError("Sorry, no voice available.")
|
||||||
|
|
||||||
return model_instance.invoke_tts(
|
return model_instance.invoke_tts(
|
||||||
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
|
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if message_id:
|
if message_id:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -283,7 +283,7 @@ class VariableTruncator:
|
|||||||
break
|
break
|
||||||
|
|
||||||
remaining_budget = target_size - used_size
|
remaining_budget = target_size - used_size
|
||||||
if item is None or isinstance(item, (str, list, dict, bool, int, float)):
|
if item is None or isinstance(item, (str, list, dict, bool, int, float, UpdatedVariable)):
|
||||||
part_result = self._truncate_json_primitives(item, remaining_budget)
|
part_result = self._truncate_json_primitives(item, remaining_budget)
|
||||||
else:
|
else:
|
||||||
raise UnknownTypeError(f"got unknown type {type(item)} in array truncation")
|
raise UnknownTypeError(f"got unknown type {type(item)} in array truncation")
|
||||||
@ -373,6 +373,11 @@ class VariableTruncator:
|
|||||||
|
|
||||||
return _PartResult(truncated_obj, used_size, truncated)
|
return _PartResult(truncated_obj, used_size, truncated)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def _truncate_json_primitives(
|
||||||
|
self, val: UpdatedVariable, target_size: int
|
||||||
|
) -> _PartResult[Mapping[str, object]]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def _truncate_json_primitives(self, val: str, target_size: int) -> _PartResult[str]: ...
|
def _truncate_json_primitives(self, val: str, target_size: int) -> _PartResult[str]: ...
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,18 @@ from core.model_runtime.entities.llm_entities import LLMUsage
|
|||||||
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
|
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
|
||||||
|
|
||||||
|
|
||||||
|
class StubCoordinator:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.state = "initial"
|
||||||
|
|
||||||
|
def dumps(self) -> str:
|
||||||
|
return json.dumps({"state": self.state})
|
||||||
|
|
||||||
|
def loads(self, data: str) -> None:
|
||||||
|
payload = json.loads(data)
|
||||||
|
self.state = payload["state"]
|
||||||
|
|
||||||
|
|
||||||
class TestGraphRuntimeState:
|
class TestGraphRuntimeState:
|
||||||
def test_property_getters_and_setters(self):
|
def test_property_getters_and_setters(self):
|
||||||
# FIXME(-LAN-): Mock VariablePool if needed
|
# FIXME(-LAN-): Mock VariablePool if needed
|
||||||
@ -191,17 +203,6 @@ class TestGraphRuntimeState:
|
|||||||
graph_execution.exceptions_count = 4
|
graph_execution.exceptions_count = 4
|
||||||
graph_execution.started = True
|
graph_execution.started = True
|
||||||
|
|
||||||
class StubCoordinator:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.state = "initial"
|
|
||||||
|
|
||||||
def dumps(self) -> str:
|
|
||||||
return json.dumps({"state": self.state})
|
|
||||||
|
|
||||||
def loads(self, data: str) -> None:
|
|
||||||
payload = json.loads(data)
|
|
||||||
self.state = payload["state"]
|
|
||||||
|
|
||||||
mock_graph = MagicMock()
|
mock_graph = MagicMock()
|
||||||
stub = StubCoordinator()
|
stub = StubCoordinator()
|
||||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub):
|
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub):
|
||||||
@ -211,8 +212,7 @@ class TestGraphRuntimeState:
|
|||||||
|
|
||||||
snapshot = state.dumps()
|
snapshot = state.dumps()
|
||||||
|
|
||||||
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
restored = GraphRuntimeState.from_snapshot(snapshot)
|
||||||
restored.loads(snapshot)
|
|
||||||
|
|
||||||
assert restored.total_tokens == 10
|
assert restored.total_tokens == 10
|
||||||
assert restored.node_run_steps == 3
|
assert restored.node_run_steps == 3
|
||||||
@ -235,3 +235,47 @@ class TestGraphRuntimeState:
|
|||||||
restored.attach_graph(mock_graph)
|
restored.attach_graph(mock_graph)
|
||||||
|
|
||||||
assert new_stub.state == "configured"
|
assert new_stub.state == "configured"
|
||||||
|
|
||||||
|
def test_loads_rehydrates_existing_instance(self):
|
||||||
|
variable_pool = VariablePool()
|
||||||
|
variable_pool.add(("node", "key"), "value")
|
||||||
|
|
||||||
|
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||||
|
state.total_tokens = 7
|
||||||
|
state.node_run_steps = 2
|
||||||
|
state.set_output("foo", "bar")
|
||||||
|
state.ready_queue.put("node-1")
|
||||||
|
|
||||||
|
execution = state.graph_execution
|
||||||
|
execution.workflow_id = "wf-456"
|
||||||
|
execution.started = True
|
||||||
|
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
original_stub = StubCoordinator()
|
||||||
|
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub):
|
||||||
|
state.attach_graph(mock_graph)
|
||||||
|
|
||||||
|
original_stub.state = "configured"
|
||||||
|
snapshot = state.dumps()
|
||||||
|
|
||||||
|
new_stub = StubCoordinator()
|
||||||
|
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
|
||||||
|
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||||
|
restored.attach_graph(mock_graph)
|
||||||
|
restored.loads(snapshot)
|
||||||
|
|
||||||
|
assert restored.total_tokens == 7
|
||||||
|
assert restored.node_run_steps == 2
|
||||||
|
assert restored.get_output("foo") == "bar"
|
||||||
|
assert restored.ready_queue.qsize() == 1
|
||||||
|
assert restored.ready_queue.get(timeout=0.01) == "node-1"
|
||||||
|
|
||||||
|
restored_segment = restored.variable_pool.get(("node", "key"))
|
||||||
|
assert restored_segment is not None
|
||||||
|
assert restored_segment.value == "value"
|
||||||
|
|
||||||
|
restored_execution = restored.graph_execution
|
||||||
|
assert restored_execution.workflow_id == "wf-456"
|
||||||
|
assert restored_execution.started is True
|
||||||
|
|
||||||
|
assert new_stub.state == "configured"
|
||||||
|
|||||||
@ -265,16 +265,18 @@ POSTGRES_MAINTENANCE_WORK_MEM=64MB
|
|||||||
POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
|
POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
|
||||||
|
|
||||||
# Sets the maximum allowed duration of any statement before termination.
|
# Sets the maximum allowed duration of any statement before termination.
|
||||||
# Default is 60000 milliseconds.
|
# Default is 0 (no timeout).
|
||||||
#
|
#
|
||||||
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
|
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
|
||||||
POSTGRES_STATEMENT_TIMEOUT=60000
|
# A value of 0 prevents the server from timing out statements.
|
||||||
|
POSTGRES_STATEMENT_TIMEOUT=0
|
||||||
|
|
||||||
# Sets the maximum allowed duration of any idle in-transaction session before termination.
|
# Sets the maximum allowed duration of any idle in-transaction session before termination.
|
||||||
# Default is 60000 milliseconds.
|
# Default is 0 (no timeout).
|
||||||
#
|
#
|
||||||
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT
|
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT
|
||||||
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=60000
|
# A value of 0 prevents the server from terminating idle sessions.
|
||||||
|
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=0
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# Redis Configuration
|
# Redis Configuration
|
||||||
|
|||||||
@ -115,8 +115,8 @@ services:
|
|||||||
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
|
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
|
||||||
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
|
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
|
||||||
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
|
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
|
||||||
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}'
|
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-0}'
|
||||||
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}'
|
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}'
|
||||||
volumes:
|
volumes:
|
||||||
- ./volumes/db/data:/var/lib/postgresql/data
|
- ./volumes/db/data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
|
|||||||
@ -15,8 +15,8 @@ services:
|
|||||||
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
|
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
|
||||||
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
|
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
|
||||||
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
|
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
|
||||||
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}'
|
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-0}'
|
||||||
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}'
|
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}'
|
||||||
volumes:
|
volumes:
|
||||||
- ${PGDATA_HOST_VOLUME:-./volumes/db/data}:/var/lib/postgresql/data
|
- ${PGDATA_HOST_VOLUME:-./volumes/db/data}:/var/lib/postgresql/data
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
@ -68,8 +68,8 @@ x-shared-env: &shared-api-worker-env
|
|||||||
POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB}
|
POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB}
|
||||||
POSTGRES_MAINTENANCE_WORK_MEM: ${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}
|
POSTGRES_MAINTENANCE_WORK_MEM: ${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}
|
||||||
POSTGRES_EFFECTIVE_CACHE_SIZE: ${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}
|
POSTGRES_EFFECTIVE_CACHE_SIZE: ${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}
|
||||||
POSTGRES_STATEMENT_TIMEOUT: ${POSTGRES_STATEMENT_TIMEOUT:-60000}
|
POSTGRES_STATEMENT_TIMEOUT: ${POSTGRES_STATEMENT_TIMEOUT:-0}
|
||||||
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT: ${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}
|
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT: ${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}
|
||||||
REDIS_HOST: ${REDIS_HOST:-redis}
|
REDIS_HOST: ${REDIS_HOST:-redis}
|
||||||
REDIS_PORT: ${REDIS_PORT:-6379}
|
REDIS_PORT: ${REDIS_PORT:-6379}
|
||||||
REDIS_USERNAME: ${REDIS_USERNAME:-}
|
REDIS_USERNAME: ${REDIS_USERNAME:-}
|
||||||
@ -724,8 +724,8 @@ services:
|
|||||||
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
|
-c 'work_mem=${POSTGRES_WORK_MEM:-4MB}'
|
||||||
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
|
-c 'maintenance_work_mem=${POSTGRES_MAINTENANCE_WORK_MEM:-64MB}'
|
||||||
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
|
-c 'effective_cache_size=${POSTGRES_EFFECTIVE_CACHE_SIZE:-4096MB}'
|
||||||
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-60000}'
|
-c 'statement_timeout=${POSTGRES_STATEMENT_TIMEOUT:-0}'
|
||||||
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-60000}'
|
-c 'idle_in_transaction_session_timeout=${POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT:-0}'
|
||||||
volumes:
|
volumes:
|
||||||
- ./volumes/db/data:/var/lib/postgresql/data
|
- ./volumes/db/data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
|
|||||||
@ -41,16 +41,18 @@ POSTGRES_MAINTENANCE_WORK_MEM=64MB
|
|||||||
POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
|
POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB
|
||||||
|
|
||||||
# Sets the maximum allowed duration of any statement before termination.
|
# Sets the maximum allowed duration of any statement before termination.
|
||||||
# Default is 60000 milliseconds.
|
# Default is 0 (no timeout).
|
||||||
#
|
#
|
||||||
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
|
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
|
||||||
POSTGRES_STATEMENT_TIMEOUT=60000
|
# A value of 0 prevents the server from timing out statements.
|
||||||
|
POSTGRES_STATEMENT_TIMEOUT=0
|
||||||
|
|
||||||
# Sets the maximum allowed duration of any idle in-transaction session before termination.
|
# Sets the maximum allowed duration of any idle in-transaction session before termination.
|
||||||
# Default is 60000 milliseconds.
|
# Default is 0 (no timeout).
|
||||||
#
|
#
|
||||||
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT
|
# Reference: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-IDLE-IN-TRANSACTION-SESSION-TIMEOUT
|
||||||
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=60000
|
# A value of 0 prevents the server from terminating idle sessions.
|
||||||
|
POSTGRES_IDLE_IN_TRANSACTION_SESSION_TIMEOUT=0
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Environment Variables for redis Service
|
# Environment Variables for redis Service
|
||||||
|
|||||||
@ -132,8 +132,6 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
|
|||||||
importedVersion: imported_dsl_version ?? '',
|
importedVersion: imported_dsl_version ?? '',
|
||||||
systemVersion: current_dsl_version ?? '',
|
systemVersion: current_dsl_version ?? '',
|
||||||
})
|
})
|
||||||
if (onClose)
|
|
||||||
onClose()
|
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
setShowErrorModal(true)
|
setShowErrorModal(true)
|
||||||
}, 300)
|
}, 300)
|
||||||
|
|||||||
@ -14,7 +14,6 @@ import timezone from 'dayjs/plugin/timezone'
|
|||||||
import { createContext, useContext } from 'use-context-selector'
|
import { createContext, useContext } from 'use-context-selector'
|
||||||
import { useShallow } from 'zustand/react/shallow'
|
import { useShallow } from 'zustand/react/shallow'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
|
|
||||||
import type { ChatItemInTree } from '../../base/chat/types'
|
import type { ChatItemInTree } from '../../base/chat/types'
|
||||||
import Indicator from '../../header/indicator'
|
import Indicator from '../../header/indicator'
|
||||||
import VarPanel from './var-panel'
|
import VarPanel from './var-panel'
|
||||||
@ -43,10 +42,6 @@ import cn from '@/utils/classnames'
|
|||||||
import { noop } from 'lodash-es'
|
import { noop } from 'lodash-es'
|
||||||
import PromptLogModal from '../../base/prompt-log-modal'
|
import PromptLogModal from '../../base/prompt-log-modal'
|
||||||
|
|
||||||
type AppStoreState = ReturnType<typeof useAppStore.getState>
|
|
||||||
type ConversationListItem = ChatConversationGeneralDetail | CompletionConversationGeneralDetail
|
|
||||||
type ConversationSelection = ConversationListItem | { id: string; isPlaceholder?: true }
|
|
||||||
|
|
||||||
dayjs.extend(utc)
|
dayjs.extend(utc)
|
||||||
dayjs.extend(timezone)
|
dayjs.extend(timezone)
|
||||||
|
|
||||||
@ -206,7 +201,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
|||||||
const { formatTime } = useTimestamp()
|
const { formatTime } = useTimestamp()
|
||||||
const { onClose, appDetail } = useContext(DrawerContext)
|
const { onClose, appDetail } = useContext(DrawerContext)
|
||||||
const { notify } = useContext(ToastContext)
|
const { notify } = useContext(ToastContext)
|
||||||
const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow((state: AppStoreState) => ({
|
const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({
|
||||||
currentLogItem: state.currentLogItem,
|
currentLogItem: state.currentLogItem,
|
||||||
setCurrentLogItem: state.setCurrentLogItem,
|
setCurrentLogItem: state.setCurrentLogItem,
|
||||||
showMessageLogModal: state.showMessageLogModal,
|
showMessageLogModal: state.showMessageLogModal,
|
||||||
@ -898,113 +893,20 @@ const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string }
|
|||||||
const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh }) => {
|
const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh }) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const { formatTime } = useTimestamp()
|
const { formatTime } = useTimestamp()
|
||||||
const router = useRouter()
|
|
||||||
const pathname = usePathname()
|
|
||||||
const searchParams = useSearchParams()
|
|
||||||
const conversationIdInUrl = searchParams.get('conversation_id') ?? undefined
|
|
||||||
|
|
||||||
const media = useBreakpoints()
|
const media = useBreakpoints()
|
||||||
const isMobile = media === MediaType.mobile
|
const isMobile = media === MediaType.mobile
|
||||||
|
|
||||||
const [showDrawer, setShowDrawer] = useState<boolean>(false) // Whether to display the chat details drawer
|
const [showDrawer, setShowDrawer] = useState<boolean>(false) // Whether to display the chat details drawer
|
||||||
const [currentConversation, setCurrentConversation] = useState<ConversationSelection | undefined>() // Currently selected conversation
|
const [currentConversation, setCurrentConversation] = useState<ChatConversationGeneralDetail | CompletionConversationGeneralDetail | undefined>() // Currently selected conversation
|
||||||
const closingConversationIdRef = useRef<string | null>(null)
|
|
||||||
const pendingConversationIdRef = useRef<string | null>(null)
|
|
||||||
const pendingConversationCacheRef = useRef<ConversationSelection | undefined>(undefined)
|
|
||||||
const isChatMode = appDetail.mode !== 'completion' // Whether the app is a chat app
|
const isChatMode = appDetail.mode !== 'completion' // Whether the app is a chat app
|
||||||
const isChatflow = appDetail.mode === 'advanced-chat' // Whether the app is a chatflow app
|
const isChatflow = appDetail.mode === 'advanced-chat' // Whether the app is a chatflow app
|
||||||
const { setShowPromptLogModal, setShowAgentLogModal, setShowMessageLogModal } = useAppStore(useShallow((state: AppStoreState) => ({
|
const { setShowPromptLogModal, setShowAgentLogModal, setShowMessageLogModal } = useAppStore(useShallow(state => ({
|
||||||
setShowPromptLogModal: state.setShowPromptLogModal,
|
setShowPromptLogModal: state.setShowPromptLogModal,
|
||||||
setShowAgentLogModal: state.setShowAgentLogModal,
|
setShowAgentLogModal: state.setShowAgentLogModal,
|
||||||
setShowMessageLogModal: state.setShowMessageLogModal,
|
setShowMessageLogModal: state.setShowMessageLogModal,
|
||||||
})))
|
})))
|
||||||
|
|
||||||
const activeConversationId = conversationIdInUrl ?? pendingConversationIdRef.current ?? currentConversation?.id
|
|
||||||
|
|
||||||
const buildUrlWithConversation = useCallback((conversationId?: string) => {
|
|
||||||
const params = new URLSearchParams(searchParams.toString())
|
|
||||||
if (conversationId)
|
|
||||||
params.set('conversation_id', conversationId)
|
|
||||||
else
|
|
||||||
params.delete('conversation_id')
|
|
||||||
|
|
||||||
const queryString = params.toString()
|
|
||||||
return queryString ? `${pathname}?${queryString}` : pathname
|
|
||||||
}, [pathname, searchParams])
|
|
||||||
|
|
||||||
const handleRowClick = useCallback((log: ConversationListItem) => {
|
|
||||||
if (conversationIdInUrl === log.id) {
|
|
||||||
if (!showDrawer)
|
|
||||||
setShowDrawer(true)
|
|
||||||
|
|
||||||
if (!currentConversation || currentConversation.id !== log.id)
|
|
||||||
setCurrentConversation(log)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
pendingConversationIdRef.current = log.id
|
|
||||||
pendingConversationCacheRef.current = log
|
|
||||||
if (!showDrawer)
|
|
||||||
setShowDrawer(true)
|
|
||||||
|
|
||||||
if (currentConversation?.id !== log.id)
|
|
||||||
setCurrentConversation(undefined)
|
|
||||||
|
|
||||||
router.push(buildUrlWithConversation(log.id), { scroll: false })
|
|
||||||
}, [buildUrlWithConversation, conversationIdInUrl, currentConversation, router, showDrawer])
|
|
||||||
|
|
||||||
const currentConversationId = currentConversation?.id
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (!conversationIdInUrl) {
|
|
||||||
if (pendingConversationIdRef.current)
|
|
||||||
return
|
|
||||||
|
|
||||||
if (showDrawer || currentConversationId) {
|
|
||||||
setShowDrawer(false)
|
|
||||||
setCurrentConversation(undefined)
|
|
||||||
}
|
|
||||||
closingConversationIdRef.current = null
|
|
||||||
pendingConversationCacheRef.current = undefined
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if (closingConversationIdRef.current === conversationIdInUrl)
|
|
||||||
return
|
|
||||||
|
|
||||||
if (pendingConversationIdRef.current === conversationIdInUrl)
|
|
||||||
pendingConversationIdRef.current = null
|
|
||||||
|
|
||||||
const matchedConversation = logs?.data?.find((item: ConversationListItem) => item.id === conversationIdInUrl)
|
|
||||||
const nextConversation: ConversationSelection = matchedConversation
|
|
||||||
?? pendingConversationCacheRef.current
|
|
||||||
?? { id: conversationIdInUrl, isPlaceholder: true }
|
|
||||||
|
|
||||||
if (!showDrawer)
|
|
||||||
setShowDrawer(true)
|
|
||||||
|
|
||||||
if (!currentConversation || currentConversation.id !== conversationIdInUrl || (matchedConversation && currentConversation !== matchedConversation))
|
|
||||||
setCurrentConversation(nextConversation)
|
|
||||||
|
|
||||||
if (pendingConversationCacheRef.current?.id === conversationIdInUrl || matchedConversation)
|
|
||||||
pendingConversationCacheRef.current = undefined
|
|
||||||
}, [conversationIdInUrl, currentConversation, isChatMode, logs?.data, showDrawer])
|
|
||||||
|
|
||||||
const onCloseDrawer = useCallback(() => {
|
|
||||||
onRefresh()
|
|
||||||
setShowDrawer(false)
|
|
||||||
setCurrentConversation(undefined)
|
|
||||||
setShowPromptLogModal(false)
|
|
||||||
setShowAgentLogModal(false)
|
|
||||||
setShowMessageLogModal(false)
|
|
||||||
pendingConversationIdRef.current = null
|
|
||||||
pendingConversationCacheRef.current = undefined
|
|
||||||
closingConversationIdRef.current = conversationIdInUrl ?? null
|
|
||||||
|
|
||||||
if (conversationIdInUrl)
|
|
||||||
router.replace(buildUrlWithConversation(), { scroll: false })
|
|
||||||
}, [buildUrlWithConversation, conversationIdInUrl, onRefresh, router, setShowAgentLogModal, setShowMessageLogModal, setShowPromptLogModal])
|
|
||||||
|
|
||||||
// Annotated data needs to be highlighted
|
// Annotated data needs to be highlighted
|
||||||
const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => {
|
const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: LogAnnotation) => {
|
||||||
return (
|
return (
|
||||||
@ -1023,6 +925,15 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const onCloseDrawer = () => {
|
||||||
|
onRefresh()
|
||||||
|
setShowDrawer(false)
|
||||||
|
setCurrentConversation(undefined)
|
||||||
|
setShowPromptLogModal(false)
|
||||||
|
setShowAgentLogModal(false)
|
||||||
|
setShowMessageLogModal(false)
|
||||||
|
}
|
||||||
|
|
||||||
if (!logs)
|
if (!logs)
|
||||||
return <Loading />
|
return <Loading />
|
||||||
|
|
||||||
@ -1049,8 +960,11 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
|||||||
const rightValue = get(log, isChatMode ? 'message_count' : 'message.answer')
|
const rightValue = get(log, isChatMode ? 'message_count' : 'message.answer')
|
||||||
return <tr
|
return <tr
|
||||||
key={log.id}
|
key={log.id}
|
||||||
className={cn('cursor-pointer border-b border-divider-subtle hover:bg-background-default-hover', activeConversationId !== log.id ? '' : 'bg-background-default-hover')}
|
className={cn('cursor-pointer border-b border-divider-subtle hover:bg-background-default-hover', currentConversation?.id !== log.id ? '' : 'bg-background-default-hover')}
|
||||||
onClick={() => handleRowClick(log)}>
|
onClick={() => {
|
||||||
|
setShowDrawer(true)
|
||||||
|
setCurrentConversation(log)
|
||||||
|
}}>
|
||||||
<td className='h-4'>
|
<td className='h-4'>
|
||||||
{!log.read_at && (
|
{!log.read_at && (
|
||||||
<div className='flex items-center p-3 pr-0.5'>
|
<div className='flex items-center p-3 pr-0.5'>
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import {
|
|||||||
import { produce } from 'immer'
|
import { produce } from 'immer'
|
||||||
import {
|
import {
|
||||||
useReactFlow,
|
useReactFlow,
|
||||||
|
useStoreApi,
|
||||||
useViewport,
|
useViewport,
|
||||||
} from 'reactflow'
|
} from 'reactflow'
|
||||||
import { useEventListener } from 'ahooks'
|
import { useEventListener } from 'ahooks'
|
||||||
@ -12,15 +13,15 @@ import {
|
|||||||
useWorkflowStore,
|
useWorkflowStore,
|
||||||
} from './store'
|
} from './store'
|
||||||
import { WorkflowHistoryEvent, useNodesInteractions, useWorkflowHistory } from './hooks'
|
import { WorkflowHistoryEvent, useNodesInteractions, useWorkflowHistory } from './hooks'
|
||||||
import { CUSTOM_NODE } from './constants'
|
import { CUSTOM_NODE, ITERATION_PADDING } from './constants'
|
||||||
import { getIterationStartNode, getLoopStartNode } from './utils'
|
import { getIterationStartNode, getLoopStartNode } from './utils'
|
||||||
import CustomNode from './nodes'
|
import CustomNode from './nodes'
|
||||||
import CustomNoteNode from './note-node'
|
import CustomNoteNode from './note-node'
|
||||||
import { CUSTOM_NOTE_NODE } from './note-node/constants'
|
import { CUSTOM_NOTE_NODE } from './note-node/constants'
|
||||||
import { BlockEnum } from './types'
|
import { BlockEnum } from './types'
|
||||||
import { useCollaborativeWorkflow } from '@/app/components/workflow/hooks/use-collaborative-workflow'
|
|
||||||
|
|
||||||
const CandidateNode = () => {
|
const CandidateNode = () => {
|
||||||
|
const store = useStoreApi()
|
||||||
const reactflow = useReactFlow()
|
const reactflow = useReactFlow()
|
||||||
const workflowStore = useWorkflowStore()
|
const workflowStore = useWorkflowStore()
|
||||||
const candidateNode = useStore(s => s.candidateNode)
|
const candidateNode = useStore(s => s.candidateNode)
|
||||||
@ -28,16 +29,45 @@ const CandidateNode = () => {
|
|||||||
const { zoom } = useViewport()
|
const { zoom } = useViewport()
|
||||||
const { handleNodeSelect } = useNodesInteractions()
|
const { handleNodeSelect } = useNodesInteractions()
|
||||||
const { saveStateToHistory } = useWorkflowHistory()
|
const { saveStateToHistory } = useWorkflowHistory()
|
||||||
const collaborativeWorkflow = useCollaborativeWorkflow()
|
|
||||||
|
|
||||||
useEventListener('click', (e) => {
|
useEventListener('click', (e) => {
|
||||||
const { candidateNode, mousePosition } = workflowStore.getState()
|
const { candidateNode, mousePosition } = workflowStore.getState()
|
||||||
|
|
||||||
if (candidateNode) {
|
if (candidateNode) {
|
||||||
e.preventDefault()
|
e.preventDefault()
|
||||||
const { nodes, setNodes } = collaborativeWorkflow.getState()
|
const {
|
||||||
|
getNodes,
|
||||||
|
setNodes,
|
||||||
|
} = store.getState()
|
||||||
const { screenToFlowPosition } = reactflow
|
const { screenToFlowPosition } = reactflow
|
||||||
const { x, y } = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY })
|
const nodes = getNodes()
|
||||||
|
// Get mouse position in flow coordinates (this is where the top-left corner should be)
|
||||||
|
let { x, y } = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY })
|
||||||
|
|
||||||
|
// If the node has a parent (e.g., inside iteration), apply constraints and convert to relative position
|
||||||
|
if (candidateNode.parentId) {
|
||||||
|
const parentNode = nodes.find(node => node.id === candidateNode.parentId)
|
||||||
|
if (parentNode && parentNode.position) {
|
||||||
|
// Apply boundary constraints for iteration nodes
|
||||||
|
if (candidateNode.data.isInIteration) {
|
||||||
|
const nodeWidth = candidateNode.width || 0
|
||||||
|
const nodeHeight = candidateNode.height || 0
|
||||||
|
const minX = parentNode.position.x + ITERATION_PADDING.left
|
||||||
|
const maxX = parentNode.position.x + (parentNode.width || 0) - ITERATION_PADDING.right - nodeWidth
|
||||||
|
const minY = parentNode.position.y + ITERATION_PADDING.top
|
||||||
|
const maxY = parentNode.position.y + (parentNode.height || 0) - ITERATION_PADDING.bottom - nodeHeight
|
||||||
|
|
||||||
|
// Constrain position
|
||||||
|
x = Math.max(minX, Math.min(maxX, x))
|
||||||
|
y = Math.max(minY, Math.min(maxY, y))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to relative position
|
||||||
|
x = x - parentNode.position.x
|
||||||
|
y = y - parentNode.position.y
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const newNodes = produce(nodes, (draft) => {
|
const newNodes = produce(nodes, (draft) => {
|
||||||
draft.push({
|
draft.push({
|
||||||
...candidateNode,
|
...candidateNode,
|
||||||
@ -55,6 +85,20 @@ const CandidateNode = () => {
|
|||||||
|
|
||||||
if (candidateNode.data.type === BlockEnum.Loop)
|
if (candidateNode.data.type === BlockEnum.Loop)
|
||||||
draft.push(getLoopStartNode(candidateNode.id))
|
draft.push(getLoopStartNode(candidateNode.id))
|
||||||
|
|
||||||
|
// Update parent iteration node's _children array
|
||||||
|
if (candidateNode.parentId && candidateNode.data.isInIteration) {
|
||||||
|
const parentNode = draft.find(node => node.id === candidateNode.parentId)
|
||||||
|
if (parentNode && parentNode.data.type === BlockEnum.Iteration) {
|
||||||
|
if (!parentNode.data._children)
|
||||||
|
parentNode.data._children = []
|
||||||
|
|
||||||
|
parentNode.data._children.push({
|
||||||
|
nodeId: candidateNode.id,
|
||||||
|
nodeType: candidateNode.data.type,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
setNodes(newNodes)
|
setNodes(newNodes)
|
||||||
if (candidateNode.type === CUSTOM_NOTE_NODE)
|
if (candidateNode.type === CUSTOM_NOTE_NODE)
|
||||||
@ -80,6 +124,34 @@ const CandidateNode = () => {
|
|||||||
if (!candidateNode)
|
if (!candidateNode)
|
||||||
return null
|
return null
|
||||||
|
|
||||||
|
// Apply boundary constraints if node is inside iteration
|
||||||
|
if (candidateNode.parentId && candidateNode.data.isInIteration) {
|
||||||
|
const { getNodes } = store.getState()
|
||||||
|
const nodes = getNodes()
|
||||||
|
const parentNode = nodes.find(node => node.id === candidateNode.parentId)
|
||||||
|
|
||||||
|
if (parentNode && parentNode.position) {
|
||||||
|
const { screenToFlowPosition, flowToScreenPosition } = reactflow
|
||||||
|
// Get mouse position in flow coordinates
|
||||||
|
const flowPosition = screenToFlowPosition({ x: mousePosition.pageX, y: mousePosition.pageY })
|
||||||
|
|
||||||
|
// Calculate boundaries in flow coordinates
|
||||||
|
const nodeWidth = candidateNode.width || 0
|
||||||
|
const nodeHeight = candidateNode.height || 0
|
||||||
|
const minX = parentNode.position.x + ITERATION_PADDING.left
|
||||||
|
const maxX = parentNode.position.x + (parentNode.width || 0) - ITERATION_PADDING.right - nodeWidth
|
||||||
|
const minY = parentNode.position.y + ITERATION_PADDING.top
|
||||||
|
const maxY = parentNode.position.y + (parentNode.height || 0) - ITERATION_PADDING.bottom - nodeHeight
|
||||||
|
|
||||||
|
// Constrain position
|
||||||
|
const constrainedX = Math.max(minX, Math.min(maxX, flowPosition.x))
|
||||||
|
const constrainedY = Math.max(minY, Math.min(maxY, flowPosition.y))
|
||||||
|
|
||||||
|
// Convert back to screen coordinates
|
||||||
|
flowToScreenPosition({ x: constrainedX, y: constrainedY })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className='absolute z-10'
|
className='absolute z-10'
|
||||||
|
|||||||
@ -125,12 +125,12 @@ export const useNodesInteractions = () => {
|
|||||||
|
|
||||||
const { restrictPosition } = handleNodeIterationChildDrag(node)
|
const { restrictPosition } = handleNodeIterationChildDrag(node)
|
||||||
const { restrictPosition: restrictLoopPosition }
|
const { restrictPosition: restrictLoopPosition }
|
||||||
= handleNodeLoopChildDrag(node)
|
= handleNodeLoopChildDrag(node)
|
||||||
|
|
||||||
const { showHorizontalHelpLineNodes, showVerticalHelpLineNodes }
|
const { showHorizontalHelpLineNodes, showVerticalHelpLineNodes }
|
||||||
= handleSetHelpline(node)
|
= handleSetHelpline(node)
|
||||||
const showHorizontalHelpLineNodesLength
|
const showHorizontalHelpLineNodesLength
|
||||||
= showHorizontalHelpLineNodes.length
|
= showHorizontalHelpLineNodes.length
|
||||||
const showVerticalHelpLineNodesLength = showVerticalHelpLineNodes.length
|
const showVerticalHelpLineNodesLength = showVerticalHelpLineNodes.length
|
||||||
|
|
||||||
const newNodes = produce(nodes, (draft) => {
|
const newNodes = produce(nodes, (draft) => {
|
||||||
@ -716,7 +716,7 @@ export const useNodesInteractions = () => {
|
|||||||
targetHandle = 'target',
|
targetHandle = 'target',
|
||||||
toolDefaultValue,
|
toolDefaultValue,
|
||||||
},
|
},
|
||||||
{ prevNodeId, prevNodeSourceHandle, nextNodeId, nextNodeTargetHandle },
|
{ prevNodeId, prevNodeSourceHandle, nextNodeId, nextNodeTargetHandle, skipAutoConnect },
|
||||||
) => {
|
) => {
|
||||||
if (getNodesReadOnly()) return
|
if (getNodesReadOnly()) return
|
||||||
|
|
||||||
@ -808,7 +808,7 @@ export const useNodesInteractions = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let newEdge = null
|
let newEdge = null
|
||||||
if (nodeType !== BlockEnum.DataSource) {
|
if (nodeType !== BlockEnum.DataSource && !skipAutoConnect) {
|
||||||
newEdge = {
|
newEdge = {
|
||||||
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
|
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
|
||||||
type: CUSTOM_EDGE,
|
type: CUSTOM_EDGE,
|
||||||
@ -948,6 +948,7 @@ export const useNodesInteractions = () => {
|
|||||||
nodeType !== BlockEnum.IfElse
|
nodeType !== BlockEnum.IfElse
|
||||||
&& nodeType !== BlockEnum.QuestionClassifier
|
&& nodeType !== BlockEnum.QuestionClassifier
|
||||||
&& nodeType !== BlockEnum.LoopEnd
|
&& nodeType !== BlockEnum.LoopEnd
|
||||||
|
&& !skipAutoConnect
|
||||||
) {
|
) {
|
||||||
newEdge = {
|
newEdge = {
|
||||||
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`,
|
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`,
|
||||||
@ -1097,7 +1098,7 @@ export const useNodesInteractions = () => {
|
|||||||
)
|
)
|
||||||
let newPrevEdge = null
|
let newPrevEdge = null
|
||||||
|
|
||||||
if (nodeType !== BlockEnum.DataSource) {
|
if (nodeType !== BlockEnum.DataSource && !skipAutoConnect) {
|
||||||
newPrevEdge = {
|
newPrevEdge = {
|
||||||
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
|
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
|
||||||
type: CUSTOM_EDGE,
|
type: CUSTOM_EDGE,
|
||||||
@ -1137,6 +1138,7 @@ export const useNodesInteractions = () => {
|
|||||||
nodeType !== BlockEnum.IfElse
|
nodeType !== BlockEnum.IfElse
|
||||||
&& nodeType !== BlockEnum.QuestionClassifier
|
&& nodeType !== BlockEnum.QuestionClassifier
|
||||||
&& nodeType !== BlockEnum.LoopEnd
|
&& nodeType !== BlockEnum.LoopEnd
|
||||||
|
&& !skipAutoConnect
|
||||||
) {
|
) {
|
||||||
newNextEdge = {
|
newNextEdge = {
|
||||||
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`,
|
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`,
|
||||||
|
|||||||
@ -15,7 +15,9 @@ import {
|
|||||||
useNodesSyncDraft,
|
useNodesSyncDraft,
|
||||||
} from '@/app/components/workflow/hooks'
|
} from '@/app/components/workflow/hooks'
|
||||||
import ShortcutsName from '@/app/components/workflow/shortcuts-name'
|
import ShortcutsName from '@/app/components/workflow/shortcuts-name'
|
||||||
import type { Node } from '@/app/components/workflow/types'
|
import { BlockEnum, type Node } from '@/app/components/workflow/types'
|
||||||
|
import PanelAddBlock from '@/app/components/workflow/nodes/iteration/panel-add-block'
|
||||||
|
import type { IterationNodeType } from '@/app/components/workflow/nodes/iteration/types'
|
||||||
|
|
||||||
type PanelOperatorPopupProps = {
|
type PanelOperatorPopupProps = {
|
||||||
id: string
|
id: string
|
||||||
@ -51,6 +53,9 @@ const PanelOperatorPopup = ({
|
|||||||
(showChangeBlock || canRunBySingle(data.type, isChildNode)) && (
|
(showChangeBlock || canRunBySingle(data.type, isChildNode)) && (
|
||||||
<>
|
<>
|
||||||
<div className='p-1'>
|
<div className='p-1'>
|
||||||
|
{data.type === BlockEnum.Iteration && (
|
||||||
|
<PanelAddBlock iterationNodeData={data as IterationNodeType} onClosePopup={onClosePopup}/>
|
||||||
|
)}
|
||||||
{
|
{
|
||||||
canRunBySingle(data.type, isChildNode) && (
|
canRunBySingle(data.type, isChildNode) && (
|
||||||
<div
|
<div
|
||||||
|
|||||||
116
web/app/components/workflow/nodes/iteration/panel-add-block.tsx
Normal file
116
web/app/components/workflow/nodes/iteration/panel-add-block.tsx
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useCallback,
|
||||||
|
useState,
|
||||||
|
} from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import type { OffsetOptions } from '@floating-ui/react'
|
||||||
|
import { useStoreApi } from 'reactflow'
|
||||||
|
|
||||||
|
import BlockSelector from '@/app/components/workflow/block-selector'
|
||||||
|
import type {
|
||||||
|
OnSelectBlock,
|
||||||
|
} from '@/app/components/workflow/types'
|
||||||
|
import {
|
||||||
|
BlockEnum,
|
||||||
|
} from '@/app/components/workflow/types'
|
||||||
|
import { useAvailableBlocks, useNodesMetaData, useNodesReadOnly, usePanelInteractions } from '../../hooks'
|
||||||
|
import type { IterationNodeType } from './types'
|
||||||
|
import { useWorkflowStore } from '../../store'
|
||||||
|
import { generateNewNode, getNodeCustomTypeByNodeDataType } from '../../utils'
|
||||||
|
import { ITERATION_CHILDREN_Z_INDEX } from '../../constants'
|
||||||
|
|
||||||
|
type AddBlockProps = {
|
||||||
|
renderTrigger?: (open: boolean) => React.ReactNode
|
||||||
|
offset?: OffsetOptions
|
||||||
|
iterationNodeData: IterationNodeType
|
||||||
|
onClosePopup: () => void
|
||||||
|
}
|
||||||
|
const AddBlock = ({
|
||||||
|
offset,
|
||||||
|
iterationNodeData,
|
||||||
|
onClosePopup,
|
||||||
|
}: AddBlockProps) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const store = useStoreApi()
|
||||||
|
const workflowStore = useWorkflowStore()
|
||||||
|
const { nodesReadOnly } = useNodesReadOnly()
|
||||||
|
const { handlePaneContextmenuCancel } = usePanelInteractions()
|
||||||
|
const [open, setOpen] = useState(false)
|
||||||
|
const { availableNextBlocks } = useAvailableBlocks(BlockEnum.Start, false)
|
||||||
|
const { nodesMap: nodesMetaDataMap } = useNodesMetaData()
|
||||||
|
|
||||||
|
const handleOpenChange = useCallback((open: boolean) => {
|
||||||
|
setOpen(open)
|
||||||
|
if (!open)
|
||||||
|
handlePaneContextmenuCancel()
|
||||||
|
}, [handlePaneContextmenuCancel])
|
||||||
|
|
||||||
|
const handleSelect = useCallback<OnSelectBlock>((type, toolDefaultValue) => {
|
||||||
|
const { getNodes } = store.getState()
|
||||||
|
const nodes = getNodes()
|
||||||
|
const nodesWithSameType = nodes.filter(node => node.data.type === type)
|
||||||
|
const { defaultValue } = nodesMetaDataMap![type]
|
||||||
|
|
||||||
|
// Find the parent iteration node
|
||||||
|
const parentIterationNode = nodes.find(node => node.data.start_node_id === iterationNodeData.start_node_id)
|
||||||
|
|
||||||
|
const { newNode } = generateNewNode({
|
||||||
|
type: getNodeCustomTypeByNodeDataType(type),
|
||||||
|
data: {
|
||||||
|
...(defaultValue as any),
|
||||||
|
title: nodesWithSameType.length > 0 ? `${defaultValue.title} ${nodesWithSameType.length + 1}` : defaultValue.title,
|
||||||
|
...toolDefaultValue,
|
||||||
|
_isCandidate: true,
|
||||||
|
// Set iteration-specific properties
|
||||||
|
isInIteration: true,
|
||||||
|
iteration_id: parentIterationNode?.id,
|
||||||
|
},
|
||||||
|
position: {
|
||||||
|
x: 0,
|
||||||
|
y: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Set parent and z-index for iteration child
|
||||||
|
if (parentIterationNode) {
|
||||||
|
newNode.parentId = parentIterationNode.id
|
||||||
|
newNode.extent = 'parent' as any
|
||||||
|
newNode.zIndex = ITERATION_CHILDREN_Z_INDEX
|
||||||
|
}
|
||||||
|
|
||||||
|
workflowStore.setState({
|
||||||
|
candidateNode: newNode,
|
||||||
|
})
|
||||||
|
onClosePopup()
|
||||||
|
}, [store, workflowStore, nodesMetaDataMap, iterationNodeData.start_node_id, onClosePopup])
|
||||||
|
|
||||||
|
const renderTrigger = () => {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className='flex h-8 cursor-pointer items-center justify-between rounded-lg px-3 text-sm text-text-secondary hover:bg-state-base-hover'
|
||||||
|
>
|
||||||
|
{t('workflow.common.addBlock')}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<BlockSelector
|
||||||
|
open={open}
|
||||||
|
onOpenChange={handleOpenChange}
|
||||||
|
disabled={nodesReadOnly}
|
||||||
|
onSelect={handleSelect}
|
||||||
|
placement='right-start'
|
||||||
|
offset={offset ?? {
|
||||||
|
mainAxis: 4,
|
||||||
|
crossAxis: -8,
|
||||||
|
}}
|
||||||
|
trigger={renderTrigger}
|
||||||
|
popupClassName='!min-w-[256px]'
|
||||||
|
availableBlocksTypes={availableNextBlocks}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(AddBlock)
|
||||||
@ -380,6 +380,7 @@ export type OnNodeAdd = (
|
|||||||
prevNodeSourceHandle?: string
|
prevNodeSourceHandle?: string
|
||||||
nextNodeId?: string
|
nextNodeId?: string
|
||||||
nextNodeTargetHandle?: string
|
nextNodeTargetHandle?: string
|
||||||
|
skipAutoConnect?: boolean
|
||||||
},
|
},
|
||||||
) => void
|
) => void
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user