Merge branch 'main' into feat/rag-2
|
|
@ -54,7 +54,7 @@ REDIS_CLUSTERS_PASSWORD=
|
|||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
|
||||
|
||||
CELERY_BACKEND=redis
|
||||
# PostgreSQL database configuration
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
|
|
|
|||
|
|
@ -211,7 +211,7 @@ class DatabaseConfig(BaseSettings):
|
|||
class CeleryConfig(DatabaseConfig):
|
||||
CELERY_BACKEND: str = Field(
|
||||
description="Backend for Celery task results. Options: 'database', 'redis'.",
|
||||
default="database",
|
||||
default="redis",
|
||||
)
|
||||
|
||||
CELERY_BROKER_URL: Optional[str] = Field(
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from flask_restful import Resource, fields, marshal_with, reqparse
|
|||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
|
|
@ -27,7 +28,7 @@ from fields.conversation_fields import annotation_fields, message_detail_fields
|
|||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
|
|
@ -124,33 +125,16 @@ class MessageFeedbackApi(Resource):
|
|||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = str(args["message_id"])
|
||||
|
||||
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
feedback = message.admin_feedback
|
||||
|
||||
if not args["rating"] and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args["rating"] and feedback:
|
||||
feedback.rating = args["rating"]
|
||||
elif not args["rating"] and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=args["rating"],
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=str(args["message_id"]),
|
||||
user=current_user,
|
||||
rating=args.get("rating"),
|
||||
content=None,
|
||||
)
|
||||
db.session.add(feedback)
|
||||
|
||||
db.session.commit()
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
|
|
|||
|
|
@ -211,10 +211,6 @@ class DatasetApi(Resource):
|
|||
else:
|
||||
data["embedding_available"] = True
|
||||
|
||||
if data.get("permission") == "partial_members":
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
return data, 200
|
||||
|
||||
@setup_required
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from controllers.console import api
|
|||
from controllers.console.datasets.error import WebsiteCrawlError
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.website_service import WebsiteService
|
||||
from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
|
||||
|
||||
|
||||
class WebsiteCrawlApi(Resource):
|
||||
|
|
@ -24,10 +24,16 @@ class WebsiteCrawlApi(Resource):
|
|||
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
WebsiteService.document_create_args_validate(args)
|
||||
# crawl url
|
||||
|
||||
# Create typed request and validate
|
||||
try:
|
||||
result = WebsiteService.crawl_url(args)
|
||||
api_request = WebsiteCrawlApiRequest.from_args(args)
|
||||
except ValueError as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
|
||||
# Crawl URL using typed request
|
||||
try:
|
||||
result = WebsiteService.crawl_url(api_request)
|
||||
except Exception as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
return result, 200
|
||||
|
|
@ -43,9 +49,16 @@ class WebsiteCrawlStatusApi(Resource):
|
|||
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# get crawl status
|
||||
|
||||
# Create typed request and validate
|
||||
try:
|
||||
result = WebsiteService.get_crawl_status(job_id, args["provider"])
|
||||
api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id)
|
||||
except ValueError as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
|
||||
# Get crawl status using typed request
|
||||
try:
|
||||
result = WebsiteService.get_crawl_status_typed(api_request)
|
||||
except Exception as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
return result, 200
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
|||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
|
|||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
||||
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
|
|
|
|||
|
|
@ -169,7 +169,3 @@ class AppQueueManager:
|
|||
raise TypeError(
|
||||
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
|
||||
)
|
||||
|
||||
|
||||
class GenerateTaskStoppedError(Exception):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class AppRunner:
|
|||
else:
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
model_mode = ModelMode(model_config.mode)
|
||||
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
|
||||
|
|
|
|||
|
|
@ -11,10 +11,11 @@ from configs import dify_config
|
|||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.chat.app_runner import ChatAppRunner
|
||||
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
|
|
|
|||
|
|
@ -10,10 +10,11 @@ from pydantic import ValidationError
|
|||
from configs import dify_config
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from core.app.apps.completion.app_runner import CompletionAppRunner
|
||||
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
class GenerateTaskStoppedError(Exception):
|
||||
pass
|
||||
|
|
@ -6,7 +6,8 @@ from typing import Optional, Union, cast
|
|||
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ import contexts
|
|||
from configs import dify_config
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -13,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
|
|||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
|
|
@ -38,11 +40,13 @@ from core.app.entities.queue_entities import (
|
|||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
|
|
@ -54,6 +58,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
|
|||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
|
@ -246,315 +251,492 @@ class WorkflowAppGenerateTaskPipeline:
|
|||
if tts_publisher:
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
@contextmanager
|
||||
def _database_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
def _ensure_workflow_initialized(self) -> None:
|
||||
"""Fluent validation for workflow state."""
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
|
||||
"""Fluent validation for graph runtime state."""
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
return graph_runtime_state
|
||||
|
||||
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
|
||||
"""Handle ping events."""
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
|
||||
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
|
||||
"""Handle error events."""
|
||||
err = self._base_task_pipeline._handle_error(event=event)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
|
||||
def _handle_workflow_started_event(
|
||||
self, event: QueueWorkflowStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow started events."""
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
yield start_resp
|
||||
|
||||
def _handle_node_retry_event(self, event: QueueNodeRetryEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle node retry events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
|
||||
def _handle_node_started_event(
|
||||
self, event: QueueNodeStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle node started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_start_response:
|
||||
yield node_start_response
|
||||
|
||||
def _handle_node_succeeded_event(
|
||||
self, event: QueueNodeSucceededEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle node succeeded events."""
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
|
||||
node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
|
||||
def _handle_node_failed_events(
|
||||
self,
|
||||
event: Union[
|
||||
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
|
||||
],
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle various node failure events."""
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
|
||||
event=event,
|
||||
)
|
||||
node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeExceptionEvent):
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
|
||||
def _handle_parallel_branch_started_event(
|
||||
self, event: QueueParallelBranchRunStartedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch started events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_start_resp
|
||||
|
||||
def _handle_parallel_branch_finished_events(
|
||||
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle parallel branch finished events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield parallel_finish_resp
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
self, event: QueueIterationStartEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle iteration start events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield iter_start_resp
|
||||
|
||||
def _handle_iteration_next_event(
|
||||
self, event: QueueIterationNextEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle iteration next events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield iter_next_resp
|
||||
|
||||
def _handle_iteration_completed_event(
|
||||
self, event: QueueIterationCompletedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle iteration completed events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield iter_finish_resp
|
||||
|
||||
def _handle_loop_start_event(self, event: QueueLoopStartEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle loop start events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield loop_start_resp
|
||||
|
||||
def _handle_loop_next_event(self, event: QueueLoopNextEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle loop next events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield loop_next_resp
|
||||
|
||||
def _handle_loop_completed_event(
|
||||
self, event: QueueLoopCompletedEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle loop completed events."""
|
||||
self._ensure_workflow_initialized()
|
||||
|
||||
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
yield loop_finish_resp
|
||||
|
||||
def _handle_workflow_succeeded_event(
|
||||
self,
|
||||
event: QueueWorkflowSucceededEvent,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow succeeded events."""
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_partial_success_event(
|
||||
self,
|
||||
event: QueueWorkflowPartialSuccessEvent,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow partial success events."""
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_failed_and_stop_events(
|
||||
self,
|
||||
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow failed and stop events."""
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
status=WorkflowExecutionStatus.FAILED
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else WorkflowExecutionStatus.STOPPED,
|
||||
error_message=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_text_chunk_event(
|
||||
self,
|
||||
event: QueueTextChunkEvent,
|
||||
*,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle text chunk events."""
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
return
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
|
||||
|
||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle agent log events."""
|
||||
yield self._workflow_response_converter.handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
|
||||
def _get_event_handlers(self) -> dict[type, Callable]:
|
||||
"""Get mapping of event types to their handlers using fluent pattern."""
|
||||
return {
|
||||
# Basic events
|
||||
QueuePingEvent: self._handle_ping_event,
|
||||
QueueErrorEvent: self._handle_error_event,
|
||||
QueueTextChunkEvent: self._handle_text_chunk_event,
|
||||
# Workflow events
|
||||
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
||||
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
||||
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
|
||||
# Node events
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
|
||||
# Parallel branch events
|
||||
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
|
||||
# Iteration events
|
||||
QueueIterationStartEvent: self._handle_iteration_start_event,
|
||||
QueueIterationNextEvent: self._handle_iteration_next_event,
|
||||
QueueIterationCompletedEvent: self._handle_iteration_completed_event,
|
||||
# Loop events
|
||||
QueueLoopStartEvent: self._handle_loop_start_event,
|
||||
QueueLoopNextEvent: self._handle_loop_next_event,
|
||||
QueueLoopCompletedEvent: self._handle_loop_completed_event,
|
||||
# Agent events
|
||||
QueueAgentLogEvent: self._handle_agent_log_event,
|
||||
}
|
||||
|
||||
def _dispatch_event(
|
||||
self,
|
||||
event: Any,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Dispatch events using elegant pattern matching."""
|
||||
handlers = self._get_event_handlers()
|
||||
event_type = type(event)
|
||||
|
||||
# Direct handler lookup
|
||||
if handler := handlers.get(event_type):
|
||||
yield from handler(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# Handle node failure events with isinstance check
|
||||
if isinstance(
|
||||
event,
|
||||
(
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
),
|
||||
):
|
||||
yield from self._handle_node_failed_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# Handle parallel branch finished events with isinstance check
|
||||
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
|
||||
yield from self._handle_parallel_branch_finished_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# Handle workflow failed and stop events with isinstance check
|
||||
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
|
||||
yield from self._handle_workflow_failed_and_stop_events(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
return
|
||||
|
||||
# For unhandled events, we continue (original behavior)
|
||||
return
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
Process stream response using elegant Fluent Python patterns.
|
||||
Maintains exact same functionality as original 44-if-statement version.
|
||||
"""
|
||||
# Initialize graph runtime state
|
||||
graph_runtime_state = None
|
||||
|
||||
for queue_message in self._base_task_pipeline._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._base_task_pipeline._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
err = self._base_task_pipeline._handle_error(event=event)
|
||||
yield self._base_task_pipeline._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
match event:
|
||||
case QueueWorkflowStartedEvent():
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
yield from self._handle_workflow_started_event(event)
|
||||
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
|
||||
yield start_resp
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
|
||||
workflow_execution_id=self._workflow_run_id, event=event
|
||||
)
|
||||
node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_start_response:
|
||||
yield node_start_response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(
|
||||
event=event
|
||||
)
|
||||
node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
):
|
||||
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
|
||||
event=event,
|
||||
)
|
||||
node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
if isinstance(event, QueueNodeExceptionEvent):
|
||||
self._save_output_for_event(event, workflow_node_execution.id)
|
||||
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
parallel_start_resp = (
|
||||
self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_start_resp
|
||||
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
parallel_finish_resp = (
|
||||
self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
|
||||
yield parallel_finish_resp
|
||||
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield iter_start_resp
|
||||
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield iter_next_resp
|
||||
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield iter_finish_resp
|
||||
|
||||
elif isinstance(event, QueueLoopStartEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield loop_start_resp
|
||||
|
||||
elif isinstance(event, QueueLoopNextEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield loop_next_resp
|
||||
|
||||
elif isinstance(event, QueueLoopCompletedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution_id=self._workflow_run_id,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield loop_finish_resp
|
||||
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
case QueueTextChunkEvent():
|
||||
yield from self._handle_text_chunk_event(
|
||||
event, tts_publisher=tts_publisher, queue_message=queue_message
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
|
||||
case QueueErrorEvent():
|
||||
yield from self._handle_error_event(event)
|
||||
break
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=event.outputs,
|
||||
exceptions_count=event.exceptions_count,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
if not graph_runtime_state:
|
||||
raise ValueError("graph runtime state not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowExecutionStatus.FAILED
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else WorkflowExecutionStatus.STOPPED,
|
||||
error_message=event.error
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
|
||||
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
yield self._text_chunk_to_stream_response(
|
||||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._workflow_response_converter.handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
else:
|
||||
continue
|
||||
# Handle all other events through elegant dispatch
|
||||
case _:
|
||||
if responses := list(
|
||||
self._dispatch_event(
|
||||
event,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
tts_publisher=tts_publisher,
|
||||
trace_manager=trace_manager,
|
||||
queue_message=queue_message,
|
||||
)
|
||||
):
|
||||
yield from responses
|
||||
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def encrypt_token(tenant_id: str, token: str):
|
|||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
|
||||
def decrypt_token(tenant_id: str, token: str):
|
||||
def decrypt_token(tenant_id: str, token: str) -> str:
|
||||
return rsa.decrypt(base64.b64decode(token), tenant_id)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,19 +29,6 @@ class ModelMode(enum.StrEnum):
|
|||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ModelMode":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
prompt_file_contents: dict[str, Any] = {}
|
||||
|
||||
|
|
@ -65,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
model_mode = ModelMode(model_config.mode)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
prompt_messages, stops = self._get_chat_model_prompt_messages(
|
||||
app_mode=app_mode,
|
||||
|
|
|
|||
|
|
@ -238,9 +238,11 @@ class WordExtractor(BaseExtractor):
|
|||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
|
||||
# Process drawing type images
|
||||
drawing_elements = run.element.findall(
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing"
|
||||
)
|
||||
has_drawing = False
|
||||
for drawing in drawing_elements:
|
||||
blip_elements = drawing.findall(
|
||||
".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip"
|
||||
|
|
@ -252,6 +254,34 @@ class WordExtractor(BaseExtractor):
|
|||
if embed_id:
|
||||
image_part = doc.part.related_parts.get(embed_id)
|
||||
if image_part in image_map:
|
||||
has_drawing = True
|
||||
paragraph_content.append(image_map[image_part])
|
||||
# Process pict type images
|
||||
shape_elements = run.element.findall(
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
|
||||
)
|
||||
for shape in shape_elements:
|
||||
# Find image data in VML
|
||||
shape_image = shape.find(
|
||||
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}binData"
|
||||
)
|
||||
if shape_image is not None and shape_image.text:
|
||||
image_id = shape_image.get(
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
image_part = doc.part.rels[image_id].target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
# Find imagedata element in VML
|
||||
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
|
||||
if image_data is not None:
|
||||
image_id = image_data.get("id") or image_data.get(
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
|
||||
)
|
||||
if image_id and image_id in doc.part.rels:
|
||||
image_part = doc.part.rels[image_id].target_part
|
||||
if image_part in image_map and not has_drawing:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
if run.text.strip():
|
||||
paragraph_content.append(run.text.strip())
|
||||
|
|
|
|||
|
|
@ -1137,7 +1137,7 @@ class DatasetRetrieval:
|
|||
def _get_prompt_template(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
|
||||
):
|
||||
model_mode = ModelMode.value_of(mode)
|
||||
model_mode = ModelMode(mode)
|
||||
input_text = query
|
||||
|
||||
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import json
|
|||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
|
@ -206,44 +205,3 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
|||
# Update the in-memory cache for faster subsequent lookups
|
||||
logger.debug(f"Updating cache for execution_id: {db_model.id}")
|
||||
self._execution_cache[db_model.id] = db_model
|
||||
|
||||
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
|
||||
"""
|
||||
Retrieve a WorkflowExecution by its ID.
|
||||
|
||||
First checks the in-memory cache, and if not found, queries the database.
|
||||
If found in the database, adds it to the cache for future lookups.
|
||||
|
||||
Args:
|
||||
execution_id: The workflow execution ID
|
||||
|
||||
Returns:
|
||||
The WorkflowExecution instance if found, None otherwise
|
||||
"""
|
||||
# First check the cache
|
||||
if execution_id in self._execution_cache:
|
||||
logger.debug(f"Cache hit for execution_id: {execution_id}")
|
||||
# Convert cached DB model to domain model
|
||||
cached_db_model = self._execution_cache[execution_id]
|
||||
return self._to_domain_model(cached_db_model)
|
||||
|
||||
# If not in cache, query the database
|
||||
logger.debug(f"Cache miss for execution_id: {execution_id}, querying database")
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowRun).where(
|
||||
WorkflowRun.id == execution_id,
|
||||
WorkflowRun.tenant_id == self._tenant_id,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowRun.app_id == self._app_id)
|
||||
|
||||
db_model = session.scalar(stmt)
|
||||
if db_model:
|
||||
# Add DB model to cache
|
||||
self._execution_cache[execution_id] = db_model
|
||||
|
||||
# Convert to domain model and return
|
||||
return self._to_domain_model(db_model)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import logging
|
|||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import UnaryExpression, asc, delete, desc, select
|
||||
from sqlalchemy import UnaryExpression, asc, desc, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
|
@ -218,47 +218,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
|
||||
self._node_execution_cache[db_model.node_execution_id] = db_model
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve a NodeExecution by its node_execution_id.
|
||||
|
||||
First checks the in-memory cache, and if not found, queries the database.
|
||||
If found in the database, adds it to the cache for future lookups.
|
||||
|
||||
Args:
|
||||
node_execution_id: The node execution ID
|
||||
|
||||
Returns:
|
||||
The NodeExecution instance if found, None otherwise
|
||||
"""
|
||||
# First check the cache
|
||||
if node_execution_id in self._node_execution_cache:
|
||||
logger.debug(f"Cache hit for node_execution_id: {node_execution_id}")
|
||||
# Convert cached DB model to domain model
|
||||
cached_db_model = self._node_execution_cache[node_execution_id]
|
||||
return self._to_domain_model(cached_db_model)
|
||||
|
||||
# If not in cache, query the database
|
||||
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.node_execution_id == node_execution_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
db_model = session.scalar(stmt)
|
||||
if db_model:
|
||||
# Add DB model to cache
|
||||
self._node_execution_cache[node_execution_id] = db_model
|
||||
|
||||
# Convert to domain model and return
|
||||
return self._to_domain_model(db_model)
|
||||
|
||||
return None
|
||||
|
||||
def get_db_models_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
|
@ -346,68 +305,3 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
domain_models.append(domain_model)
|
||||
|
||||
return domain_models
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all running NodeExecution instances for a specific workflow run.
|
||||
|
||||
This method queries the database directly and updates the cache with any
|
||||
retrieved executions that have a node_execution_id.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
A list of running NodeExecution instances
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING,
|
||||
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
db_models = session.scalars(stmt).all()
|
||||
domain_models = []
|
||||
|
||||
for model in db_models:
|
||||
# Update cache if node_execution_id is present
|
||||
if model.node_execution_id:
|
||||
self._node_execution_cache[model.node_execution_id] = model
|
||||
|
||||
# Convert to domain model
|
||||
domain_model = self._to_domain_model(model)
|
||||
domain_models.append(domain_model)
|
||||
|
||||
return domain_models
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
|
||||
|
||||
This method deletes all WorkflowNodeExecution records that match the tenant_id
|
||||
and app_id (if provided) associated with this repository instance.
|
||||
It also clears the in-memory cache.
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
|
||||
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info(
|
||||
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
|
||||
+ (f" and app {self._app_id}" if self._app_id else "")
|
||||
)
|
||||
|
||||
# Clear the in-memory cache
|
||||
self._node_execution_cache.clear()
|
||||
logger.info("Cleared in-memory node execution cache")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from core.workflow.nodes.base import BaseNode
|
|||
|
||||
|
||||
class WorkflowNodeRunFailedError(Exception):
|
||||
def __init__(self, node_instance: BaseNode, error: str):
|
||||
self.node_instance = node_instance
|
||||
self.error = error
|
||||
super().__init__(f"Node {node_instance.node_data.title} run failed: {error}")
|
||||
def __init__(self, node: BaseNode, err_msg: str):
|
||||
self._node = node
|
||||
self._error = err_msg
|
||||
super().__init__(f"Node {node.title} run failed: {err_msg}")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
|
||||
from .graph_engine import GraphEngine
|
||||
|
||||
__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import Any, Optional, cast
|
|||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
|
|
@ -48,11 +48,9 @@ from core.workflow.nodes.agent.entities import AgentNodeData
|
|||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.utils import variable_utils
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models.enums import UserFrom
|
||||
|
|
@ -260,12 +258,16 @@ class GraphEngine:
|
|||
# convert to specific node
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
node_version = node_config.get("data", {}).get("version", "1")
|
||||
|
||||
# Import here to avoid circular import
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls( # type: ignore
|
||||
node = node_cls(
|
||||
id=route_node_state.id,
|
||||
config=node_config,
|
||||
graph_init_params=self.init_params,
|
||||
|
|
@ -274,11 +276,11 @@ class GraphEngine:
|
|||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
node_instance = cast(BaseNode[BaseNodeData], node_instance)
|
||||
node.init_node_data(node_config.get("data", {}))
|
||||
try:
|
||||
# run node
|
||||
generator = self._run_node(
|
||||
node_instance=node_instance,
|
||||
node=node,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
|
|
@ -306,16 +308,16 @@ class GraphEngine:
|
|||
route_node_state.failed_reason = str(e)
|
||||
yield NodeRunFailedEvent(
|
||||
error=str(e),
|
||||
id=node_instance.id,
|
||||
id=node.id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
raise e
|
||||
|
||||
|
|
@ -337,7 +339,7 @@ class GraphEngine:
|
|||
edge = edge_mappings[0]
|
||||
if (
|
||||
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
|
||||
and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and node.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and edge.run_condition is None
|
||||
):
|
||||
break
|
||||
|
|
@ -413,8 +415,8 @@ class GraphEngine:
|
|||
|
||||
next_node_id = final_node_id
|
||||
elif (
|
||||
node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and node_instance.should_continue_on_error
|
||||
node.continue_on_error
|
||||
and node.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
|
||||
):
|
||||
break
|
||||
|
|
@ -597,7 +599,7 @@ class GraphEngine:
|
|||
|
||||
def _run_node(
|
||||
self,
|
||||
node_instance: BaseNode[BaseNodeData],
|
||||
node: BaseNode,
|
||||
route_node_state: RouteNodeState,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
|
|
@ -611,29 +613,29 @@ class GraphEngine:
|
|||
# trigger node run start event
|
||||
agent_strategy = (
|
||||
AgentNodeStrategyInit(
|
||||
name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name,
|
||||
icon=cast(AgentNode, node_instance).agent_strategy_icon,
|
||||
name=cast(AgentNodeData, node.get_base_node_data()).agent_strategy_name,
|
||||
icon=cast(AgentNode, node).agent_strategy_icon,
|
||||
)
|
||||
if node_instance.node_type == NodeType.AGENT
|
||||
if node.type_ == NodeType.AGENT
|
||||
else None
|
||||
)
|
||||
yield NodeRunStartedEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
predecessor_node_id=node.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
agent_strategy=agent_strategy,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
|
||||
max_retries = node_instance.node_data.retry_config.max_retries
|
||||
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||
max_retries = node.retry_config.max_retries
|
||||
retry_interval = node.retry_config.retry_interval_seconds
|
||||
retries = 0
|
||||
should_continue_retry = True
|
||||
while should_continue_retry and retries <= max_retries:
|
||||
|
|
@ -642,7 +644,7 @@ class GraphEngine:
|
|||
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
# yield control to other threads
|
||||
time.sleep(0.001)
|
||||
event_stream = node_instance.run()
|
||||
event_stream = node.run()
|
||||
for event in event_stream:
|
||||
if isinstance(event, GraphEngineEvent):
|
||||
# add parallel info to iteration event
|
||||
|
|
@ -658,21 +660,21 @@ class GraphEngine:
|
|||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if (
|
||||
retries == max_retries
|
||||
and node_instance.node_type == NodeType.HTTP_REQUEST
|
||||
and node.type_ == NodeType.HTTP_REQUEST
|
||||
and run_result.outputs
|
||||
and not node_instance.should_continue_on_error
|
||||
and not node.continue_on_error
|
||||
):
|
||||
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if node_instance.should_retry and retries < max_retries:
|
||||
if node.retry and retries < max_retries:
|
||||
retries += 1
|
||||
route_node_state.node_run_result = run_result
|
||||
yield NodeRunRetryEvent(
|
||||
id=str(uuid.uuid4()),
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
predecessor_node_id=node.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
|
|
@ -680,17 +682,17 @@ class GraphEngine:
|
|||
error=run_result.error or "Unknown error",
|
||||
retry_index=retries,
|
||||
start_at=retry_start_at,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
break
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_instance.should_continue_on_error:
|
||||
if node.continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
node,
|
||||
event.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
|
|
@ -701,44 +703,44 @@ class GraphEngine:
|
|||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
node_id=node.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if (
|
||||
node_instance.should_continue_on_error
|
||||
and self.graph.edge_mapping.get(node_instance.node_id)
|
||||
and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH
|
||||
node.continue_on_error
|
||||
and self.graph.edge_mapping.get(node.node_id)
|
||||
and node.error_strategy is ErrorStrategy.FAIL_BRANCH
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(
|
||||
|
|
@ -758,7 +760,7 @@ class GraphEngine:
|
|||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
node_id=node.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
|
|
@ -783,26 +785,26 @@ class GraphEngine:
|
|||
run_result.metadata = metadata_dict
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
should_continue_retry = False
|
||||
|
||||
break
|
||||
elif isinstance(event, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
chunk_content=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
|
|
@ -810,14 +812,14 @@ class GraphEngine:
|
|||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
elif isinstance(event, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
retriever_resources=event.retriever_resources,
|
||||
context=event.context,
|
||||
route_node_state=route_node_state,
|
||||
|
|
@ -825,7 +827,7 @@ class GraphEngine:
|
|||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
|
|
@ -833,20 +835,20 @@ class GraphEngine:
|
|||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
id=node.id,
|
||||
node_id=node.node_id,
|
||||
node_type=node.type_,
|
||||
node_data=node.get_base_node_data(),
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
node_version=node_instance.version(),
|
||||
node_version=node.version(),
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
logger.exception(f"Node {node.title} run failed")
|
||||
raise e
|
||||
|
||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
|
|
@ -886,22 +888,14 @@ class GraphEngine:
|
|||
|
||||
def _handle_continue_on_error(
|
||||
self,
|
||||
node_instance: BaseNode[BaseNodeData],
|
||||
node: BaseNode,
|
||||
error_result: NodeRunResult,
|
||||
variable_pool: VariablePool,
|
||||
handle_exceptions: list[str] = [],
|
||||
) -> NodeRunResult:
|
||||
"""
|
||||
handle continue on error when self._should_continue_on_error is True
|
||||
|
||||
|
||||
:param error_result (NodeRunResult): error run result
|
||||
:param variable_pool (VariablePool): variable pool
|
||||
:return: excption run result
|
||||
"""
|
||||
# add error message and error type to variable pool
|
||||
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
|
||||
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
|
||||
variable_pool.add([node.node_id, "error_message"], error_result.error)
|
||||
variable_pool.add([node.node_id, "error_type"], error_result.error_type)
|
||||
# add error message to handle_exceptions
|
||||
handle_exceptions.append(error_result.error or "")
|
||||
node_error_args: dict[str, Any] = {
|
||||
|
|
@ -909,21 +903,21 @@ class GraphEngine:
|
|||
"error": error_result.error,
|
||||
"inputs": error_result.inputs,
|
||||
"metadata": {
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
|
||||
WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy,
|
||||
},
|
||||
}
|
||||
|
||||
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
return NodeRunResult(
|
||||
**node_error_args,
|
||||
outputs={
|
||||
**node_instance.node_data.default_value_dict,
|
||||
**node.default_value_dict,
|
||||
"error_message": error_result.error,
|
||||
"error_type": error_result.error_type,
|
||||
},
|
||||
)
|
||||
elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
|
||||
if self.graph.edge_mapping.get(node_instance.node_id):
|
||||
elif node.error_strategy is ErrorStrategy.FAIL_BRANCH:
|
||||
if self.graph.edge_mapping.get(node.node_id):
|
||||
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
|
||||
return NodeRunResult(
|
||||
**node_error_args,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
|
|
@ -11,8 +10,10 @@ from sqlalchemy.orm import Session
|
|||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
|
@ -25,45 +26,75 @@ from core.tools.entities.tool_entities import (
|
|||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import StringSegment
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models import ToolFile
|
||||
from models.model import Conversation
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exc import (
|
||||
AgentInputTypeError,
|
||||
AgentInvocationError,
|
||||
AgentMessageTransformError,
|
||||
AgentVariableNotFoundError,
|
||||
AgentVariableTypeError,
|
||||
ToolFileNotFoundError,
|
||||
)
|
||||
|
||||
|
||||
class AgentNode(ToolNode):
|
||||
class AgentNode(BaseNode):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
_node_data_cls = AgentNodeData # type: ignore
|
||||
_node_type = NodeType.AGENT
|
||||
_node_data: AgentNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = AgentNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the agent node
|
||||
"""
|
||||
node_data = cast(AgentNodeData, self.node_data)
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
agent_strategy_provider_name=node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=node_data.agent_strategy_name,
|
||||
agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self._node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield RunCompletedEvent(
|
||||
|
|
@ -81,13 +112,13 @@ class AgentNode(ToolNode):
|
|||
parameters = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
node_data=self._node_data,
|
||||
strategy=strategy,
|
||||
)
|
||||
parameters_for_log = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
node_data=self._node_data,
|
||||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
|
|
@ -105,59 +136,39 @@ class AgentNode(ToolNode):
|
|||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=f"Failed to invoke agent: {str(e)}",
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# convert tool messages
|
||||
agent_thoughts: list = []
|
||||
|
||||
thought_log_message = ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LOG,
|
||||
message=ToolInvokeMessage.LogMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}",
|
||||
parent_id=None,
|
||||
error=None,
|
||||
status=ToolInvokeMessage.LogMessage.LogStatus.START,
|
||||
data={
|
||||
"strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
"parameters": parameters_for_log,
|
||||
"thought_process": "Agent strategy execution started",
|
||||
},
|
||||
metadata={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
def enhanced_message_stream():
|
||||
yield thought_log_message
|
||||
|
||||
yield from message_stream
|
||||
|
||||
yield from self._transform_message(
|
||||
message_stream,
|
||||
{
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
|
||||
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
|
||||
},
|
||||
parameters_for_log,
|
||||
agent_thoughts,
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_type=self.type_,
|
||||
node_id=self.node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=f"Failed to transform agent message: {str(e)}",
|
||||
error=str(transform_error),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -194,7 +205,7 @@ class AgentNode(ToolNode):
|
|||
if agent_input.type == "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise ValueError(f"Variable {agent_input.value} does not exist")
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
elif agent_input.type in {"mixed", "constant"}:
|
||||
# variable_pool.convert_template expects a string template,
|
||||
|
|
@ -216,7 +227,7 @@ class AgentNode(ToolNode):
|
|||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
else:
|
||||
raise ValueError(f"Unknown agent input type '{agent_input.type}'")
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
|
|
@ -259,7 +270,7 @@ class AgentNode(ToolNode):
|
|||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
runtime_variable_pool = variable_pool if self.node_data.version != "1" else None
|
||||
runtime_variable_pool = variable_pool if self._node_data.version != "1" else None
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
|
||||
)
|
||||
|
|
@ -343,19 +354,14 @@ class AgentNode(ToolNode):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: BaseNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(AgentNodeData, node_data)
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AgentNodeData.model_validate(node_data)
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
input = node_data.agent_parameters[parameter_name]
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
if input.type in ["mixed", "constant"]:
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
|
|
@ -380,7 +386,7 @@ class AgentNode(ToolNode):
|
|||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}"
|
||||
== cast(AgentNodeData, self.node_data).agent_strategy_provider_name
|
||||
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
|
|
@ -448,3 +454,236 @@ class AgentNode(ToolNode):
|
|||
return tools
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value]
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage: LLMUsage | None = None
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(msg_metadata)
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
if message.message.json_object is not None:
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise AgentVariableTypeError(
|
||||
"When 'stream' is True, 'variable_value' must be a string.",
|
||||
variable_name=variable_name,
|
||||
expected_type="str",
|
||||
actual_type=type(variable_value).__name__,
|
||||
)
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, File)
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.id == agent_log.id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json:
|
||||
json_output.extend(json)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,124 @@
|
|||
from typing import Optional
|
||||
|
||||
|
||||
class AgentNodeError(Exception):
|
||||
"""Base exception for all agent node errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class AgentStrategyError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent strategy."""
|
||||
|
||||
def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
|
||||
self.strategy_name = strategy_name
|
||||
self.provider_name = provider_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentStrategyNotFoundError(AgentStrategyError):
|
||||
"""Exception raised when the specified agent strategy is not found."""
|
||||
|
||||
def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
|
||||
super().__init__(
|
||||
f"Agent strategy '{strategy_name}' not found"
|
||||
+ (f" for provider '{provider_name}'" if provider_name else ""),
|
||||
strategy_name,
|
||||
provider_name,
|
||||
)
|
||||
|
||||
|
||||
class AgentInvocationError(AgentNodeError):
|
||||
"""Exception raised when there's an error invoking the agent."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentParameterError(AgentNodeError):
|
||||
"""Exception raised when there's an error with agent parameters."""
|
||||
|
||||
def __init__(self, message: str, parameter_name: Optional[str] = None):
|
||||
self.parameter_name = parameter_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentVariableError(AgentNodeError):
|
||||
"""Exception raised when there's an error with variables in the agent node."""
|
||||
|
||||
def __init__(self, message: str, variable_name: Optional[str] = None):
|
||||
self.variable_name = variable_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentVariableNotFoundError(AgentVariableError):
|
||||
"""Exception raised when a variable is not found in the variable pool."""
|
||||
|
||||
def __init__(self, variable_name: str):
|
||||
super().__init__(f"Variable '{variable_name}' does not exist", variable_name)
|
||||
|
||||
|
||||
class AgentInputTypeError(AgentNodeError):
|
||||
"""Exception raised when an unknown agent input type is encountered."""
|
||||
|
||||
def __init__(self, input_type: str):
|
||||
super().__init__(f"Unknown agent input type '{input_type}'")
|
||||
|
||||
|
||||
class ToolFileError(AgentNodeError):
|
||||
"""Exception raised when there's an error with a tool file."""
|
||||
|
||||
def __init__(self, message: str, file_id: Optional[str] = None):
|
||||
self.file_id = file_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ToolFileNotFoundError(ToolFileError):
|
||||
"""Exception raised when a tool file is not found."""
|
||||
|
||||
def __init__(self, file_id: str):
|
||||
super().__init__(f"Tool file '{file_id}' does not exist", file_id)
|
||||
|
||||
|
||||
class AgentMessageTransformError(AgentNodeError):
|
||||
"""Exception raised when there's an error transforming agent messages."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
self.original_error = original_error
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentModelError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the model used by the agent."""
|
||||
|
||||
def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentMemoryError(AgentNodeError):
|
||||
"""Exception raised when there's an error with the agent's memory."""
|
||||
|
||||
def __init__(self, message: str, conversation_id: Optional[str] = None):
|
||||
self.conversation_id = conversation_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentVariableTypeError(AgentNodeError):
|
||||
"""Exception raised when a variable has an unexpected type."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
variable_name: Optional[str] = None,
|
||||
expected_type: Optional[str] = None,
|
||||
actual_type: Optional[str] = None,
|
||||
):
|
||||
self.variable_name = variable_name
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(message)
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.variables import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
|
|
@ -12,14 +12,37 @@ from core.workflow.nodes.answer.entities import (
|
|||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerNode(BaseNode[AnswerNodeData]):
|
||||
_node_data_cls = AnswerNodeData
|
||||
class AnswerNode(BaseNode):
|
||||
_node_type = NodeType.ANSWER
|
||||
|
||||
_node_data: AnswerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = AnswerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -30,7 +53,7 @@ class AnswerNode(BaseNode[AnswerNodeData]):
|
|||
:return:
|
||||
"""
|
||||
# generate routes
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data)
|
||||
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
|
||||
|
||||
answer = ""
|
||||
files = []
|
||||
|
|
@ -60,16 +83,12 @@ class AnswerNode(BaseNode[AnswerNodeData]):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AnswerNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = AnswerNodeData.model_validate(node_data)
|
||||
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
variable_mapping = {}
|
||||
|
|
|
|||
|
|
@ -122,13 +122,13 @@ class RetryConfig(BaseModel):
|
|||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
version: str = "1"
|
||||
error_strategy: Optional[ErrorStrategy] = None
|
||||
default_value: Optional[list[DefaultValue]] = None
|
||||
version: str = "1"
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self):
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
return {item.key: item.value for item in self.default_value}
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -1,28 +1,22 @@
|
|||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
|
||||
from .entities import BaseNodeData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
|
||||
|
||||
|
||||
class BaseNode(Generic[GenericNodeData]):
|
||||
_node_data_cls: type[GenericNodeData]
|
||||
class BaseNode:
|
||||
_node_type: ClassVar[NodeType]
|
||||
|
||||
def __init__(
|
||||
|
|
@ -56,8 +50,8 @@ class BaseNode(Generic[GenericNodeData]):
|
|||
|
||||
self.node_id = node_id
|
||||
|
||||
node_data = self._node_data_cls.model_validate(config.get("data", {}))
|
||||
self.node_data = node_data
|
||||
@abstractmethod
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
|
|
@ -130,9 +124,9 @@ class BaseNode(Generic[GenericNodeData]):
|
|||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
# Pass raw dict data instead of creating NodeData instance
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
|
||||
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
|
||||
)
|
||||
return data
|
||||
|
||||
|
|
@ -142,32 +136,16 @@ class BaseNode(Generic[GenericNodeData]):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: GenericNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def node_type(self) -> NodeType:
|
||||
"""
|
||||
Get node type
|
||||
:return:
|
||||
"""
|
||||
def type_(self) -> NodeType:
|
||||
return self._node_type
|
||||
|
||||
@classmethod
|
||||
|
|
@ -181,19 +159,68 @@ class BaseNode(Generic[GenericNodeData]):
|
|||
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
|
||||
|
||||
@property
|
||||
def should_continue_on_error(self) -> bool:
|
||||
"""judge if should continue on error
|
||||
|
||||
Returns:
|
||||
bool: if should continue on error
|
||||
"""
|
||||
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
|
||||
def continue_on_error(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def should_retry(self) -> bool:
|
||||
"""judge if should retry
|
||||
def retry(self) -> bool:
|
||||
return False
|
||||
|
||||
Returns:
|
||||
bool: if should retry
|
||||
"""
|
||||
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE
|
||||
# Abstract methods that subclasses must implement to provide access
|
||||
# to BaseNodeData properties in a type-safe way
|
||||
|
||||
@abstractmethod
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
"""Get the error strategy for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_title(self) -> str:
|
||||
"""Get the node title."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_description(self) -> Optional[str]:
|
||||
"""Get the node description."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
"""Get the BaseNodeData object for this node."""
|
||||
...
|
||||
|
||||
# Public interface properties that delegate to abstract methods
|
||||
@property
|
||||
def error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
"""Get the error strategy for this node."""
|
||||
return self._get_error_strategy()
|
||||
|
||||
@property
|
||||
def retry_config(self) -> RetryConfig:
|
||||
"""Get the retry configuration for this node."""
|
||||
return self._get_retry_config()
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
"""Get the node title."""
|
||||
return self._get_title()
|
||||
|
||||
@property
|
||||
def description(self) -> Optional[str]:
|
||||
"""Get the node description."""
|
||||
return self._get_description()
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
"""Get the default values dictionary for this node."""
|
||||
return self._get_default_value_dict()
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ from core.variables.segments import ArrayFileSegment
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
|
|
@ -21,10 +22,32 @@ from .exc import (
|
|||
)
|
||||
|
||||
|
||||
class CodeNode(BaseNode[CodeNodeData]):
|
||||
_node_data_cls = CodeNodeData
|
||||
class CodeNode(BaseNode):
|
||||
_node_type = NodeType.CODE
|
||||
|
||||
_node_data: CodeNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = CodeNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
|
|
@ -47,12 +70,12 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get code language
|
||||
code_language = self.node_data.code_language
|
||||
code = self.node_data.code
|
||||
code_language = self._node_data.code_language
|
||||
code = self._node_data.code
|
||||
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in self.node_data.variables:
|
||||
for variable_selector in self._node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
|
|
@ -68,7 +91,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
)
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
|
||||
result = self._transform_result(result=result, output_schema=self._node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
|
|
@ -334,16 +357,20 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: CodeNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = CodeNodeData.model_validate(node_data)
|
||||
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in node_data.variables
|
||||
for variable_selector in typed_node_data.variables
|
||||
}
|
||||
|
||||
@property
|
||||
def continue_on_error(self) -> bool:
|
||||
return self._node_data.error_strategy is not None
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import logging
|
|||
import os
|
||||
import tempfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import chardet
|
||||
import docx
|
||||
|
|
@ -28,7 +28,8 @@ from core.variables.segments import ArrayStringSegment, FileSegment
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
|
@ -36,21 +37,43 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
class DocumentExtractorNode(BaseNode):
|
||||
"""
|
||||
Extracts text content from various file types.
|
||||
Supports plain text, PDF, and DOC/DOCX files.
|
||||
"""
|
||||
|
||||
_node_data_cls = DocumentExtractorNodeData
|
||||
_node_type = NodeType.DOCUMENT_EXTRACTOR
|
||||
|
||||
_node_data: DocumentExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = DocumentExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
variable_selector = self.node_data.variable_selector
|
||||
variable_selector = self._node_data.variable_selector
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
||||
|
||||
if variable is None:
|
||||
|
|
@ -97,16 +120,12 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: DocumentExtractorNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {node_id + ".files": node_data.variable_selector}
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
|
||||
|
||||
return {node_id + ".files": typed_node_data.variable_selector}
|
||||
|
||||
|
||||
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,40 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
|
||||
class EndNode(BaseNode[EndNodeData]):
|
||||
_node_data_cls = EndNodeData
|
||||
class EndNode(BaseNode):
|
||||
_node_type = NodeType.END
|
||||
|
||||
_node_data: EndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = EndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -18,7 +44,7 @@ class EndNode(BaseNode[EndNodeData]):
|
|||
Run node
|
||||
:return:
|
||||
"""
|
||||
output_variables = self.node_data.outputs
|
||||
output_variables = self._node_data.outputs
|
||||
|
||||
outputs = {}
|
||||
for variable_selector in output_variables:
|
||||
|
|
|
|||
|
|
@ -37,7 +37,3 @@ class ErrorStrategy(StrEnum):
|
|||
class FailBranchSourceHandle(StrEnum):
|
||||
FAILED = "fail-branch"
|
||||
SUCCESS = "success-branch"
|
||||
|
||||
|
||||
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||
RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE
|
||||
|
|
|
|||
|
|
@ -11,7 +11,8 @@ from core.workflow.entities.node_entities import NodeRunResult
|
|||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from factories import file_factory
|
||||
|
|
@ -32,10 +33,32 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
_node_data_cls = HttpRequestNodeData
|
||||
class HttpRequestNode(BaseNode):
|
||||
_node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
_node_data: HttpRequestNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = HttpRequestNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
|
||||
return {
|
||||
|
|
@ -69,8 +92,8 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||
process_data = {}
|
||||
try:
|
||||
http_executor = Executor(
|
||||
node_data=self.node_data,
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
node_data=self._node_data,
|
||||
timeout=self._get_request_timeout(self._node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
max_retries=0,
|
||||
)
|
||||
|
|
@ -78,7 +101,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||
|
||||
response = http_executor.invoke()
|
||||
files = self.extract_files(url=http_executor.url, response=response)
|
||||
if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
|
||||
if not response.response.is_success and (self.continue_on_error or self.retry):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs={
|
||||
|
|
@ -131,15 +154,18 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: HttpRequestNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = HttpRequestNodeData.model_validate(node_data)
|
||||
|
||||
selectors: list[VariableSelector] = []
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
|
||||
if node_data.body:
|
||||
body_type = node_data.body.type
|
||||
data = node_data.body.data
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
|
||||
selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
|
||||
if typed_node_data.body:
|
||||
body_type = typed_node_data.body.type
|
||||
data = typed_node_data.body.data
|
||||
match body_type:
|
||||
case "binary":
|
||||
if len(data) != 1:
|
||||
|
|
@ -217,3 +243,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||
files.append(file)
|
||||
|
||||
return ArrayFileSegment(value=files)
|
||||
|
||||
@property
|
||||
def continue_on_error(self) -> bool:
|
||||
return self._node_data.error_strategy is not None
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
|
|
@ -7,16 +7,39 @@ from core.workflow.entities.node_entities import NodeRunResult
|
|||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class IfElseNode(BaseNode[IfElseNodeData]):
|
||||
_node_data_cls = IfElseNodeData
|
||||
class IfElseNode(BaseNode):
|
||||
_node_type = NodeType.IF_ELSE
|
||||
|
||||
_node_data: IfElseNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = IfElseNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -36,8 +59,8 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
|||
condition_processor = ConditionProcessor()
|
||||
try:
|
||||
# Check if the new cases structure is used
|
||||
if self.node_data.cases:
|
||||
for case in self.node_data.cases:
|
||||
if self._node_data.cases:
|
||||
for case in self._node_data.cases:
|
||||
input_conditions, group_result, final_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=case.conditions,
|
||||
|
|
@ -63,8 +86,8 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
|||
input_conditions, group_result, final_result = _should_not_use_old_function(
|
||||
condition_processor=condition_processor,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=self.node_data.conditions or [],
|
||||
operator=self.node_data.logical_operator or "and",
|
||||
conditions=self._node_data.conditions or [],
|
||||
operator=self._node_data.logical_operator or "and",
|
||||
)
|
||||
|
||||
selected_case_id = "true" if final_result else "false"
|
||||
|
|
@ -98,10 +121,13 @@ class IfElseNode(BaseNode[IfElseNodeData]):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IfElseNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IfElseNodeData.model_validate(node_data)
|
||||
|
||||
var_mapping: dict[str, list[str]] = {}
|
||||
for case in node_data.cases or []:
|
||||
for case in typed_node_data.cases or []:
|
||||
for condition in case.conditions:
|
||||
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
|
||||
var_mapping[key] = condition.variable_selector
|
||||
|
|
|
|||
|
|
@ -36,7 +36,8 @@ from core.workflow.graph_engine.entities.event import (
|
|||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from factories.variable_factory import build_segment
|
||||
|
|
@ -56,14 +57,36 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IterationNode(BaseNode[IterationNodeData]):
|
||||
class IterationNode(BaseNode):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
_node_data: IterationNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = IterationNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
return {
|
||||
|
|
@ -83,10 +106,10 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
"""
|
||||
Run the node.
|
||||
"""
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
||||
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
||||
|
||||
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
|
@ -116,10 +139,10 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
|
||||
graph_config = self.graph_config
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
if not self._node_data.start_node_id:
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
|
||||
|
||||
root_node_id = self.node_data.start_node_id
|
||||
root_node_id = self._node_data.start_node_id
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
|
||||
|
|
@ -161,8 +184,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
yield IterationRunStartedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
metadata={"iterator_length": len(iterator_list_value)},
|
||||
|
|
@ -172,8 +195,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
index=0,
|
||||
pre_iteration_output=None,
|
||||
duration=None,
|
||||
|
|
@ -181,11 +204,11 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
iter_run_map: dict[str, float] = {}
|
||||
outputs: list[Any] = [None] * len(iterator_list_value)
|
||||
try:
|
||||
if self.node_data.is_parallel:
|
||||
if self._node_data.is_parallel:
|
||||
futures: list[Future] = []
|
||||
q: Queue = Queue()
|
||||
thread_pool = GraphEngineThreadPool(
|
||||
max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
|
||||
max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
|
||||
)
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
future: Future = thread_pool.submit(
|
||||
|
|
@ -242,7 +265,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
iteration_graph=iteration_graph,
|
||||
iter_run_map=iter_run_map,
|
||||
)
|
||||
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs = [output for output in outputs if output is not None]
|
||||
|
||||
# Flatten the list of lists
|
||||
|
|
@ -253,8 +276,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
yield IterationRunSucceededEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
|
|
@ -278,8 +301,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
|
|
@ -305,21 +328,17 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = IterationNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping: dict[str, Sequence[str]] = {
|
||||
f"{node_id}.input_selector": node_data.iterator_selector,
|
||||
f"{node_id}.input_selector": typed_node_data.iterator_selector,
|
||||
}
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
|
||||
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
|
||||
|
||||
if not iteration_graph:
|
||||
raise IterationGraphNotFoundError("iteration graph not found")
|
||||
|
|
@ -375,7 +394,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
"""
|
||||
if not isinstance(event, BaseNodeEvent):
|
||||
return event
|
||||
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
|
||||
if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
|
||||
event.parallel_mode_run_id = parallel_mode_run_id
|
||||
|
||||
iter_metadata = {
|
||||
|
|
@ -438,12 +457,12 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# iteration run failed
|
||||
if self.node_data.is_parallel:
|
||||
if self._node_data.is_parallel:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
|
|
@ -456,8 +475,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
|
|
@ -478,7 +497,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
|
||||
)
|
||||
if isinstance(event, NodeRunFailedEvent):
|
||||
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
|
|
@ -491,15 +510,15 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=None,
|
||||
duration=duration,
|
||||
)
|
||||
return
|
||||
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
|
|
@ -512,15 +531,15 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=None,
|
||||
duration=duration,
|
||||
)
|
||||
return
|
||||
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
|
||||
elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
|
|
@ -531,12 +550,12 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
variable_pool.remove([node_id])
|
||||
|
||||
# iteration run failed
|
||||
if self.node_data.is_parallel:
|
||||
if self._node_data.is_parallel:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
|
|
@ -549,8 +568,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
|
|
@ -569,7 +588,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
return
|
||||
yield metadata_event
|
||||
|
||||
current_output_segment = variable_pool.get(self.node_data.output_selector)
|
||||
current_output_segment = variable_pool.get(self._node_data.output_selector)
|
||||
if current_output_segment is None:
|
||||
raise IterationNodeError("iteration output selector not found")
|
||||
current_iteration_output = current_output_segment.value
|
||||
|
|
@ -588,8 +607,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=current_iteration_output or None,
|
||||
|
|
@ -601,8 +620,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
iteration_node_type=self.type_,
|
||||
iteration_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": None},
|
||||
|
|
|
|||
|
|
@ -1,18 +1,44 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.iteration.entities import IterationStartNodeData
|
||||
|
||||
|
||||
class IterationStartNode(BaseNode[IterationStartNodeData]):
|
||||
class IterationStartNode(BaseNode):
|
||||
"""
|
||||
Iteration Start Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = IterationStartNodeData
|
||||
_node_type = NodeType.ITERATION_START
|
||||
|
||||
_node_data: IterationStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = IterationStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm.entities import VisionConfig
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
|
|
@ -55,14 +55,6 @@ class MultipleRetrievalConfig(BaseModel):
|
|||
reranking_model: Optional[RerankingModelConfig] = None
|
||||
weights: Optional[WeightedScoreConfig] = None
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = {}
|
||||
|
||||
|
||||
class SingleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Single Retrieval Config.
|
||||
|
|
@ -125,7 +117,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData):
|
|||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
||||
metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled"
|
||||
metadata_model_config: Optional[ModelConfig] = None
|
||||
metadata_model_config: ModelConfig
|
||||
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import re
|
|||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from sqlalchemy import Float, and_, func, or_, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
|
|
@ -15,20 +15,31 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
|||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageRole,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
ModelFeature,
|
||||
ModelType,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables import StringSegment
|
||||
from core.variables import (
|
||||
StringSegment,
|
||||
)
|
||||
from core.variables.segments import ArrayObjectSegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
)
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
||||
|
|
@ -38,7 +49,8 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
|||
METADATA_FILTER_USER_PROMPT_2,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
|
|
@ -46,7 +58,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
|
|||
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData, ModelConfig
|
||||
from .entities import KnowledgeRetrievalNodeData
|
||||
from .exc import (
|
||||
InvalidModelTypeError,
|
||||
KnowledgeRetrievalNodeError,
|
||||
|
|
@ -56,6 +68,10 @@ from .exc import (
|
|||
ModelQuotaExceededError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
|
|
@ -67,18 +83,76 @@ default_retrieval_model = {
|
|||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(LLMNode):
|
||||
_node_data_cls = KnowledgeRetrievalNodeData # type: ignore
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
_node_data: KnowledgeRetrievalNodeData
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=thread_pool_id,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
|
||||
# extract variables
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
||||
if not isinstance(variable, StringSegment):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
|
@ -119,7 +193,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
|
||||
# retrieve knowledge
|
||||
try:
|
||||
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
|
||||
results = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
|
||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
|
@ -435,20 +509,15 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
# get metadata model config
|
||||
metadata_model_config = node_data.metadata_model_config
|
||||
if metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance
|
||||
# fetch model config
|
||||
model_instance, model_config = self.get_model_config(metadata_model_config)
|
||||
# get metadata model instance and fetch model config
|
||||
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
|
||||
# fetch prompt messages
|
||||
prompt_template = self._get_prompt_template(
|
||||
node_data=node_data,
|
||||
metadata_fields=all_metadata_fields,
|
||||
query=query or "",
|
||||
)
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query=query,
|
||||
memory=None,
|
||||
|
|
@ -458,16 +527,23 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=node_data.metadata_model_config, # type: ignore
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=node_data.metadata_model_config,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
|
|
@ -557,17 +633,13 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: KnowledgeRetrievalNodeData, # type: ignore
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
||||
variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
|
||||
return variable_mapping
|
||||
|
||||
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
|
|
@ -629,7 +701,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||
)
|
||||
|
||||
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
|
||||
model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore
|
||||
model_mode = ModelMode(node_data.metadata_model_config.mode)
|
||||
input_text = query
|
||||
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Literal, Union
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
|
|
@ -7,16 +7,39 @@ from core.variables.segments import ArrayAnySegment, ArraySegment
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
from .entities import ListOperatorNodeData
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
_node_data_cls = ListOperatorNodeData
|
||||
class ListOperatorNode(BaseNode):
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
|
||||
_node_data: ListOperatorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = ListOperatorNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -26,9 +49,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||
process_data: dict[str, list] = {}
|
||||
outputs: dict[str, Any] = {}
|
||||
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
|
||||
if variable is None:
|
||||
error_message = f"Variable not found for selector: {self.node_data.variable}"
|
||||
error_message = f"Variable not found for selector: {self._node_data.variable}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
|
|
@ -48,7 +71,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||
)
|
||||
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
|
||||
error_message = (
|
||||
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
|
||||
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
|
||||
"or ArrayStringSegment"
|
||||
)
|
||||
return NodeRunResult(
|
||||
|
|
@ -64,19 +87,19 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||
|
||||
try:
|
||||
# Filter
|
||||
if self.node_data.filter_by.enabled:
|
||||
if self._node_data.filter_by.enabled:
|
||||
variable = self._apply_filter(variable)
|
||||
|
||||
# Extract
|
||||
if self.node_data.extract_by.enabled:
|
||||
if self._node_data.extract_by.enabled:
|
||||
variable = self._extract_slice(variable)
|
||||
|
||||
# Order
|
||||
if self.node_data.order_by.enabled:
|
||||
if self._node_data.order_by.enabled:
|
||||
variable = self._apply_order(variable)
|
||||
|
||||
# Slice
|
||||
if self.node_data.limit.enabled:
|
||||
if self._node_data.limit.enabled:
|
||||
variable = self._apply_slice(variable)
|
||||
|
||||
outputs = {
|
||||
|
|
@ -104,7 +127,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
filter_func: Callable[[Any], bool]
|
||||
result: list[Any] = []
|
||||
for condition in self.node_data.filter_by.conditions:
|
||||
for condition in self._node_data.filter_by.conditions:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
if not isinstance(condition.value, str):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
|
|
@ -137,14 +160,14 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
|
||||
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
result = _order_number(order=self.node_data.order_by.value, array=variable.value)
|
||||
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
result = _order_file(
|
||||
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
|
||||
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
return variable
|
||||
|
|
@ -152,13 +175,13 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||
def _apply_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
result = variable.value[: self.node_data.limit.size]
|
||||
result = variable.value[: self._node_data.limit.size]
|
||||
return variable.model_copy(update={"value": result})
|
||||
|
||||
def _extract_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
|
||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
|
||||
if value < 1:
|
||||
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
|
||||
value -= 1
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
|
@ -65,7 +65,7 @@ class LLMNodeData(BaseNodeData):
|
|||
memory: Optional[MemoryConfig] = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: dict | None = None
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
|||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
NodeEvent,
|
||||
|
|
@ -90,17 +91,16 @@ from .file_saver import FileSaverImpl, LLMFileSaver
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMNode(BaseNode[LLMNodeData]):
|
||||
_node_data_cls = LLMNodeData
|
||||
class LLMNode(BaseNode):
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
_node_data: LLMNodeData
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
|
|
@ -138,6 +138,27 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = LLMNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -152,13 +173,13 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
|
||||
try:
|
||||
# init messages template
|
||||
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
|
||||
self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
|
||||
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data=self.node_data)
|
||||
inputs = self._fetch_inputs(node_data=self._node_data)
|
||||
|
||||
# fetch jinja2 inputs
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
|
||||
|
||||
# merge inputs
|
||||
inputs.update(jinja_inputs)
|
||||
|
|
@ -169,9 +190,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
files = (
|
||||
llm_utils.fetch_files(
|
||||
variable_pool=variable_pool,
|
||||
selector=self.node_data.vision.configs.variable_selector,
|
||||
selector=self._node_data.vision.configs.variable_selector,
|
||||
)
|
||||
if self.node_data.vision.enabled
|
||||
if self._node_data.vision.enabled
|
||||
else []
|
||||
)
|
||||
|
||||
|
|
@ -179,7 +200,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
node_inputs["#files#"] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
generator = self._fetch_context(node_data=self.node_data)
|
||||
generator = self._fetch_context(node_data=self._node_data)
|
||||
context = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunRetrieverResourceEvent):
|
||||
|
|
@ -189,44 +210,54 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
node_inputs["#context#"] = context
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(self.node_data.model)
|
||||
model_instance, model_config = LLMNode._fetch_model_config(
|
||||
node_data_model=self._node_data.model,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=self.node_data.memory,
|
||||
node_data_memory=self._node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
query = None
|
||||
if self.node_data.memory:
|
||||
query = self.node_data.memory.query_prompt_template
|
||||
if self._node_data.memory:
|
||||
query = self._node_data.memory.query_prompt_template
|
||||
if not query and (
|
||||
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
):
|
||||
query = query_variable.text
|
||||
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
prompt_template=self._node_data.prompt_template,
|
||||
memory_config=self._node_data.memory,
|
||||
vision_enabled=self._node_data.vision.enabled,
|
||||
vision_detail=self._node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
jinja2_variables=self._node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
node_data_model=self.node_data.model,
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=self._node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
structured_output=self._node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
|
@ -296,12 +327,19 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
)
|
||||
)
|
||||
|
||||
def _invoke_llm(
|
||||
self,
|
||||
@staticmethod
|
||||
def invoke_llm(
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
user_id: str,
|
||||
structured_output_enabled: bool,
|
||||
structured_output: Optional[Mapping[str, Any]] = None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
node_id: str,
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
|
|
@ -309,8 +347,10 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {node_data_model.name}")
|
||||
|
||||
if self.node_data.structured_output_enabled:
|
||||
output_schema = self._fetch_structured_output_schema()
|
||||
if structured_output_enabled:
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=structured_output or {},
|
||||
)
|
||||
invoke_result = invoke_llm_with_structured_output(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
|
|
@ -320,7 +360,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
user=user_id,
|
||||
)
|
||||
else:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
|
|
@ -328,17 +368,31 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
model_parameters=node_data_model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=self.user_id,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
return self._handle_invoke_result(invoke_result=invoke_result)
|
||||
return LLMNode.handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
file_saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def _handle_invoke_result(
|
||||
self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None]
|
||||
@staticmethod
|
||||
def handle_invoke_result(
|
||||
*,
|
||||
invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
node_id: str,
|
||||
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
|
||||
# For blocking mode
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
event = self._handle_blocking_result(invoke_result=invoke_result)
|
||||
event = LLMNode.handle_blocking_result(
|
||||
invoke_result=invoke_result,
|
||||
saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
)
|
||||
yield event
|
||||
return
|
||||
|
||||
|
|
@ -356,11 +410,13 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
yield result
|
||||
if isinstance(result, LLMResultChunk):
|
||||
contents = result.delta.message.content
|
||||
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
|
||||
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
|
||||
contents=contents,
|
||||
file_saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
):
|
||||
full_text_buffer.write(text_part)
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=text_part, from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"])
|
||||
|
||||
# Update the whole metadata
|
||||
if not model and result.model:
|
||||
|
|
@ -378,7 +434,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
|
||||
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
|
||||
|
||||
def _image_file_to_markdown(self, file: "File", /):
|
||||
@staticmethod
|
||||
def _image_file_to_markdown(file: "File", /):
|
||||
text_chunk = f"})"
|
||||
return text_chunk
|
||||
|
||||
|
|
@ -539,11 +596,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _fetch_model_config(
|
||||
self, node_data_model: ModelConfig
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
tenant_id: str,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model, model_config_with_cred = llm_utils.fetch_model_config(
|
||||
tenant_id=self.tenant_id, node_data_model=node_data_model
|
||||
tenant_id=tenant_id, node_data_model=node_data_model
|
||||
)
|
||||
completion_params = model_config_with_cred.parameters
|
||||
|
||||
|
|
@ -556,8 +616,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
node_data_model.completion_params = completion_params
|
||||
return model, model_config_with_cred
|
||||
|
||||
def _fetch_prompt_messages(
|
||||
self,
|
||||
@staticmethod
|
||||
def fetch_prompt_messages(
|
||||
*,
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence["File"],
|
||||
|
|
@ -570,13 +630,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
tenant_id: str,
|
||||
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
prompt_messages.extend(
|
||||
self._handle_list_messages(
|
||||
LLMNode.handle_list_messages(
|
||||
messages=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
|
|
@ -602,7 +663,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
edition_type="basic",
|
||||
)
|
||||
prompt_messages.extend(
|
||||
self._handle_list_messages(
|
||||
LLMNode.handle_list_messages(
|
||||
messages=[message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
|
|
@ -731,7 +792,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
)
|
||||
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.provider,
|
||||
model=model_config.model,
|
||||
|
|
@ -750,10 +811,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LLMNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
prompt_template = node_data.prompt_template
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LLMNodeData.model_validate(node_data)
|
||||
|
||||
prompt_template = typed_node_data.prompt_template
|
||||
variable_selectors = []
|
||||
if isinstance(prompt_template, list) and all(
|
||||
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
|
||||
|
|
@ -773,7 +836,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
memory = node_data.memory
|
||||
memory = typed_node_data.memory
|
||||
if memory and memory.query_prompt_template:
|
||||
query_variable_selectors = VariableTemplateParser(
|
||||
template=memory.query_prompt_template
|
||||
|
|
@ -781,16 +844,16 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
for variable_selector in query_variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
if node_data.context.enabled:
|
||||
variable_mapping["#context#"] = node_data.context.variable_selector
|
||||
if typed_node_data.context.enabled:
|
||||
variable_mapping["#context#"] = typed_node_data.context.variable_selector
|
||||
|
||||
if node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = node_data.vision.configs.variable_selector
|
||||
if typed_node_data.vision.enabled:
|
||||
variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
|
||||
|
||||
if node_data.memory:
|
||||
if typed_node_data.memory:
|
||||
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
|
||||
|
||||
if node_data.prompt_config:
|
||||
if typed_node_data.prompt_config:
|
||||
enable_jinja = False
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
|
|
@ -803,7 +866,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
enable_jinja = True
|
||||
|
||||
if enable_jinja:
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
|
@ -835,8 +898,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
},
|
||||
}
|
||||
|
||||
def _handle_list_messages(
|
||||
self,
|
||||
@staticmethod
|
||||
def handle_list_messages(
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: Optional[str],
|
||||
|
|
@ -849,7 +912,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinjia2_variables=jinja2_variables,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
|
|
@ -897,9 +960,19 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
|
||||
return prompt_messages
|
||||
|
||||
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
|
||||
@staticmethod
|
||||
def handle_blocking_result(
|
||||
*,
|
||||
invoke_result: LLMResult,
|
||||
saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
) -> ModelInvokeCompletedEvent:
|
||||
buffer = io.StringIO()
|
||||
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
|
||||
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
|
||||
contents=invoke_result.message.content,
|
||||
file_saver=saver,
|
||||
file_outputs=file_outputs,
|
||||
):
|
||||
buffer.write(text_part)
|
||||
|
||||
return ModelInvokeCompletedEvent(
|
||||
|
|
@ -908,7 +981,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
finish_reason=None,
|
||||
)
|
||||
|
||||
def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
|
||||
@staticmethod
|
||||
def save_multimodal_image_output(
|
||||
*,
|
||||
content: ImagePromptMessageContent,
|
||||
file_saver: LLMFileSaver,
|
||||
) -> "File":
|
||||
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
|
||||
|
||||
There are two kinds of multimodal outputs:
|
||||
|
|
@ -918,26 +996,21 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
|
||||
Currently, only image files are supported.
|
||||
"""
|
||||
# Inject the saver somehow...
|
||||
_saver = self._llm_file_saver
|
||||
|
||||
# If this
|
||||
if content.url != "":
|
||||
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
|
||||
saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE)
|
||||
else:
|
||||
saved_file = _saver.save_binary_string(
|
||||
saved_file = file_saver.save_binary_string(
|
||||
data=base64.b64decode(content.base64_data),
|
||||
mime_type=content.mime_type,
|
||||
file_type=FileType.IMAGE,
|
||||
)
|
||||
self._file_outputs.append(saved_file)
|
||||
return saved_file
|
||||
|
||||
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
|
||||
"""
|
||||
Fetch model schema
|
||||
"""
|
||||
model_name = self.node_data.model.name
|
||||
model_name = self._node_data.model.name
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
|
||||
|
|
@ -948,16 +1021,20 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_schema
|
||||
|
||||
def _fetch_structured_output_schema(self) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def fetch_structured_output_schema(
|
||||
*,
|
||||
structured_output: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch the structured output schema from the node data.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The structured output schema
|
||||
"""
|
||||
if not self.node_data.structured_output:
|
||||
if not structured_output:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
|
||||
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
|
||||
if not structured_output_schema:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
|
||||
|
|
@ -969,9 +1046,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
except json.JSONDecodeError:
|
||||
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||
|
||||
@staticmethod
|
||||
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||
self,
|
||||
*,
|
||||
contents: str | list[PromptMessageContentUnionTypes] | None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
) -> Generator[str, None, None]:
|
||||
"""Convert intermediate prompt messages into strings and yield them to the caller.
|
||||
|
||||
|
|
@ -994,9 +1074,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
if isinstance(item, TextPromptMessageContent):
|
||||
yield item.data
|
||||
elif isinstance(item, ImagePromptMessageContent):
|
||||
file = self._save_multimodal_image_output(item)
|
||||
self._file_outputs.append(file)
|
||||
yield self._image_file_to_markdown(file)
|
||||
file = LLMNode.save_multimodal_image_output(
|
||||
content=item,
|
||||
file_saver=file_saver,
|
||||
)
|
||||
file_outputs.append(file)
|
||||
yield LLMNode._image_file_to_markdown(file)
|
||||
else:
|
||||
logger.warning("unknown item type encountered, type=%s", type(item))
|
||||
yield str(item)
|
||||
|
|
@ -1004,6 +1087,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
logger.warning("unknown contents type encountered, type=%s", type(contents))
|
||||
yield str(contents)
|
||||
|
||||
@property
|
||||
def continue_on_error(self) -> bool:
|
||||
return self._node_data.error_strategy is not None
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
|
||||
|
|
@ -1021,20 +1112,20 @@ def _combine_message_content_with_role(
|
|||
def _render_jinja2_message(
|
||||
*,
|
||||
template: str,
|
||||
jinjia2_variables: Sequence[VariableSelector],
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
):
|
||||
if not template:
|
||||
return ""
|
||||
|
||||
jinjia2_inputs = {}
|
||||
for jinja2_variable in jinjia2_variables:
|
||||
jinja2_inputs = {}
|
||||
for jinja2_variable in jinja2_variables:
|
||||
variable = variable_pool.get(jinja2_variable.value_selector)
|
||||
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
||||
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
||||
code_execute_resp = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2,
|
||||
code=template,
|
||||
inputs=jinjia2_inputs,
|
||||
inputs=jinja2_inputs,
|
||||
)
|
||||
result_text = code_execute_resp["result"]
|
||||
return result_text
|
||||
|
|
@ -1130,7 +1221,7 @@ def _handle_completion_template(
|
|||
if template.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=template.jinja2_text or "",
|
||||
jinjia2_variables=jinja2_variables,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,18 +1,44 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopEndNodeData
|
||||
|
||||
|
||||
class LoopEndNode(BaseNode[LoopEndNodeData]):
|
||||
class LoopEndNode(BaseNode):
|
||||
"""
|
||||
Loop End Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = LoopEndNodeData
|
||||
_node_type = NodeType.LOOP_END
|
||||
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = LoopEndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import (
|
||||
|
|
@ -30,7 +30,8 @@ from core.workflow.graph_engine.entities.event import (
|
|||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
|
@ -43,14 +44,36 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopNode(BaseNode[LoopNodeData]):
|
||||
class LoopNode(BaseNode):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
_node_data: LoopNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = LoopNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -58,17 +81,17 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""Run the node."""
|
||||
# Get inputs
|
||||
loop_count = self.node_data.loop_count
|
||||
break_conditions = self.node_data.break_conditions
|
||||
logical_operator = self.node_data.logical_operator
|
||||
loop_count = self._node_data.loop_count
|
||||
break_conditions = self._node_data.break_conditions
|
||||
logical_operator = self._node_data.logical_operator
|
||||
|
||||
inputs = {"loop_count": loop_count}
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
if not self._node_data.start_node_id:
|
||||
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
|
||||
|
||||
# Initialize graph
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id)
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
|
||||
|
|
@ -78,8 +101,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
|
||||
# Initialize loop variables
|
||||
loop_variable_selectors = {}
|
||||
if self.node_data.loop_variables:
|
||||
for loop_variable in self.node_data.loop_variables:
|
||||
if self._node_data.loop_variables:
|
||||
for loop_variable in self._node_data.loop_variables:
|
||||
value_processor = {
|
||||
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"variable": lambda var=loop_variable: variable_pool.get(var.value),
|
||||
|
|
@ -127,8 +150,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
yield LoopRunStartedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
metadata={"loop_length": loop_count},
|
||||
|
|
@ -184,11 +207,11 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
yield LoopRunSucceededEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs=self.node_data.outputs,
|
||||
outputs=self._node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
|
|
@ -206,7 +229,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
outputs=self.node_data.outputs,
|
||||
outputs=self._node_data.outputs,
|
||||
inputs=inputs,
|
||||
)
|
||||
)
|
||||
|
|
@ -217,8 +240,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=loop_count,
|
||||
|
|
@ -320,8 +343,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
|
|
@ -351,8 +374,8 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=current_index,
|
||||
|
|
@ -388,7 +411,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
_outputs[loop_variable_key] = None
|
||||
|
||||
_outputs["loop_round"] = current_index + 1
|
||||
self.node_data.outputs = _outputs
|
||||
self._node_data.outputs = _outputs
|
||||
|
||||
if check_break_result:
|
||||
return {"check_break_result": True}
|
||||
|
|
@ -400,10 +423,10 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
yield LoopRunNextEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
loop_node_type=self.type_,
|
||||
loop_node_data=self._node_data,
|
||||
index=next_index,
|
||||
pre_loop_output=self.node_data.outputs,
|
||||
pre_loop_output=self._node_data.outputs,
|
||||
)
|
||||
|
||||
return {"check_break_result": False}
|
||||
|
|
@ -438,19 +461,15 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LoopNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = LoopNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {}
|
||||
|
||||
# init graph
|
||||
loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
|
||||
loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
|
||||
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
|
|
@ -486,7 +505,7 @@ class LoopNode(BaseNode[LoopNodeData]):
|
|||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
for loop_variable in node_data.loop_variables or []:
|
||||
for loop_variable in typed_node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
|
|
|
|||
|
|
@ -1,18 +1,44 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopStartNodeData
|
||||
|
||||
|
||||
class LoopStartNode(BaseNode[LoopStartNodeData]):
|
||||
class LoopStartNode(BaseNode):
|
||||
"""
|
||||
Loop Start Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = LoopStartNodeData
|
||||
_node_type = NodeType.LOOP_START
|
||||
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = LoopStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -29,8 +29,9 @@ from core.variables.types import SegmentType
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
|
@ -91,10 +92,31 @@ class ParameterExtractorNode(BaseNode):
|
|||
Parameter Extractor Node.
|
||||
"""
|
||||
|
||||
# FIXME: figure out why here is different from super class
|
||||
_node_data_cls = ParameterExtractorNodeData # type: ignore
|
||||
_node_type = NodeType.PARAMETER_EXTRACTOR
|
||||
|
||||
_node_data: ParameterExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = ParameterExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
_model_instance: Optional[ModelInstance] = None
|
||||
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
|
||||
|
||||
|
|
@ -119,7 +141,7 @@ class ParameterExtractorNode(BaseNode):
|
|||
"""
|
||||
Run the node.
|
||||
"""
|
||||
node_data = cast(ParameterExtractorNodeData, self.node_data)
|
||||
node_data = cast(ParameterExtractorNodeData, self._node_data)
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
|
||||
query = variable.text if variable else ""
|
||||
|
||||
|
|
@ -398,7 +420,7 @@ class ParameterExtractorNode(BaseNode):
|
|||
"""
|
||||
Generate prompt engineering prompt.
|
||||
"""
|
||||
model_mode = ModelMode.value_of(data.model.mode)
|
||||
model_mode = ModelMode(data.model.mode)
|
||||
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
return self._generate_prompt_engineering_completion_prompt(
|
||||
|
|
@ -694,7 +716,7 @@ class ParameterExtractorNode(BaseNode):
|
|||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
|
|
@ -721,7 +743,7 @@ class ParameterExtractorNode(BaseNode):
|
|||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
input_text = query
|
||||
memory_str = ""
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
|
|
@ -827,19 +849,15 @@ class ParameterExtractorNode(BaseNode):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ParameterExtractorNodeData, # type: ignore
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
|
||||
|
||||
if node_data.instruction:
|
||||
selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
|
||||
variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
|
||||
|
||||
if typed_node_data.instruction:
|
||||
selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
|
||||
for selector in selectors:
|
||||
variable_mapping[selector.variable] = selector.value_selector
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
|
@ -11,8 +11,11 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import ModelInvokeCompletedEvent
|
||||
from core.workflow.nodes.llm import (
|
||||
LLMNode,
|
||||
|
|
@ -20,6 +23,7 @@ from core.workflow.nodes.llm import (
|
|||
LLMNodeCompletionModelPromptTemplate,
|
||||
llm_utils,
|
||||
)
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
|
|
@ -35,17 +39,77 @@ from .template_prompts import (
|
|||
QUESTION_CLASSIFIER_USER_PROMPT_3,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
|
||||
class QuestionClassifierNode(LLMNode):
|
||||
_node_data_cls = QuestionClassifierNodeData # type: ignore
|
||||
|
||||
class QuestionClassifierNode(BaseNode):
|
||||
_node_type = NodeType.QUESTION_CLASSIFIER
|
||||
|
||||
_node_data: QuestionClassifierNodeData
|
||||
|
||||
_file_outputs: list["File"]
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=thread_pool_id,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = QuestionClassifierNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
node_data = cast(QuestionClassifierNodeData, self.node_data)
|
||||
node_data = cast(QuestionClassifierNodeData, self._node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# extract variables
|
||||
|
|
@ -53,7 +117,10 @@ class QuestionClassifierNode(LLMNode):
|
|||
query = variable.value if variable else None
|
||||
variables = {"query": query}
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
model_instance, model_config = LLMNode._fetch_model_config(
|
||||
node_data_model=node_data.model,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
|
|
@ -91,7 +158,7 @@ class QuestionClassifierNode(LLMNode):
|
|||
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
|
||||
# two consecutive user prompts will be generated, causing model's error.
|
||||
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query="",
|
||||
memory=memory,
|
||||
|
|
@ -101,6 +168,7 @@ class QuestionClassifierNode(LLMNode):
|
|||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
|
|
@ -109,11 +177,17 @@ class QuestionClassifierNode(LLMNode):
|
|||
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = self._invoke_llm(
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=False,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
|
|
@ -183,23 +257,18 @@ class QuestionClassifierNode(LLMNode):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Any,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(QuestionClassifierNodeData, node_data)
|
||||
variable_mapping = {"query": node_data.query_variable_selector}
|
||||
variable_selectors = []
|
||||
if node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
|
||||
|
||||
variable_mapping = {"query": typed_node_data.query_variable_selector}
|
||||
variable_selectors: list[VariableSelector] = []
|
||||
if typed_node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
|
||||
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
|
||||
for variable_selector in variable_selectors:
|
||||
variable_mapping[variable_selector.variable] = variable_selector.value_selector
|
||||
variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
|
||||
|
||||
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
|
||||
|
||||
|
|
@ -265,7 +334,7 @@ class QuestionClassifierNode(LLMNode):
|
|||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
classes = node_data.classes
|
||||
categories = []
|
||||
for class_ in classes:
|
||||
|
|
|
|||
|
|
@ -1,15 +1,41 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
|
||||
|
||||
class StartNode(BaseNode[StartNodeData]):
|
||||
_node_data_cls = StartNodeData
|
||||
class StartNode(BaseNode):
|
||||
_node_type = NodeType.START
|
||||
|
||||
_node_data: StartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = StartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -6,16 +6,39 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
|
||||
|
||||
|
||||
class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
|
||||
_node_data_cls = TemplateTransformNodeData
|
||||
class TemplateTransformNode(BaseNode):
|
||||
_node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
|
||||
_node_data: TemplateTransformNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
|
|
@ -35,14 +58,14 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
|
|||
def _run(self) -> NodeRunResult:
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in self.node_data.variables:
|
||||
for variable_selector in self._node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
variables[variable_name] = value.to_object() if value else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
|
||||
language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
|
@ -60,16 +83,12 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
|
|||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
|
||||
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = TemplateTransformNodeData.model_validate(node_data)
|
||||
|
||||
return {
|
||||
node_id + "." + variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in node_data.variables
|
||||
for variable_selector in typed_node_data.variables
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
|
|||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
|
|
@ -19,10 +18,10 @@ from core.workflow.entities.node_entities import NodeRunResult
|
|||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
|
|
@ -37,14 +36,18 @@ from .exc import (
|
|||
)
|
||||
|
||||
|
||||
class ToolNode(BaseNode[ToolNodeData]):
|
||||
class ToolNode(BaseNode):
|
||||
"""
|
||||
Tool Node
|
||||
"""
|
||||
|
||||
_node_data_cls = ToolNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
_node_data: ToolNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = ToolNodeData.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -54,7 +57,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
Run the tool node
|
||||
"""
|
||||
|
||||
node_data = cast(ToolNodeData, self.node_data)
|
||||
node_data = cast(ToolNodeData, self._node_data)
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
|
|
@ -67,9 +70,9 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
try:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None
|
||||
variable_pool = self.graph_runtime_state.variable_pool if self._node_data.version != "1" else None
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool
|
||||
self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
|
|
@ -88,12 +91,12 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
node_data=self._node_data,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
node_data=self._node_data,
|
||||
for_log=True,
|
||||
)
|
||||
# get conversation id
|
||||
|
|
@ -124,7 +127,14 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
|
||||
try:
|
||||
# convert tool messages
|
||||
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
tool_info=tool_info,
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
except (PluginDaemonClientSideError, ToolInvokeError) as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
|
|
@ -191,7 +201,9 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
agent_thoughts: Optional[list] = None,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
|
|
@ -199,8 +211,8 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
# transform message and handle file storage
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
|
|
@ -208,9 +220,6 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage: LLMUsage | None = None
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
|
|
@ -243,7 +252,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
|
|
@ -266,45 +275,36 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if self.node_type == NodeType.AGENT:
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(msg_metadata)
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
# JSON message handling for tool node
|
||||
if message.message.json_object is not None:
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
|
||||
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
|
||||
chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
|
|
@ -319,7 +319,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(self.tenant_id)
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
|
|
@ -334,8 +334,8 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
self.user_id,
|
||||
self.tenant_id,
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
|
|
@ -347,57 +347,10 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
id=message.message.id,
|
||||
node_execution_id=self.id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.id == agent_log.id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
elif message.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
|
||||
assert isinstance(message.message, ToolInvokeMessage.RetrieverResourceMessage)
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=message.message.retriever_resources,
|
||||
context=message.message.context,
|
||||
)
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json:
|
||||
json_output.extend(json)
|
||||
|
|
@ -409,12 +362,9 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -424,7 +374,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ToolNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
|
|
@ -433,9 +383,12 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ToolNodeData.model_validate(node_data)
|
||||
|
||||
result = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
for parameter_name in typed_node_data.tool_parameters:
|
||||
input = typed_node_data.tool_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
|
|
@ -449,3 +402,29 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@property
|
||||
def continue_on_error(self) -> bool:
|
||||
return self._node_data.error_strategy is not None
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
|
|
|||
|
|
@ -1,17 +1,41 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
|
||||
|
||||
|
||||
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
class VariableAggregatorNode(BaseNode):
|
||||
_node_type = NodeType.VARIABLE_AGGREGATOR
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = VariableAssignerNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -21,8 +45,8 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
|||
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
|
||||
inputs = {}
|
||||
|
||||
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
|
||||
for selector in self.node_data.variables:
|
||||
if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled:
|
||||
for selector in self._node_data.variables:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if variable is not None:
|
||||
outputs = {"output": variable}
|
||||
|
|
@ -30,7 +54,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
|
|||
inputs = {".".join(selector[1:]): variable.to_object()}
|
||||
break
|
||||
else:
|
||||
for group in self.node_data.advanced_settings.groups:
|
||||
for group in self._node_data.advanced_settings.groups:
|
||||
for selector in group.variables:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from factories import variable_factory
|
||||
|
|
@ -22,11 +23,33 @@ if TYPE_CHECKING:
|
|||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
_node_data_cls = VariableAssignerData
|
||||
class VariableAssignerNode(BaseNode):
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
||||
|
||||
_node_data: VariableAssignerData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = VariableAssignerData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
|
|
@ -59,36 +82,39 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: VariableAssignerData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
mapping = {}
|
||||
assigned_variable_node_id = node_data.assigned_variable_selector[0]
|
||||
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
selector_key = ".".join(node_data.assigned_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = node_data.assigned_variable_selector
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = VariableAssignerData.model_validate(node_data)
|
||||
|
||||
selector_key = ".".join(node_data.input_variable_selector)
|
||||
mapping = {}
|
||||
assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
|
||||
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
selector_key = ".".join(typed_node_data.assigned_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = typed_node_data.assigned_variable_selector
|
||||
|
||||
selector_key = ".".join(typed_node_data.input_variable_selector)
|
||||
key = f"{node_id}.#{selector_key}#"
|
||||
mapping[key] = node_data.input_variable_selector
|
||||
mapping[key] = typed_node_data.input_variable_selector
|
||||
return mapping
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
assigned_variable_selector = self.node_data.assigned_variable_selector
|
||||
assigned_variable_selector = self._node_data.assigned_variable_selector
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
match self.node_data.write_mode:
|
||||
match self._node_data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
|
|
@ -101,7 +127,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
|||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from collections.abc import Callable, Mapping, MutableMapping, Sequence
|
||||
from typing import Any, TypeAlias, cast
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import SegmentType, Variable
|
||||
|
|
@ -10,7 +10,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda
|
|||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
|
||||
|
|
@ -28,8 +29,6 @@ from .exc import (
|
|||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
|
||||
|
||||
|
||||
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
|
||||
selector_node_id = item.variable_selector[0]
|
||||
|
|
@ -54,10 +53,32 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
|||
mapping[key] = selector
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
class VariableAssignerNode(BaseNode):
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
|
||||
return conversation_variable_updater_factory()
|
||||
|
||||
|
|
@ -71,22 +92,25 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
|||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: VariableAssignerNodeData,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = VariableAssignerNodeData.model_validate(node_data)
|
||||
|
||||
var_mapping: dict[str, Sequence[str]] = {}
|
||||
for item in node_data.items:
|
||||
for item in typed_node_data.items:
|
||||
_target_mapping_from_item(var_mapping, node_id, item)
|
||||
_source_mapping_from_item(var_mapping, node_id, item)
|
||||
return var_mapping
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
inputs = self.node_data.model_dump()
|
||||
inputs = self._node_data.model_dump()
|
||||
process_data: dict[str, Any] = {}
|
||||
# NOTE: This node has no outputs
|
||||
updated_variable_selectors: list[Sequence[str]] = []
|
||||
|
||||
try:
|
||||
for item in self.node_data.items:
|
||||
for item in self._node_data.items:
|
||||
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
|
||||
|
||||
# ==================== Validation Part
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Protocol
|
||||
from typing import Protocol
|
||||
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
|
||||
|
|
@ -28,15 +28,3 @@ class WorkflowExecutionRepository(Protocol):
|
|||
execution: The WorkflowExecution instance to save or update
|
||||
"""
|
||||
...
|
||||
|
||||
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
|
||||
"""
|
||||
Retrieve a WorkflowExecution by its ID.
|
||||
|
||||
Args:
|
||||
execution_id: The workflow execution ID
|
||||
|
||||
Returns:
|
||||
The WorkflowExecution instance if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -39,18 +39,6 @@ class WorkflowNodeExecutionRepository(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve a NodeExecution by its node_execution_id.
|
||||
|
||||
Args:
|
||||
node_execution_id: The node execution ID
|
||||
|
||||
Returns:
|
||||
The NodeExecution instance if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
|
|
@ -69,24 +57,3 @@ class WorkflowNodeExecutionRepository(Protocol):
|
|||
A list of NodeExecution instances
|
||||
"""
|
||||
...
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all running NodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
A list of running NodeExecution instances
|
||||
"""
|
||||
...
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all NodeExecution records based on implementation-specific criteria.
|
||||
|
||||
This method is intended to be used for bulk deletion operations, such as removing
|
||||
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -55,24 +55,15 @@ class WorkflowCycleManager:
|
|||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
# Initialize caches for workflow execution cycle
|
||||
# These caches avoid redundant repository calls during a single workflow execution
|
||||
self._workflow_execution_cache: dict[str, WorkflowExecution] = {}
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
|
||||
def handle_workflow_run_start(self) -> WorkflowExecution:
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
inputs = self._prepare_workflow_inputs()
|
||||
execution_id = self._get_or_generate_execution_id()
|
||||
|
||||
# Iterate over SystemVariable fields using Pydantic's model_fields
|
||||
if self._workflow_system_variables:
|
||||
for field_name, value in self._workflow_system_variables.to_dict().items():
|
||||
if field_name == SystemVariableKey.CONVERSATION_ID:
|
||||
continue
|
||||
inputs[f"sys.{field_name}"] = value
|
||||
|
||||
# handle special values
|
||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
# init workflow run
|
||||
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
||||
execution_id = str(
|
||||
self._workflow_system_variables.workflow_execution_id if self._workflow_system_variables else None
|
||||
) or str(uuid4())
|
||||
execution = WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id=self._workflow_info.workflow_id,
|
||||
|
|
@ -83,9 +74,7 @@ class WorkflowCycleManager:
|
|||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
|
||||
return execution
|
||||
return self._save_and_cache_workflow_execution(execution)
|
||||
|
||||
def handle_workflow_run_success(
|
||||
self,
|
||||
|
|
@ -99,23 +88,15 @@ class WorkflowCycleManager:
|
|||
) -> WorkflowExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
|
||||
# outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
self._update_workflow_execution_completion(
|
||||
workflow_execution,
|
||||
status=WorkflowExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
total_tokens=total_tokens,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
|
||||
workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
|
||||
workflow_execution.outputs = outputs or {}
|
||||
workflow_execution.total_tokens = total_tokens
|
||||
workflow_execution.total_steps = total_steps
|
||||
workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_execution=workflow_execution,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
)
|
||||
)
|
||||
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id)
|
||||
|
||||
self._workflow_execution_repository.save(workflow_execution)
|
||||
return workflow_execution
|
||||
|
|
@ -132,24 +113,17 @@ class WorkflowCycleManager:
|
|||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowExecution:
|
||||
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
# outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
|
||||
|
||||
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
|
||||
execution.outputs = outputs or {}
|
||||
execution.total_tokens = total_tokens
|
||||
execution.total_steps = total_steps
|
||||
execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
execution.exceptions_count = exceptions_count
|
||||
self._update_workflow_execution_completion(
|
||||
execution,
|
||||
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
|
||||
outputs=outputs,
|
||||
total_tokens=total_tokens,
|
||||
total_steps=total_steps,
|
||||
exceptions_count=exceptions_count,
|
||||
)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_execution=execution,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
)
|
||||
)
|
||||
self._add_trace_task_if_needed(trace_manager, execution, conversation_id)
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
return execution
|
||||
|
|
@ -169,39 +143,18 @@ class WorkflowCycleManager:
|
|||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
|
||||
now = naive_utc_now()
|
||||
|
||||
workflow_execution.status = WorkflowExecutionStatus(status.value)
|
||||
workflow_execution.error_message = error_message
|
||||
workflow_execution.total_tokens = total_tokens
|
||||
workflow_execution.total_steps = total_steps
|
||||
workflow_execution.finished_at = now
|
||||
workflow_execution.exceptions_count = exceptions_count
|
||||
|
||||
# Use the instance repository to find running executions for a workflow run
|
||||
running_node_executions = self._workflow_node_execution_repository.get_running_executions(
|
||||
workflow_run_id=workflow_execution.id_
|
||||
self._update_workflow_execution_completion(
|
||||
workflow_execution,
|
||||
status=status,
|
||||
total_tokens=total_tokens,
|
||||
total_steps=total_steps,
|
||||
error_message=error_message,
|
||||
exceptions_count=exceptions_count,
|
||||
finished_at=now,
|
||||
)
|
||||
|
||||
# Update the domain models
|
||||
for node_execution in running_node_executions:
|
||||
if node_execution.node_execution_id:
|
||||
# Update the domain model
|
||||
node_execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||
node_execution.error = error_message
|
||||
node_execution.finished_at = now
|
||||
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
|
||||
|
||||
# Update the repository with the domain model
|
||||
self._workflow_node_execution_repository.save(node_execution)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_execution=workflow_execution,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
)
|
||||
)
|
||||
self._fail_running_node_executions(workflow_execution.id_, error_message, now)
|
||||
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id)
|
||||
|
||||
self._workflow_execution_repository.save(workflow_execution)
|
||||
return workflow_execution
|
||||
|
|
@ -214,8 +167,198 @@ class WorkflowCycleManager:
|
|||
) -> WorkflowNodeExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
|
||||
|
||||
# Create a domain model
|
||||
created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
domain_execution = self._create_node_execution_from_event(
|
||||
workflow_execution=workflow_execution,
|
||||
event=event,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
)
|
||||
|
||||
return self._save_and_cache_node_execution(domain_execution)
|
||||
|
||||
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
|
||||
|
||||
self._update_node_execution_completion(
|
||||
domain_execution,
|
||||
event=event,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
)
|
||||
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
return domain_execution
|
||||
|
||||
def handle_workflow_node_execution_failed(
|
||||
self,
|
||||
*,
|
||||
event: QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
|
||||
|
||||
status = (
|
||||
WorkflowNodeExecutionStatus.EXCEPTION
|
||||
if isinstance(event, QueueNodeExceptionEvent)
|
||||
else WorkflowNodeExecutionStatus.FAILED
|
||||
)
|
||||
|
||||
self._update_node_execution_completion(
|
||||
domain_execution,
|
||||
event=event,
|
||||
status=status,
|
||||
error=event.error,
|
||||
handle_special_values=True,
|
||||
)
|
||||
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
return domain_execution
|
||||
|
||||
def handle_workflow_node_execution_retried(
|
||||
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
|
||||
|
||||
domain_execution = self._create_node_execution_from_event(
|
||||
workflow_execution=workflow_execution,
|
||||
event=event,
|
||||
status=WorkflowNodeExecutionStatus.RETRY,
|
||||
error=event.error,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
|
||||
# Handle inputs and outputs
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = event.outputs
|
||||
metadata = self._merge_event_metadata(event)
|
||||
|
||||
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata)
|
||||
|
||||
return self._save_and_cache_node_execution(domain_execution)
|
||||
|
||||
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
|
||||
# Check cache first
|
||||
if id in self._workflow_execution_cache:
|
||||
return self._workflow_execution_cache[id]
|
||||
|
||||
raise WorkflowRunNotFoundError(id)
|
||||
|
||||
def _prepare_workflow_inputs(self) -> dict[str, Any]:
|
||||
"""Prepare workflow inputs by merging application inputs with system variables."""
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
|
||||
if self._workflow_system_variables:
|
||||
for field_name, value in self._workflow_system_variables.to_dict().items():
|
||||
if field_name != SystemVariableKey.CONVERSATION_ID:
|
||||
inputs[f"sys.{field_name}"] = value
|
||||
|
||||
return dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
def _get_or_generate_execution_id(self) -> str:
|
||||
"""Get execution ID from system variables or generate a new one."""
|
||||
if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id:
|
||||
return str(self._workflow_system_variables.workflow_execution_id)
|
||||
return str(uuid4())
|
||||
|
||||
def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution:
|
||||
"""Save workflow execution to repository and cache it."""
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._workflow_execution_cache[execution.id_] = execution
|
||||
return execution
|
||||
|
||||
def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution:
|
||||
"""Save node execution to repository and cache it if it has an ID."""
|
||||
self._workflow_node_execution_repository.save(execution)
|
||||
if execution.node_execution_id:
|
||||
self._node_execution_cache[execution.node_execution_id] = execution
|
||||
return execution
|
||||
|
||||
def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
"""Get node execution from cache or raise error if not found."""
|
||||
domain_execution = self._node_execution_cache.get(node_execution_id)
|
||||
if not domain_execution:
|
||||
raise ValueError(f"Domain node execution not found: {node_execution_id}")
|
||||
return domain_execution
|
||||
|
||||
def _update_workflow_execution_completion(
|
||||
self,
|
||||
execution: WorkflowExecution,
|
||||
*,
|
||||
status: WorkflowExecutionStatus,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
error_message: Optional[str] = None,
|
||||
exceptions_count: int = 0,
|
||||
finished_at: Optional[datetime] = None,
|
||||
) -> None:
|
||||
"""Update workflow execution with completion data."""
|
||||
execution.status = status
|
||||
execution.outputs = outputs or {}
|
||||
execution.total_tokens = total_tokens
|
||||
execution.total_steps = total_steps
|
||||
execution.finished_at = finished_at or naive_utc_now()
|
||||
execution.exceptions_count = exceptions_count
|
||||
if error_message:
|
||||
execution.error_message = error_message
|
||||
|
||||
def _add_trace_task_if_needed(
|
||||
self,
|
||||
trace_manager: Optional[TraceQueueManager],
|
||||
workflow_execution: WorkflowExecution,
|
||||
conversation_id: Optional[str],
|
||||
) -> None:
|
||||
"""Add trace task if trace manager is provided."""
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_execution=workflow_execution,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
)
|
||||
)
|
||||
|
||||
def _fail_running_node_executions(
|
||||
self,
|
||||
workflow_execution_id: str,
|
||||
error_message: str,
|
||||
now: datetime,
|
||||
) -> None:
|
||||
"""Fail all running node executions for a workflow."""
|
||||
running_node_executions = [
|
||||
node_exec
|
||||
for node_exec in self._node_execution_cache.values()
|
||||
if node_exec.workflow_execution_id == workflow_execution_id
|
||||
and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
|
||||
]
|
||||
|
||||
for node_execution in running_node_executions:
|
||||
if node_execution.node_execution_id:
|
||||
node_execution.status = WorkflowNodeExecutionStatus.FAILED
|
||||
node_execution.error = error_message
|
||||
node_execution.finished_at = now
|
||||
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
|
||||
self._workflow_node_execution_repository.save(node_execution)
|
||||
|
||||
def _create_node_execution_from_event(
|
||||
self,
|
||||
*,
|
||||
workflow_execution: WorkflowExecution,
|
||||
event: Union[QueueNodeStartedEvent, QueueNodeRetryEvent],
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
error: Optional[str] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a node execution from an event."""
|
||||
now = datetime.now(UTC).replace(tzinfo=None)
|
||||
created_at = created_at or now
|
||||
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
|
|
@ -232,152 +375,76 @@ class WorkflowCycleManager:
|
|||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=event.node_data.title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
status=status,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
error=error,
|
||||
)
|
||||
|
||||
# Use the instance repository to save the domain model
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
if status == WorkflowNodeExecutionStatus.RETRY:
|
||||
domain_execution.finished_at = now
|
||||
domain_execution.elapsed_time = (now - created_at).total_seconds()
|
||||
|
||||
return domain_execution
|
||||
|
||||
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
# Get the domain model from repository
|
||||
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
|
||||
if not domain_execution:
|
||||
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
|
||||
|
||||
# Process data
|
||||
inputs = event.inputs
|
||||
process_data = event.process_data
|
||||
outputs = event.outputs
|
||||
|
||||
# Convert metadata keys to strings
|
||||
execution_metadata_dict = {}
|
||||
if event.execution_metadata:
|
||||
for key, value in event.execution_metadata.items():
|
||||
execution_metadata_dict[key] = value
|
||||
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
# Update domain model
|
||||
domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
domain_execution.update_from_mapping(
|
||||
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
|
||||
)
|
||||
domain_execution.finished_at = finished_at
|
||||
domain_execution.elapsed_time = elapsed_time
|
||||
|
||||
# Update the repository with the domain model
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
|
||||
return domain_execution
|
||||
|
||||
def handle_workflow_node_execution_failed(
|
||||
def _update_node_execution_completion(
|
||||
self,
|
||||
domain_execution: WorkflowNodeExecution,
|
||||
*,
|
||||
event: QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
# Get the domain model from repository
|
||||
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
|
||||
if not domain_execution:
|
||||
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
|
||||
|
||||
# Process data
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = event.outputs
|
||||
|
||||
# Convert metadata keys to strings
|
||||
execution_metadata_dict = {}
|
||||
if event.execution_metadata:
|
||||
for key, value in event.execution_metadata.items():
|
||||
execution_metadata_dict[key] = value
|
||||
|
||||
event: Union[
|
||||
QueueNodeSucceededEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
],
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
error: Optional[str] = None,
|
||||
handle_special_values: bool = False,
|
||||
) -> None:
|
||||
"""Update node execution with completion data."""
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
# Process data
|
||||
if handle_special_values:
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
else:
|
||||
inputs = event.inputs
|
||||
process_data = event.process_data
|
||||
|
||||
outputs = event.outputs
|
||||
|
||||
# Convert metadata
|
||||
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
if event.execution_metadata:
|
||||
execution_metadata_dict.update(event.execution_metadata)
|
||||
|
||||
# Update domain model
|
||||
domain_execution.status = (
|
||||
WorkflowNodeExecutionStatus.FAILED
|
||||
if not isinstance(event, QueueNodeExceptionEvent)
|
||||
else WorkflowNodeExecutionStatus.EXCEPTION
|
||||
)
|
||||
domain_execution.error = event.error
|
||||
domain_execution.status = status
|
||||
domain_execution.update_from_mapping(
|
||||
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
metadata=execution_metadata_dict,
|
||||
)
|
||||
domain_execution.finished_at = finished_at
|
||||
domain_execution.elapsed_time = elapsed_time
|
||||
|
||||
# Update the repository with the domain model
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
if error:
|
||||
domain_execution.error = error
|
||||
|
||||
return domain_execution
|
||||
|
||||
def handle_workflow_node_execution_retried(
|
||||
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
|
||||
created_at = event.start_at
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = event.outputs
|
||||
|
||||
# Convert metadata keys to strings
|
||||
def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]:
|
||||
"""Merge event metadata with origin metadata."""
|
||||
origin_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
|
||||
# Convert execution metadata keys to strings
|
||||
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
|
||||
if event.execution_metadata:
|
||||
for key, value in event.execution_metadata.items():
|
||||
execution_metadata_dict[key] = value
|
||||
execution_metadata_dict.update(event.execution_metadata)
|
||||
|
||||
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
|
||||
|
||||
# Create a domain model
|
||||
domain_execution = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
workflow_execution_id=workflow_execution.id_,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
node_execution_id=event.node_execution_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
title=event.node_data.title,
|
||||
status=WorkflowNodeExecutionStatus.RETRY,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
elapsed_time=elapsed_time,
|
||||
error=event.error,
|
||||
index=event.node_run_index,
|
||||
)
|
||||
|
||||
# Update with mappings
|
||||
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata)
|
||||
|
||||
# Use the instance repository to save the domain model
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
|
||||
return domain_execution
|
||||
|
||||
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
|
||||
execution = self._workflow_execution_repository.get(id)
|
||||
if not execution:
|
||||
raise WorkflowRunNotFoundError(id)
|
||||
return execution
|
||||
return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from collections.abc import Generator, Mapping, Sequence
|
|||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.workflow.callbacks import WorkflowCallback
|
||||
|
|
@ -146,7 +146,7 @@ class WorkflowEntry:
|
|||
graph = Graph.init(graph_config=workflow.graph_dict)
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls(
|
||||
node = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
|
|
@ -190,17 +190,11 @@ class WorkflowEntry:
|
|||
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
generator = node.run()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
|
||||
workflow.id,
|
||||
node_instance.id,
|
||||
node_instance.node_type,
|
||||
node_instance.version(),
|
||||
)
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
return node_instance, generator
|
||||
logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}")
|
||||
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
|
||||
return node, generator
|
||||
|
||||
@classmethod
|
||||
def run_free_node(
|
||||
|
|
@ -262,7 +256,7 @@ class WorkflowEntry:
|
|||
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
# init workflow run state
|
||||
node_instance: BaseNode = node_cls(
|
||||
node: BaseNode = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
|
|
@ -297,17 +291,12 @@ class WorkflowEntry:
|
|||
)
|
||||
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
generator = node.run()
|
||||
|
||||
return node_instance, generator
|
||||
return node, generator
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"error while running node_instance, node_id=%s, type=%s, version=%s",
|
||||
node_instance.id,
|
||||
node_instance.node_type,
|
||||
node_instance.version(),
|
||||
)
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}")
|
||||
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
|
||||
|
||||
@staticmethod
|
||||
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import hashlib
|
||||
from typing import Union
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.PublicKey import RSA
|
||||
|
|
@ -9,7 +10,7 @@ from extensions.ext_storage import storage
|
|||
from libs import gmpy2_pkcs10aep_cipher
|
||||
|
||||
|
||||
def generate_key_pair(tenant_id):
|
||||
def generate_key_pair(tenant_id: str) -> str:
|
||||
private_key = RSA.generate(2048)
|
||||
public_key = private_key.publickey()
|
||||
|
||||
|
|
@ -26,7 +27,7 @@ def generate_key_pair(tenant_id):
|
|||
prefix_hybrid = b"HYBRID:"
|
||||
|
||||
|
||||
def encrypt(text, public_key):
|
||||
def encrypt(text: str, public_key: Union[str, bytes]) -> bytes:
|
||||
if isinstance(public_key, str):
|
||||
public_key = public_key.encode()
|
||||
|
||||
|
|
@ -38,14 +39,14 @@ def encrypt(text, public_key):
|
|||
rsa_key = RSA.import_key(public_key)
|
||||
cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
|
||||
|
||||
enc_aes_key = cipher_rsa.encrypt(aes_key)
|
||||
enc_aes_key: bytes = cipher_rsa.encrypt(aes_key)
|
||||
|
||||
encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext
|
||||
|
||||
return prefix_hybrid + encrypted_data
|
||||
|
||||
|
||||
def get_decrypt_decoding(tenant_id):
|
||||
def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]:
|
||||
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
|
||||
|
||||
cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
|
||||
|
|
@ -64,7 +65,7 @@ def get_decrypt_decoding(tenant_id):
|
|||
return rsa_key, cipher_rsa
|
||||
|
||||
|
||||
def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa):
|
||||
def decrypt_token_with_decoding(encrypted_text: bytes, rsa_key: RSA.RsaKey, cipher_rsa) -> str:
|
||||
if encrypted_text.startswith(prefix_hybrid):
|
||||
encrypted_text = encrypted_text[len(prefix_hybrid) :]
|
||||
|
||||
|
|
@ -83,10 +84,10 @@ def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa):
|
|||
return decrypted_text.decode()
|
||||
|
||||
|
||||
def decrypt(encrypted_text, tenant_id):
|
||||
def decrypt(encrypted_text: bytes, tenant_id: str) -> str:
|
||||
rsa_key, cipher_rsa = get_decrypt_decoding(tenant_id)
|
||||
|
||||
return decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa)
|
||||
return decrypt_token_with_decoding(encrypted_text=encrypted_text, rsa_key=rsa_key, cipher_rsa=cipher_rsa)
|
||||
|
||||
|
||||
class PrivkeyNotFoundError(Exception):
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ class Tenant(Base):
|
|||
__tablename__ = "tenants"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
encrypt_public_key = db.Column(db.Text)
|
||||
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
|
||||
|
|
|
|||
|
|
@ -347,21 +347,33 @@ class ToolTransformService:
|
|||
)
|
||||
|
||||
# get tool parameters
|
||||
parameters = tool.entity.parameters or []
|
||||
base_parameters = tool.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = tool.get_runtime_parameters()
|
||||
# override parameters
|
||||
current_parameters = parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
# merge parameters using a functional approach to avoid type issues
|
||||
merged_parameters: list[ToolParameter] = []
|
||||
|
||||
# create a mapping of runtime parameters for quick lookup
|
||||
runtime_param_map = {(rp.name, rp.form): rp for rp in runtime_parameters}
|
||||
|
||||
# process base parameters, replacing with runtime versions if they exist
|
||||
for base_param in base_parameters:
|
||||
key = (base_param.name, base_param.form)
|
||||
if key in runtime_param_map:
|
||||
merged_parameters.append(runtime_param_map[key])
|
||||
else:
|
||||
merged_parameters.append(base_param)
|
||||
|
||||
# add any runtime parameters that weren't in base parameters
|
||||
for runtime_parameter in runtime_parameters:
|
||||
if runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# check if this parameter is already in merged_parameters
|
||||
already_exists = any(
|
||||
p.name == runtime_parameter.name and p.form == runtime_parameter.form for p in merged_parameters
|
||||
)
|
||||
if not already_exists:
|
||||
merged_parameters.append(runtime_parameter)
|
||||
|
||||
return ToolApiEntity(
|
||||
author=tool.entity.identity.author,
|
||||
|
|
@ -369,10 +381,10 @@ class ToolTransformService:
|
|||
label=tool.entity.identity.label,
|
||||
description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
|
||||
output_schema=tool.entity.output_schema,
|
||||
parameters=current_parameters,
|
||||
parameters=merged_parameters,
|
||||
labels=labels or [],
|
||||
)
|
||||
if isinstance(tool, ApiToolBundle):
|
||||
elif isinstance(tool, ApiToolBundle):
|
||||
return ToolApiEntity(
|
||||
author=tool.author,
|
||||
name=tool.operation_id or "",
|
||||
|
|
@ -381,6 +393,9 @@ class ToolTransformService:
|
|||
parameters=tool.parameters,
|
||||
labels=labels or [],
|
||||
)
|
||||
else:
|
||||
# Handle WorkflowTool case
|
||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||
|
||||
@staticmethod
|
||||
def convert_builtin_provider_to_credential_entity(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
import json
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
from flask_login import current_user
|
||||
|
|
@ -13,241 +14,392 @@ from extensions.ext_storage import storage
|
|||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
|
||||
class WebsiteService:
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
if "url" not in args or not args["url"]:
|
||||
raise ValueError("url is required")
|
||||
if "options" not in args or not args["options"]:
|
||||
raise ValueError("options is required")
|
||||
if "limit" not in args["options"] or not args["options"]["limit"]:
|
||||
raise ValueError("limit is required")
|
||||
@dataclass
|
||||
class CrawlOptions:
|
||||
"""Options for crawling operations."""
|
||||
|
||||
limit: int = 1
|
||||
crawl_sub_pages: bool = False
|
||||
only_main_content: bool = False
|
||||
includes: Optional[str] = None
|
||||
excludes: Optional[str] = None
|
||||
max_depth: Optional[int] = None
|
||||
use_sitemap: bool = True
|
||||
|
||||
def get_include_paths(self) -> list[str]:
|
||||
"""Get list of include paths from comma-separated string."""
|
||||
return self.includes.split(",") if self.includes else []
|
||||
|
||||
def get_exclude_paths(self) -> list[str]:
|
||||
"""Get list of exclude paths from comma-separated string."""
|
||||
return self.excludes.split(",") if self.excludes else []
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrawlRequest:
|
||||
"""Request container for crawling operations."""
|
||||
|
||||
url: str
|
||||
provider: str
|
||||
options: CrawlOptions
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScrapeRequest:
|
||||
"""Request container for scraping operations."""
|
||||
|
||||
provider: str
|
||||
url: str
|
||||
tenant_id: str
|
||||
only_main_content: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebsiteCrawlApiRequest:
|
||||
"""Request container for website crawl API arguments."""
|
||||
|
||||
provider: str
|
||||
url: str
|
||||
options: dict[str, Any]
|
||||
|
||||
def to_crawl_request(self) -> CrawlRequest:
|
||||
"""Convert API request to internal CrawlRequest."""
|
||||
options = CrawlOptions(
|
||||
limit=self.options.get("limit", 1),
|
||||
crawl_sub_pages=self.options.get("crawl_sub_pages", False),
|
||||
only_main_content=self.options.get("only_main_content", False),
|
||||
includes=self.options.get("includes"),
|
||||
excludes=self.options.get("excludes"),
|
||||
max_depth=self.options.get("max_depth"),
|
||||
use_sitemap=self.options.get("use_sitemap", True),
|
||||
)
|
||||
return CrawlRequest(url=self.url, provider=self.provider, options=options)
|
||||
|
||||
@classmethod
|
||||
def crawl_url(cls, args: dict) -> dict:
|
||||
provider = args.get("provider", "")
|
||||
def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest":
|
||||
"""Create from Flask-RESTful parsed arguments."""
|
||||
provider = args.get("provider")
|
||||
url = args.get("url")
|
||||
options = args.get("options", "")
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
crawl_sub_pages = options.get("crawl_sub_pages", False)
|
||||
only_main_content = options.get("only_main_content", False)
|
||||
if not crawl_sub_pages:
|
||||
params = {
|
||||
"includePaths": [],
|
||||
"excludePaths": [],
|
||||
"limit": 1,
|
||||
"scrapeOptions": {"onlyMainContent": only_main_content},
|
||||
}
|
||||
else:
|
||||
includes = options.get("includes").split(",") if options.get("includes") else []
|
||||
excludes = options.get("excludes").split(",") if options.get("excludes") else []
|
||||
params = {
|
||||
"includePaths": includes,
|
||||
"excludePaths": excludes,
|
||||
"limit": options.get("limit", 1),
|
||||
"scrapeOptions": {"onlyMainContent": only_main_content},
|
||||
}
|
||||
if options.get("max_depth"):
|
||||
params["maxDepth"] = options.get("max_depth")
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
time = str(datetime.datetime.now().timestamp())
|
||||
redis_client.setex(website_crawl_time_cache_key, 3600, time)
|
||||
return {"status": "active", "job_id": job_id}
|
||||
elif provider == "watercrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).crawl_url(url, options)
|
||||
options = args.get("options", {})
|
||||
|
||||
elif provider == "jinareader":
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
crawl_sub_pages = options.get("crawl_sub_pages", False)
|
||||
if not crawl_sub_pages:
|
||||
response = requests.get(
|
||||
f"https://r.jina.ai/{url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
if response.json().get("code") != 200:
|
||||
raise ValueError("Failed to crawl")
|
||||
return {"status": "active", "data": response.json().get("data")}
|
||||
else:
|
||||
response = requests.post(
|
||||
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
|
||||
json={
|
||||
"url": url,
|
||||
"maxPages": options.get("limit", 1),
|
||||
"useSitemap": options.get("use_sitemap", True),
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
},
|
||||
)
|
||||
if response.json().get("code") != 200:
|
||||
raise ValueError("Failed to crawl")
|
||||
return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
|
||||
if not provider:
|
||||
raise ValueError("Provider is required")
|
||||
if not url:
|
||||
raise ValueError("URL is required")
|
||||
if not options:
|
||||
raise ValueError("Options are required")
|
||||
|
||||
return cls(provider=provider, url=url, options=options)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebsiteCrawlStatusApiRequest:
|
||||
"""Request container for website crawl status API arguments."""
|
||||
|
||||
provider: str
|
||||
job_id: str
|
||||
|
||||
@classmethod
|
||||
def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest":
|
||||
"""Create from Flask-RESTful parsed arguments."""
|
||||
provider = args.get("provider")
|
||||
|
||||
if not provider:
|
||||
raise ValueError("Provider is required")
|
||||
if not job_id:
|
||||
raise ValueError("Job ID is required")
|
||||
|
||||
return cls(provider=provider, job_id=job_id)
|
||||
|
||||
|
||||
class WebsiteService:
|
||||
"""Service class for website crawling operations using different providers."""
|
||||
|
||||
@classmethod
|
||||
def _get_credentials_and_config(cls, tenant_id: str, provider: str) -> tuple[dict, dict]:
|
||||
"""Get and validate credentials for a provider."""
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
if not credentials or "config" not in credentials:
|
||||
raise ValueError("No valid credentials found for the provider")
|
||||
return credentials, credentials["config"]
|
||||
|
||||
@classmethod
|
||||
def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str:
|
||||
"""Decrypt and return the API key from config."""
|
||||
api_key = config.get("api_key")
|
||||
if not api_key:
|
||||
raise ValueError("API key not found in configuration")
|
||||
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict) -> None:
|
||||
"""Validate arguments for document creation."""
|
||||
try:
|
||||
WebsiteCrawlApiRequest.from_args(args)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid arguments: {e}")
|
||||
|
||||
@classmethod
|
||||
def crawl_url(cls, api_request: WebsiteCrawlApiRequest) -> dict[str, Any]:
|
||||
"""Crawl a URL using the specified provider with typed request."""
|
||||
request = api_request.to_crawl_request()
|
||||
|
||||
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider)
|
||||
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
|
||||
|
||||
if request.provider == "firecrawl":
|
||||
return cls._crawl_with_firecrawl(request=request, api_key=api_key, config=config)
|
||||
elif request.provider == "watercrawl":
|
||||
return cls._crawl_with_watercrawl(request=request, api_key=api_key, config=config)
|
||||
elif request.provider == "jinareader":
|
||||
return cls._crawl_with_jinareader(request=request, api_key=api_key)
|
||||
else:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data = {
|
||||
"status": result.get("status", "active"),
|
||||
"job_id": job_id,
|
||||
"total": result.get("total", 0),
|
||||
"current": result.get("current", 0),
|
||||
"data": result.get("data", []),
|
||||
def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
|
||||
if not request.options.crawl_sub_pages:
|
||||
params = {
|
||||
"includePaths": [],
|
||||
"excludePaths": [],
|
||||
"limit": 1,
|
||||
"scrapeOptions": {"onlyMainContent": request.options.only_main_content},
|
||||
}
|
||||
if crawl_status_data["status"] == "completed":
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
start_time = redis_client.get(website_crawl_time_cache_key)
|
||||
if start_time:
|
||||
end_time = datetime.datetime.now().timestamp()
|
||||
time_consuming = abs(end_time - float(start_time))
|
||||
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
|
||||
redis_client.delete(website_crawl_time_cache_key)
|
||||
elif provider == "watercrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
else:
|
||||
params = {
|
||||
"includePaths": request.options.get_include_paths(),
|
||||
"excludePaths": request.options.get_exclude_paths(),
|
||||
"limit": request.options.limit,
|
||||
"scrapeOptions": {"onlyMainContent": request.options.only_main_content},
|
||||
}
|
||||
if request.options.max_depth:
|
||||
params["maxDepth"] = request.options.max_depth
|
||||
|
||||
job_id = firecrawl_app.crawl_url(request.url, params)
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
time = str(datetime.datetime.now().timestamp())
|
||||
redis_client.setex(website_crawl_time_cache_key, 3600, time)
|
||||
return {"status": "active", "job_id": job_id}
|
||||
|
||||
@classmethod
|
||||
def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
# Convert CrawlOptions back to dict format for WaterCrawlProvider
|
||||
options = {
|
||||
"limit": request.options.limit,
|
||||
"crawl_sub_pages": request.options.crawl_sub_pages,
|
||||
"only_main_content": request.options.only_main_content,
|
||||
"includes": request.options.includes,
|
||||
"excludes": request.options.excludes,
|
||||
"max_depth": request.options.max_depth,
|
||||
"use_sitemap": request.options.use_sitemap,
|
||||
}
|
||||
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
|
||||
url=request.url, options=options
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
|
||||
if not request.options.crawl_sub_pages:
|
||||
response = requests.get(
|
||||
f"https://r.jina.ai/{request.url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
crawl_status_data = WaterCrawlProvider(
|
||||
api_key, credentials.get("config").get("base_url", None)
|
||||
).get_crawl_status(job_id)
|
||||
elif provider == "jinareader":
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
if response.json().get("code") != 200:
|
||||
raise ValueError("Failed to crawl")
|
||||
return {"status": "active", "data": response.json().get("data")}
|
||||
else:
|
||||
response = requests.post(
|
||||
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
|
||||
json={
|
||||
"url": request.url,
|
||||
"maxPages": request.options.limit,
|
||||
"useSitemap": request.options.use_sitemap,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
},
|
||||
)
|
||||
if response.json().get("code") != 200:
|
||||
raise ValueError("Failed to crawl")
|
||||
return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]:
|
||||
"""Get crawl status using string parameters."""
|
||||
api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id)
|
||||
return cls.get_crawl_status_typed(api_request)
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]:
|
||||
"""Get crawl status using typed request."""
|
||||
_, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider)
|
||||
api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
|
||||
|
||||
if api_request.provider == "firecrawl":
|
||||
return cls._get_firecrawl_status(api_request.job_id, api_key, config)
|
||||
elif api_request.provider == "watercrawl":
|
||||
return cls._get_watercrawl_status(api_request.job_id, api_key, config)
|
||||
elif api_request.provider == "jinareader":
|
||||
return cls._get_jinareader_status(api_request.job_id, api_key)
|
||||
else:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data = {
|
||||
"status": result.get("status", "active"),
|
||||
"job_id": job_id,
|
||||
"total": result.get("total", 0),
|
||||
"current": result.get("current", 0),
|
||||
"data": result.get("data", []),
|
||||
}
|
||||
if crawl_status_data["status"] == "completed":
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
start_time = redis_client.get(website_crawl_time_cache_key)
|
||||
if start_time:
|
||||
end_time = datetime.datetime.now().timestamp()
|
||||
time_consuming = abs(end_time - float(start_time))
|
||||
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
|
||||
redis_client.delete(website_crawl_time_cache_key)
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
|
||||
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)
|
||||
|
||||
@classmethod
|
||||
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
|
||||
response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
)
|
||||
data = response.json().get("data", {})
|
||||
crawl_status_data = {
|
||||
"status": data.get("status", "active"),
|
||||
"job_id": job_id,
|
||||
"total": len(data.get("urls", [])),
|
||||
"current": len(data.get("processed", [])) + len(data.get("failed", [])),
|
||||
"data": [],
|
||||
"time_consuming": data.get("duration", 0) / 1000,
|
||||
}
|
||||
|
||||
if crawl_status_data["status"] == "completed":
|
||||
response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
|
||||
)
|
||||
data = response.json().get("data", {})
|
||||
crawl_status_data = {
|
||||
"status": data.get("status", "active"),
|
||||
"job_id": job_id,
|
||||
"total": len(data.get("urls", [])),
|
||||
"current": len(data.get("processed", [])) + len(data.get("failed", [])),
|
||||
"data": [],
|
||||
"time_consuming": data.get("duration", 0) / 1000,
|
||||
}
|
||||
|
||||
if crawl_status_data["status"] == "completed":
|
||||
response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
|
||||
)
|
||||
data = response.json().get("data", {})
|
||||
formatted_data = [
|
||||
{
|
||||
"title": item.get("data", {}).get("title"),
|
||||
"source_url": item.get("data", {}).get("url"),
|
||||
"description": item.get("data", {}).get("description"),
|
||||
"markdown": item.get("data", {}).get("content"),
|
||||
}
|
||||
for item in data.get("processed", {}).values()
|
||||
]
|
||||
crawl_status_data["data"] = formatted_data
|
||||
else:
|
||||
raise ValueError("Invalid provider")
|
||||
formatted_data = [
|
||||
{
|
||||
"title": item.get("data", {}).get("title"),
|
||||
"source_url": item.get("data", {}).get("url"),
|
||||
"description": item.get("data", {}).get("description"),
|
||||
"markdown": item.get("data", {}).get("content"),
|
||||
}
|
||||
for item in data.get("processed", {}).values()
|
||||
]
|
||||
crawl_status_data["data"] = formatted_data
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
_, config = cls._get_credentials_and_config(tenant_id, provider)
|
||||
api_key = cls._get_decrypted_api_key(tenant_id, config)
|
||||
|
||||
if provider == "firecrawl":
|
||||
crawl_data: list[dict[str, Any]] | None = None
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
stored_data = storage.load_once(file_key)
|
||||
if stored_data:
|
||||
crawl_data = json.loads(stored_data.decode("utf-8"))
|
||||
else:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
if result.get("status") != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
crawl_data = result.get("data")
|
||||
|
||||
if crawl_data:
|
||||
for item in crawl_data:
|
||||
if item.get("source_url") == url:
|
||||
return dict(item)
|
||||
return None
|
||||
return cls._get_firecrawl_url_data(job_id, url, api_key, config)
|
||||
elif provider == "watercrawl":
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).get_crawl_url_data(
|
||||
job_id, url
|
||||
)
|
||||
return cls._get_watercrawl_url_data(job_id, url, api_key, config)
|
||||
elif provider == "jinareader":
|
||||
if not job_id:
|
||||
response = requests.get(
|
||||
f"https://r.jina.ai/{url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
if response.json().get("code") != 200:
|
||||
raise ValueError("Failed to crawl")
|
||||
return dict(response.json().get("data", {}))
|
||||
else:
|
||||
# Get crawl status first
|
||||
status_response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
)
|
||||
status_data = status_response.json().get("data", {})
|
||||
if status_data.get("status") != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
|
||||
# Get processed data
|
||||
data_response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
|
||||
)
|
||||
processed_data = data_response.json().get("data", {})
|
||||
for item in processed_data.get("processed", {}).values():
|
||||
if item.get("data", {}).get("url") == url:
|
||||
return dict(item.get("data", {}))
|
||||
return None
|
||||
return cls._get_jinareader_url_data(job_id, url, api_key)
|
||||
else:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
params = {"onlyMainContent": only_main_content}
|
||||
result = firecrawl_app.scrape_url(url, params)
|
||||
return result
|
||||
elif provider == "watercrawl":
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).scrape_url(url)
|
||||
def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
|
||||
crawl_data: list[dict[str, Any]] | None = None
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
stored_data = storage.load_once(file_key)
|
||||
if stored_data:
|
||||
crawl_data = json.loads(stored_data.decode("utf-8"))
|
||||
else:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
if result.get("status") != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
crawl_data = result.get("data")
|
||||
|
||||
if crawl_data:
|
||||
for item in crawl_data:
|
||||
if item.get("source_url") == url:
|
||||
return dict(item)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
|
||||
return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
|
||||
|
||||
@classmethod
|
||||
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
|
||||
if not job_id:
|
||||
response = requests.get(
|
||||
f"https://r.jina.ai/{url}",
|
||||
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
if response.json().get("code") != 200:
|
||||
raise ValueError("Failed to crawl")
|
||||
return dict(response.json().get("data", {}))
|
||||
else:
|
||||
# Get crawl status first
|
||||
status_response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id},
|
||||
)
|
||||
status_data = status_response.json().get("data", {})
|
||||
if status_data.get("status") != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
|
||||
# Get processed data
|
||||
data_response = requests.post(
|
||||
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
|
||||
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||||
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
|
||||
)
|
||||
processed_data = data_response.json().get("data", {})
|
||||
for item in processed_data.get("processed", {}).values():
|
||||
if item.get("data", {}).get("url") == url:
|
||||
return dict(item.get("data", {}))
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict[str, Any]:
|
||||
request = ScrapeRequest(provider=provider, url=url, tenant_id=tenant_id, only_main_content=only_main_content)
|
||||
|
||||
_, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, provider=request.provider)
|
||||
api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config)
|
||||
|
||||
if request.provider == "firecrawl":
|
||||
return cls._scrape_with_firecrawl(request=request, api_key=api_key, config=config)
|
||||
elif request.provider == "watercrawl":
|
||||
return cls._scrape_with_watercrawl(request=request, api_key=api_key, config=config)
|
||||
else:
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
|
||||
params = {"onlyMainContent": request.only_main_content}
|
||||
return firecrawl_app.scrape_url(url=request.url, params=params)
|
||||
|
||||
@classmethod
|
||||
def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
|
||||
return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url)
|
||||
|
|
|
|||
|
|
@ -466,10 +466,10 @@ class WorkflowService:
|
|||
node_id: str,
|
||||
) -> WorkflowNodeExecution:
|
||||
try:
|
||||
node_instance, generator = invoke_node_fn()
|
||||
node, node_events = invoke_node_fn()
|
||||
|
||||
node_run_result: NodeRunResult | None = None
|
||||
for event in generator:
|
||||
for event in node_events:
|
||||
if isinstance(event, RunCompletedEvent):
|
||||
node_run_result = event.run_result
|
||||
|
||||
|
|
@ -480,18 +480,18 @@ class WorkflowService:
|
|||
if not node_run_result:
|
||||
raise ValueError("Node run failed with no run result")
|
||||
# single step debug mode error handling return
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.continue_on_error:
|
||||
node_error_args: dict[str, Any] = {
|
||||
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
"error": node_run_result.error,
|
||||
"inputs": node_run_result.inputs,
|
||||
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
|
||||
"metadata": {"error_strategy": node.error_strategy},
|
||||
}
|
||||
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
node_run_result = NodeRunResult(
|
||||
**node_error_args,
|
||||
outputs={
|
||||
**node_instance.node_data.default_value_dict,
|
||||
**node.default_value_dict,
|
||||
"error_message": node_run_result.error,
|
||||
"error_type": node_run_result.error_type,
|
||||
},
|
||||
|
|
@ -510,10 +510,10 @@ class WorkflowService:
|
|||
)
|
||||
error = node_run_result.error if not run_succeeded else None
|
||||
except WorkflowNodeRunFailedError as e:
|
||||
node_instance = e.node_instance
|
||||
node = e._node
|
||||
run_succeeded = False
|
||||
node_run_result = None
|
||||
error = e.error
|
||||
error = e._error
|
||||
|
||||
# Create a NodeExecution domain model
|
||||
node_execution = WorkflowNodeExecution(
|
||||
|
|
@ -521,8 +521,8 @@ class WorkflowService:
|
|||
workflow_id="", # This is a single-step execution, so no workflow ID
|
||||
index=1,
|
||||
node_id=node_id,
|
||||
node_type=node_instance.node_type,
|
||||
title=node_instance.node_data.title,
|
||||
node_type=node.type_,
|
||||
title=node.title,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
finished_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class WorkspaceService:
|
|||
assert tenant_account_join is not None, "TenantAccountJoin not found"
|
||||
tenant_info["role"] = tenant_account_join.role
|
||||
|
||||
can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo
|
||||
can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo
|
||||
|
||||
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
|
||||
base_url = dify_config.FILES_URL
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ def get_mocked_fetch_model_config(
|
|||
mode: str,
|
||||
credentials: dict,
|
||||
):
|
||||
model_provider_factory = ModelProviderFactory(tenant_id="test_tenant")
|
||||
model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
|
||||
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
configuration=ProviderConfiguration(
|
||||
|
|
|
|||
|
|
@ -66,6 +66,10 @@ def init_code_node(code_config: dict):
|
|||
config=code_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in code_config:
|
||||
node.init_node_data(code_config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
|
|
@ -234,10 +238,10 @@ def test_execute_code_output_validator_depth():
|
|||
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
|
||||
}
|
||||
|
||||
node.node_data = cast(CodeNodeData, node.node_data)
|
||||
node._node_data = cast(CodeNodeData, node._node_data)
|
||||
|
||||
# validate
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
|
|
@ -250,7 +254,7 @@ def test_execute_code_output_validator_depth():
|
|||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
|
|
@ -263,7 +267,7 @@ def test_execute_code_output_validator_depth():
|
|||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
|
|
@ -276,7 +280,7 @@ def test_execute_code_output_validator_depth():
|
|||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
|
||||
def test_execute_code_output_object_list():
|
||||
|
|
@ -330,10 +334,10 @@ def test_execute_code_output_object_list():
|
|||
]
|
||||
}
|
||||
|
||||
node.node_data = cast(CodeNodeData, node.node_data)
|
||||
node._node_data = cast(CodeNodeData, node._node_data)
|
||||
|
||||
# validate
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
|
|
@ -353,7 +357,7 @@ def test_execute_code_output_object_list():
|
|||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node.node_data.outputs)
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
|
||||
def test_execute_code_scientific_notation():
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ def init_http_node(config: dict):
|
|||
variable_pool.add(["a", "b123", "args1"], 1)
|
||||
variable_pool.add(["a", "b123", "args2"], 2)
|
||||
|
||||
return HttpRequestNode(
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
|
|
@ -60,6 +60,12 @@ def init_http_node(config: dict):
|
|||
config=config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in config:
|
||||
node.init_node_data(config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
|
||||
def test_get(setup_http_mock):
|
||||
|
|
|
|||
|
|
@ -2,15 +2,10 @@ import json
|
|||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
|
|
@ -24,8 +19,6 @@ from models.enums import UserFrom
|
|||
from models.workflow import WorkflowType
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
|
||||
def init_llm_node(config: dict) -> LLMNode:
|
||||
|
|
@ -84,10 +77,14 @@ def init_llm_node(config: dict) -> LLMNode:
|
|||
config=config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in config:
|
||||
node.init_node_data(config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
def test_execute_llm(flask_req_ctx):
|
||||
def test_execute_llm():
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
|
|
@ -95,7 +92,7 @@ def test_execute_llm(flask_req_ctx):
|
|||
"title": "123",
|
||||
"type": "llm",
|
||||
"model": {
|
||||
"provider": "langgenius/openai/openai",
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
|
|
@ -114,53 +111,62 @@ def test_execute_llm(flask_req_ctx):
|
|||
},
|
||||
)
|
||||
|
||||
# Create a proper LLM result with real entities
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
db.session.close = MagicMock()
|
||||
|
||||
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
|
||||
# Mock the _fetch_model_config to avoid database calls
|
||||
def mock_fetch_model_config(**_kwargs):
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
# Create a simple mock model instance that doesn't call real providers
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
# Create mock model instance
|
||||
mock_model_instance = MagicMock()
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
mock_message = AssistantPromptMessage(content="Test response from mock")
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create a simple mock model config with required attributes
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "langgenius/openai/openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
# Create mock model config
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.parameters = {}
|
||||
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config_func(_node_data_model):
|
||||
return mock_model_instance, mock_model_config
|
||||
|
||||
# Also mock ModelManager.get_model_instance to avoid database calls
|
||||
def mock_get_model_instance(_self, **kwargs):
|
||||
return mock_model_instance
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_1(**_kwargs):
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
return [
|
||||
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
|
||||
UserPromptMessage(content="what's the weather today?"),
|
||||
], []
|
||||
|
||||
with (
|
||||
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
|
||||
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
|
||||
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
|
||||
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
|
||||
):
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
|
@ -168,6 +174,9 @@ def test_execute_llm(flask_req_ctx):
|
|||
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
if item.run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
print(f"Error: {item.run_result.error}")
|
||||
print(f"Error type: {item.run_result.error_type}")
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert item.run_result.outputs is not None
|
||||
|
|
@ -175,8 +184,7 @@ def test_execute_llm(flask_req_ctx):
|
|||
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
|
||||
def test_execute_llm_with_jinja2():
|
||||
"""
|
||||
Test execute LLM node with jinja2
|
||||
"""
|
||||
|
|
@ -217,53 +225,60 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
|
|||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Create a proper LLM result with real entities
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
|
||||
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
|
||||
# Create a simple mock model instance that doesn't call real providers
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create a simple mock model config with required attributes
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config_func(_node_data_model):
|
||||
def mock_fetch_model_config(**_kwargs):
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
# Create mock model instance
|
||||
mock_model_instance = MagicMock()
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.00003"),
|
||||
completion_tokens=20,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.00004"),
|
||||
total_tokens=50,
|
||||
total_price=Decimal("0.00007"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
|
||||
mock_llm_result = LLMResult(
|
||||
model="gpt-3.5-turbo",
|
||||
prompt_messages=[],
|
||||
message=mock_message,
|
||||
usage=mock_usage,
|
||||
)
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create mock model config
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.parameters = {}
|
||||
|
||||
return mock_model_instance, mock_model_config
|
||||
|
||||
# Also mock ModelManager.get_model_instance to avoid database calls
|
||||
def mock_get_model_instance(_self, **kwargs):
|
||||
return mock_model_instance
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_2(**_kwargs):
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
return [
|
||||
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
|
||||
UserPromptMessage(content="what's the weather today?"),
|
||||
], []
|
||||
|
||||
with (
|
||||
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
|
||||
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
|
||||
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
|
||||
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
|
||||
):
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
|
|
|||
|
|
@ -74,13 +74,15 @@ def init_parameter_extractor_node(config: dict):
|
|||
variable_pool.add(["a", "b123", "args1"], 1)
|
||||
variable_pool.add(["a", "b123", "args2"], 2)
|
||||
|
||||
return ParameterExtractorNode(
|
||||
node = ParameterExtractorNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
|
||||
|
||||
def test_function_calling_parameter_extractor(setup_model_mock):
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ def test_execute_code(setup_code_executor_mock):
|
|||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
|
|
|||
|
|
@ -50,13 +50,15 @@ def init_tool_node(config: dict):
|
|||
conversation_variables=[],
|
||||
)
|
||||
|
||||
return ToolNode(
|
||||
node = ToolNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
|
||||
|
||||
def test_tool_variable_invoke():
|
||||
|
|
|
|||
|
|
@ -58,21 +58,26 @@ def test_execute_answer():
|
|||
pool.add(["start", "weather"], "sunny")
|
||||
pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
|
||||
node_config = {
|
||||
"id": "answer",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
}
|
||||
|
||||
node = AnswerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "answer",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
|
|
|
|||
|
|
@ -57,12 +57,15 @@ def test_http_request_node_binary_file(monkeypatch):
|
|||
),
|
||||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
|
|
@ -90,6 +93,9 @@ def test_http_request_node_binary_file(monkeypatch):
|
|||
start_at=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda *args, **kwargs: b"test",
|
||||
|
|
@ -145,12 +151,15 @@ def test_http_request_node_form_with_file(monkeypatch):
|
|||
),
|
||||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
|
|
@ -178,6 +187,10 @@ def test_http_request_node_form_with_file(monkeypatch):
|
|||
start_at=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda *args, **kwargs: b"test",
|
||||
|
|
@ -257,12 +270,14 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
|
|||
),
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
|
||||
node = HttpRequestNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
|
|
@ -291,6 +306,9 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
|
|||
),
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.http_request.executor.file_manager.download",
|
||||
lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data",
|
||||
|
|
|
|||
|
|
@ -162,25 +162,30 @@ def test_run():
|
|||
)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
|
||||
node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "tt",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "tt",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
iteration_node.init_node_data(node_config["data"])
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
|
@ -379,25 +384,30 @@ def test_run_parallel():
|
|||
)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
|
||||
node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
iteration_node.init_node_data(node_config["data"])
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
|
@ -595,45 +605,55 @@ def test_iteration_run_in_parallel_mode():
|
|||
)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
|
||||
parallel_node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
parallel_iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
config=parallel_node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
parallel_iteration_node.init_node_data(parallel_node_config["data"])
|
||||
sequential_node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
sequential_iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
config=sequential_node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
sequential_iteration_node.init_node_data(sequential_node_config["data"])
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
|
@ -645,8 +665,8 @@ def test_iteration_run_in_parallel_mode():
|
|||
# execute node
|
||||
parallel_result = parallel_iteration_node._run()
|
||||
sequential_result = sequential_iteration_node._run()
|
||||
assert parallel_iteration_node.node_data.parallel_nums == 10
|
||||
assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED
|
||||
assert parallel_iteration_node._node_data.parallel_nums == 10
|
||||
assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED
|
||||
count = 0
|
||||
parallel_arr = []
|
||||
sequential_arr = []
|
||||
|
|
@ -818,26 +838,31 @@ def test_iteration_run_error_handle():
|
|||
environment_variables=[],
|
||||
)
|
||||
pool.add(["pe", "list_output"], ["1", "1"])
|
||||
error_node_config = {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
config=error_node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
iteration_node.init_node_data(error_node_config["data"])
|
||||
# execute continue on error node
|
||||
result = iteration_node._run()
|
||||
result_arr = []
|
||||
|
|
@ -851,7 +876,7 @@ def test_iteration_run_error_handle():
|
|||
|
||||
assert count == 14
|
||||
# execute remove abnormal output
|
||||
iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
|
||||
iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
|
||||
result = iteration_node._run()
|
||||
count = 0
|
||||
for item in result:
|
||||
|
|
|
|||
|
|
@ -119,17 +119,20 @@ def llm_node(
|
|||
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
|
||||
) -> LLMNode:
|
||||
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
}
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
|
|
@ -488,7 +491,7 @@ def test_handle_list_messages_basic(llm_node):
|
|||
variable_pool = llm_node.graph_runtime_state.variable_pool
|
||||
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
result = llm_node._handle_list_messages(
|
||||
result = llm_node.handle_list_messages(
|
||||
messages=messages,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
|
|
@ -506,17 +509,20 @@ def llm_node_for_multimodal(
|
|||
llm_node_data, graph_init_params, graph, graph_runtime_state
|
||||
) -> tuple[LLMNode, LLMFileSaver]:
|
||||
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
}
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node, mock_file_saver
|
||||
|
||||
|
||||
|
|
@ -540,7 +546,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
|
|||
size=9,
|
||||
)
|
||||
mock_file_saver.save_binary_string.return_value = mock_file
|
||||
file = llm_node._save_multimodal_image_output(content=content)
|
||||
file = llm_node.save_multimodal_image_output(
|
||||
content=content,
|
||||
file_saver=mock_file_saver,
|
||||
)
|
||||
# Manually append to _file_outputs since the static method doesn't do it
|
||||
llm_node._file_outputs.append(file)
|
||||
assert llm_node._file_outputs == [mock_file]
|
||||
assert file == mock_file
|
||||
mock_file_saver.save_binary_string.assert_called_once_with(
|
||||
|
|
@ -566,7 +577,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
|
|||
size=9,
|
||||
)
|
||||
mock_file_saver.save_remote_url.return_value = mock_file
|
||||
file = llm_node._save_multimodal_image_output(content=content)
|
||||
file = llm_node.save_multimodal_image_output(
|
||||
content=content,
|
||||
file_saver=mock_file_saver,
|
||||
)
|
||||
# Manually append to _file_outputs since the static method doesn't do it
|
||||
llm_node._file_outputs.append(file)
|
||||
assert llm_node._file_outputs == [mock_file]
|
||||
assert file == mock_file
|
||||
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
|
||||
|
|
@ -582,7 +598,9 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
|
|||
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
|
||||
def test_str_content(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
contents="hello world", file_saver=mock_file_saver, file_outputs=[]
|
||||
)
|
||||
assert list(gen) == ["hello world"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
|
@ -590,7 +608,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
|
|||
def test_text_prompt_message_content(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
[TextPromptMessageContent(data="hello world")]
|
||||
contents=[TextPromptMessageContent(data="hello world")], file_saver=mock_file_saver, file_outputs=[]
|
||||
)
|
||||
assert list(gen) == ["hello world"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
|
|
@ -616,13 +634,15 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
|
|||
)
|
||||
mock_file_saver.save_binary_string.return_value = mock_saved_file
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
[
|
||||
contents=[
|
||||
ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data=image_b64_data,
|
||||
mime_type="image/png",
|
||||
)
|
||||
]
|
||||
],
|
||||
file_saver=mock_file_saver,
|
||||
file_outputs=llm_node._file_outputs,
|
||||
)
|
||||
yielded_strs = list(gen)
|
||||
assert len(yielded_strs) == 1
|
||||
|
|
@ -645,21 +665,27 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
|
|||
|
||||
def test_unknown_content_type(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[]
|
||||
)
|
||||
assert list(gen) == ["frozenset({'hello world'})"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
def test_unknown_item_type(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
contents=[frozenset(["hello world"])], file_saver=mock_file_saver, file_outputs=[]
|
||||
)
|
||||
assert list(gen) == ["frozenset({'hello world'})"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
def test_none_content(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
contents=None, file_saver=mock_file_saver, file_outputs=[]
|
||||
)
|
||||
assert list(gen) == []
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -61,21 +61,26 @@ def test_execute_answer():
|
|||
variable_pool.add(["start", "weather"], "sunny")
|
||||
variable_pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
|
||||
node_config = {
|
||||
"id": "answer",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
}
|
||||
|
||||
node = AnswerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "answer",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
|
|
|
|||
|
|
@ -27,13 +27,17 @@ def document_extractor_node():
|
|||
title="Test Document Extractor",
|
||||
variable_selector=["node_id", "variable_name"],
|
||||
)
|
||||
return DocumentExtractorNode(
|
||||
node_config = {"id": "test_node_id", "data": node_data.model_dump()}
|
||||
node = DocumentExtractorNode(
|
||||
id="test_node_id",
|
||||
config={"id": "test_node_id", "data": node_data.model_dump()},
|
||||
config=node_config,
|
||||
graph_init_params=Mock(),
|
||||
graph=Mock(),
|
||||
graph_runtime_state=Mock(),
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -57,57 +57,62 @@ def test_execute_if_else_result_true():
|
|||
pool.add(["start", "null"], None)
|
||||
pool.add(["start", "not_null"], "1212")
|
||||
|
||||
node_config = {
|
||||
"id": "if-else",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "if-else",
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "contains",
|
||||
"variable_selector": ["start", "array_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "array_not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
|
||||
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
|
||||
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
|
||||
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
|
||||
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
|
||||
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
|
||||
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
|
||||
{"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"},
|
||||
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
|
||||
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
|
||||
{
|
||||
"comparison_operator": "≥",
|
||||
"variable_selector": ["start", "greater_than_or_equal"],
|
||||
"value": "22",
|
||||
},
|
||||
{"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
|
||||
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
|
||||
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "if-else",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "if-else",
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "contains",
|
||||
"variable_selector": ["start", "array_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "array_not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
|
||||
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
|
||||
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
|
||||
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
|
||||
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
|
||||
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
|
||||
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
|
||||
{"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"},
|
||||
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
|
||||
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
|
||||
{
|
||||
"comparison_operator": "≥",
|
||||
"variable_selector": ["start", "greater_than_or_equal"],
|
||||
"value": "22",
|
||||
},
|
||||
{"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
|
||||
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
|
||||
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
|
||||
],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
|
|
@ -162,33 +167,38 @@ def test_execute_if_else_result_false():
|
|||
pool.add(["start", "array_contains"], ["1ab", "def"])
|
||||
pool.add(["start", "array_not_contains"], ["ab", "def"])
|
||||
|
||||
node_config = {
|
||||
"id": "if-else",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "if-else",
|
||||
"logical_operator": "or",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "contains",
|
||||
"variable_selector": ["start", "array_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "array_not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "if-else",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "if-else",
|
||||
"logical_operator": "or",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "contains",
|
||||
"variable_selector": ["start", "array_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "array_not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
|
|
@ -228,17 +238,22 @@ def test_array_file_contains_file_name():
|
|||
],
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "if-else",
|
||||
"data": node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=Mock(),
|
||||
graph=Mock(),
|
||||
graph_runtime_state=Mock(),
|
||||
config={
|
||||
"id": "if-else",
|
||||
"data": node_data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
|
||||
value=[
|
||||
File(
|
||||
|
|
|
|||
|
|
@ -33,16 +33,19 @@ def list_operator_node():
|
|||
"title": "Test Title",
|
||||
}
|
||||
node_data = ListOperatorNodeData(**config)
|
||||
node_config = {
|
||||
"id": "test_node_id",
|
||||
"data": node_data.model_dump(),
|
||||
}
|
||||
node = ListOperatorNode(
|
||||
id="test_node_id",
|
||||
config={
|
||||
"id": "test_node_id",
|
||||
"data": node_data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=MagicMock(),
|
||||
graph=MagicMock(),
|
||||
graph_runtime_state=MagicMock(),
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.variable_pool = MagicMock()
|
||||
return node
|
||||
|
|
|
|||
|
|
@ -38,12 +38,13 @@ def _create_tool_node():
|
|||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
}
|
||||
node = ToolNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
|
|
@ -71,6 +72,8 @@ def _create_tool_node():
|
|||
start_at=0,
|
||||
),
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -82,23 +82,28 @@ def test_overwrite_string_variable():
|
|||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.OVER_WRITE.value,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.OVER_WRITE.value,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
list(node.run())
|
||||
expected_var = StringVariable(
|
||||
id=conversation_variable.id,
|
||||
|
|
@ -178,23 +183,28 @@ def test_append_variable_to_array():
|
|||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.APPEND.value,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.APPEND.value,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
list(node.run())
|
||||
expected_value = list(conversation_variable.value)
|
||||
expected_value.append(input_variable.value)
|
||||
|
|
@ -265,23 +275,28 @@ def test_clear_array():
|
|||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.CLEAR.value,
|
||||
"input_variable_selector": [],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.CLEAR.value,
|
||||
"input_variable_selector": [],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
list(node.run())
|
||||
expected_var = ArrayStringVariable(
|
||||
id=conversation_variable.id,
|
||||
|
|
|
|||
|
|
@ -115,28 +115,33 @@ def test_remove_first_from_array():
|
|||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_FIRST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_FIRST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
# Print the variable before running
|
||||
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
|
||||
|
|
@ -202,28 +207,33 @@ def test_remove_last_from_array():
|
|||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_LAST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_LAST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
|
|
@ -281,28 +291,33 @@ def test_remove_first_from_empty_array():
|
|||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_FIRST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_FIRST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
|
|
@ -360,28 +375,33 @@ def test_remove_last_from_empty_array():
|
|||
conversation_variables=[conversation_variable],
|
||||
)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_LAST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
node = VariableAssignerNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"title": "test",
|
||||
"version": "2",
|
||||
"items": [
|
||||
{
|
||||
"variable_selector": ["conversation", conversation_variable.name],
|
||||
"input_type": InputType.VARIABLE,
|
||||
"operation": Operation.REMOVE_LAST,
|
||||
"value": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
|
|
|
|||
|
|
@ -80,15 +80,12 @@ def real_workflow_system_variables():
|
|||
@pytest.fixture
|
||||
def mock_node_execution_repository():
|
||||
repo = MagicMock(spec=WorkflowNodeExecutionRepository)
|
||||
repo.get_by_node_execution_id.return_value = None
|
||||
repo.get_running_executions.return_value = []
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_execution_repository():
|
||||
repo = MagicMock(spec=WorkflowExecutionRepository)
|
||||
repo.get.return_value = None
|
||||
return repo
|
||||
|
||||
|
||||
|
|
@ -217,8 +214,8 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
|
|||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
|
||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_run_success(
|
||||
|
|
@ -251,11 +248,10 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
|
|||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
|
||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||
|
||||
# Mock get_running_executions to return an empty list
|
||||
workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = []
|
||||
# No running node executions in cache (empty cache)
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_run_failed(
|
||||
|
|
@ -289,8 +285,8 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
|
|||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
|
||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||
|
||||
# Create a mock event
|
||||
event = MagicMock(spec=QueueNodeStartedEvent)
|
||||
|
|
@ -342,8 +338,8 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
|
|||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock the repository get method to return the real execution
|
||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache["test-workflow-run-id"] = workflow_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id")
|
||||
|
|
@ -351,11 +347,13 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
|
|||
# Verify the result
|
||||
assert result == workflow_execution
|
||||
|
||||
# Test error case
|
||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = None
|
||||
# Test error case - clear cache
|
||||
workflow_cycle_manager._workflow_execution_cache.clear()
|
||||
|
||||
# Expect an error when execution is not found
|
||||
with pytest.raises(ValueError):
|
||||
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
|
||||
|
||||
with pytest.raises(WorkflowRunNotFoundError):
|
||||
workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id")
|
||||
|
||||
|
||||
|
|
@ -384,8 +382,8 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
|
|||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock the repository to return the node execution
|
||||
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
|
||||
# Pre-populate the cache with the node execution
|
||||
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_node_execution_success(
|
||||
|
|
@ -414,8 +412,8 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
|
|||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
|
||||
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
|
||||
# Pre-populate the cache with the workflow execution
|
||||
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_run_partial_success(
|
||||
|
|
@ -462,8 +460,8 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
|
|||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
# Mock the repository to return the node execution
|
||||
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
|
||||
# Pre-populate the cache with the node execution
|
||||
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
|
||||
|
||||
# Call the method
|
||||
result = workflow_cycle_manager.handle_workflow_node_execution_failed(
|
||||
|
|
|
|||
|
|
@ -137,37 +137,6 @@ def test_save_with_existing_tenant_id(repository, session):
|
|||
session_obj.merge.assert_called_once_with(modified_execution)
|
||||
|
||||
|
||||
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
|
||||
"""Test get_by_node_execution_id method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
|
||||
# Create a properly configured mock execution
|
||||
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
configure_mock_execution(mock_execution)
|
||||
session_obj.scalar.return_value = mock_execution
|
||||
|
||||
# Create a mock domain model to be returned by _to_domain_model
|
||||
mock_domain_model = mocker.MagicMock()
|
||||
# Mock the _to_domain_model method to return our mock domain model
|
||||
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
|
||||
|
||||
# Call method
|
||||
result = repository.get_by_node_execution_id("test-node-execution-id")
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalar.assert_called_once_with(mock_stmt)
|
||||
# Assert _to_domain_model was called with the mock execution
|
||||
repository._to_domain_model.assert_called_once_with(mock_execution)
|
||||
# Assert the result is our mock domain model
|
||||
assert result is mock_domain_model
|
||||
|
||||
|
||||
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||
"""Test get_by_workflow_run method."""
|
||||
session_obj, _ = session
|
||||
|
|
@ -202,88 +171,6 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
|||
assert result[0] is mock_domain_model
|
||||
|
||||
|
||||
def test_get_running_executions(repository, session, mocker: MockerFixture):
|
||||
"""Test get_running_executions method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
|
||||
# Create a properly configured mock execution
|
||||
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
configure_mock_execution(mock_execution)
|
||||
session_obj.scalars.return_value.all.return_value = [mock_execution]
|
||||
|
||||
# Create a mock domain model to be returned by _to_domain_model
|
||||
mock_domain_model = mocker.MagicMock()
|
||||
# Mock the _to_domain_model method to return our mock domain model
|
||||
repository._to_domain_model = mocker.MagicMock(return_value=mock_domain_model)
|
||||
|
||||
# Call method
|
||||
result = repository.get_running_executions("test-workflow-run-id")
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||
# Assert _to_domain_model was called with the mock execution
|
||||
repository._to_domain_model.assert_called_once_with(mock_execution)
|
||||
# Assert the result contains our mock domain model
|
||||
assert len(result) == 1
|
||||
assert result[0] is mock_domain_model
|
||||
|
||||
|
||||
def test_update_via_save(repository, session):
|
||||
"""Test updating an existing record via save method."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution
|
||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
execution.tenant_id = None
|
||||
execution.app_id = None
|
||||
execution.inputs = None
|
||||
execution.process_data = None
|
||||
execution.outputs = None
|
||||
execution.metadata = None
|
||||
|
||||
# Mock the to_db_model method to return the execution itself
|
||||
# This simulates the behavior of setting tenant_id and app_id
|
||||
repository.to_db_model = MagicMock(return_value=execution)
|
||||
|
||||
# Call save method to update an existing record
|
||||
repository.save(execution)
|
||||
|
||||
# Assert to_db_model was called with the execution
|
||||
repository.to_db_model.assert_called_once_with(execution)
|
||||
|
||||
# Assert session.merge was called (for updates)
|
||||
session_obj.merge.assert_called_once_with(execution)
|
||||
|
||||
|
||||
def test_clear(repository, session, mocker: MockerFixture):
|
||||
"""Test clear method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_delete = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.delete")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_delete.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
|
||||
# Mock the execute result with rowcount
|
||||
mock_result = mocker.MagicMock()
|
||||
mock_result.rowcount = 5 # Simulate 5 records deleted
|
||||
session_obj.execute.return_value = mock_result
|
||||
|
||||
# Call method
|
||||
repository.clear()
|
||||
|
||||
# Assert delete was called with correct parameters
|
||||
mock_delete.assert_called_once_with(WorkflowNodeExecutionModel)
|
||||
mock_stmt.where.assert_called()
|
||||
session_obj.execute.assert_called_once_with(mock_stmt)
|
||||
session_obj.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_to_db_model(repository):
|
||||
"""Test to_db_model method."""
|
||||
# Create a domain model
|
||||
|
|
|
|||
|
|
@ -0,0 +1,301 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.api_entities import ToolApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class TestToolTransformService:
|
||||
"""Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
|
||||
|
||||
def test_convert_tool_with_parameter_override(self):
|
||||
"""Test that runtime parameters correctly override base parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
# Create mock runtime parameters that override base parameters
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1" # Different label to verify override
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_tool"
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Find the overridden parameter
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1" # Should be runtime version
|
||||
|
||||
# Find the non-overridden parameter
|
||||
original_param = next((p for p in result.parameters if p.name == "param2"), None)
|
||||
assert original_param is not None
|
||||
assert original_param.label == "Base Param 2" # Should be base version
|
||||
|
||||
def test_convert_tool_with_additional_runtime_parameters(self):
|
||||
"""Test that additional runtime parameters are added to the final list"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters - one that overrides and one that's new
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "runtime_only"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Only Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Check that both parameters are present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "runtime_only" in param_names
|
||||
|
||||
# Verify the overridden parameter has runtime version
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1"
|
||||
|
||||
# Verify the new runtime parameter is included
|
||||
new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
|
||||
assert new_param is not None
|
||||
assert new_param.label == "Runtime Only Param"
|
||||
|
||||
def test_convert_tool_with_non_form_runtime_parameters(self):
|
||||
"""Test that non-FORM runtime parameters are not added as new parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters with different forms
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "llm_param"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.LLM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "LLM Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 1 # Only the FORM parameter should be present
|
||||
|
||||
# Check that only the FORM parameter is present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "llm_param" not in param_names
|
||||
|
||||
def test_convert_tool_with_empty_parameters(self):
|
||||
"""Test conversion with empty base and runtime parameters"""
|
||||
# Create mock tool with no parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = []
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_with_none_parameters(self):
|
||||
"""Test conversion when base parameters is None"""
|
||||
# Create mock tool with None parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = None
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_parameter_order_preserved(self):
|
||||
"""Test that parameter order is preserved correctly"""
|
||||
# Create mock base parameters in specific order
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
base_param3 = Mock(spec=ToolParameter)
|
||||
base_param3.name = "param3"
|
||||
base_param3.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param3.type = "string"
|
||||
base_param3.label = "Base Param 3"
|
||||
|
||||
# Create runtime parameter that overrides middle parameter
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "param2"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Param 2"
|
||||
|
||||
# Create new runtime parameter
|
||||
runtime_param4 = Mock(spec=ToolParameter)
|
||||
runtime_param4.name = "param4"
|
||||
runtime_param4.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param4.type = "string"
|
||||
runtime_param4.label = "Runtime Param 4"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 4
|
||||
|
||||
# Check that order is maintained: base parameters first, then new runtime parameters
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert param_names == ["param1", "param2", "param3", "param4"]
|
||||
|
||||
# Verify that param2 was overridden with runtime version
|
||||
param2 = result.parameters[1]
|
||||
assert param2.name == "param2"
|
||||
assert param2.label == "Runtime Param 2"
|
||||
|
|
@ -289,6 +289,7 @@ REDIS_CLUSTERS_PASSWORD=
|
|||
# If use Redis Sentinel, format as follows: `sentinel://<sentinel_username>:<sentinel_password>@<sentinel_host>:<sentinel_port>/<redis_database>`
|
||||
# Example: sentinel://localhost:26379/1;sentinel://localhost:26380/1;sentinel://localhost:26381/1
|
||||
CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1
|
||||
CELERY_BACKEND=redis
|
||||
BROKER_USE_SSL=false
|
||||
|
||||
# If you are using Redis Sentinel for high availability, configure the following settings.
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ x-shared-env: &shared-api-worker-env
|
|||
REDIS_CLUSTERS: ${REDIS_CLUSTERS:-}
|
||||
REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-}
|
||||
CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1}
|
||||
CELERY_BACKEND: ${CELERY_BACKEND:-redis}
|
||||
BROKER_USE_SSL: ${BROKER_USE_SSL:-false}
|
||||
CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false}
|
||||
CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
'use client'
|
||||
|
||||
import WorkflowApp from '@/app/components/workflow-app'
|
||||
|
||||
const Page = () => {
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M8 4V8M8 8V12M8 8H12M8 8H4" stroke="#6B7280" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 206 B |
|
|
@ -1,4 +0,0 @@
|
|||
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M0.631586 8.25C0.631586 6.46656 2.04586 5 3.8158 5C5.58573 5 7.00001 6.46656 7.00001 8.25C7.00001 10.0334 5.58573 11.5 3.8158 11.5C3.45197 11.5 3.10149 11.4375 2.77474 11.3222C2.72073 11.3031 2.68723 11.2913 2.66266 11.2832C2.65821 11.2817 2.65456 11.2806 2.65164 11.2796L2.64892 11.2799C2.63177 11.2818 2.60839 11.285 2.56507 11.2909L1.06766 11.4954C0.905637 11.5175 0.743029 11.459 0.632239 11.3387C0.521449 11.2185 0.476481 11.0516 0.511825 10.8919L0.817497 9.51109C0.828118 9.46311 0.833802 9.43722 0.837453 9.41817C0.83766 9.4171 0.838022 9.41517 0.838022 9.41517C0.837114 9.412 0.835963 9.40808 0.834525 9.40332C0.826292 9.37605 0.814183 9.33888 0.794499 9.27863C0.688657 8.95463 0.631586 8.60857 0.631586 8.25Z" fill="#98A2B3"/>
|
||||
<path d="M2.57377 4.1863C2.96256 4.06535 3.37698 4 3.80894 4C6.16566 4 8.00006 5.94534 8.00006 8.24999C8.00006 8.65682 7.9429 9.05245 7.8358 9.42816C8.10681 9.37948 8.36964 9.30678 8.6219 9.21229C8.65748 9.19897 8.69298 9.18534 8.72893 9.17304C8.75795 9.17641 8.78684 9.18093 8.81574 9.18517L10.4222 9.42065C10.498 9.43179 10.5841 9.44444 10.6591 9.4487C10.7422 9.45343 10.8713 9.45292 11.0081 9.39408C11.1789 9.32061 11.3164 9.18628 11.3938 9.01716C11.4558 8.88174 11.4593 8.75269 11.4564 8.66955C11.4539 8.59442 11.4433 8.5081 11.4339 8.43202L11.2309 6.78307C11.2256 6.7402 11.2229 6.71768 11.2213 6.70118C11.23 6.66505 11.2466 6.6301 11.2598 6.59546C11.4492 6.09896 11.5526 5.56093 11.5526 5C11.5526 2.51163 9.52304 0.5 7.02632 0.5C4.80843 0.5 2.95915 2.08742 2.57377 4.1863Z" fill="#98A2B3"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.6 KiB |
|
|
@ -1,3 +0,0 @@
|
|||
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M14.1667 6.66634H15.8333C16.2754 6.66634 16.6993 6.84194 17.0118 7.1545C17.3244 7.46706 17.5 7.89098 17.5 8.33301V13.333C17.5 13.775 17.3244 14.199 17.0118 14.5115C16.6993 14.8241 16.2754 14.9997 15.8333 14.9997H14.1667V18.333L10.8333 14.9997H7.5C7.28111 14.9999 7.06433 14.9569 6.86211 14.8731C6.6599 14.7893 6.47623 14.6663 6.32167 14.5113M6.32167 14.5113L9.16667 11.6663H12.5C12.942 11.6663 13.366 11.4907 13.6785 11.1782C13.9911 10.8656 14.1667 10.4417 14.1667 9.99967V4.99967C14.1667 4.55765 13.9911 4.13372 13.6785 3.82116C13.366 3.5086 12.942 3.33301 12.5 3.33301H4.16667C3.72464 3.33301 3.30072 3.5086 2.98816 3.82116C2.67559 4.13372 2.5 4.55765 2.5 4.99967V9.99967C2.5 10.4417 2.67559 10.8656 2.98816 11.1782C3.30072 11.4907 3.72464 11.6663 4.16667 11.6663H5.83333V14.9997L6.32167 14.5113Z" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1002 B |
|
|
@ -1,4 +0,0 @@
|
|||
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M6.5 1.00779C6.5 0.994638 6.5 0.988062 6.49943 0.976137C6.48764 0.729248 6.27052 0.51224 6.02363 0.50056C6.01171 0.499996 6.0078 0.499998 6.00001 0.5H4.37933C3.97686 0.499995 3.64468 0.49999 3.37409 0.522098C3.09304 0.545061 2.83469 0.594343 2.59202 0.717989C2.2157 0.909735 1.90973 1.2157 1.71799 1.59202C1.59434 1.83469 1.54506 2.09304 1.5221 2.37409C1.49999 2.64468 1.49999 2.97686 1.5 3.37934V8.62066C1.49999 9.02313 1.49999 9.35532 1.5221 9.62591C1.54506 9.90696 1.59434 10.1653 1.71799 10.408C1.90973 10.7843 2.2157 11.0903 2.59202 11.282C2.83469 11.4057 3.09304 11.4549 3.37409 11.4779C3.64468 11.5 3.97686 11.5 4.37934 11.5H7.62066C8.02314 11.5 8.35532 11.5 8.62591 11.4779C8.90696 11.4549 9.16531 11.4057 9.40798 11.282C9.78431 11.0903 10.0903 10.7843 10.282 10.408C10.4057 10.1653 10.4549 9.90696 10.4779 9.62591C10.5 9.35532 10.5 9.02314 10.5 8.62066V4.99997C10.5 4.9922 10.5 4.98832 10.4994 4.97641C10.4878 4.72949 10.2707 4.51236 10.0238 4.50057C10.0119 4.50001 10.0054 4.50001 9.99225 4.50001L7.78404 4.50001C7.65786 4.50002 7.53496 4.50004 7.43089 4.49153C7.31659 4.48219 7.18172 4.46016 7.04601 4.39101C6.85785 4.29514 6.70487 4.14216 6.609 3.954C6.53985 3.81828 6.51781 3.68342 6.50848 3.56912C6.49997 3.46504 6.49999 3.34215 6.5 3.21596L6.5 1.00779ZM4 6.5C3.72386 6.5 3.5 6.72386 3.5 7C3.5 7.27614 3.72386 7.5 4 7.5H8C8.27614 7.5 8.5 7.27614 8.5 7C8.5 6.72386 8.27614 6.5 8 6.5H4ZM4 8.5C3.72386 8.5 3.5 8.72386 3.5 9C3.5 9.27614 3.72386 9.5 4 9.5H7C7.27614 9.5 7.5 9.27614 7.5 9C7.5 8.72386 7.27614 8.5 7 8.5H4Z" fill="#98A2B3"/>
|
||||
<path d="M9.45398 3.5C9.60079 3.5 9.67419 3.5 9.73432 3.46314C9.81925 3.41107 9.87002 3.28842 9.84674 3.19157C9.83025 3.12299 9.78238 3.07516 9.68665 2.97952L8.02049 1.31336C7.92484 1.21762 7.87701 1.16975 7.80843 1.15326C7.71158 1.12998 7.58893 1.18075 7.53687 1.26567C7.5 1.3258 7.5 1.39921 7.5 1.54602L7.5 3.09998C7.5 3.23999 7.5 3.30999 7.52725 3.36347C7.55122 3.41051 7.58946 3.44876 7.6365 3.47272C7.68998 3.49997 7.75998 3.49997 7.9 3.49998L9.45398 3.5Z" fill="#98A2B3"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 2.1 KiB |
|
|
@ -1,3 +0,0 @@
|
|||
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M16.25 11.875V9.6875C16.25 8.1342 14.9908 6.875 13.4375 6.875H12.1875C11.6697 6.875 11.25 6.45527 11.25 5.9375V4.6875C11.25 3.1342 9.9908 1.875 8.4375 1.875H6.875M6.875 12.5H13.125M6.875 15H10M8.75 1.875H4.6875C4.16973 1.875 3.75 2.29473 3.75 2.8125V17.1875C3.75 17.7053 4.16973 18.125 4.6875 18.125H15.3125C15.8303 18.125 16.25 17.7053 16.25 17.1875V9.375C16.25 5.23286 12.8921 1.875 8.75 1.875Z" stroke="#1F2A37" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 595 B |
|
|
@ -1,3 +0,0 @@
|
|||
<svg width="26" height="26" viewBox="0 0 26 26" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M22.0101 4.50191C20.3529 3.74154 18.5759 3.18133 16.7179 2.86048C16.6841 2.85428 16.6503 2.86976 16.6328 2.90071C16.4043 3.30719 16.1511 3.83748 15.9738 4.25429C13.9754 3.95511 11.9873 3.95511 10.0298 4.25429C9.85253 3.82822 9.59019 3.30719 9.36062 2.90071C9.34319 2.87079 9.30939 2.85532 9.27555 2.86048C7.41857 3.18031 5.64152 3.74051 3.98335 4.50191C3.96899 4.5081 3.95669 4.51843 3.94852 4.53183C0.577841 9.56755 -0.345529 14.4795 0.107445 19.3306C0.109495 19.3543 0.122817 19.377 0.141265 19.3914C2.36514 21.0246 4.51935 22.0161 6.63355 22.6732C6.66739 22.6836 6.70324 22.6712 6.72477 22.6433C7.22489 21.9604 7.6707 21.2402 8.05293 20.4829C8.07549 20.4386 8.05396 20.386 8.00785 20.3684C7.30073 20.1002 6.6274 19.7731 5.97971 19.4017C5.92848 19.3718 5.92437 19.2985 5.9715 19.2635C6.1078 19.1613 6.24414 19.0551 6.37428 18.9478C6.39783 18.9282 6.43064 18.924 6.45833 18.9364C10.7134 20.8791 15.32 20.8791 19.5249 18.9364C19.5525 18.923 19.5854 18.9272 19.6099 18.9467C19.7401 19.054 19.8764 19.1613 20.0137 19.2635C20.0609 19.2985 20.0578 19.3718 20.0066 19.4017C19.3589 19.7804 18.6855 20.1002 17.9774 20.3674C17.9313 20.3849 17.9108 20.4386 17.9333 20.4829C18.3238 21.2392 18.7696 21.9593 19.2605 22.6423C19.281 22.6712 19.3179 22.6836 19.3517 22.6732C21.4761 22.0161 23.6303 21.0246 25.8542 19.3914C25.8737 19.377 25.886 19.3553 25.8881 19.3316C26.4302 13.7232 24.98 8.85156 22.0439 4.53286C22.0367 4.51843 22.0245 4.5081 22.0101 4.50191ZM8.68836 16.3768C7.40729 16.3768 6.35173 15.2007 6.35173 13.7563C6.35173 12.3119 7.38682 11.1358 8.68836 11.1358C10.0001 11.1358 11.0455 12.3222 11.025 13.7563C11.025 15.2007 9.98986 16.3768 8.68836 16.3768ZM17.3276 16.3768C16.0466 16.3768 14.991 15.2007 14.991 13.7563C14.991 12.3119 16.0261 11.1358 17.3276 11.1358C18.6394 11.1358 19.6847 12.3222 19.6643 13.7563C19.6643 15.2007 18.6394 16.3768 17.3276 16.3768Z" fill="#5865F2"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.9 KiB |
|
|
@ -1,17 +0,0 @@
|
|||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g clip-path="url(#clip0_131_1011)">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M12.0003 0.5C9.15149 0.501478 6.39613 1.51046 4.22687 3.34652C2.05761 5.18259 0.615903 7.72601 0.159545 10.522C-0.296814 13.318 0.261927 16.1842 1.73587 18.6082C3.20981 21.0321 5.50284 22.8558 8.20493 23.753C8.80105 23.8636 9.0256 23.4941 9.0256 23.18C9.0256 22.8658 9.01367 21.955 9.0097 20.9592C5.6714 21.6804 4.96599 19.5505 4.96599 19.5505C4.42152 18.1674 3.63464 17.8039 3.63464 17.8039C2.54571 17.065 3.71611 17.0788 3.71611 17.0788C4.92227 17.1637 5.55616 18.3097 5.55616 18.3097C6.62521 20.1333 8.36389 19.6058 9.04745 19.2976C9.15475 18.5251 9.46673 17.9995 9.8105 17.7012C7.14383 17.4008 4.34204 16.3774 4.34204 11.8054C4.32551 10.6197 4.76802 9.47305 5.57801 8.60268C5.45481 8.30236 5.04348 7.08923 5.69524 5.44143C5.69524 5.44143 6.7027 5.12135 8.9958 6.66444C10.9627 6.12962 13.0379 6.12962 15.0047 6.66444C17.2958 5.12135 18.3013 5.44143 18.3013 5.44143C18.9551 7.08528 18.5437 8.29841 18.4205 8.60268C19.2331 9.47319 19.6765 10.6218 19.6585 11.8094C19.6585 16.3912 16.8507 17.4008 14.1801 17.6952C14.6093 18.0667 14.9928 18.7918 14.9928 19.9061C14.9928 21.5026 14.9789 22.7868 14.9789 23.18C14.9789 23.4981 15.1955 23.8695 15.8035 23.753C18.5059 22.8557 20.7992 21.0317 22.2731 18.6073C23.747 16.183 24.3055 13.3163 23.8486 10.5201C23.3917 7.7238 21.9493 5.18035 19.7793 3.34461C17.6093 1.50886 14.8533 0.500541 12.0042 0.5H12.0003Z" fill="#191717"/>
|
||||
<path d="M4.54444 17.6321C4.5186 17.6914 4.42322 17.7092 4.34573 17.6677C4.26823 17.6262 4.21061 17.5491 4.23843 17.4879C4.26625 17.4266 4.35964 17.4108 4.43714 17.4523C4.51463 17.4938 4.57424 17.5729 4.54444 17.6321Z" fill="#191717"/>
|
||||
<path d="M5.03123 18.1714C4.99008 18.192 4.943 18.1978 4.89805 18.1877C4.8531 18.1776 4.81308 18.1523 4.78483 18.1161C4.70734 18.0331 4.69143 17.9185 4.75104 17.8671C4.81066 17.8157 4.91797 17.8395 4.99546 17.9224C5.07296 18.0054 5.09084 18.12 5.03123 18.1714Z" fill="#191717"/>
|
||||
<path d="M5.50425 18.857C5.43072 18.9084 5.30553 18.857 5.23598 18.7543C5.21675 18.7359 5.20146 18.7138 5.19101 18.6893C5.18056 18.6649 5.17517 18.6386 5.17517 18.612C5.17517 18.5855 5.18056 18.5592 5.19101 18.5347C5.20146 18.5103 5.21675 18.4882 5.23598 18.4698C5.3095 18.4204 5.4347 18.4698 5.50425 18.5705C5.57379 18.6713 5.57578 18.8057 5.50425 18.857V18.857Z" fill="#191717"/>
|
||||
<path d="M6.14612 19.5207C6.08054 19.5939 5.94741 19.5741 5.83812 19.4753C5.72883 19.3765 5.70299 19.2422 5.76857 19.171C5.83414 19.0999 5.96727 19.1197 6.08054 19.2165C6.1938 19.3133 6.21566 19.4496 6.14612 19.5207V19.5207Z" fill="#191717"/>
|
||||
<path d="M7.04617 19.9081C7.01637 20.001 6.88124 20.0425 6.74612 20.003C6.611 19.9635 6.52158 19.8528 6.54741 19.758C6.57325 19.6631 6.71036 19.6197 6.84747 19.6631C6.98457 19.7066 7.07201 19.8113 7.04617 19.9081Z" fill="#191717"/>
|
||||
<path d="M8.02783 19.9752C8.02783 20.072 7.91656 20.155 7.77349 20.1569C7.63042 20.1589 7.51318 20.0799 7.51318 19.9831C7.51318 19.8863 7.62445 19.8033 7.76752 19.8013C7.91059 19.7993 8.02783 19.8764 8.02783 19.9752Z" fill="#191717"/>
|
||||
<path d="M8.9419 19.8232C8.95978 19.92 8.86042 20.0207 8.71735 20.0445C8.57428 20.0682 8.4491 20.0109 8.43121 19.916C8.41333 19.8212 8.51666 19.7185 8.65576 19.6928C8.79485 19.6671 8.92401 19.7264 8.9419 19.8232Z" fill="#191717"/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_131_1011">
|
||||
<rect width="24" height="24" fill="white"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 3.4 KiB |
|
|
@ -1,3 +0,0 @@
|
|||
<svg width="13" height="14" viewBox="0 0 13 14" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M5.41663 3.75033H3.24996C2.96264 3.75033 2.68709 3.86446 2.48393 4.06763C2.28076 4.27079 2.16663 4.54634 2.16663 4.83366V10.2503C2.16663 10.5376 2.28076 10.8132 2.48393 11.0164C2.68709 11.2195 2.96264 11.3337 3.24996 11.3337H8.66663C8.95394 11.3337 9.22949 11.2195 9.43266 11.0164C9.63582 10.8132 9.74996 10.5376 9.74996 10.2503V8.08366M7.58329 2.66699H10.8333M10.8333 2.66699V5.91699M10.8333 2.66699L5.41663 8.08366" stroke="#9CA3AF" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 596 B |
|
|
@ -1,3 +0,0 @@
|
|||
<svg width="13" height="14" viewBox="0 0 13 14" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M5.41663 3.75008H3.24996C2.96264 3.75008 2.68709 3.86422 2.48393 4.06738C2.28076 4.27055 2.16663 4.5461 2.16663 4.83341V10.2501C2.16663 10.5374 2.28076 10.8129 2.48393 11.0161C2.68709 11.2193 2.96264 11.3334 3.24996 11.3334H8.66663C8.95394 11.3334 9.22949 11.2193 9.43266 11.0161C9.63582 10.8129 9.74996 10.5374 9.74996 10.2501V8.08341M7.58329 2.66675H10.8333M10.8333 2.66675V5.91675M10.8333 2.66675L5.41663 8.08341" stroke="#1C64F2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 595 B |