Merge branch 'main' into feat/rag-2

This commit is contained in:
twwu 2025-07-18 14:03:48 +08:00
commit 5b2c99e183
155 changed files with 8176 additions and 7681 deletions

View File

@ -54,7 +54,7 @@ REDIS_CLUSTERS_PASSWORD=
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis
# PostgreSQL database configuration
DB_USERNAME=postgres
DB_PASSWORD=difyai123456

View File

@ -211,7 +211,7 @@ class DatabaseConfig(BaseSettings):
class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field(
description="Backend for Celery task results. Options: 'database', 'redis'.",
default="database",
default="redis",
)
CELERY_BROKER_URL: Optional[str] = Field(

View File

@ -5,6 +5,7 @@ from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import api
from controllers.console.app.error import (
CompletionRequestError,
@ -27,7 +28,7 @@ from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from models.model import AppMode, Conversation, Message, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@ -124,33 +125,16 @@ class MessageFeedbackApi(Resource):
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
args = parser.parse_args()
message_id = str(args["message_id"])
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")
feedback = message.admin_feedback
if not args["rating"] and feedback:
db.session.delete(feedback)
elif args["rating"] and feedback:
feedback.rating = args["rating"]
elif not args["rating"] and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=args["rating"],
from_source="admin",
from_account_id=current_user.id,
try:
MessageService.create_feedback(
app_model=app_model,
message_id=str(args["message_id"]),
user=current_user,
rating=args.get("rating"),
content=None,
)
db.session.add(feedback)
db.session.commit()
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}

View File

@ -211,10 +211,6 @@ class DatasetApi(Resource):
else:
data["embedding_available"] = True
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
return data, 200
@setup_required

View File

@ -4,7 +4,7 @@ from controllers.console import api
from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.website_service import WebsiteService
from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
class WebsiteCrawlApi(Resource):
@ -24,10 +24,16 @@ class WebsiteCrawlApi(Resource):
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args()
WebsiteService.document_create_args_validate(args)
# crawl url
# Create typed request and validate
try:
result = WebsiteService.crawl_url(args)
api_request = WebsiteCrawlApiRequest.from_args(args)
except ValueError as e:
raise WebsiteCrawlError(str(e))
# Crawl URL using typed request
try:
result = WebsiteService.crawl_url(api_request)
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200
@ -43,9 +49,16 @@ class WebsiteCrawlStatusApi(Resource):
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
)
args = parser.parse_args()
# get crawl status
# Create typed request and validate
try:
result = WebsiteService.get_crawl_status(job_id, args["provider"])
api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id)
except ValueError as e:
raise WebsiteCrawlError(str(e))
# Get crawl status using typed request
try:
result = WebsiteService.get_crawl_status_typed(api_request)
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200

View File

@ -17,7 +17,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom

File diff suppressed because it is too large Load Diff

View File

@ -15,7 +15,8 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom

View File

@ -169,7 +169,3 @@ class AppQueueManager:
raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
)
class GenerateTaskStoppedError(Exception):
pass

View File

@ -118,7 +118,7 @@ class AppRunner:
else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
model_mode = ModelMode.value_of(model_config.mode)
model_mode = ModelMode(model_config.mode)
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template

View File

@ -11,10 +11,11 @@ from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom

View File

@ -10,10 +10,11 @@ from pydantic import ValidationError
from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom

2
api/core/app/apps/exc.py Normal file
View File

@ -0,0 +1,2 @@
class GenerateTaskStoppedError(Exception):
pass

View File

@ -6,7 +6,8 @@ from typing import Optional, Union, cast
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,

View File

@ -1,4 +1,5 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,

View File

@ -13,7 +13,8 @@ import contexts
from configs import dify_config
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner

View File

@ -1,4 +1,5 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,

View File

@ -1,7 +1,8 @@
import logging
import time
from collections.abc import Generator
from typing import Optional, Union
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import Any, Optional, Union
from sqlalchemy.orm import Session
@ -13,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
@ -38,11 +40,13 @@ from core.app.entities.queue_entities import (
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
@ -54,6 +58,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -246,315 +251,492 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
@contextmanager
def _database_session(self):
"""Context manager for database sessions."""
with Session(db.engine, expire_on_commit=False) as session:
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
def _ensure_workflow_initialized(self) -> None:
"""Fluent validation for workflow state."""
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
return graph_runtime_state
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline._ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
# init workflow run
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
yield start_resp
def _handle_node_retry_event(self, event: QueueNodeRetryEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle node retry events."""
self._ensure_workflow_initialized()
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
def _handle_node_started_event(
self, event: QueueNodeStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle node started events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_response:
yield node_start_response
def _handle_node_succeeded_event(
self, event: QueueNodeSucceededEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle node succeeded events."""
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
if node_success_response:
yield node_success_response
def _handle_node_failed_events(
self,
event: Union[
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
],
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event,
)
node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
if node_failed_response:
yield node_failed_response
def _handle_parallel_branch_started_event(
self, event: QueueParallelBranchRunStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch started events."""
self._ensure_workflow_initialized()
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
def _handle_parallel_branch_finished_events(
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch finished events."""
self._ensure_workflow_initialized()
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_finish_resp
def _handle_iteration_start_event(
self, event: QueueIterationStartEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle iteration start events."""
self._ensure_workflow_initialized()
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_start_resp
def _handle_iteration_next_event(
self, event: QueueIterationNextEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle iteration next events."""
self._ensure_workflow_initialized()
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_next_resp
def _handle_iteration_completed_event(
self, event: QueueIterationCompletedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle iteration completed events."""
self._ensure_workflow_initialized()
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_finish_resp
def _handle_loop_start_event(self, event: QueueLoopStartEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle loop start events."""
self._ensure_workflow_initialized()
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_start_resp
def _handle_loop_next_event(self, event: QueueLoopNextEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle loop next events."""
self._ensure_workflow_initialized()
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_next_resp
def _handle_loop_completed_event(
self, event: QueueLoopCompletedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle loop completed events."""
self._ensure_workflow_initialized()
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_finish_resp
def _handle_workflow_succeeded_event(
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
yield workflow_finish_resp
def _handle_workflow_partial_success_event(
self,
event: QueueWorkflowPartialSuccessEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
yield workflow_finish_resp
def _handle_workflow_failed_and_stop_events(
self,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed and stop events."""
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
status=WorkflowExecutionStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowExecutionStatus.STOPPED,
error_message=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
yield workflow_finish_resp
def _handle_text_chunk_event(
self,
event: QueueTextChunkEvent,
*,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle text chunk events."""
delta_text = event.text
if delta_text is None:
return
# only publish tts message at text chunk streaming
if tts_publisher and queue_message:
tts_publisher.publish(queue_message)
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle agent log events."""
yield self._workflow_response_converter.handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
def _get_event_handlers(self) -> dict[type, Callable]:
"""Get mapping of event types to their handlers using fluent pattern."""
return {
# Basic events
QueuePingEvent: self._handle_ping_event,
QueueErrorEvent: self._handle_error_event,
QueueTextChunkEvent: self._handle_text_chunk_event,
# Workflow events
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Parallel branch events
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
# Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event,
QueueIterationCompletedEvent: self._handle_iteration_completed_event,
# Loop events
QueueLoopStartEvent: self._handle_loop_start_event,
QueueLoopNextEvent: self._handle_loop_next_event,
QueueLoopCompletedEvent: self._handle_loop_completed_event,
# Agent events
QueueAgentLogEvent: self._handle_agent_log_event,
}
def _dispatch_event(
self,
event: Any,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
) -> Generator[StreamResponse, None, None]:
"""Dispatch events using elegant pattern matching."""
handlers = self._get_event_handlers()
event_type = type(event)
# Direct handler lookup
if handler := handlers.get(event_type):
yield from handler(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# Handle node failure events with isinstance check
if isinstance(
event,
(
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent,
),
):
yield from self._handle_node_failed_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# Handle parallel branch finished events with isinstance check
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
yield from self._handle_parallel_branch_finished_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# Handle workflow failed and stop events with isinstance check
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
yield from self._handle_workflow_failed_and_stop_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# For unhandled events, we continue (original behavior)
return
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
Process stream response using elegant Fluent Python patterns.
Maintains exact same functionality as original 44-if-statement version.
"""
# Initialize graph runtime state
graph_runtime_state = None
for queue_message in self._base_task_pipeline._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._base_task_pipeline._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
match event:
case QueueWorkflowStartedEvent():
graph_runtime_state = event.graph_runtime_state
yield from self._handle_workflow_started_event(event)
# init workflow run
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
yield start_resp
elif isinstance(
event,
QueueNodeRetryEvent,
):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
if response:
yield response
elif isinstance(event, QueueNodeStartedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_response:
yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(
event=event
)
node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
if node_success_response:
yield node_success_response
elif isinstance(
event,
QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
):
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event,
)
node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
if node_failed_response:
yield node_failed_response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
parallel_start_resp = (
self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
)
yield parallel_start_resp
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
parallel_finish_resp = (
self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
)
yield parallel_finish_resp
elif isinstance(event, QueueIterationStartEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_start_resp
elif isinstance(event, QueueIterationNextEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_next_resp
elif isinstance(event, QueueIterationCompletedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield iter_finish_resp
elif isinstance(event, QueueLoopStartEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_start_resp
elif isinstance(event, QueueLoopNextEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_next_resp
elif isinstance(event, QueueLoopCompletedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_finish_resp
elif isinstance(event, QueueWorkflowSucceededEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
case QueueTextChunkEvent():
yield from self._handle_text_chunk_event(
event, tts_publisher=tts_publisher, queue_message=queue_message
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
case QueueErrorEvent():
yield from self._handle_error_event(event)
break
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowExecutionStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowExecutionStatus.STOPPED,
error_message=event.error
if isinstance(event, QueueWorkflowFailedEvent)
else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(queue_message)
yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueAgentLogEvent):
yield self._workflow_response_converter.handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
else:
continue
# Handle all other events through elegant dispatch
case _:
if responses := list(
self._dispatch_event(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
):
yield from responses
if tts_publisher:
tts_publisher.publish(None)

View File

@ -21,7 +21,7 @@ def encrypt_token(tenant_id: str, token: str):
return base64.b64encode(encrypted_token).decode()
def decrypt_token(tenant_id: str, token: str):
def decrypt_token(tenant_id: str, token: str) -> str:
return rsa.decrypt(base64.b64decode(token), tenant_id)

View File

@ -29,19 +29,6 @@ class ModelMode(enum.StrEnum):
COMPLETION = "completion"
CHAT = "chat"
@classmethod
def value_of(cls, value: str) -> "ModelMode":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
prompt_file_contents: dict[str, Any] = {}
@ -65,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
) -> tuple[list[PromptMessage], Optional[list[str]]]:
inputs = {key: str(value) for key, value in inputs.items()}
model_mode = ModelMode.value_of(model_config.mode)
model_mode = ModelMode(model_config.mode)
if model_mode == ModelMode.CHAT:
prompt_messages, stops = self._get_chat_model_prompt_messages(
app_mode=app_mode,

View File

@ -238,9 +238,11 @@ class WordExtractor(BaseExtractor):
paragraph_content = []
for run in paragraph.runs:
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
# Process drawing type images
drawing_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing"
)
has_drawing = False
for drawing in drawing_elements:
blip_elements = drawing.findall(
".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip"
@ -252,6 +254,34 @@ class WordExtractor(BaseExtractor):
if embed_id:
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
has_drawing = True
paragraph_content.append(image_map[image_part])
# Process pict type images
shape_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
)
for shape in shape_elements:
# Find image data in VML
shape_image = shape.find(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}binData"
)
if shape_image is not None and shape_image.text:
image_id = shape_image.get(
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
image_part = doc.part.rels[image_id].target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
# Find imagedata element in VML
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
if image_data is not None:
image_id = image_data.get("id") or image_data.get(
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
image_part = doc.part.rels[image_id].target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
if run.text.strip():
paragraph_content.append(run.text.strip())

View File

@ -1137,7 +1137,7 @@ class DatasetRetrieval:
def _get_prompt_template(
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
):
model_mode = ModelMode.value_of(mode)
model_mode = ModelMode(mode)
input_text = query
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]

View File

@ -6,7 +6,6 @@ import json
import logging
from typing import Optional, Union
from sqlalchemy import select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@ -206,44 +205,3 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
# Update the in-memory cache for faster subsequent lookups
logger.debug(f"Updating cache for execution_id: {db_model.id}")
self._execution_cache[db_model.id] = db_model
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
"""
Retrieve a WorkflowExecution by its ID.
First checks the in-memory cache, and if not found, queries the database.
If found in the database, adds it to the cache for future lookups.
Args:
execution_id: The workflow execution ID
Returns:
The WorkflowExecution instance if found, None otherwise
"""
# First check the cache
if execution_id in self._execution_cache:
logger.debug(f"Cache hit for execution_id: {execution_id}")
# Convert cached DB model to domain model
cached_db_model = self._execution_cache[execution_id]
return self._to_domain_model(cached_db_model)
# If not in cache, query the database
logger.debug(f"Cache miss for execution_id: {execution_id}, querying database")
with self._session_factory() as session:
stmt = select(WorkflowRun).where(
WorkflowRun.id == execution_id,
WorkflowRun.tenant_id == self._tenant_id,
)
if self._app_id:
stmt = stmt.where(WorkflowRun.app_id == self._app_id)
db_model = session.scalar(stmt)
if db_model:
# Add DB model to cache
self._execution_cache[execution_id] = db_model
# Convert to domain model and return
return self._to_domain_model(db_model)
return None

View File

@ -7,7 +7,7 @@ import logging
from collections.abc import Sequence
from typing import Optional, Union
from sqlalchemy import UnaryExpression, asc, delete, desc, select
from sqlalchemy import UnaryExpression, asc, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@ -218,47 +218,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
self._node_execution_cache[db_model.node_execution_id] = db_model
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a NodeExecution by its node_execution_id.
First checks the in-memory cache, and if not found, queries the database.
If found in the database, adds it to the cache for future lookups.
Args:
node_execution_id: The node execution ID
Returns:
The NodeExecution instance if found, None otherwise
"""
# First check the cache
if node_execution_id in self._node_execution_cache:
logger.debug(f"Cache hit for node_execution_id: {node_execution_id}")
# Convert cached DB model to domain model
cached_db_model = self._node_execution_cache[node_execution_id]
return self._to_domain_model(cached_db_model)
# If not in cache, query the database
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
with self._session_factory() as session:
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.node_execution_id == node_execution_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
db_model = session.scalar(stmt)
if db_model:
# Add DB model to cache
self._node_execution_cache[node_execution_id] = db_model
# Convert to domain model and return
return self._to_domain_model(db_model)
return None
def get_db_models_by_workflow_run(
self,
workflow_run_id: str,
@ -346,68 +305,3 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
domain_models.append(domain_model)
return domain_models
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running NodeExecution instances for a specific workflow run.
This method queries the database directly and updates the cache with any
retrieved executions that have a node_execution_id.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running NodeExecution instances
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING,
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
db_models = session.scalars(stmt).all()
domain_models = []
for model in db_models:
# Update cache if node_execution_id is present
if model.node_execution_id:
self._node_execution_cache[model.node_execution_id] = model
# Convert to domain model
domain_model = self._to_domain_model(model)
domain_models.append(domain_model)
return domain_models
def clear(self) -> None:
"""
Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
This method deletes all WorkflowNodeExecution records that match the tenant_id
and app_id (if provided) associated with this repository instance.
It also clears the in-memory cache.
"""
with self._session_factory() as session:
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
result = session.execute(stmt)
session.commit()
deleted_count = result.rowcount
logger.info(
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
+ (f" and app {self._app_id}" if self._app_id else "")
)
# Clear the in-memory cache
self._node_execution_cache.clear()
logger.info("Cleared in-memory node execution cache")

View File

@ -2,7 +2,7 @@ from core.workflow.nodes.base import BaseNode
class WorkflowNodeRunFailedError(Exception):
def __init__(self, node_instance: BaseNode, error: str):
self.node_instance = node_instance
self.error = error
super().__init__(f"Node {node_instance.node_data.title} run failed: {error}")
def __init__(self, node: BaseNode, err_msg: str):
self._node = node
self._error = err_msg
super().__init__(f"Node {node.title} run failed: {err_msg}")

View File

@ -1,3 +1,4 @@
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
from .graph_engine import GraphEngine
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]

View File

@ -12,7 +12,7 @@ from typing import Any, Optional, cast
from flask import Flask, current_app
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
@ -48,11 +48,9 @@ from core.workflow.nodes.agent.entities import AgentNodeData
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom
@ -260,12 +258,16 @@ class GraphEngine:
# convert to specific node
node_type = NodeType(node_config.get("data", {}).get("type"))
node_version = node_config.get("data", {}).get("version", "1")
# Import here to avoid circular import
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
# init workflow run state
node_instance = node_cls( # type: ignore
node = node_cls(
id=route_node_state.id,
config=node_config,
graph_init_params=self.init_params,
@ -274,11 +276,11 @@ class GraphEngine:
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id,
)
node_instance = cast(BaseNode[BaseNodeData], node_instance)
node.init_node_data(node_config.get("data", {}))
try:
# run node
generator = self._run_node(
node_instance=node_instance,
node=node,
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
@ -306,16 +308,16 @@ class GraphEngine:
route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent(
error=str(e),
id=node_instance.id,
id=node.id,
node_id=next_node_id,
node_type=node_type,
node_data=node_instance.node_data,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
raise e
@ -337,7 +339,7 @@ class GraphEngine:
edge = edge_mappings[0]
if (
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
and node.error_strategy == ErrorStrategy.FAIL_BRANCH
and edge.run_condition is None
):
break
@ -413,8 +415,8 @@ class GraphEngine:
next_node_id = final_node_id
elif (
node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
and node_instance.should_continue_on_error
node.continue_on_error
and node.error_strategy == ErrorStrategy.FAIL_BRANCH
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
):
break
@ -597,7 +599,7 @@ class GraphEngine:
def _run_node(
self,
node_instance: BaseNode[BaseNodeData],
node: BaseNode,
route_node_state: RouteNodeState,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
@ -611,29 +613,29 @@ class GraphEngine:
# trigger node run start event
agent_strategy = (
AgentNodeStrategyInit(
name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name,
icon=cast(AgentNode, node_instance).agent_strategy_icon,
name=cast(AgentNodeData, node.get_base_node_data()).agent_strategy_name,
icon=cast(AgentNode, node).agent_strategy_icon,
)
if node_instance.node_type == NodeType.AGENT
if node.type_ == NodeType.AGENT
else None
)
yield NodeRunStartedEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
predecessor_node_id=node_instance.previous_node_id,
predecessor_node_id=node.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
agent_strategy=agent_strategy,
node_version=node_instance.version(),
node_version=node.version(),
)
max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
max_retries = node.retry_config.max_retries
retry_interval = node.retry_config.retry_interval_seconds
retries = 0
should_continue_retry = True
while should_continue_retry and retries <= max_retries:
@ -642,7 +644,7 @@ class GraphEngine:
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
# yield control to other threads
time.sleep(0.001)
event_stream = node_instance.run()
event_stream = node.run()
for event in event_stream:
if isinstance(event, GraphEngineEvent):
# add parallel info to iteration event
@ -658,21 +660,21 @@ class GraphEngine:
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if (
retries == max_retries
and node_instance.node_type == NodeType.HTTP_REQUEST
and node.type_ == NodeType.HTTP_REQUEST
and run_result.outputs
and not node_instance.should_continue_on_error
and not node.continue_on_error
):
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
if node_instance.should_retry and retries < max_retries:
if node.retry and retries < max_retries:
retries += 1
route_node_state.node_run_result = run_result
yield NodeRunRetryEvent(
id=str(uuid.uuid4()),
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
predecessor_node_id=node_instance.previous_node_id,
predecessor_node_id=node.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
@ -680,17 +682,17 @@ class GraphEngine:
error=run_result.error or "Unknown error",
retry_index=retries,
start_at=retry_start_at,
node_version=node_instance.version(),
node_version=node.version(),
)
time.sleep(retry_interval)
break
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
if node.continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
node,
event.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
@ -701,44 +703,44 @@ class GraphEngine:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
should_continue_retry = False
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
should_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if (
node_instance.should_continue_on_error
and self.graph.edge_mapping.get(node_instance.node_id)
and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH
node.continue_on_error
and self.graph.edge_mapping.get(node.node_id)
and node.error_strategy is ErrorStrategy.FAIL_BRANCH
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(
@ -758,7 +760,7 @@ class GraphEngine:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
@ -783,26 +785,26 @@ class GraphEngine:
run_result.metadata = metadata_dict
yield NodeRunSucceededEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
should_continue_retry = False
break
elif isinstance(event, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
chunk_content=event.chunk_content,
from_variable_selector=event.from_variable_selector,
route_node_state=route_node_state,
@ -810,14 +812,14 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
elif isinstance(event, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
retriever_resources=event.retriever_resources,
context=event.context,
route_node_state=route_node_state,
@ -825,7 +827,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
except GenerateTaskStoppedError:
# trigger node run failed event
@ -833,20 +835,20 @@ class GraphEngine:
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
id=node.id,
node_id=node.node_id,
node_type=node.type_,
node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
node_version=node.version(),
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
logger.exception(f"Node {node.title} run failed")
raise e
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
@ -886,22 +888,14 @@ class GraphEngine:
def _handle_continue_on_error(
self,
node_instance: BaseNode[BaseNodeData],
node: BaseNode,
error_result: NodeRunResult,
variable_pool: VariablePool,
handle_exceptions: list[str] = [],
) -> NodeRunResult:
"""
handle continue on error when self._should_continue_on_error is True
:param error_result (NodeRunResult): error run result
:param variable_pool (VariablePool): variable pool
:return: excption run result
"""
# add error message and error type to variable pool
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
variable_pool.add([node.node_id, "error_message"], error_result.error)
variable_pool.add([node.node_id, "error_type"], error_result.error_type)
# add error message to handle_exceptions
handle_exceptions.append(error_result.error or "")
node_error_args: dict[str, Any] = {
@ -909,21 +903,21 @@ class GraphEngine:
"error": error_result.error,
"inputs": error_result.inputs,
"metadata": {
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy,
},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
return NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
**node.default_value_dict,
"error_message": error_result.error,
"error_type": error_result.error_type,
},
)
elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
if self.graph.edge_mapping.get(node_instance.node_id):
elif node.error_strategy is ErrorStrategy.FAIL_BRANCH:
if self.graph.edge_mapping.get(node.node_id):
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
return NodeRunResult(
**node_error_args,

View File

@ -1,5 +1,4 @@
import json
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
@ -11,8 +10,10 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.agent.strategy.plugin import PluginAgentStrategy
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError
@ -25,45 +26,75 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
AgentInputTypeError,
AgentInvocationError,
AgentMessageTransformError,
AgentVariableNotFoundError,
AgentVariableTypeError,
ToolFileNotFoundError,
)
class AgentNode(ToolNode):
class AgentNode(BaseNode):
"""
Agent Node
"""
_node_data_cls = AgentNodeData # type: ignore
_node_type = NodeType.AGENT
_node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = AgentNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
"""
Run the agent node
"""
node_data = cast(AgentNodeData, self.node_data)
try:
strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id,
agent_strategy_provider_name=node_data.agent_strategy_provider_name,
agent_strategy_name=node_data.agent_strategy_name,
agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
agent_strategy_name=self._node_data.agent_strategy_name,
)
except Exception as e:
yield RunCompletedEvent(
@ -81,13 +112,13 @@ class AgentNode(ToolNode):
parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
node_data=self._node_data,
strategy=strategy,
)
parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
node_data=self._node_data,
for_log=True,
strategy=strategy,
)
@ -105,59 +136,39 @@ class AgentNode(ToolNode):
credentials=credentials,
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=f"Failed to invoke agent: {str(e)}",
error=str(error),
)
)
return
try:
# convert tool messages
agent_thoughts: list = []
thought_log_message = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LOG,
message=ToolInvokeMessage.LogMessage(
id=str(uuid.uuid4()),
label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}",
parent_id=None,
error=None,
status=ToolInvokeMessage.LogMessage.LogStatus.START,
data={
"strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
"parameters": parameters_for_log,
"thought_process": "Agent strategy execution started",
},
metadata={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
},
),
)
def enhanced_message_stream():
yield thought_log_message
yield from message_stream
yield from self._transform_message(
message_stream,
{
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
},
parameters_for_log,
agent_thoughts,
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_type=self.type_,
node_id=self.node_id,
node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=f"Failed to transform agent message: {str(e)}",
error=str(transform_error),
)
)
@ -194,7 +205,7 @@ class AgentNode(ToolNode):
if agent_input.type == "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise ValueError(f"Variable {agent_input.value} does not exist")
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
elif agent_input.type in {"mixed", "constant"}:
# variable_pool.convert_template expects a string template,
@ -216,7 +227,7 @@ class AgentNode(ToolNode):
except json.JSONDecodeError:
parameter_value = parameter_value
else:
raise ValueError(f"Unknown agent input type '{agent_input.type}'")
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
@ -259,7 +270,7 @@ class AgentNode(ToolNode):
)
extra = tool.get("extra", {})
runtime_variable_pool = variable_pool if self.node_data.version != "1" else None
runtime_variable_pool = variable_pool if self._node_data.version != "1" else None
tool_runtime = ToolManager.get_agent_tool_runtime(
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
)
@ -343,19 +354,14 @@ class AgentNode(ToolNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: BaseNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
node_data = cast(AgentNodeData, node_data)
# Create typed NodeData from dict
typed_node_data = AgentNodeData.model_validate(node_data)
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
input = node_data.agent_parameters[parameter_name]
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
if input.type in ["mixed", "constant"]:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
@ -380,7 +386,7 @@ class AgentNode(ToolNode):
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}"
== cast(AgentNodeData, self.node_data).agent_strategy_provider_name
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:
@ -448,3 +454,236 @@ class AgentNode(ToolNode):
return tools
else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value]
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage: LLMUsage | None = None
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(msg_metadata)
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
if message.message.json_object is not None:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise AgentVariableTypeError(
"When 'stream' is True, 'variable_value' must be a string.",
variable_name=variable_name,
expected_type="str",
actual_type=type(variable_value).__name__,
)
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, File)
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
json_output.extend(json)
else:
json_output.append({"data": []})
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

View File

@ -0,0 +1,124 @@
from typing import Optional
class AgentNodeError(Exception):
"""Base exception for all agent node errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class AgentStrategyError(AgentNodeError):
"""Exception raised when there's an error with the agent strategy."""
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
self.strategy_name = strategy_name
self.provider_name = provider_name
super().__init__(message)
class AgentStrategyNotFoundError(AgentStrategyError):
"""Exception raised when the specified agent strategy is not found."""
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
super().__init__(
f"Agent strategy '{strategy_name}' not found"
+ (f" for provider '{provider_name}'" if provider_name else ""),
strategy_name,
provider_name,
)
class AgentInvocationError(AgentNodeError):
"""Exception raised when there's an error invoking the agent."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
self.original_error = original_error
super().__init__(message)
class AgentParameterError(AgentNodeError):
"""Exception raised when there's an error with agent parameters."""
def __init__(self, message: str, parameter_name: Optional[str] = None):
self.parameter_name = parameter_name
super().__init__(message)
class AgentVariableError(AgentNodeError):
"""Exception raised when there's an error with variables in the agent node."""
def __init__(self, message: str, variable_name: Optional[str] = None):
self.variable_name = variable_name
super().__init__(message)
class AgentVariableNotFoundError(AgentVariableError):
"""Exception raised when a variable is not found in the variable pool."""
def __init__(self, variable_name: str):
super().__init__(f"Variable '{variable_name}' does not exist", variable_name)
class AgentInputTypeError(AgentNodeError):
"""Exception raised when an unknown agent input type is encountered."""
def __init__(self, input_type: str):
super().__init__(f"Unknown agent input type '{input_type}'")
class ToolFileError(AgentNodeError):
"""Exception raised when there's an error with a tool file."""
def __init__(self, message: str, file_id: Optional[str] = None):
self.file_id = file_id
super().__init__(message)
class ToolFileNotFoundError(ToolFileError):
"""Exception raised when a tool file is not found."""
def __init__(self, file_id: str):
super().__init__(f"Tool file '{file_id}' does not exist", file_id)
class AgentMessageTransformError(AgentNodeError):
"""Exception raised when there's an error transforming agent messages."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
self.original_error = original_error
super().__init__(message)
class AgentModelError(AgentNodeError):
"""Exception raised when there's an error with the model used by the agent."""
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
self.model_name = model_name
self.provider = provider
super().__init__(message)
class AgentMemoryError(AgentNodeError):
"""Exception raised when there's an error with the agent's memory."""
def __init__(self, message: str, conversation_id: Optional[str] = None):
self.conversation_id = conversation_id
super().__init__(message)
class AgentVariableTypeError(AgentNodeError):
"""Exception raised when a variable has an unexpected type."""
def __init__(
self,
message: str,
variable_name: Optional[str] = None,
expected_type: Optional[str] = None,
actual_type: Optional[str] = None,
):
self.variable_name = variable_name
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Any, Optional, cast
from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
@ -12,14 +12,37 @@ from core.workflow.nodes.answer.entities import (
VarGenerateRouteChunk,
)
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerNode(BaseNode[AnswerNodeData]):
_node_data_cls = AnswerNodeData
class AnswerNode(BaseNode):
_node_type = NodeType.ANSWER
_node_data: AnswerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = AnswerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@ -30,7 +53,7 @@ class AnswerNode(BaseNode[AnswerNodeData]):
:return:
"""
# generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data)
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
answer = ""
files = []
@ -60,16 +83,12 @@ class AnswerNode(BaseNode[AnswerNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AnswerNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
# Create typed NodeData from dict
typed_node_data = AnswerNodeData.model_validate(node_data)
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}

View File

@ -122,13 +122,13 @@ class RetryConfig(BaseModel):
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
version: str = "1"
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
version: str = "1"
retry_config: RetryConfig = RetryConfig()
@property
def default_value_dict(self):
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}

View File

@ -1,28 +1,22 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from .entities import BaseNodeData
if TYPE_CHECKING:
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__)
GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[GenericNodeData]
class BaseNode:
_node_type: ClassVar[NodeType]
def __init__(
@ -56,8 +50,8 @@ class BaseNode(Generic[GenericNodeData]):
self.node_id = node_id
node_data = self._node_data_cls.model_validate(config.get("data", {}))
self.node_data = node_data
@abstractmethod
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
@ -130,9 +124,9 @@ class BaseNode(Generic[GenericNodeData]):
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {}))
# Pass raw dict data instead of creating NodeData instance
data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
)
return data
@ -142,32 +136,16 @@ class BaseNode(Generic[GenericNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: GenericNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
Get default config of node.
:param filters: filter by node config parameters.
:return:
"""
return {}
@property
def node_type(self) -> NodeType:
"""
Get node type
:return:
"""
def type_(self) -> NodeType:
return self._node_type
@classmethod
@ -181,19 +159,68 @@ class BaseNode(Generic[GenericNodeData]):
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@property
def should_continue_on_error(self) -> bool:
"""judge if should continue on error
Returns:
bool: if should continue on error
"""
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
def continue_on_error(self) -> bool:
return False
@property
def should_retry(self) -> bool:
"""judge if should retry
def retry(self) -> bool:
return False
Returns:
bool: if should retry
"""
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE
# Abstract methods that subclasses must implement to provide access
# to BaseNodeData properties in a type-safe way
@abstractmethod
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
"""Get the error strategy for this node."""
...
@abstractmethod
def _get_retry_config(self) -> RetryConfig:
"""Get the retry configuration for this node."""
...
@abstractmethod
def _get_title(self) -> str:
"""Get the node title."""
...
@abstractmethod
def _get_description(self) -> Optional[str]:
"""Get the node description."""
...
@abstractmethod
def _get_default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
...
@abstractmethod
def get_base_node_data(self) -> BaseNodeData:
"""Get the BaseNodeData object for this node."""
...
# Public interface properties that delegate to abstract methods
@property
def error_strategy(self) -> Optional[ErrorStrategy]:
"""Get the error strategy for this node."""
return self._get_error_strategy()
@property
def retry_config(self) -> RetryConfig:
"""Get the retry configuration for this node."""
return self._get_retry_config()
@property
def title(self) -> str:
"""Get the node title."""
return self._get_title()
@property
def description(self) -> Optional[str]:
"""Get the node description."""
return self._get_description()
@property
def default_value_dict(self) -> dict[str, Any]:
"""Get the default values dictionary for this node."""
return self._get_default_value_dict()

View File

@ -11,8 +11,9 @@ from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .exc import (
CodeNodeError,
@ -21,10 +22,32 @@ from .exc import (
)
class CodeNode(BaseNode[CodeNodeData]):
_node_data_cls = CodeNodeData
class CodeNode(BaseNode):
_node_type = NodeType.CODE
_node_data: CodeNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = CodeNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
@ -47,12 +70,12 @@ class CodeNode(BaseNode[CodeNodeData]):
def _run(self) -> NodeRunResult:
# Get code language
code_language = self.node_data.code_language
code = self.node_data.code
code_language = self._node_data.code_language
code = self._node_data.code
# Get variables
variables = {}
for variable_selector in self.node_data.variables:
for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if isinstance(variable, ArrayFileSegment):
@ -68,7 +91,7 @@ class CodeNode(BaseNode[CodeNodeData]):
)
# Transform result
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
result = self._transform_result(result=result, output_schema=self._node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@ -334,16 +357,20 @@ class CodeNode(BaseNode[CodeNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: CodeNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = CodeNodeData.model_validate(node_data)
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables
for variable_selector in typed_node_data.variables
}
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@ -5,7 +5,7 @@ import logging
import os
import tempfile
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Any, Optional, cast
import chardet
import docx
@ -28,7 +28,8 @@ from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
@ -36,21 +37,43 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
logger = logging.getLogger(__name__)
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
class DocumentExtractorNode(BaseNode):
"""
Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files.
"""
_node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR
_node_data: DocumentExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = DocumentExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
variable_selector = self.node_data.variable_selector
variable_selector = self._node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None:
@ -97,16 +120,12 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: DocumentExtractorNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {node_id + ".files": node_data.variable_selector}
# Create typed NodeData from dict
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
return {node_id + ".files": typed_node_data.variable_selector}
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:

View File

@ -1,14 +1,40 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.enums import ErrorStrategy, NodeType
class EndNode(BaseNode[EndNodeData]):
_node_data_cls = EndNodeData
class EndNode(BaseNode):
_node_type = NodeType.END
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = EndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@ -18,7 +44,7 @@ class EndNode(BaseNode[EndNodeData]):
Run node
:return:
"""
output_variables = self.node_data.outputs
output_variables = self._node_data.outputs
outputs = {}
for variable_selector in output_variables:

View File

@ -37,7 +37,3 @@ class ErrorStrategy(StrEnum):
class FailBranchSourceHandle(StrEnum):
FAILED = "fail-branch"
SUCCESS = "success-branch"
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE

View File

@ -11,7 +11,8 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.utils import variable_template_parser
from factories import file_factory
@ -32,10 +33,32 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
class HttpRequestNode(BaseNode[HttpRequestNodeData]):
_node_data_cls = HttpRequestNodeData
class HttpRequestNode(BaseNode):
_node_type = NodeType.HTTP_REQUEST
_node_data: HttpRequestNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = HttpRequestNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
return {
@ -69,8 +92,8 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
process_data = {}
try:
http_executor = Executor(
node_data=self.node_data,
timeout=self._get_request_timeout(self.node_data),
node_data=self._node_data,
timeout=self._get_request_timeout(self._node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
)
@ -78,7 +101,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
if not response.response.is_success and (self.continue_on_error or self.retry):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
@ -131,15 +154,18 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: HttpRequestNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = HttpRequestNodeData.model_validate(node_data)
selectors: list[VariableSelector] = []
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
if node_data.body:
body_type = node_data.body.type
data = node_data.body.data
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
if typed_node_data.body:
body_type = typed_node_data.body.type
data = typed_node_data.body.data
match body_type:
case "binary":
if len(data) != 1:
@ -217,3 +243,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
files.append(file)
return ArrayFileSegment(value=files)
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing import Any, Literal, Optional
from typing_extensions import deprecated
@ -7,16 +7,39 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor
class IfElseNode(BaseNode[IfElseNodeData]):
_node_data_cls = IfElseNodeData
class IfElseNode(BaseNode):
_node_type = NodeType.IF_ELSE
_node_data: IfElseNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = IfElseNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@ -36,8 +59,8 @@ class IfElseNode(BaseNode[IfElseNodeData]):
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
if self.node_data.cases:
for case in self.node_data.cases:
if self._node_data.cases:
for case in self._node_data.cases:
input_conditions, group_result, final_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions,
@ -63,8 +86,8 @@ class IfElseNode(BaseNode[IfElseNodeData]):
input_conditions, group_result, final_result = _should_not_use_old_function(
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
conditions=self.node_data.conditions or [],
operator=self.node_data.logical_operator or "and",
conditions=self._node_data.conditions or [],
operator=self._node_data.logical_operator or "and",
)
selected_case_id = "true" if final_result else "false"
@ -98,10 +121,13 @@ class IfElseNode(BaseNode[IfElseNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IfElseNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = IfElseNodeData.model_validate(node_data)
var_mapping: dict[str, list[str]] = {}
for case in node_data.cases or []:
for case in typed_node_data.cases or []:
for condition in case.conditions:
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
var_mapping[key] = condition.variable_selector

View File

@ -36,7 +36,8 @@ from core.workflow.graph_engine.entities.event import (
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from factories.variable_factory import build_segment
@ -56,14 +57,36 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class IterationNode(BaseNode[IterationNodeData]):
class IterationNode(BaseNode):
"""
Iteration Node.
"""
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
_node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = IterationNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return {
@ -83,10 +106,10 @@ class IterationNode(BaseNode[IterationNodeData]):
"""
Run the node.
"""
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
@ -116,10 +139,10 @@ class IterationNode(BaseNode[IterationNodeData]):
graph_config = self.graph_config
if not self.node_data.start_node_id:
if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
root_node_id = self.node_data.start_node_id
root_node_id = self._node_data.start_node_id
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
@ -161,8 +184,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunStartedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
metadata={"iterator_length": len(iterator_list_value)},
@ -172,8 +195,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=0,
pre_iteration_output=None,
duration=None,
@ -181,11 +204,11 @@ class IterationNode(BaseNode[IterationNodeData]):
iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value)
try:
if self.node_data.is_parallel:
if self._node_data.is_parallel:
futures: list[Future] = []
q: Queue = Queue()
thread_pool = GraphEngineThreadPool(
max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
@ -242,7 +265,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
)
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]
# Flatten the list of lists
@ -253,8 +276,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@ -278,8 +301,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@ -305,21 +328,17 @@ class IterationNode(BaseNode[IterationNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = IterationNodeData.model_validate(node_data)
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": node_data.iterator_selector,
f"{node_id}.input_selector": typed_node_data.iterator_selector,
}
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
@ -375,7 +394,7 @@ class IterationNode(BaseNode[IterationNodeData]):
"""
if not isinstance(event, BaseNodeEvent):
return event
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id
iter_metadata = {
@ -438,12 +457,12 @@ class IterationNode(BaseNode[IterationNodeData]):
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
if self.node_data.is_parallel:
if self._node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
@ -456,8 +475,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@ -478,7 +497,7 @@ class IterationNode(BaseNode[IterationNodeData]):
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
if isinstance(event, NodeRunFailedEvent):
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
@ -491,15 +510,15 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
@ -512,15 +531,15 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
@ -531,12 +550,12 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.remove([node_id])
# iteration run failed
if self.node_data.is_parallel:
if self._node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
@ -549,8 +568,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@ -569,7 +588,7 @@ class IterationNode(BaseNode[IterationNodeData]):
return
yield metadata_event
current_output_segment = variable_pool.get(self.node_data.output_selector)
current_output_segment = variable_pool.get(self._node_data.output_selector)
if current_output_segment is None:
raise IterationNodeError("iteration output selector not found")
current_iteration_output = current_output_segment.value
@ -588,8 +607,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=current_iteration_output or None,
@ -601,8 +620,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
iteration_node_type=self.type_,
iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": None},

View File

@ -1,18 +1,44 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.iteration.entities import IterationStartNodeData
class IterationStartNode(BaseNode[IterationStartNodeData]):
class IterationStartNode(BaseNode):
"""
Iteration Start Node.
"""
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START
_node_data: IterationStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = IterationStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -1,10 +1,10 @@
from collections.abc import Sequence
from typing import Any, Literal, Optional
from typing import Literal, Optional
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm.entities import VisionConfig
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
class RerankingModelConfig(BaseModel):
@ -55,14 +55,6 @@ class MultipleRetrievalConfig(BaseModel):
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
class ModelConfig(BaseModel):
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class SingleRetrievalConfig(BaseModel):
"""
Single Retrieval Config.
@ -125,7 +117,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
single_retrieval_config: Optional[SingleRetrievalConfig] = None
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
metadata_model_config: Optional[ModelConfig] = None
metadata_model_config: ModelConfig
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
vision: VisionConfig = Field(default_factory=VisionConfig)

View File

@ -4,7 +4,7 @@ import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
@ -15,20 +15,31 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.entities.message_entities import (
PromptMessageRole,
)
from core.model_runtime.entities.model_entities import (
ModelFeature,
ModelType,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
from core.variables import (
StringSegment,
)
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
)
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
@ -38,7 +49,8 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -46,7 +58,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from services.feature_service import FeatureService
from .entities import KnowledgeRetrievalNodeData, ModelConfig
from .entities import KnowledgeRetrievalNodeData
from .exc import (
InvalidModelTypeError,
KnowledgeRetrievalNodeError,
@ -56,6 +68,10 @@ from .exc import (
ModelQuotaExceededError,
)
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
logger = logging.getLogger(__name__)
default_retrieval_model = {
@ -67,18 +83,76 @@ default_retrieval_model = {
}
class KnowledgeRetrievalNode(LLMNode):
_node_data_cls = KnowledgeRetrievalNodeData # type: ignore
class KnowledgeRetrievalNode(BaseNode):
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
_node_data: KnowledgeRetrievalNodeData
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls):
return "1"
def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -119,7 +193,7 @@ class KnowledgeRetrievalNode(LLMNode):
# retrieve knowledge
try:
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
results = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -435,20 +509,15 @@ class KnowledgeRetrievalNode(LLMNode):
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config
metadata_model_config = node_data.metadata_model_config
if metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self.get_model_config(metadata_model_config)
# get metadata model instance and fetch model config
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
# fetch prompt messages
prompt_template = self._get_prompt_template(
node_data=node_data,
metadata_fields=all_metadata_fields,
query=query or "",
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query=query,
memory=None,
@ -458,16 +527,23 @@ class KnowledgeRetrievalNode(LLMNode):
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
tenant_id=self.tenant_id,
)
result_text = ""
try:
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.metadata_model_config, # type: ignore
generator = LLMNode.invoke_llm(
node_data_model=node_data.metadata_model_config,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
)
for event in generator:
@ -557,17 +633,13 @@ class KnowledgeRetrievalNode(LLMNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: KnowledgeRetrievalNodeData, # type: ignore
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
variable_mapping = {}
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
return variable_mapping
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
@ -629,7 +701,7 @@ class KnowledgeRetrievalNode(LLMNode):
)
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore
model_mode = ModelMode(node_data.metadata_model_config.mode)
input_text = query
prompt_messages: list[LLMNodeChatModelMessage] = []

View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence
from typing import Any, Literal, Union
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal, Optional, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@ -7,16 +7,39 @@ from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_data_cls = ListOperatorNodeData
class ListOperatorNode(BaseNode):
_node_type = NodeType.LIST_OPERATOR
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = ListOperatorNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@ -26,9 +49,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
process_data: dict[str, list] = {}
outputs: dict[str, Any] = {}
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
if variable is None:
error_message = f"Variable not found for selector: {self.node_data.variable}"
error_message = f"Variable not found for selector: {self._node_data.variable}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
@ -48,7 +71,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
)
return NodeRunResult(
@ -64,19 +87,19 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
try:
# Filter
if self.node_data.filter_by.enabled:
if self._node_data.filter_by.enabled:
variable = self._apply_filter(variable)
# Extract
if self.node_data.extract_by.enabled:
if self._node_data.extract_by.enabled:
variable = self._extract_slice(variable)
# Order
if self.node_data.order_by.enabled:
if self._node_data.order_by.enabled:
variable = self._apply_order(variable)
# Slice
if self.node_data.limit.enabled:
if self._node_data.limit.enabled:
variable = self._apply_slice(variable)
outputs = {
@ -104,7 +127,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
filter_func: Callable[[Any], bool]
result: list[Any] = []
for condition in self.node_data.filter_by.conditions:
for condition in self._node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
@ -137,14 +160,14 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
return variable
@ -152,13 +175,13 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
result = variable.value[: self.node_data.limit.size]
result = variable.value[: self._node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
value -= 1

View File

@ -1,4 +1,4 @@
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator
@ -65,7 +65,7 @@ class LLMNodeData(BaseNodeData):
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None
structured_output: Mapping[str, Any] | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")

View File

@ -59,7 +59,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
NodeEvent,
@ -90,17 +91,16 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
logger = logging.getLogger(__name__)
class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
class LLMNode(BaseNode):
_node_type = NodeType.LLM
_node_data: LLMNodeData
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
@ -138,6 +138,27 @@ class LLMNode(BaseNode[LLMNodeData]):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = LLMNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@ -152,13 +173,13 @@ class LLMNode(BaseNode[LLMNodeData]):
try:
# init messages template
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self.node_data)
inputs = self._fetch_inputs(node_data=self._node_data)
# fetch jinja2 inputs
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
# merge inputs
inputs.update(jinja_inputs)
@ -169,9 +190,9 @@ class LLMNode(BaseNode[LLMNodeData]):
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
selector=self.node_data.vision.configs.variable_selector,
selector=self._node_data.vision.configs.variable_selector,
)
if self.node_data.vision.enabled
if self._node_data.vision.enabled
else []
)
@ -179,7 +200,7 @@ class LLMNode(BaseNode[LLMNodeData]):
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
generator = self._fetch_context(node_data=self.node_data)
generator = self._fetch_context(node_data=self._node_data)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
@ -189,44 +210,54 @@ class LLMNode(BaseNode[LLMNodeData]):
node_inputs["#context#"] = context
# fetch model config
model_instance, model_config = self._fetch_model_config(self.node_data.model)
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=self._node_data.model,
tenant_id=self.tenant_id,
)
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=self.node_data.memory,
node_data_memory=self._node_data.memory,
model_instance=model_instance,
)
query = None
if self.node_data.memory:
query = self.node_data.memory.query_prompt_template
if self._node_data.memory:
query = self._node_data.memory.query_prompt_template
if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
query = query_variable.text
prompt_messages, stop = self._fetch_prompt_messages(
prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
prompt_template=self._node_data.prompt_template,
memory_config=self._node_data.memory,
vision_enabled=self._node_data.vision.enabled,
vision_detail=self._node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
jinja2_variables=self._node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
)
# handle invoke result
generator = self._invoke_llm(
node_data_model=self.node_data.model,
generator = LLMNode.invoke_llm(
node_data_model=self._node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=self._node_data.structured_output,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
)
structured_output: LLMStructuredOutput | None = None
@ -296,12 +327,19 @@ class LLMNode(BaseNode[LLMNodeData]):
)
)
def _invoke_llm(
self,
@staticmethod
def invoke_llm(
*,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Optional[Mapping[str, Any]] = None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
@ -309,8 +347,10 @@ class LLMNode(BaseNode[LLMNodeData]):
if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}")
if self.node_data.structured_output_enabled:
output_schema = self._fetch_structured_output_schema()
if structured_output_enabled:
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output or {},
)
invoke_result = invoke_llm_with_structured_output(
provider=model_instance.provider,
model_schema=model_schema,
@ -320,7 +360,7 @@ class LLMNode(BaseNode[LLMNodeData]):
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
stream=True,
user=self.user_id,
user=user_id,
)
else:
invoke_result = model_instance.invoke_llm(
@ -328,17 +368,31 @@ class LLMNode(BaseNode[LLMNodeData]):
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
stream=True,
user=self.user_id,
user=user_id,
)
return self._handle_invoke_result(invoke_result=invoke_result)
return LLMNode.handle_invoke_result(
invoke_result=invoke_result,
file_saver=file_saver,
file_outputs=file_outputs,
node_id=node_id,
)
def _handle_invoke_result(
self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None]
@staticmethod
def handle_invoke_result(
*,
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
file_saver: LLMFileSaver,
file_outputs: list["File"],
node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
event = self._handle_blocking_result(invoke_result=invoke_result)
event = LLMNode.handle_blocking_result(
invoke_result=invoke_result,
saver=file_saver,
file_outputs=file_outputs,
)
yield event
return
@ -356,11 +410,13 @@ class LLMNode(BaseNode[LLMNodeData]):
yield result
if isinstance(result, LLMResultChunk):
contents = result.delta.message.content
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=contents,
file_saver=file_saver,
file_outputs=file_outputs,
):
full_text_buffer.write(text_part)
yield RunStreamChunkEvent(
chunk_content=text_part, from_variable_selector=[self.node_id, "text"]
)
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"])
# Update the whole metadata
if not model and result.model:
@ -378,7 +434,8 @@ class LLMNode(BaseNode[LLMNodeData]):
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
def _image_file_to_markdown(self, file: "File", /):
@staticmethod
def _image_file_to_markdown(file: "File", /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
@ -539,11 +596,14 @@ class LLMNode(BaseNode[LLMNodeData]):
return None
@staticmethod
def _fetch_model_config(
self, node_data_model: ModelConfig
*,
node_data_model: ModelConfig,
tenant_id: str,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model, model_config_with_cred = llm_utils.fetch_model_config(
tenant_id=self.tenant_id, node_data_model=node_data_model
tenant_id=tenant_id, node_data_model=node_data_model
)
completion_params = model_config_with_cred.parameters
@ -556,8 +616,8 @@ class LLMNode(BaseNode[LLMNodeData]):
node_data_model.completion_params = completion_params
return model, model_config_with_cred
def _fetch_prompt_messages(
self,
@staticmethod
def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence["File"],
@ -570,13 +630,14 @@ class LLMNode(BaseNode[LLMNodeData]):
vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
self._handle_list_messages(
LLMNode.handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
@ -602,7 +663,7 @@ class LLMNode(BaseNode[LLMNodeData]):
edition_type="basic",
)
prompt_messages.extend(
self._handle_list_messages(
LLMNode.handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
@ -731,7 +792,7 @@ class LLMNode(BaseNode[LLMNodeData]):
)
model = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.model,
@ -750,10 +811,12 @@ class LLMNode(BaseNode[LLMNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LLMNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
prompt_template = node_data.prompt_template
# Create typed NodeData from dict
typed_node_data = LLMNodeData.model_validate(node_data)
prompt_template = typed_node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list) and all(
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
@ -773,7 +836,7 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
memory = node_data.memory
memory = typed_node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
@ -781,16 +844,16 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
if node_data.context.enabled:
variable_mapping["#context#"] = node_data.context.variable_selector
if typed_node_data.context.enabled:
variable_mapping["#context#"] = typed_node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
if typed_node_data.vision.enabled:
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
if node_data.memory:
if typed_node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
if node_data.prompt_config:
if typed_node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, list):
@ -803,7 +866,7 @@ class LLMNode(BaseNode[LLMNodeData]):
enable_jinja = True
if enable_jinja:
for variable_selector in node_data.prompt_config.jinja2_variables or []:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
@ -835,8 +898,8 @@ class LLMNode(BaseNode[LLMNodeData]):
},
}
def _handle_list_messages(
self,
@staticmethod
def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
@ -849,7 +912,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
@ -897,9 +960,19 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
@staticmethod
def handle_blocking_result(
*,
invoke_result: LLMResult,
saver: LLMFileSaver,
file_outputs: list["File"],
) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
contents=invoke_result.message.content,
file_saver=saver,
file_outputs=file_outputs,
):
buffer.write(text_part)
return ModelInvokeCompletedEvent(
@ -908,7 +981,12 @@ class LLMNode(BaseNode[LLMNodeData]):
finish_reason=None,
)
def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
@staticmethod
def save_multimodal_image_output(
*,
content: ImagePromptMessageContent,
file_saver: LLMFileSaver,
) -> "File":
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
@ -918,26 +996,21 @@ class LLMNode(BaseNode[LLMNodeData]):
Currently, only image files are supported.
"""
# Inject the saver somehow...
_saver = self._llm_file_saver
# If this
if content.url != "":
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE)
else:
saved_file = _saver.save_binary_string(
saved_file = file_saver.save_binary_string(
data=base64.b64decode(content.base64_data),
mime_type=content.mime_type,
file_type=FileType.IMAGE,
)
self._file_outputs.append(saved_file)
return saved_file
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
"""
Fetch model schema
"""
model_name = self.node_data.model.name
model_name = self._node_data.model.name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
@ -948,16 +1021,20 @@ class LLMNode(BaseNode[LLMNodeData]):
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_schema
def _fetch_structured_output_schema(self) -> dict[str, Any]:
@staticmethod
def fetch_structured_output_schema(
*,
structured_output: Mapping[str, Any],
) -> dict[str, Any]:
"""
Fetch the structured output schema from the node data.
Returns:
dict[str, Any]: The structured output schema
"""
if not self.node_data.structured_output:
if not structured_output:
raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
@ -969,9 +1046,12 @@ class LLMNode(BaseNode[LLMNodeData]):
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
@staticmethod
def _save_multimodal_output_and_convert_result_to_markdown(
self,
*,
contents: str | list[PromptMessageContentUnionTypes] | None,
file_saver: LLMFileSaver,
file_outputs: list["File"],
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.
@ -994,9 +1074,12 @@ class LLMNode(BaseNode[LLMNodeData]):
if isinstance(item, TextPromptMessageContent):
yield item.data
elif isinstance(item, ImagePromptMessageContent):
file = self._save_multimodal_image_output(item)
self._file_outputs.append(file)
yield self._image_file_to_markdown(file)
file = LLMNode.save_multimodal_image_output(
content=item,
file_saver=file_saver,
)
file_outputs.append(file)
yield LLMNode._image_file_to_markdown(file)
else:
logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item)
@ -1004,6 +1087,14 @@ class LLMNode(BaseNode[LLMNodeData]):
logger.warning("unknown contents type encountered, type=%s", type(contents))
yield str(contents)
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
@ -1021,20 +1112,20 @@ def _combine_message_content_with_role(
def _render_jinja2_message(
*,
template: str,
jinjia2_variables: Sequence[VariableSelector],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""
jinjia2_inputs = {}
for jinja2_variable in jinjia2_variables:
jinja2_inputs = {}
for jinja2_variable in jinja2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=jinjia2_inputs,
inputs=jinja2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
@ -1130,7 +1221,7 @@ def _handle_completion_template(
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinjia2_variables=jinja2_variables,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:

View File

@ -1,18 +1,44 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.loop.entities import LoopEndNodeData
class LoopEndNode(BaseNode[LoopEndNodeData]):
class LoopEndNode(BaseNode):
"""
Loop End Node.
"""
_node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = LoopEndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -3,7 +3,7 @@ import logging
import time
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from configs import dify_config
from core.variables import (
@ -30,7 +30,8 @@ from core.workflow.graph_engine.entities.event import (
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
@ -43,14 +44,36 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class LoopNode(BaseNode[LoopNodeData]):
class LoopNode(BaseNode):
"""
Loop Node.
"""
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
_node_data: LoopNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = LoopNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@ -58,17 +81,17 @@ class LoopNode(BaseNode[LoopNodeData]):
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""Run the node."""
# Get inputs
loop_count = self.node_data.loop_count
break_conditions = self.node_data.break_conditions
logical_operator = self.node_data.logical_operator
loop_count = self._node_data.loop_count
break_conditions = self._node_data.break_conditions
logical_operator = self._node_data.logical_operator
inputs = {"loop_count": loop_count}
if not self.node_data.start_node_id:
if not self._node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
# Initialize graph
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id)
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
@ -78,8 +101,8 @@ class LoopNode(BaseNode[LoopNodeData]):
# Initialize loop variables
loop_variable_selectors = {}
if self.node_data.loop_variables:
for loop_variable in self.node_data.loop_variables:
if self._node_data.loop_variables:
for loop_variable in self._node_data.loop_variables:
value_processor = {
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var=loop_variable: variable_pool.get(var.value),
@ -127,8 +150,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunStartedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
metadata={"loop_length": loop_count},
@ -184,11 +207,11 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunSucceededEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs=self.node_data.outputs,
outputs=self._node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
@ -206,7 +229,7 @@ class LoopNode(BaseNode[LoopNodeData]):
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
outputs=self.node_data.outputs,
outputs=self._node_data.outputs,
inputs=inputs,
)
)
@ -217,8 +240,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=loop_count,
@ -320,8 +343,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
@ -351,8 +374,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_type=self.type_,
loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
@ -388,7 +411,7 @@ class LoopNode(BaseNode[LoopNodeData]):
_outputs[loop_variable_key] = None
_outputs["loop_round"] = current_index + 1
self.node_data.outputs = _outputs
self._node_data.outputs = _outputs
if check_break_result:
return {"check_break_result": True}
@ -400,10 +423,10 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunNextEvent(
loop_id=self.id,
loop_node_id=self.node_id,
loop_node_type=self.node_type,
loop_node_data=self.node_data,
loop_node_type=self.type_,
loop_node_data=self._node_data,
index=next_index,
pre_loop_output=self.node_data.outputs,
pre_loop_output=self._node_data.outputs,
)
return {"check_break_result": False}
@ -438,19 +461,15 @@ class LoopNode(BaseNode[LoopNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LoopNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = LoopNodeData.model_validate(node_data)
variable_mapping = {}
# init graph
loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
@ -486,7 +505,7 @@ class LoopNode(BaseNode[LoopNodeData]):
variable_mapping.update(sub_node_variable_mapping)
for loop_variable in node_data.loop_variables or []:
for loop_variable in typed_node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping

View File

@ -1,18 +1,44 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.loop.entities import LoopStartNodeData
class LoopStartNode(BaseNode[LoopStartNodeData]):
class LoopStartNode(BaseNode):
"""
Loop Start Node.
"""
_node_data_cls = LoopStartNodeData
_node_type = NodeType.LOOP_START
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = LoopStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -29,8 +29,9 @@ from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.utils import variable_template_parser
from factories.variable_factory import build_segment_with_type
@ -91,10 +92,31 @@ class ParameterExtractorNode(BaseNode):
Parameter Extractor Node.
"""
# FIXME: figure out why here is different from super class
_node_data_cls = ParameterExtractorNodeData # type: ignore
_node_type = NodeType.PARAMETER_EXTRACTOR
_node_data: ParameterExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = ParameterExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
_model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
@ -119,7 +141,7 @@ class ParameterExtractorNode(BaseNode):
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self.node_data)
node_data = cast(ParameterExtractorNodeData, self._node_data)
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""
@ -398,7 +420,7 @@ class ParameterExtractorNode(BaseNode):
"""
Generate prompt engineering prompt.
"""
model_mode = ModelMode.value_of(data.model.mode)
model_mode = ModelMode(data.model.mode)
if model_mode == ModelMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt(
@ -694,7 +716,7 @@ class ParameterExtractorNode(BaseNode):
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode.value_of(node_data.model.mode)
model_mode = ModelMode(node_data.model.mode)
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
@ -721,7 +743,7 @@ class ParameterExtractorNode(BaseNode):
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
):
model_mode = ModelMode.value_of(node_data.model.mode)
model_mode = ModelMode(node_data.model.mode)
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
@ -827,19 +849,15 @@ class ParameterExtractorNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ParameterExtractorNodeData, # type: ignore
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
# Create typed NodeData from dict
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
if node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
if typed_node_data.instruction:
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@ -11,8 +11,11 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import ModelInvokeCompletedEvent
from core.workflow.nodes.llm import (
LLMNode,
@ -20,6 +23,7 @@ from core.workflow.nodes.llm import (
LLMNodeCompletionModelPromptTemplate,
llm_utils,
)
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
@ -35,17 +39,77 @@ from .template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData # type: ignore
class QuestionClassifierNode(BaseNode):
_node_type = NodeType.QUESTION_CLASSIFIER
_node_data: QuestionClassifierNodeData
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = QuestionClassifierNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls):
return "1"
def _run(self):
node_data = cast(QuestionClassifierNodeData, self.node_data)
node_data = cast(QuestionClassifierNodeData, self._node_data)
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
@ -53,7 +117,10 @@ class QuestionClassifierNode(LLMNode):
query = variable.value if variable else None
variables = {"query": query}
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
model_instance, model_config = LLMNode._fetch_model_config(
node_data_model=node_data.model,
tenant_id=self.tenant_id,
)
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
@ -91,7 +158,7 @@ class QuestionClassifierNode(LLMNode):
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
prompt_messages, stop = self._fetch_prompt_messages(
prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
memory=memory,
@ -101,6 +168,7 @@ class QuestionClassifierNode(LLMNode):
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
tenant_id=self.tenant_id,
)
result_text = ""
@ -109,11 +177,17 @@ class QuestionClassifierNode(LLMNode):
try:
# handle invoke result
generator = self._invoke_llm(
generator = LLMNode.invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=False,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self.node_id,
)
for event in generator:
@ -183,23 +257,18 @@ class QuestionClassifierNode(LLMNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Any,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
node_data = cast(QuestionClassifierNodeData, node_data)
variable_mapping = {"query": node_data.query_variable_selector}
variable_selectors = []
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
# Create typed NodeData from dict
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
variable_mapping = {"query": typed_node_data.query_variable_selector}
variable_selectors: list[VariableSelector] = []
if typed_node_data.instruction:
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
@ -265,7 +334,7 @@ class QuestionClassifierNode(LLMNode):
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
):
model_mode = ModelMode.value_of(node_data.model.mode)
model_mode = ModelMode(node_data.model.mode)
classes = node_data.classes
categories = []
for class_ in classes:

View File

@ -1,15 +1,41 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.start.entities import StartNodeData
class StartNode(BaseNode[StartNodeData]):
_node_data_cls = StartNodeData
class StartNode(BaseNode):
_node_type = NodeType.START
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = StartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"

View File

@ -6,16 +6,39 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
_node_data_cls = TemplateTransformNodeData
class TemplateTransformNode(BaseNode):
_node_type = NodeType.TEMPLATE_TRANSFORM
_node_data: TemplateTransformNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = TemplateTransformNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
@ -35,14 +58,14 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
def _run(self) -> NodeRunResult:
# Get variables
variables = {}
for variable_selector in self.node_data.variables:
for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
)
except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
@ -60,16 +83,12 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = TemplateTransformNodeData.model_validate(node_data)
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables
for variable_selector in typed_node_data.variables
}

View File

@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -19,10 +18,10 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
@ -37,14 +36,18 @@ from .exc import (
)
class ToolNode(BaseNode[ToolNodeData]):
class ToolNode(BaseNode):
"""
Tool Node
"""
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
_node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = ToolNodeData.model_validate(data)
@classmethod
def version(cls) -> str:
return "1"
@ -54,7 +57,7 @@ class ToolNode(BaseNode[ToolNodeData]):
Run the tool node
"""
node_data = cast(ToolNodeData, self.node_data)
node_data = cast(ToolNodeData, self._node_data)
# fetch tool icon
tool_info = {
@ -67,9 +70,9 @@ class ToolNode(BaseNode[ToolNodeData]):
try:
from core.tools.tool_manager import ToolManager
variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None
variable_pool = self.graph_runtime_state.variable_pool if self._node_data.version != "1" else None
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool
self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
)
except ToolNodeError as e:
yield RunCompletedEvent(
@ -88,12 +91,12 @@ class ToolNode(BaseNode[ToolNodeData]):
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
node_data=self._node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
node_data=self._node_data,
for_log=True,
)
# get conversation id
@ -124,7 +127,14 @@ class ToolNode(BaseNode[ToolNodeData]):
try:
# convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
yield from self._transform_message(
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_id=self.node_id,
)
except (PluginDaemonClientSideError, ToolInvokeError) as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
@ -191,7 +201,9 @@ class ToolNode(BaseNode[ToolNodeData]):
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
agent_thoughts: Optional[list] = None,
user_id: str,
tenant_id: str,
node_id: str,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
@ -199,8 +211,8 @@ class ToolNode(BaseNode[ToolNodeData]):
# transform message and handle file storage
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
@ -208,9 +220,6 @@ class ToolNode(BaseNode[ToolNodeData]):
files: list[File] = []
json: list[dict] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage: LLMUsage | None = None
variables: dict[str, Any] = {}
for message in message_stream:
@ -243,7 +252,7 @@ class ToolNode(BaseNode[ToolNodeData]):
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
@ -266,45 +275,36 @@ class ToolNode(BaseNode[ToolNodeData]):
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield RunStreamChunkEvent(
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
)
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT:
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(msg_metadata)
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
# JSON message handling for tool node
if message.message.json_object is not None:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
)
else:
variables[variable_name] = variable_value
@ -319,7 +319,7 @@ class ToolNode(BaseNode[ToolNodeData]):
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(self.tenant_id)
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
@ -334,8 +334,8 @@ class ToolNode(BaseNode[ToolNodeData]):
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
self.user_id,
self.tenant_id,
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
@ -347,57 +347,10 @@ class ToolNode(BaseNode[ToolNodeData]):
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
id=message.message.id,
node_execution_id=self.id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=self.node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.id == agent_log.id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
elif message.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
assert isinstance(message.message, ToolInvokeMessage.RetrieverResourceMessage)
yield RunRetrieverResourceEvent(
retriever_resources=message.message.retriever_resources,
context=message.message.context,
)
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
json_output.extend(json)
@ -409,12 +362,9 @@ class ToolNode(BaseNode[ToolNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)
@ -424,7 +374,7 @@ class ToolNode(BaseNode[ToolNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ToolNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@ -433,9 +383,12 @@ class ToolNode(BaseNode[ToolNodeData]):
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = ToolNodeData.model_validate(node_data)
result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
@ -449,3 +402,29 @@ class ToolNode(BaseNode[ToolNodeData]):
result = {node_id + "." + key: value for key, value in result.items()}
return result
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@property
def continue_on_error(self) -> bool:
return self._node_data.error_strategy is not None
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@ -1,17 +1,41 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
class VariableAggregatorNode(BaseNode):
_node_type = NodeType.VARIABLE_AGGREGATOR
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = VariableAssignerNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
@ -21,8 +45,8 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
inputs = {}
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables:
if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled:
for selector in self._node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs = {"output": variable}
@ -30,7 +54,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
inputs = {".".join(selector[1:]): variable.to_object()}
break
else:
for group in self.node_data.advanced_settings.groups:
for group in self._node_data.advanced_settings.groups:
for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)

View File

@ -7,7 +7,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
@ -22,11 +23,33 @@ if TYPE_CHECKING:
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
class VariableAssignerNode(BaseNode[VariableAssignerData]):
_node_data_cls = VariableAssignerData
class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
_node_data: VariableAssignerData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = VariableAssignerData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def __init__(
self,
id: str,
@ -59,36 +82,39 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: VariableAssignerData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
mapping = {}
assigned_variable_node_id = node_data.assigned_variable_selector[0]
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
selector_key = ".".join(node_data.assigned_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = node_data.assigned_variable_selector
# Create typed NodeData from dict
typed_node_data = VariableAssignerData.model_validate(node_data)
selector_key = ".".join(node_data.input_variable_selector)
mapping = {}
assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
selector_key = ".".join(typed_node_data.assigned_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = typed_node_data.assigned_variable_selector
selector_key = ".".join(typed_node_data.input_variable_selector)
key = f"{node_id}.#{selector_key}#"
mapping[key] = node_data.input_variable_selector
mapping[key] = typed_node_data.input_variable_selector
return mapping
def _run(self) -> NodeRunResult:
assigned_variable_selector = self.node_data.assigned_variable_selector
assigned_variable_selector = self._node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:
match self._node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
updated_value = original_variable.value + [income_value.value]
@ -101,7 +127,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from typing import Any, TypeAlias, cast
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
@ -10,7 +10,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
@ -28,8 +29,6 @@ from .exc import (
VariableNotFoundError,
)
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
selector_node_id = item.variable_selector[0]
@ -54,10 +53,32 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
mapping[key] = selector
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory()
@ -71,22 +92,25 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: VariableAssignerNodeData,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
# Create typed NodeData from dict
typed_node_data = VariableAssignerNodeData.model_validate(node_data)
var_mapping: dict[str, Sequence[str]] = {}
for item in node_data.items:
for item in typed_node_data.items:
_target_mapping_from_item(var_mapping, node_id, item)
_source_mapping_from_item(var_mapping, node_id, item)
return var_mapping
def _run(self) -> NodeRunResult:
inputs = self.node_data.model_dump()
inputs = self._node_data.model_dump()
process_data: dict[str, Any] = {}
# NOTE: This node has no outputs
updated_variable_selectors: list[Sequence[str]] = []
try:
for item in self.node_data.items:
for item in self._node_data.items:
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
# ==================== Validation Part

View File

@ -1,4 +1,4 @@
from typing import Optional, Protocol
from typing import Protocol
from core.workflow.entities.workflow_execution import WorkflowExecution
@ -28,15 +28,3 @@ class WorkflowExecutionRepository(Protocol):
execution: The WorkflowExecution instance to save or update
"""
...
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
"""
Retrieve a WorkflowExecution by its ID.
Args:
execution_id: The workflow execution ID
Returns:
The WorkflowExecution instance if found, None otherwise
"""
...

View File

@ -39,18 +39,6 @@ class WorkflowNodeExecutionRepository(Protocol):
"""
...
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a NodeExecution by its node_execution_id.
Args:
node_execution_id: The node execution ID
Returns:
The NodeExecution instance if found, None otherwise
"""
...
def get_by_workflow_run(
self,
workflow_run_id: str,
@ -69,24 +57,3 @@ class WorkflowNodeExecutionRepository(Protocol):
A list of NodeExecution instances
"""
...
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running NodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running NodeExecution instances
"""
...
def clear(self) -> None:
"""
Clear all NodeExecution records based on implementation-specific criteria.
This method is intended to be used for bulk deletion operations, such as removing
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
"""
...

View File

@ -55,24 +55,15 @@ class WorkflowCycleManager:
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
# Initialize caches for workflow execution cycle
# These caches avoid redundant repository calls during a single workflow execution
self._workflow_execution_cache: dict[str, WorkflowExecution] = {}
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
def handle_workflow_run_start(self) -> WorkflowExecution:
inputs = {**self._application_generate_entity.inputs}
inputs = self._prepare_workflow_inputs()
execution_id = self._get_or_generate_execution_id()
# Iterate over SystemVariable fields using Pydantic's model_fields
if self._workflow_system_variables:
for field_name, value in self._workflow_system_variables.to_dict().items():
if field_name == SystemVariableKey.CONVERSATION_ID:
continue
inputs[f"sys.{field_name}"] = value
# handle special values
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
execution_id = str(
self._workflow_system_variables.workflow_execution_id if self._workflow_system_variables else None
) or str(uuid4())
execution = WorkflowExecution.new(
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
@ -83,9 +74,7 @@ class WorkflowCycleManager:
started_at=datetime.now(UTC).replace(tzinfo=None),
)
self._workflow_execution_repository.save(execution)
return execution
return self._save_and_cache_workflow_execution(execution)
def handle_workflow_run_success(
self,
@ -99,23 +88,15 @@ class WorkflowCycleManager:
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
# outputs = WorkflowEntry.handle_special_values(outputs)
self._update_workflow_execution_completion(
workflow_execution,
status=WorkflowExecutionStatus.SUCCEEDED,
outputs=outputs,
total_tokens=total_tokens,
total_steps=total_steps,
)
workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
workflow_execution.outputs = outputs or {}
workflow_execution.total_tokens = total_tokens
workflow_execution.total_steps = total_steps
workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
@ -132,24 +113,17 @@ class WorkflowCycleManager:
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
# outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
execution.outputs = outputs or {}
execution.total_tokens = total_tokens
execution.total_steps = total_steps
execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
execution.exceptions_count = exceptions_count
self._update_workflow_execution_completion(
execution,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
outputs=outputs,
total_tokens=total_tokens,
total_steps=total_steps,
exceptions_count=exceptions_count,
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
self._add_trace_task_if_needed(trace_manager, execution, conversation_id)
self._workflow_execution_repository.save(execution)
return execution
@ -169,39 +143,18 @@ class WorkflowCycleManager:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
now = naive_utc_now()
workflow_execution.status = WorkflowExecutionStatus(status.value)
workflow_execution.error_message = error_message
workflow_execution.total_tokens = total_tokens
workflow_execution.total_steps = total_steps
workflow_execution.finished_at = now
workflow_execution.exceptions_count = exceptions_count
# Use the instance repository to find running executions for a workflow run
running_node_executions = self._workflow_node_execution_repository.get_running_executions(
workflow_run_id=workflow_execution.id_
self._update_workflow_execution_completion(
workflow_execution,
status=status,
total_tokens=total_tokens,
total_steps=total_steps,
error_message=error_message,
exceptions_count=exceptions_count,
finished_at=now,
)
# Update the domain models
for node_execution in running_node_executions:
if node_execution.node_execution_id:
# Update the domain model
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error_message
node_execution.finished_at = now
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
# Update the repository with the domain model
self._workflow_node_execution_repository.save(node_execution)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
self._fail_running_node_executions(workflow_execution.id_, error_message, now)
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
@ -214,8 +167,198 @@ class WorkflowCycleManager:
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
# Create a domain model
created_at = datetime.now(UTC).replace(tzinfo=None)
domain_execution = self._create_node_execution_from_event(
workflow_execution=workflow_execution,
event=event,
status=WorkflowNodeExecutionStatus.RUNNING,
)
return self._save_and_cache_node_execution(domain_execution)
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
self._update_node_execution_completion(
domain_execution,
event=event,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
)
self._workflow_node_execution_repository.save(domain_execution)
return domain_execution
def handle_workflow_node_execution_failed(
self,
*,
event: QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
status = (
WorkflowNodeExecutionStatus.EXCEPTION
if isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.FAILED
)
self._update_node_execution_completion(
domain_execution,
event=event,
status=status,
error=event.error,
handle_special_values=True,
)
self._workflow_node_execution_repository.save(domain_execution)
return domain_execution
def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
domain_execution = self._create_node_execution_from_event(
workflow_execution=workflow_execution,
event=event,
status=WorkflowNodeExecutionStatus.RETRY,
error=event.error,
created_at=event.start_at,
)
# Handle inputs and outputs
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = event.outputs
metadata = self._merge_event_metadata(event)
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata)
return self._save_and_cache_node_execution(domain_execution)
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
# Check cache first
if id in self._workflow_execution_cache:
return self._workflow_execution_cache[id]
raise WorkflowRunNotFoundError(id)
def _prepare_workflow_inputs(self) -> dict[str, Any]:
"""Prepare workflow inputs by merging application inputs with system variables."""
inputs = {**self._application_generate_entity.inputs}
if self._workflow_system_variables:
for field_name, value in self._workflow_system_variables.to_dict().items():
if field_name != SystemVariableKey.CONVERSATION_ID:
inputs[f"sys.{field_name}"] = value
return dict(WorkflowEntry.handle_special_values(inputs) or {})
def _get_or_generate_execution_id(self) -> str:
"""Get execution ID from system variables or generate a new one."""
if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id:
return str(self._workflow_system_variables.workflow_execution_id)
return str(uuid4())
def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution:
"""Save workflow execution to repository and cache it."""
self._workflow_execution_repository.save(execution)
self._workflow_execution_cache[execution.id_] = execution
return execution
def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution:
"""Save node execution to repository and cache it if it has an ID."""
self._workflow_node_execution_repository.save(execution)
if execution.node_execution_id:
self._node_execution_cache[execution.node_execution_id] = execution
return execution
def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution:
"""Get node execution from cache or raise error if not found."""
domain_execution = self._node_execution_cache.get(node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {node_execution_id}")
return domain_execution
def _update_workflow_execution_completion(
self,
execution: WorkflowExecution,
*,
status: WorkflowExecutionStatus,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
error_message: Optional[str] = None,
exceptions_count: int = 0,
finished_at: Optional[datetime] = None,
) -> None:
"""Update workflow execution with completion data."""
execution.status = status
execution.outputs = outputs or {}
execution.total_tokens = total_tokens
execution.total_steps = total_steps
execution.finished_at = finished_at or naive_utc_now()
execution.exceptions_count = exceptions_count
if error_message:
execution.error_message = error_message
def _add_trace_task_if_needed(
self,
trace_manager: Optional[TraceQueueManager],
workflow_execution: WorkflowExecution,
conversation_id: Optional[str],
) -> None:
"""Add trace task if trace manager is provided."""
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
def _fail_running_node_executions(
self,
workflow_execution_id: str,
error_message: str,
now: datetime,
) -> None:
"""Fail all running node executions for a workflow."""
running_node_executions = [
node_exec
for node_exec in self._node_execution_cache.values()
if node_exec.workflow_execution_id == workflow_execution_id
and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
]
for node_execution in running_node_executions:
if node_execution.node_execution_id:
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error_message
node_execution.finished_at = now
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
self._workflow_node_execution_repository.save(node_execution)
def _create_node_execution_from_event(
self,
*,
workflow_execution: WorkflowExecution,
event: Union[QueueNodeStartedEvent, QueueNodeRetryEvent],
status: WorkflowNodeExecutionStatus,
error: Optional[str] = None,
created_at: Optional[datetime] = None,
) -> WorkflowNodeExecution:
"""Create a node execution from an event."""
now = datetime.now(UTC).replace(tzinfo=None)
created_at = created_at or now
metadata = {
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
@ -232,152 +375,76 @@ class WorkflowCycleManager:
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
status=WorkflowNodeExecutionStatus.RUNNING,
status=status,
metadata=metadata,
created_at=created_at,
error=error,
)
# Use the instance repository to save the domain model
self._workflow_node_execution_repository.save(domain_execution)
if status == WorkflowNodeExecutionStatus.RETRY:
domain_execution.finished_at = now
domain_execution.elapsed_time = (now - created_at).total_seconds()
return domain_execution
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
# Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
# Process data
inputs = event.inputs
process_data = event.process_data
outputs = event.outputs
# Convert metadata keys to strings
execution_metadata_dict = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
# Update domain model
domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
domain_execution.update_from_mapping(
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
)
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = elapsed_time
# Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
return domain_execution
def handle_workflow_node_execution_failed(
def _update_node_execution_completion(
self,
domain_execution: WorkflowNodeExecution,
*,
event: QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
# Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
# Process data
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = event.outputs
# Convert metadata keys to strings
execution_metadata_dict = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
event: Union[
QueueNodeSucceededEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent,
],
status: WorkflowNodeExecutionStatus,
error: Optional[str] = None,
handle_special_values: bool = False,
) -> None:
"""Update node execution with completion data."""
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
# Process data
if handle_special_values:
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
else:
inputs = event.inputs
process_data = event.process_data
outputs = event.outputs
# Convert metadata
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
if event.execution_metadata:
execution_metadata_dict.update(event.execution_metadata)
# Update domain model
domain_execution.status = (
WorkflowNodeExecutionStatus.FAILED
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION
)
domain_execution.error = event.error
domain_execution.status = status
domain_execution.update_from_mapping(
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
inputs=inputs,
process_data=process_data,
outputs=outputs,
metadata=execution_metadata_dict,
)
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = elapsed_time
# Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
if error:
domain_execution.error = error
return domain_execution
def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = event.outputs
# Convert metadata keys to strings
def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]:
"""Merge event metadata with origin metadata."""
origin_metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
# Convert execution metadata keys to strings
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
execution_metadata_dict.update(event.execution_metadata)
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
# Create a domain model
domain_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=workflow_execution.workflow_id,
workflow_execution_id=workflow_execution.id_,
predecessor_node_id=event.predecessor_node_id,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
status=WorkflowNodeExecutionStatus.RETRY,
created_at=created_at,
finished_at=finished_at,
elapsed_time=elapsed_time,
error=event.error,
index=event.node_run_index,
)
# Update with mappings
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata)
# Use the instance repository to save the domain model
self._workflow_node_execution_repository.save(domain_execution)
return domain_execution
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
execution = self._workflow_execution_repository.get(id)
if not execution:
raise WorkflowRunNotFoundError(id)
return execution
return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata

View File

@ -5,7 +5,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.workflow.callbacks import WorkflowCallback
@ -146,7 +146,7 @@ class WorkflowEntry:
graph = Graph.init(graph_config=workflow.graph_dict)
# init workflow run state
node_instance = node_cls(
node = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
@ -190,17 +190,11 @@ class WorkflowEntry:
try:
# run node
generator = node_instance.run()
generator = node.run()
except Exception as e:
logger.exception(
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
workflow.id,
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
return node_instance, generator
logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}")
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
return node, generator
@classmethod
def run_free_node(
@ -262,7 +256,7 @@ class WorkflowEntry:
node_cls = cast(type[BaseNode], node_cls)
# init workflow run state
node_instance: BaseNode = node_cls(
node: BaseNode = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
@ -297,17 +291,12 @@ class WorkflowEntry:
)
# run node
generator = node_instance.run()
generator = node.run()
return node_instance, generator
return node, generator
except Exception as e:
logger.exception(
"error while running node_instance, node_id=%s, type=%s, version=%s",
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}")
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:

View File

@ -1,4 +1,5 @@
import hashlib
from typing import Union
from Crypto.Cipher import AES
from Crypto.PublicKey import RSA
@ -9,7 +10,7 @@ from extensions.ext_storage import storage
from libs import gmpy2_pkcs10aep_cipher
def generate_key_pair(tenant_id):
def generate_key_pair(tenant_id: str) -> str:
private_key = RSA.generate(2048)
public_key = private_key.publickey()
@ -26,7 +27,7 @@ def generate_key_pair(tenant_id):
prefix_hybrid = b"HYBRID:"
def encrypt(text, public_key):
def encrypt(text: str, public_key: Union[str, bytes]) -> bytes:
if isinstance(public_key, str):
public_key = public_key.encode()
@ -38,14 +39,14 @@ def encrypt(text, public_key):
rsa_key = RSA.import_key(public_key)
cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
enc_aes_key = cipher_rsa.encrypt(aes_key)
enc_aes_key: bytes = cipher_rsa.encrypt(aes_key)
encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext
return prefix_hybrid + encrypted_data
def get_decrypt_decoding(tenant_id):
def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]:
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
@ -64,7 +65,7 @@ def get_decrypt_decoding(tenant_id):
return rsa_key, cipher_rsa
def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa):
def decrypt_token_with_decoding(encrypted_text: bytes, rsa_key: RSA.RsaKey, cipher_rsa) -> str:
if encrypted_text.startswith(prefix_hybrid):
encrypted_text = encrypted_text[len(prefix_hybrid) :]
@ -83,10 +84,10 @@ def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa):
return decrypted_text.decode()
def decrypt(encrypted_text, tenant_id):
def decrypt(encrypted_text: bytes, tenant_id: str) -> str:
rsa_key, cipher_rsa = get_decrypt_decoding(tenant_id)
return decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa)
return decrypt_token_with_decoding(encrypted_text=encrypted_text, rsa_key=rsa_key, cipher_rsa=cipher_rsa)
class PrivkeyNotFoundError(Exception):

View File

@ -196,7 +196,7 @@ class Tenant(Base):
__tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name = db.Column(db.String(255), nullable=False)
encrypt_public_key = db.Column(db.Text)
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))

View File

@ -347,21 +347,33 @@ class ToolTransformService:
)
# get tool parameters
parameters = tool.entity.parameters or []
base_parameters = tool.entity.parameters or []
# get tool runtime parameters
runtime_parameters = tool.get_runtime_parameters()
# override parameters
current_parameters = parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
# merge parameters using a functional approach to avoid type issues
merged_parameters: list[ToolParameter] = []
# create a mapping of runtime parameters for quick lookup
runtime_param_map = {(rp.name, rp.form): rp for rp in runtime_parameters}
# process base parameters, replacing with runtime versions if they exist
for base_param in base_parameters:
key = (base_param.name, base_param.form)
if key in runtime_param_map:
merged_parameters.append(runtime_param_map[key])
else:
merged_parameters.append(base_param)
# add any runtime parameters that weren't in base parameters
for runtime_parameter in runtime_parameters:
if runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
# check if this parameter is already in merged_parameters
already_exists = any(
p.name == runtime_parameter.name and p.form == runtime_parameter.form for p in merged_parameters
)
if not already_exists:
merged_parameters.append(runtime_parameter)
return ToolApiEntity(
author=tool.entity.identity.author,
@ -369,10 +381,10 @@ class ToolTransformService:
label=tool.entity.identity.label,
description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
output_schema=tool.entity.output_schema,
parameters=current_parameters,
parameters=merged_parameters,
labels=labels or [],
)
if isinstance(tool, ApiToolBundle):
elif isinstance(tool, ApiToolBundle):
return ToolApiEntity(
author=tool.author,
name=tool.operation_id or "",
@ -381,6 +393,9 @@ class ToolTransformService:
parameters=tool.parameters,
labels=labels or [],
)
else:
# Handle WorkflowTool case
raise ValueError(f"Unsupported tool type: {type(tool)}")
@staticmethod
def convert_builtin_provider_to_credential_entity(

View File

@ -1,6 +1,7 @@
import datetime
import json
from typing import Any
from dataclasses import dataclass
from typing import Any, Optional
import requests
from flask_login import current_user
@ -13,241 +14,392 @@ from extensions.ext_storage import storage
from services.auth.api_key_auth_service import ApiKeyAuthService
class WebsiteService:
@classmethod
def document_create_args_validate(cls, args: dict):
if "url" not in args or not args["url"]:
raise ValueError("url is required")
if "options" not in args or not args["options"]:
raise ValueError("options is required")
if "limit" not in args["options"] or not args["options"]["limit"]:
raise ValueError("limit is required")
@dataclass
class CrawlOptions:
"""Options for crawling operations."""
limit: int = 1
crawl_sub_pages: bool = False
only_main_content: bool = False
includes: Optional[str] = None
excludes: Optional[str] = None
max_depth: Optional[int] = None
use_sitemap: bool = True
def get_include_paths(self) -> list[str]:
"""Get list of include paths from comma-separated string."""
return self.includes.split(",") if self.includes else []
def get_exclude_paths(self) -> list[str]:
"""Get list of exclude paths from comma-separated string."""
return self.excludes.split(",") if self.excludes else []
@dataclass
class CrawlRequest:
"""Request container for crawling operations."""
url: str
provider: str
options: CrawlOptions
@dataclass
class ScrapeRequest:
"""Request container for scraping operations."""
provider: str
url: str
tenant_id: str
only_main_content: bool
@dataclass
class WebsiteCrawlApiRequest:
"""Request container for website crawl API arguments."""
provider: str
url: str
options: dict[str, Any]
def to_crawl_request(self) -> CrawlRequest:
"""Convert API request to internal CrawlRequest."""
options = CrawlOptions(
limit=self.options.get("limit", 1),
crawl_sub_pages=self.options.get("crawl_sub_pages", False),
only_main_content=self.options.get("only_main_content", False),
includes=self.options.get("includes"),
excludes=self.options.get("excludes"),
max_depth=self.options.get("max_depth"),
use_sitemap=self.options.get("use_sitemap", True),
)
return CrawlRequest(url=self.url, provider=self.provider, options=options)
@classmethod
def crawl_url(cls, args: dict) -> dict:
provider = args.get("provider", "")
def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest":
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
url = args.get("url")
options = args.get("options", "")
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
if provider == "firecrawl":
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
crawl_sub_pages = options.get("crawl_sub_pages", False)
only_main_content = options.get("only_main_content", False)
if not crawl_sub_pages:
params = {
"includePaths": [],
"excludePaths": [],
"limit": 1,
"scrapeOptions": {"onlyMainContent": only_main_content},
}
else:
includes = options.get("includes").split(",") if options.get("includes") else []
excludes = options.get("excludes").split(",") if options.get("excludes") else []
params = {
"includePaths": includes,
"excludePaths": excludes,
"limit": options.get("limit", 1),
"scrapeOptions": {"onlyMainContent": only_main_content},
}
if options.get("max_depth"):
params["maxDepth"] = options.get("max_depth")
job_id = firecrawl_app.crawl_url(url, params)
website_crawl_time_cache_key = f"website_crawl_{job_id}"
time = str(datetime.datetime.now().timestamp())
redis_client.setex(website_crawl_time_cache_key, 3600, time)
return {"status": "active", "job_id": job_id}
elif provider == "watercrawl":
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).crawl_url(url, options)
options = args.get("options", {})
elif provider == "jinareader":
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
crawl_sub_pages = options.get("crawl_sub_pages", False)
if not crawl_sub_pages:
response = requests.get(
f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
return {"status": "active", "data": response.json().get("data")}
else:
response = requests.post(
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
json={
"url": url,
"maxPages": options.get("limit", 1),
"useSitemap": options.get("use_sitemap", True),
},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
if not provider:
raise ValueError("Provider is required")
if not url:
raise ValueError("URL is required")
if not options:
raise ValueError("Options are required")
return cls(provider=provider, url=url, options=options)
@dataclass
class WebsiteCrawlStatusApiRequest:
"""Request container for website crawl status API arguments."""
provider: str
job_id: str
@classmethod
def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest":
"""Create from Flask-RESTful parsed arguments."""
provider = args.get("provider")
if not provider:
raise ValueError("Provider is required")
if not job_id:
raise ValueError("Job ID is required")
return cls(provider=provider, job_id=job_id)
class WebsiteService:
"""Service class for website crawling operations using different providers."""
@classmethod
def _get_credentials_and_config(cls, tenant_id: str, provider: str) -> tuple[dict, dict]:
"""Get and validate credentials for a provider."""
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
if not credentials or "config" not in credentials:
raise ValueError("No valid credentials found for the provider")
return credentials, credentials["config"]
@classmethod
def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str:
"""Decrypt and return the API key from config."""
api_key = config.get("api_key")
if not api_key:
raise ValueError("API key not found in configuration")
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
@classmethod
def document_create_args_validate(cls, args: dict) -> None:
"""Validate arguments for document creation."""
try:
WebsiteCrawlApiRequest.from_args(args)
except ValueError as e:
raise ValueError(f"Invalid arguments: {e}")
@classmethod
def crawl_url(cls, api_request: WebsiteCrawlApiRequest) -> dict[str, Any]:
"""Crawl a URL using the specified provider with typed request."""
request = api_request.to_crawl_request()
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider)
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
if request.provider == "firecrawl":
return cls._crawl_with_firecrawl(request=request, api_key=api_key, config=config)
elif request.provider == "watercrawl":
return cls._crawl_with_watercrawl(request=request, api_key=api_key, config=config)
elif request.provider == "jinareader":
return cls._crawl_with_jinareader(request=request, api_key=api_key)
else:
raise ValueError("Invalid provider")
@classmethod
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
if provider == "firecrawl":
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
)
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
result = firecrawl_app.check_crawl_status(job_id)
crawl_status_data = {
"status": result.get("status", "active"),
"job_id": job_id,
"total": result.get("total", 0),
"current": result.get("current", 0),
"data": result.get("data", []),
def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
if not request.options.crawl_sub_pages:
params = {
"includePaths": [],
"excludePaths": [],
"limit": 1,
"scrapeOptions": {"onlyMainContent": request.options.only_main_content},
}
if crawl_status_data["status"] == "completed":
website_crawl_time_cache_key = f"website_crawl_{job_id}"
start_time = redis_client.get(website_crawl_time_cache_key)
if start_time:
end_time = datetime.datetime.now().timestamp()
time_consuming = abs(end_time - float(start_time))
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
redis_client.delete(website_crawl_time_cache_key)
elif provider == "watercrawl":
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
else:
params = {
"includePaths": request.options.get_include_paths(),
"excludePaths": request.options.get_exclude_paths(),
"limit": request.options.limit,
"scrapeOptions": {"onlyMainContent": request.options.only_main_content},
}
if request.options.max_depth:
params["maxDepth"] = request.options.max_depth
job_id = firecrawl_app.crawl_url(request.url, params)
website_crawl_time_cache_key = f"website_crawl_{job_id}"
time = str(datetime.datetime.now().timestamp())
redis_client.setex(website_crawl_time_cache_key, 3600, time)
return {"status": "active", "job_id": job_id}
@classmethod
def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
# Convert CrawlOptions back to dict format for WaterCrawlProvider
options = {
"limit": request.options.limit,
"crawl_sub_pages": request.options.crawl_sub_pages,
"only_main_content": request.options.only_main_content,
"includes": request.options.includes,
"excludes": request.options.excludes,
"max_depth": request.options.max_depth,
"use_sitemap": request.options.use_sitemap,
}
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
url=request.url, options=options
)
@classmethod
def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
if not request.options.crawl_sub_pages:
response = requests.get(
f"https://r.jina.ai/{request.url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
crawl_status_data = WaterCrawlProvider(
api_key, credentials.get("config").get("base_url", None)
).get_crawl_status(job_id)
elif provider == "jinareader":
api_key = encrypter.decrypt_token(
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
return {"status": "active", "data": response.json().get("data")}
else:
response = requests.post(
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
json={
"url": request.url,
"maxPages": request.options.limit,
"useSitemap": request.options.use_sitemap,
},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
@classmethod
def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]:
"""Get crawl status using string parameters."""
api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id)
return cls.get_crawl_status_typed(api_request)
@classmethod
def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]:
"""Get crawl status using typed request."""
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider)
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
if api_request.provider == "firecrawl":
return cls._get_firecrawl_status(api_request.job_id, api_key, config)
elif api_request.provider == "watercrawl":
return cls._get_watercrawl_status(api_request.job_id, api_key, config)
elif api_request.provider == "jinareader":
return cls._get_jinareader_status(api_request.job_id, api_key)
else:
raise ValueError("Invalid provider")
@classmethod
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
result = firecrawl_app.check_crawl_status(job_id)
crawl_status_data = {
"status": result.get("status", "active"),
"job_id": job_id,
"total": result.get("total", 0),
"current": result.get("current", 0),
"data": result.get("data", []),
}
if crawl_status_data["status"] == "completed":
website_crawl_time_cache_key = f"website_crawl_{job_id}"
start_time = redis_client.get(website_crawl_time_cache_key)
if start_time:
end_time = datetime.datetime.now().timestamp()
time_consuming = abs(end_time - float(start_time))
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
redis_client.delete(website_crawl_time_cache_key)
return crawl_status_data
@classmethod
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)
@classmethod
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
)
data = response.json().get("data", {})
crawl_status_data = {
"status": data.get("status", "active"),
"job_id": job_id,
"total": len(data.get("urls", [])),
"current": len(data.get("processed", [])) + len(data.get("failed", [])),
"data": [],
"time_consuming": data.get("duration", 0) / 1000,
}
if crawl_status_data["status"] == "completed":
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
)
data = response.json().get("data", {})
crawl_status_data = {
"status": data.get("status", "active"),
"job_id": job_id,
"total": len(data.get("urls", [])),
"current": len(data.get("processed", [])) + len(data.get("failed", [])),
"data": [],
"time_consuming": data.get("duration", 0) / 1000,
}
if crawl_status_data["status"] == "completed":
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
)
data = response.json().get("data", {})
formatted_data = [
{
"title": item.get("data", {}).get("title"),
"source_url": item.get("data", {}).get("url"),
"description": item.get("data", {}).get("description"),
"markdown": item.get("data", {}).get("content"),
}
for item in data.get("processed", {}).values()
]
crawl_status_data["data"] = formatted_data
else:
raise ValueError("Invalid provider")
formatted_data = [
{
"title": item.get("data", {}).get("title"),
"source_url": item.get("data", {}).get("url"),
"description": item.get("data", {}).get("description"),
"markdown": item.get("data", {}).get("content"),
}
for item in data.get("processed", {}).values()
]
crawl_status_data["data"] = formatted_data
return crawl_status_data
@classmethod
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
# decrypt api_key
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
_, config = cls._get_credentials_and_config(tenant_id, provider)
api_key = cls._get_decrypted_api_key(tenant_id, config)
if provider == "firecrawl":
crawl_data: list[dict[str, Any]] | None = None
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
stored_data = storage.load_once(file_key)
if stored_data:
crawl_data = json.loads(stored_data.decode("utf-8"))
else:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
result = firecrawl_app.check_crawl_status(job_id)
if result.get("status") != "completed":
raise ValueError("Crawl job is not completed")
crawl_data = result.get("data")
if crawl_data:
for item in crawl_data:
if item.get("source_url") == url:
return dict(item)
return None
return cls._get_firecrawl_url_data(job_id, url, api_key, config)
elif provider == "watercrawl":
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).get_crawl_url_data(
job_id, url
)
return cls._get_watercrawl_url_data(job_id, url, api_key, config)
elif provider == "jinareader":
if not job_id:
response = requests.get(
f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
return dict(response.json().get("data", {}))
else:
# Get crawl status first
status_response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
)
status_data = status_response.json().get("data", {})
if status_data.get("status") != "completed":
raise ValueError("Crawl job is not completed")
# Get processed data
data_response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
)
processed_data = data_response.json().get("data", {})
for item in processed_data.get("processed", {}).values():
if item.get("data", {}).get("url") == url:
return dict(item.get("data", {}))
return None
return cls._get_jinareader_url_data(job_id, url, api_key)
else:
raise ValueError("Invalid provider")
@classmethod
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
if provider == "firecrawl":
# decrypt api_key
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
params = {"onlyMainContent": only_main_content}
result = firecrawl_app.scrape_url(url, params)
return result
elif provider == "watercrawl":
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).scrape_url(url)
def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
crawl_data: list[dict[str, Any]] | None = None
file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key):
stored_data = storage.load_once(file_key)
if stored_data:
crawl_data = json.loads(stored_data.decode("utf-8"))
else:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
result = firecrawl_app.check_crawl_status(job_id)
if result.get("status") != "completed":
raise ValueError("Crawl job is not completed")
crawl_data = result.get("data")
if crawl_data:
for item in crawl_data:
if item.get("source_url") == url:
return dict(item)
return None
@classmethod
def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
@classmethod
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
if not job_id:
response = requests.get(
f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
if response.json().get("code") != 200:
raise ValueError("Failed to crawl")
return dict(response.json().get("data", {}))
else:
# Get crawl status first
status_response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id},
)
status_data = status_response.json().get("data", {})
if status_data.get("status") != "completed":
raise ValueError("Crawl job is not completed")
# Get processed data
data_response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
)
processed_data = data_response.json().get("data", {})
for item in processed_data.get("processed", {}).values():
if item.get("data", {}).get("url") == url:
return dict(item.get("data", {}))
return None
@classmethod
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict[str, Any]:
request = ScrapeRequest(provider=provider, url=url, tenant_id=tenant_id, only_main_content=only_main_content)
_, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, provider=request.provider)
api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config)
if request.provider == "firecrawl":
return cls._scrape_with_firecrawl(request=request, api_key=api_key, config=config)
elif request.provider == "watercrawl":
return cls._scrape_with_watercrawl(request=request, api_key=api_key, config=config)
else:
raise ValueError("Invalid provider")
@classmethod
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
params = {"onlyMainContent": request.only_main_content}
return firecrawl_app.scrape_url(url=request.url, params=params)
@classmethod
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url)

View File

@ -466,10 +466,10 @@ class WorkflowService:
node_id: str,
) -> WorkflowNodeExecution:
try:
node_instance, generator = invoke_node_fn()
node, node_events = invoke_node_fn()
node_run_result: NodeRunResult | None = None
for event in generator:
for event in node_events:
if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result
@ -480,18 +480,18 @@ class WorkflowService:
if not node_run_result:
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.continue_on_error:
node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
"metadata": {"error_strategy": node.error_strategy},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
**node.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
@ -510,10 +510,10 @@ class WorkflowService:
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
node_instance = e.node_instance
node = e._node
run_succeeded = False
node_run_result = None
error = e.error
error = e._error
# Create a NodeExecution domain model
node_execution = WorkflowNodeExecution(
@ -521,8 +521,8 @@ class WorkflowService:
workflow_id="", # This is a single-step execution, so no workflow ID
index=1,
node_id=node_id,
node_type=node_instance.node_type,
title=node_instance.node_data.title,
node_type=node.type_,
title=node.title,
elapsed_time=time.perf_counter() - start_at,
created_at=datetime.now(UTC).replace(tzinfo=None),
finished_at=datetime.now(UTC).replace(tzinfo=None),

View File

@ -31,7 +31,7 @@ class WorkspaceService:
assert tenant_account_join is not None, "TenantAccountJoin not found"
tenant_info["role"] = tenant_account_join.role
can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo
can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
base_url = dify_config.FILES_URL

View File

@ -15,7 +15,7 @@ def get_mocked_fetch_model_config(
mode: str,
credentials: dict,
):
model_provider_factory = ModelProviderFactory(tenant_id="test_tenant")
model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(

View File

@ -66,6 +66,10 @@ def init_code_node(code_config: dict):
config=code_config,
)
# Initialize node data
if "data" in code_config:
node.init_node_data(code_config["data"])
return node
@ -234,10 +238,10 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}
node.node_data = cast(CodeNodeData, node.node_data)
node._node_data = cast(CodeNodeData, node._node_data)
# validate
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@ -250,7 +254,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@ -263,7 +267,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@ -276,7 +280,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
def test_execute_code_output_object_list():
@ -330,10 +334,10 @@ def test_execute_code_output_object_list():
]
}
node.node_data = cast(CodeNodeData, node.node_data)
node._node_data = cast(CodeNodeData, node._node_data)
# validate
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@ -353,7 +357,7 @@ def test_execute_code_output_object_list():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
def test_execute_code_scientific_notation():

View File

@ -52,7 +52,7 @@ def init_http_node(config: dict):
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
return HttpRequestNode(
node = HttpRequestNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
@ -60,6 +60,12 @@ def init_http_node(config: dict):
config=config,
)
# Initialize node data
if "data" in config:
node.init_node_data(config["data"])
return node
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_get(setup_http_mock):

View File

@ -2,15 +2,10 @@ import json
import time
import uuid
from collections.abc import Generator
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.output_parser.structured_output import _parse_structured_output
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.graph import Graph
@ -24,8 +19,6 @@ from models.enums import UserFrom
from models.workflow import WorkflowType
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
def init_llm_node(config: dict) -> LLMNode:
@ -84,10 +77,14 @@ def init_llm_node(config: dict) -> LLMNode:
config=config,
)
# Initialize node data
if "data" in config:
node.init_node_data(config["data"])
return node
def test_execute_llm(flask_req_ctx):
def test_execute_llm():
node = init_llm_node(
config={
"id": "llm",
@ -95,7 +92,7 @@ def test_execute_llm(flask_req_ctx):
"title": "123",
"type": "llm",
"model": {
"provider": "langgenius/openai/openai",
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
@ -114,53 +111,62 @@ def test_execute_llm(flask_req_ctx):
},
)
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
db.session.close = MagicMock()
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
# Mock the _fetch_model_config to avoid database calls
def mock_fetch_model_config(**_kwargs):
from decimal import Decimal
from unittest.mock import MagicMock
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model instance
mock_model_instance = MagicMock()
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response from mock")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "langgenius/openai/openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_1(**_kwargs):
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
UserPromptMessage(content="what's the weather today?"),
], []
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
):
# execute node
result = node._run()
@ -168,6 +174,9 @@ def test_execute_llm(flask_req_ctx):
for item in result:
if isinstance(item, RunCompletedEvent):
if item.run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
print(f"Error: {item.run_result.error}")
print(f"Error type: {item.run_result.error_type}")
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
@ -175,8 +184,7 @@ def test_execute_llm(flask_req_ctx):
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
def test_execute_llm_with_jinja2():
"""
Test execute LLM node with jinja2
"""
@ -217,53 +225,60 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
# Mock db.session.close()
db.session.close = MagicMock()
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
def mock_fetch_model_config(**_kwargs):
from decimal import Decimal
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock()
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_2(**_kwargs):
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
UserPromptMessage(content="what's the weather today?"),
], []
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
):
# execute node
result = node._run()

View File

@ -74,13 +74,15 @@ def init_parameter_extractor_node(config: dict):
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
return ParameterExtractorNode(
node = ParameterExtractorNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
node.init_node_data(config.get("data", {}))
return node
def test_function_calling_parameter_extractor(setup_model_mock):

View File

@ -76,6 +76,7 @@ def test_execute_code(setup_code_executor_mock):
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
node.init_node_data(config.get("data", {}))
# execute node
result = node._run()

View File

@ -50,13 +50,15 @@ def init_tool_node(config: dict):
conversation_variables=[],
)
return ToolNode(
node = ToolNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
node.init_node_data(config.get("data", {}))
return node
def test_tool_variable_invoke():

View File

@ -58,21 +58,26 @@ def test_execute_answer():
pool.add(["start", "weather"], "sunny")
pool.add(["llm", "text"], "You are a helpful AI.")
node_config = {
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
}
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()

View File

@ -57,12 +57,15 @@ def test_http_request_node_binary_file(monkeypatch):
),
),
)
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@ -90,6 +93,9 @@ def test_http_request_node_binary_file(monkeypatch):
start_at=0,
),
)
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
@ -145,12 +151,15 @@ def test_http_request_node_form_with_file(monkeypatch):
),
),
)
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@ -178,6 +187,10 @@ def test_http_request_node_form_with_file(monkeypatch):
start_at=0,
),
)
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
@ -257,12 +270,14 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
),
)
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@ -291,6 +306,9 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
),
)
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data",

View File

@ -162,25 +162,30 @@ def test_run():
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
}
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
},
config=node_config,
)
# Initialize node data
iteration_node.init_node_data(node_config["data"])
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -379,25 +384,30 @@ def test_run_parallel():
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
}
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
},
config=node_config,
)
# Initialize node data
iteration_node.init_node_data(node_config["data"])
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -595,45 +605,55 @@ def test_iteration_run_in_parallel_mode():
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
parallel_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
}
parallel_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
},
config=parallel_node_config,
)
# Initialize node data
parallel_iteration_node.init_node_data(parallel_node_config["data"])
sequential_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
}
sequential_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
},
config=sequential_node_config,
)
# Initialize node data
sequential_iteration_node.init_node_data(sequential_node_config["data"])
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -645,8 +665,8 @@ def test_iteration_run_in_parallel_mode():
# execute node
parallel_result = parallel_iteration_node._run()
sequential_result = sequential_iteration_node._run()
assert parallel_iteration_node.node_data.parallel_nums == 10
assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED
assert parallel_iteration_node._node_data.parallel_nums == 10
assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED
count = 0
parallel_arr = []
sequential_arr = []
@ -818,26 +838,31 @@ def test_iteration_run_error_handle():
environment_variables=[],
)
pool.add(["pe", "list_output"], ["1", "1"])
error_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
"is_parallel": True,
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
},
"id": "iteration-1",
}
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
"is_parallel": True,
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
},
"id": "iteration-1",
},
config=error_node_config,
)
# Initialize node data
iteration_node.init_node_data(error_node_config["data"])
# execute continue on error node
result = iteration_node._run()
result_arr = []
@ -851,7 +876,7 @@ def test_iteration_run_error_handle():
assert count == 14
# execute remove abnormal output
iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
result = iteration_node._run()
count = 0
for item in result:

View File

@ -119,17 +119,20 @@ def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
}
node = LLMNode(
id="1",
config={
"id": "1",
"data": llm_node_data.model_dump(),
},
config=node_config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
# Initialize node data
node.init_node_data(node_config["data"])
return node
@ -488,7 +491,7 @@ def test_handle_list_messages_basic(llm_node):
variable_pool = llm_node.graph_runtime_state.variable_pool
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
result = llm_node._handle_list_messages(
result = llm_node.handle_list_messages(
messages=messages,
context=context,
jinja2_variables=jinja2_variables,
@ -506,17 +509,20 @@ def llm_node_for_multimodal(
llm_node_data, graph_init_params, graph, graph_runtime_state
) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
}
node = LLMNode(
id="1",
config={
"id": "1",
"data": llm_node_data.model_dump(),
},
config=node_config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
# Initialize node data
node.init_node_data(node_config["data"])
return node, mock_file_saver
@ -540,7 +546,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9,
)
mock_file_saver.save_binary_string.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
file = llm_node.save_multimodal_image_output(
content=content,
file_saver=mock_file_saver,
)
# Manually append to _file_outputs since the static method doesn't do it
llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_binary_string.assert_called_once_with(
@ -566,7 +577,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9,
)
mock_file_saver.save_remote_url.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
file = llm_node.save_multimodal_image_output(
content=content,
file_saver=mock_file_saver,
)
# Manually append to _file_outputs since the static method doesn't do it
llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
@ -582,7 +598,9 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_str_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents="hello world", file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
@ -590,7 +608,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_text_prompt_message_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[TextPromptMessageContent(data="hello world")]
contents=[TextPromptMessageContent(data="hello world")], file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
@ -616,13 +634,15 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
)
mock_file_saver.save_binary_string.return_value = mock_saved_file
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[
contents=[
ImagePromptMessageContent(
format="png",
base64_data=image_b64_data,
mime_type="image/png",
)
]
],
file_saver=mock_file_saver,
file_outputs=llm_node._file_outputs,
)
yielded_strs = list(gen)
assert len(yielded_strs) == 1
@ -645,21 +665,27 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_unknown_content_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_unknown_item_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=[frozenset(["hello world"])], file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_none_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=None, file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == []
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()

View File

@ -61,21 +61,26 @@ def test_execute_answer():
variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.")
node_config = {
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
}
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()

View File

@ -27,13 +27,17 @@ def document_extractor_node():
title="Test Document Extractor",
variable_selector=["node_id", "variable_name"],
)
return DocumentExtractorNode(
node_config = {"id": "test_node_id", "data": node_data.model_dump()}
node = DocumentExtractorNode(
id="test_node_id",
config={"id": "test_node_id", "data": node_data.model_dump()},
config=node_config,
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
)
# Initialize node data
node.init_node_data(node_config["data"])
return node
@pytest.fixture

View File

@ -57,57 +57,62 @@ def test_execute_if_else_result_true():
pool.add(["start", "null"], None)
pool.add(["start", "not_null"], "1212")
node_config = {
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "and",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "not_contains"],
"value": "ab",
},
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
{"comparison_operator": "", "variable_selector": ["start", "not_equals"], "value": "22"},
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
{
"comparison_operator": "",
"variable_selector": ["start", "greater_than_or_equal"],
"value": "22",
},
{"comparison_operator": "", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "and",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "not_contains"],
"value": "ab",
},
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
{"comparison_operator": "", "variable_selector": ["start", "not_equals"], "value": "22"},
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
{
"comparison_operator": "",
"variable_selector": ["start", "greater_than_or_equal"],
"value": "22",
},
{"comparison_operator": "", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()
@ -162,33 +167,38 @@ def test_execute_if_else_result_false():
pool.add(["start", "array_contains"], ["1ab", "def"])
pool.add(["start", "array_not_contains"], ["ab", "def"])
node_config = {
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "or",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
],
},
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "or",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()
@ -228,17 +238,22 @@ def test_array_file_contains_file_name():
],
)
node_config = {
"id": "if-else",
"data": node_data.model_dump(),
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
config={
"id": "if-else",
"data": node_data.model_dump(),
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[
File(

View File

@ -33,16 +33,19 @@ def list_operator_node():
"title": "Test Title",
}
node_data = ListOperatorNodeData(**config)
node_config = {
"id": "test_node_id",
"data": node_data.model_dump(),
}
node = ListOperatorNode(
id="test_node_id",
config={
"id": "test_node_id",
"data": node_data.model_dump(),
},
config=node_config,
graph_init_params=MagicMock(),
graph=MagicMock(),
graph_runtime_state=MagicMock(),
)
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock()
return node

View File

@ -38,12 +38,13 @@ def _create_tool_node():
system_variables=SystemVariable.empty(),
user_inputs={},
)
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = ToolNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@ -71,6 +72,8 @@ def _create_tool_node():
start_at=0,
),
)
# Initialize node data
node.init_node_data(node_config["data"])
return node

View File

@ -82,23 +82,28 @@ def test_overwrite_string_variable():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_var = StringVariable(
id=conversation_variable.id,
@ -178,23 +183,28 @@ def test_append_variable_to_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value)
@ -265,23 +275,28 @@ def test_clear_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"input_variable_selector": [],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"input_variable_selector": [],
},
},
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_var = ArrayStringVariable(
id=conversation_variable.id,

View File

@ -115,28 +115,33 @@ def test_remove_first_from_array():
conversation_variables=[conversation_variable],
)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
# Print the variable before running
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
@ -202,28 +207,33 @@ def test_remove_last_from_array():
conversation_variables=[conversation_variable],
)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())
@ -281,28 +291,33 @@ def test_remove_first_from_empty_array():
conversation_variables=[conversation_variable],
)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())
@ -360,28 +375,33 @@ def test_remove_last_from_empty_array():
conversation_variables=[conversation_variable],
)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())

View File

@ -80,15 +80,12 @@ def real_workflow_system_variables():
@pytest.fixture
def mock_node_execution_repository():
repo = MagicMock(spec=WorkflowNodeExecutionRepository)
repo.get_by_node_execution_id.return_value = None
repo.get_running_executions.return_value = []
return repo
@pytest.fixture
def mock_workflow_execution_repository():
repo = MagicMock(spec=WorkflowExecutionRepository)
repo.get.return_value = None
return repo
@ -217,8 +214,8 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_success(
@ -251,11 +248,10 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Mock get_running_executions to return an empty list
workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = []
# No running node executions in cache (empty cache)
# Call the method
result = workflow_cycle_manager.handle_workflow_run_failed(
@ -289,8 +285,8 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Create a mock event
event = MagicMock(spec=QueueNodeStartedEvent)
@ -342,8 +338,8 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock the repository get method to return the real execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache["test-workflow-run-id"] = workflow_execution
# Call the method
result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id")
@ -351,11 +347,13 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
# Verify the result
assert result == workflow_execution
# Test error case
workflow_cycle_manager._workflow_execution_repository.get.return_value = None
# Test error case - clear cache
workflow_cycle_manager._workflow_execution_cache.clear()
# Expect an error when execution is not found
with pytest.raises(ValueError):
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
with pytest.raises(WorkflowRunNotFoundError):
workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id")
@ -384,8 +382,8 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock the repository to return the node execution
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
# Pre-populate the cache with the node execution
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_node_execution_success(
@ -414,8 +412,8 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_partial_success(
@ -462,8 +460,8 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock the repository to return the node execution
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
# Pre-populate the cache with the node execution
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_node_execution_failed(

View File

@ -137,37 +137,6 @@ def test_save_with_existing_tenant_id(repository, session):
session_obj.merge.assert_called_once_with(modified_execution)
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
"""Test get_by_node_execution_id method."""
session_obj, _ = session
# Set up mock
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
# Create a properly configured mock execution
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
configure_mock_execution(mock_execution)
session_obj.scalar.return_value = mock_execution
# Create a mock domain model to be returned by _to_domain_model
mock_domain_model = mocker.MagicMock()
# Mock the _to_domain_model method to return our mock domain model
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
# Call method
result = repository.get_by_node_execution_id("test-node-execution-id")
# Assert select was called with correct parameters
mock_select.assert_called_once()
session_obj.scalar.assert_called_once_with(mock_stmt)
# Assert _to_domain_model was called with the mock execution
repository._to_domain_model.assert_called_once_with(mock_execution)
# Assert the result is our mock domain model
assert result is mock_domain_model
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
"""Test get_by_workflow_run method."""
session_obj, _ = session
@ -202,88 +171,6 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
assert result[0] is mock_domain_model
def test_get_running_executions(repository, session, mocker: MockerFixture):
"""Test get_running_executions method."""
session_obj, _ = session
# Set up mock
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
# Create a properly configured mock execution
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
configure_mock_execution(mock_execution)
session_obj.scalars.return_value.all.return_value = [mock_execution]
# Create a mock domain model to be returned by _to_domain_model
mock_domain_model = mocker.MagicMock()
# Mock the _to_domain_model method to return our mock domain model
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
# Call method
result = repository.get_running_executions("test-workflow-run-id")
# Assert select was called with correct parameters
mock_select.assert_called_once()
session_obj.scalars.assert_called_once_with(mock_stmt)
# Assert _to_domain_model was called with the mock execution
repository._to_domain_model.assert_called_once_with(mock_execution)
# Assert the result contains our mock domain model
assert len(result) == 1
assert result[0] is mock_domain_model
def test_update_via_save(repository, session):
"""Test updating an existing record via save method."""
session_obj, _ = session
# Create a mock execution
execution = MagicMock(spec=WorkflowNodeExecutionModel)
execution.tenant_id = None
execution.app_id = None
execution.inputs = None
execution.process_data = None
execution.outputs = None
execution.metadata = None
# Mock the to_db_model method to return the execution itself
# This simulates the behavior of setting tenant_id and app_id
repository.to_db_model = MagicMock(return_value=execution)
# Call save method to update an existing record
repository.save(execution)
# Assert to_db_model was called with the execution
repository.to_db_model.assert_called_once_with(execution)
# Assert session.merge was called (for updates)
session_obj.merge.assert_called_once_with(execution)
def test_clear(repository, session, mocker: MockerFixture):
"""Test clear method."""
session_obj, _ = session
# Set up mock
mock_delete = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.delete")
mock_stmt = mocker.MagicMock()
mock_delete.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
# Mock the execute result with rowcount
mock_result = mocker.MagicMock()
mock_result.rowcount = 5 # Simulate 5 records deleted
session_obj.execute.return_value = mock_result
# Call method
repository.clear()
# Assert delete was called with correct parameters
mock_delete.assert_called_once_with(WorkflowNodeExecutionModel)
mock_stmt.where.assert_called()
session_obj.execute.assert_called_once_with(mock_stmt)
session_obj.commit.assert_called_once()
def test_to_db_model(repository):
"""Test to_db_model method."""
# Create a domain model

View File

@ -0,0 +1,301 @@
from unittest.mock import Mock
from core.tools.__base.tool import Tool
from core.tools.entities.api_entities import ToolApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolParameter
from services.tools.tools_transform_service import ToolTransformService
class TestToolTransformService:
"""Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
def test_convert_tool_with_parameter_override(self):
"""Test that runtime parameters correctly override base parameters"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
base_param2 = Mock(spec=ToolParameter)
base_param2.name = "param2"
base_param2.form = ToolParameter.ToolParameterForm.FORM
base_param2.type = "string"
base_param2.label = "Base Param 2"
# Create mock runtime parameters that override base parameters
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1" # Different label to verify override
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1, base_param2]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.author == "test_author"
assert result.name == "test_tool"
assert result.parameters is not None
assert len(result.parameters) == 2
# Find the overridden parameter
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
assert overridden_param is not None
assert overridden_param.label == "Runtime Param 1" # Should be runtime version
# Find the non-overridden parameter
original_param = next((p for p in result.parameters if p.name == "param2"), None)
assert original_param is not None
assert original_param.label == "Base Param 2" # Should be base version
def test_convert_tool_with_additional_runtime_parameters(self):
"""Test that additional runtime parameters are added to the final list"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
# Create mock runtime parameters - one that overrides and one that's new
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1"
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "runtime_only"
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
runtime_param2.type = "string"
runtime_param2.label = "Runtime Only Param"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 2
# Check that both parameters are present
param_names = [p.name for p in result.parameters]
assert "param1" in param_names
assert "runtime_only" in param_names
# Verify the overridden parameter has runtime version
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
assert overridden_param is not None
assert overridden_param.label == "Runtime Param 1"
# Verify the new runtime parameter is included
new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
assert new_param is not None
assert new_param.label == "Runtime Only Param"
def test_convert_tool_with_non_form_runtime_parameters(self):
"""Test that non-FORM runtime parameters are not added as new parameters"""
# Create mock base parameters
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
# Create mock runtime parameters with different forms
runtime_param1 = Mock(spec=ToolParameter)
runtime_param1.name = "param1"
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
runtime_param1.type = "string"
runtime_param1.label = "Runtime Param 1"
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "llm_param"
runtime_param2.form = ToolParameter.ToolParameterForm.LLM
runtime_param2.type = "string"
runtime_param2.label = "LLM Param"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 1 # Only the FORM parameter should be present
# Check that only the FORM parameter is present
param_names = [p.name for p in result.parameters]
assert "param1" in param_names
assert "llm_param" not in param_names
def test_convert_tool_with_empty_parameters(self):
"""Test conversion with empty base and runtime parameters"""
# Create mock tool with no parameters
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = []
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = []
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 0
def test_convert_tool_with_none_parameters(self):
"""Test conversion when base parameters is None"""
# Create mock tool with None parameters
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = None
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = []
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 0
def test_convert_tool_parameter_order_preserved(self):
"""Test that parameter order is preserved correctly"""
# Create mock base parameters in specific order
base_param1 = Mock(spec=ToolParameter)
base_param1.name = "param1"
base_param1.form = ToolParameter.ToolParameterForm.FORM
base_param1.type = "string"
base_param1.label = "Base Param 1"
base_param2 = Mock(spec=ToolParameter)
base_param2.name = "param2"
base_param2.form = ToolParameter.ToolParameterForm.FORM
base_param2.type = "string"
base_param2.label = "Base Param 2"
base_param3 = Mock(spec=ToolParameter)
base_param3.name = "param3"
base_param3.form = ToolParameter.ToolParameterForm.FORM
base_param3.type = "string"
base_param3.label = "Base Param 3"
# Create runtime parameter that overrides middle parameter
runtime_param2 = Mock(spec=ToolParameter)
runtime_param2.name = "param2"
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
runtime_param2.type = "string"
runtime_param2.label = "Runtime Param 2"
# Create new runtime parameter
runtime_param4 = Mock(spec=ToolParameter)
runtime_param4.name = "param4"
runtime_param4.form = ToolParameter.ToolParameterForm.FORM
runtime_param4.type = "string"
runtime_param4.label = "Runtime Param 4"
# Create mock tool
mock_tool = Mock(spec=Tool)
mock_tool.entity = Mock()
mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
mock_tool.entity.identity = Mock()
mock_tool.entity.identity.author = "test_author"
mock_tool.entity.identity.name = "test_tool"
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
mock_tool.entity.description = Mock()
mock_tool.entity.description.human = I18nObject(en_US="Test description")
mock_tool.entity.output_schema = {}
mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
# Mock fork_tool_runtime to return the same tool
mock_tool.fork_tool_runtime.return_value = mock_tool
# Call the method
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
# Verify the result
assert isinstance(result, ToolApiEntity)
assert result.parameters is not None
assert len(result.parameters) == 4
# Check that order is maintained: base parameters first, then new runtime parameters
param_names = [p.name for p in result.parameters]
assert param_names == ["param1", "param2", "param3", "param4"]
# Verify that param2 was overridden with runtime version
param2 = result.parameters[1]
assert param2.name == "param2"
assert param2.label == "Runtime Param 2"

View File

@ -289,6 +289,7 @@ REDIS_CLUSTERS_PASSWORD=
# If use Redis Sentinel, format as follows: `sentinel://<sentinel_username>:<sentinel_password>@<sentinel_host>:<sentinel_port>/<redis_database>`
# Example: sentinel://localhost:26379/1;sentinel://localhost:26380/1;sentinel://localhost:26381/1
CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1
CELERY_BACKEND=redis
BROKER_USE_SSL=false
# If you are using Redis Sentinel for high availability, configure the following settings.

View File

@ -79,6 +79,7 @@ x-shared-env: &shared-api-worker-env
REDIS_CLUSTERS: ${REDIS_CLUSTERS:-}
REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-}
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
CELERY_BACKEND: ${CELERY_BACKEND:-redis}
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}
CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false}
CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-}

View File

@ -1,5 +1,3 @@
'use client'
import WorkflowApp from '@/app/components/workflow-app'
const Page = () => {

View File

@ -1,3 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8 4V8M8 8V12M8 8H12M8 8H4" stroke="#6B7280" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 206 B

View File

@ -1,4 +0,0 @@
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M0.631586 8.25C0.631586 6.46656 2.04586 5 3.8158 5C5.58573 5 7.00001 6.46656 7.00001 8.25C7.00001 10.0334 5.58573 11.5 3.8158 11.5C3.45197 11.5 3.10149 11.4375 2.77474 11.3222C2.72073 11.3031 2.68723 11.2913 2.66266 11.2832C2.65821 11.2817 2.65456 11.2806 2.65164 11.2796L2.64892 11.2799C2.63177 11.2818 2.60839 11.285 2.56507 11.2909L1.06766 11.4954C0.905637 11.5175 0.743029 11.459 0.632239 11.3387C0.521449 11.2185 0.476481 11.0516 0.511825 10.8919L0.817497 9.51109C0.828118 9.46311 0.833802 9.43722 0.837453 9.41817C0.83766 9.4171 0.838022 9.41517 0.838022 9.41517C0.837114 9.412 0.835963 9.40808 0.834525 9.40332C0.826292 9.37605 0.814183 9.33888 0.794499 9.27863C0.688657 8.95463 0.631586 8.60857 0.631586 8.25Z" fill="#98A2B3"/>
<path d="M2.57377 4.1863C2.96256 4.06535 3.37698 4 3.80894 4C6.16566 4 8.00006 5.94534 8.00006 8.24999C8.00006 8.65682 7.9429 9.05245 7.8358 9.42816C8.10681 9.37948 8.36964 9.30678 8.6219 9.21229C8.65748 9.19897 8.69298 9.18534 8.72893 9.17304C8.75795 9.17641 8.78684 9.18093 8.81574 9.18517L10.4222 9.42065C10.498 9.43179 10.5841 9.44444 10.6591 9.4487C10.7422 9.45343 10.8713 9.45292 11.0081 9.39408C11.1789 9.32061 11.3164 9.18628 11.3938 9.01716C11.4558 8.88174 11.4593 8.75269 11.4564 8.66955C11.4539 8.59442 11.4433 8.5081 11.4339 8.43202L11.2309 6.78307C11.2256 6.7402 11.2229 6.71768 11.2213 6.70118C11.23 6.66505 11.2466 6.6301 11.2598 6.59546C11.4492 6.09896 11.5526 5.56093 11.5526 5C11.5526 2.51163 9.52304 0.5 7.02632 0.5C4.80843 0.5 2.95915 2.08742 2.57377 4.1863Z" fill="#98A2B3"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.6 KiB

View File

@ -1,3 +0,0 @@
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M14.1667 6.66634H15.8333C16.2754 6.66634 16.6993 6.84194 17.0118 7.1545C17.3244 7.46706 17.5 7.89098 17.5 8.33301V13.333C17.5 13.775 17.3244 14.199 17.0118 14.5115C16.6993 14.8241 16.2754 14.9997 15.8333 14.9997H14.1667V18.333L10.8333 14.9997H7.5C7.28111 14.9999 7.06433 14.9569 6.86211 14.8731C6.6599 14.7893 6.47623 14.6663 6.32167 14.5113M6.32167 14.5113L9.16667 11.6663H12.5C12.942 11.6663 13.366 11.4907 13.6785 11.1782C13.9911 10.8656 14.1667 10.4417 14.1667 9.99967V4.99967C14.1667 4.55765 13.9911 4.13372 13.6785 3.82116C13.366 3.5086 12.942 3.33301 12.5 3.33301H4.16667C3.72464 3.33301 3.30072 3.5086 2.98816 3.82116C2.67559 4.13372 2.5 4.55765 2.5 4.99967V9.99967C2.5 10.4417 2.67559 10.8656 2.98816 11.1782C3.30072 11.4907 3.72464 11.6663 4.16667 11.6663H5.83333V14.9997L6.32167 14.5113Z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 1002 B

View File

@ -1,4 +0,0 @@
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M6.5 1.00779C6.5 0.994638 6.5 0.988062 6.49943 0.976137C6.48764 0.729248 6.27052 0.51224 6.02363 0.50056C6.01171 0.499996 6.0078 0.499998 6.00001 0.5H4.37933C3.97686 0.499995 3.64468 0.49999 3.37409 0.522098C3.09304 0.545061 2.83469 0.594343 2.59202 0.717989C2.2157 0.909735 1.90973 1.2157 1.71799 1.59202C1.59434 1.83469 1.54506 2.09304 1.5221 2.37409C1.49999 2.64468 1.49999 2.97686 1.5 3.37934V8.62066C1.49999 9.02313 1.49999 9.35532 1.5221 9.62591C1.54506 9.90696 1.59434 10.1653 1.71799 10.408C1.90973 10.7843 2.2157 11.0903 2.59202 11.282C2.83469 11.4057 3.09304 11.4549 3.37409 11.4779C3.64468 11.5 3.97686 11.5 4.37934 11.5H7.62066C8.02314 11.5 8.35532 11.5 8.62591 11.4779C8.90696 11.4549 9.16531 11.4057 9.40798 11.282C9.78431 11.0903 10.0903 10.7843 10.282 10.408C10.4057 10.1653 10.4549 9.90696 10.4779 9.62591C10.5 9.35532 10.5 9.02314 10.5 8.62066V4.99997C10.5 4.9922 10.5 4.98832 10.4994 4.97641C10.4878 4.72949 10.2707 4.51236 10.0238 4.50057C10.0119 4.50001 10.0054 4.50001 9.99225 4.50001L7.78404 4.50001C7.65786 4.50002 7.53496 4.50004 7.43089 4.49153C7.31659 4.48219 7.18172 4.46016 7.04601 4.39101C6.85785 4.29514 6.70487 4.14216 6.609 3.954C6.53985 3.81828 6.51781 3.68342 6.50848 3.56912C6.49997 3.46504 6.49999 3.34215 6.5 3.21596L6.5 1.00779ZM4 6.5C3.72386 6.5 3.5 6.72386 3.5 7C3.5 7.27614 3.72386 7.5 4 7.5H8C8.27614 7.5 8.5 7.27614 8.5 7C8.5 6.72386 8.27614 6.5 8 6.5H4ZM4 8.5C3.72386 8.5 3.5 8.72386 3.5 9C3.5 9.27614 3.72386 9.5 4 9.5H7C7.27614 9.5 7.5 9.27614 7.5 9C7.5 8.72386 7.27614 8.5 7 8.5H4Z" fill="#98A2B3"/>
<path d="M9.45398 3.5C9.60079 3.5 9.67419 3.5 9.73432 3.46314C9.81925 3.41107 9.87002 3.28842 9.84674 3.19157C9.83025 3.12299 9.78238 3.07516 9.68665 2.97952L8.02049 1.31336C7.92484 1.21762 7.87701 1.16975 7.80843 1.15326C7.71158 1.12998 7.58893 1.18075 7.53687 1.26567C7.5 1.3258 7.5 1.39921 7.5 1.54602L7.5 3.09998C7.5 3.23999 7.5 3.30999 7.52725 3.36347C7.55122 3.41051 7.58946 3.44876 7.6365 3.47272C7.68998 3.49997 7.75998 3.49997 7.9 3.49998L9.45398 3.5Z" fill="#98A2B3"/>
</svg>

Before

Width:  |  Height:  |  Size: 2.1 KiB

View File

@ -1,3 +0,0 @@
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M16.25 11.875V9.6875C16.25 8.1342 14.9908 6.875 13.4375 6.875H12.1875C11.6697 6.875 11.25 6.45527 11.25 5.9375V4.6875C11.25 3.1342 9.9908 1.875 8.4375 1.875H6.875M6.875 12.5H13.125M6.875 15H10M8.75 1.875H4.6875C4.16973 1.875 3.75 2.29473 3.75 2.8125V17.1875C3.75 17.7053 4.16973 18.125 4.6875 18.125H15.3125C15.8303 18.125 16.25 17.7053 16.25 17.1875V9.375C16.25 5.23286 12.8921 1.875 8.75 1.875Z" stroke="#1F2A37" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 595 B

View File

@ -1,3 +0,0 @@
<svg width="26" height="26" viewBox="0 0 26 26" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M22.0101 4.50191C20.3529 3.74154 18.5759 3.18133 16.7179 2.86048C16.6841 2.85428 16.6503 2.86976 16.6328 2.90071C16.4043 3.30719 16.1511 3.83748 15.9738 4.25429C13.9754 3.95511 11.9873 3.95511 10.0298 4.25429C9.85253 3.82822 9.59019 3.30719 9.36062 2.90071C9.34319 2.87079 9.30939 2.85532 9.27555 2.86048C7.41857 3.18031 5.64152 3.74051 3.98335 4.50191C3.96899 4.5081 3.95669 4.51843 3.94852 4.53183C0.577841 9.56755 -0.345529 14.4795 0.107445 19.3306C0.109495 19.3543 0.122817 19.377 0.141265 19.3914C2.36514 21.0246 4.51935 22.0161 6.63355 22.6732C6.66739 22.6836 6.70324 22.6712 6.72477 22.6433C7.22489 21.9604 7.6707 21.2402 8.05293 20.4829C8.07549 20.4386 8.05396 20.386 8.00785 20.3684C7.30073 20.1002 6.6274 19.7731 5.97971 19.4017C5.92848 19.3718 5.92437 19.2985 5.9715 19.2635C6.1078 19.1613 6.24414 19.0551 6.37428 18.9478C6.39783 18.9282 6.43064 18.924 6.45833 18.9364C10.7134 20.8791 15.32 20.8791 19.5249 18.9364C19.5525 18.923 19.5854 18.9272 19.6099 18.9467C19.7401 19.054 19.8764 19.1613 20.0137 19.2635C20.0609 19.2985 20.0578 19.3718 20.0066 19.4017C19.3589 19.7804 18.6855 20.1002 17.9774 20.3674C17.9313 20.3849 17.9108 20.4386 17.9333 20.4829C18.3238 21.2392 18.7696 21.9593 19.2605 22.6423C19.281 22.6712 19.3179 22.6836 19.3517 22.6732C21.4761 22.0161 23.6303 21.0246 25.8542 19.3914C25.8737 19.377 25.886 19.3553 25.8881 19.3316C26.4302 13.7232 24.98 8.85156 22.0439 4.53286C22.0367 4.51843 22.0245 4.5081 22.0101 4.50191ZM8.68836 16.3768C7.40729 16.3768 6.35173 15.2007 6.35173 13.7563C6.35173 12.3119 7.38682 11.1358 8.68836 11.1358C10.0001 11.1358 11.0455 12.3222 11.025 13.7563C11.025 15.2007 9.98986 16.3768 8.68836 16.3768ZM17.3276 16.3768C16.0466 16.3768 14.991 15.2007 14.991 13.7563C14.991 12.3119 16.0261 11.1358 17.3276 11.1358C18.6394 11.1358 19.6847 12.3222 19.6643 13.7563C19.6643 15.2007 18.6394 16.3768 17.3276 16.3768Z" fill="#5865F2"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.9 KiB

View File

@ -1,17 +0,0 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_131_1011)">
<path fill-rule="evenodd" clip-rule="evenodd" d="M12.0003 0.5C9.15149 0.501478 6.39613 1.51046 4.22687 3.34652C2.05761 5.18259 0.615903 7.72601 0.159545 10.522C-0.296814 13.318 0.261927 16.1842 1.73587 18.6082C3.20981 21.0321 5.50284 22.8558 8.20493 23.753C8.80105 23.8636 9.0256 23.4941 9.0256 23.18C9.0256 22.8658 9.01367 21.955 9.0097 20.9592C5.6714 21.6804 4.96599 19.5505 4.96599 19.5505C4.42152 18.1674 3.63464 17.8039 3.63464 17.8039C2.54571 17.065 3.71611 17.0788 3.71611 17.0788C4.92227 17.1637 5.55616 18.3097 5.55616 18.3097C6.62521 20.1333 8.36389 19.6058 9.04745 19.2976C9.15475 18.5251 9.46673 17.9995 9.8105 17.7012C7.14383 17.4008 4.34204 16.3774 4.34204 11.8054C4.32551 10.6197 4.76802 9.47305 5.57801 8.60268C5.45481 8.30236 5.04348 7.08923 5.69524 5.44143C5.69524 5.44143 6.7027 5.12135 8.9958 6.66444C10.9627 6.12962 13.0379 6.12962 15.0047 6.66444C17.2958 5.12135 18.3013 5.44143 18.3013 5.44143C18.9551 7.08528 18.5437 8.29841 18.4205 8.60268C19.2331 9.47319 19.6765 10.6218 19.6585 11.8094C19.6585 16.3912 16.8507 17.4008 14.1801 17.6952C14.6093 18.0667 14.9928 18.7918 14.9928 19.9061C14.9928 21.5026 14.9789 22.7868 14.9789 23.18C14.9789 23.4981 15.1955 23.8695 15.8035 23.753C18.5059 22.8557 20.7992 21.0317 22.2731 18.6073C23.747 16.183 24.3055 13.3163 23.8486 10.5201C23.3917 7.7238 21.9493 5.18035 19.7793 3.34461C17.6093 1.50886 14.8533 0.500541 12.0042 0.5H12.0003Z" fill="#191717"/>
<path d="M4.54444 17.6321C4.5186 17.6914 4.42322 17.7092 4.34573 17.6677C4.26823 17.6262 4.21061 17.5491 4.23843 17.4879C4.26625 17.4266 4.35964 17.4108 4.43714 17.4523C4.51463 17.4938 4.57424 17.5729 4.54444 17.6321Z" fill="#191717"/>
<path d="M5.03123 18.1714C4.99008 18.192 4.943 18.1978 4.89805 18.1877C4.8531 18.1776 4.81308 18.1523 4.78483 18.1161C4.70734 18.0331 4.69143 17.9185 4.75104 17.8671C4.81066 17.8157 4.91797 17.8395 4.99546 17.9224C5.07296 18.0054 5.09084 18.12 5.03123 18.1714Z" fill="#191717"/>
<path d="M5.50425 18.857C5.43072 18.9084 5.30553 18.857 5.23598 18.7543C5.21675 18.7359 5.20146 18.7138 5.19101 18.6893C5.18056 18.6649 5.17517 18.6386 5.17517 18.612C5.17517 18.5855 5.18056 18.5592 5.19101 18.5347C5.20146 18.5103 5.21675 18.4882 5.23598 18.4698C5.3095 18.4204 5.4347 18.4698 5.50425 18.5705C5.57379 18.6713 5.57578 18.8057 5.50425 18.857V18.857Z" fill="#191717"/>
<path d="M6.14612 19.5207C6.08054 19.5939 5.94741 19.5741 5.83812 19.4753C5.72883 19.3765 5.70299 19.2422 5.76857 19.171C5.83414 19.0999 5.96727 19.1197 6.08054 19.2165C6.1938 19.3133 6.21566 19.4496 6.14612 19.5207V19.5207Z" fill="#191717"/>
<path d="M7.04617 19.9081C7.01637 20.001 6.88124 20.0425 6.74612 20.003C6.611 19.9635 6.52158 19.8528 6.54741 19.758C6.57325 19.6631 6.71036 19.6197 6.84747 19.6631C6.98457 19.7066 7.07201 19.8113 7.04617 19.9081Z" fill="#191717"/>
<path d="M8.02783 19.9752C8.02783 20.072 7.91656 20.155 7.77349 20.1569C7.63042 20.1589 7.51318 20.0799 7.51318 19.9831C7.51318 19.8863 7.62445 19.8033 7.76752 19.8013C7.91059 19.7993 8.02783 19.8764 8.02783 19.9752Z" fill="#191717"/>
<path d="M8.9419 19.8232C8.95978 19.92 8.86042 20.0207 8.71735 20.0445C8.57428 20.0682 8.4491 20.0109 8.43121 19.916C8.41333 19.8212 8.51666 19.7185 8.65576 19.6928C8.79485 19.6671 8.92401 19.7264 8.9419 19.8232Z" fill="#191717"/>
</g>
<defs>
<clipPath id="clip0_131_1011">
<rect width="24" height="24" fill="white"/>
</clipPath>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 3.4 KiB

View File

@ -1,3 +0,0 @@
<svg width="13" height="14" viewBox="0 0 13 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M5.41663 3.75033H3.24996C2.96264 3.75033 2.68709 3.86446 2.48393 4.06763C2.28076 4.27079 2.16663 4.54634 2.16663 4.83366V10.2503C2.16663 10.5376 2.28076 10.8132 2.48393 11.0164C2.68709 11.2195 2.96264 11.3337 3.24996 11.3337H8.66663C8.95394 11.3337 9.22949 11.2195 9.43266 11.0164C9.63582 10.8132 9.74996 10.5376 9.74996 10.2503V8.08366M7.58329 2.66699H10.8333M10.8333 2.66699V5.91699M10.8333 2.66699L5.41663 8.08366" stroke="#9CA3AF" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 596 B

View File

@ -1,3 +0,0 @@
<svg width="13" height="14" viewBox="0 0 13 14" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M5.41663 3.75008H3.24996C2.96264 3.75008 2.68709 3.86422 2.48393 4.06738C2.28076 4.27055 2.16663 4.5461 2.16663 4.83341V10.2501C2.16663 10.5374 2.28076 10.8129 2.48393 11.0161C2.68709 11.2193 2.96264 11.3334 3.24996 11.3334H8.66663C8.95394 11.3334 9.22949 11.2193 9.43266 11.0161C9.63582 10.8129 9.74996 10.5374 9.74996 10.2501V8.08341M7.58329 2.66675H10.8333M10.8333 2.66675V5.91675M10.8333 2.66675L5.41663 8.08341" stroke="#1C64F2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 595 B

Some files were not shown because too many files have changed in this diff Show More