add chatflow app event convert

This commit is contained in:
takatost 2024-07-31 02:21:35 +08:00
parent 0818b7b078
commit 917aacbf7f
19 changed files with 1566 additions and 239 deletions

View File

@ -1,6 +1,6 @@
import time import time
from collections.abc import Generator from collections.abc import Generator, Mapping
from typing import Optional, Union from typing import Any, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -342,7 +342,7 @@ class AppRunner:
self, app_id: str, self, app_id: str,
tenant_id: str, tenant_id: str,
app_generate_entity: AppGenerateEntity, app_generate_entity: AppGenerateEntity,
inputs: dict, inputs: Mapping[str, Any],
query: str, query: str,
message_id: str, message_id: str,
) -> tuple[bool, dict, str]: ) -> tuple[bool, dict, str]:

View File

View File

@ -0,0 +1,101 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
from core.app.app_config.features.suggested_questions_after_answer.manager import (
SuggestedQuestionsAfterAnswerConfigManager,
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.model import App, AppMode
from models.workflow import Workflow
class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
"""
Advanced Chatbot App Config Entity.
"""
pass
class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App,
workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=app_mode,
workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=features_dict
),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict, app_mode)
)
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
"""
Validate for advanced chat app model config
:param tenant_id: tenant id
:param config: app model config args
:param only_structure_validate: if True, only structure validation will be performed
"""
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config,
is_vision=False
)
related_config_keys.extend(current_related_config_keys)
# opening_statement
config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# return retriever resource
config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id,
config=config,
only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@ -0,0 +1,189 @@
import logging
import os
import uuid
from collections.abc import Generator
from typing import Union
from pydantic import ValidationError
import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.apps.chatflow.app_runner import AdvancedChatAppRunner
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
def generate(
self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
"""
if not args.get('query'):
raise ValueError('query is required')
query = args['query']
if not isinstance(query, str):
raise ValueError('query must be a string')
query = query.replace('\x00', '')
inputs = args['inputs']
extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', False)
}
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
# parse files
files = args['files'] if args.get('files') else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
else:
file_objs = []
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
# get tracing instance
trace_manager = TraceQueueManager(app_id=app_model.id)
if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
conversation=conversation,
stream=stream
)
def _generate(self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation = None,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
is_first_conversation = False
if not conversation:
is_first_conversation = True
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
db.session.refresh(conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
)
try:
# chatbot app
runner = AdvancedChatAppRunner()
response = runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
raise
except ValidationError as e:
logger.exception("Validation Error when generating")
raise e
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException()
else:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
logger.exception(e)
raise e
except InvokeError as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
logger.exception("Error when generating")
raise e
except Exception as e:
logger.exception("Unknown Error when generating")
raise e
finally:
db.session.close()
return AdvancedChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)

View File

@ -0,0 +1,422 @@
import logging
import os
import time
from collections.abc import Generator, Mapping
from typing import Any, Optional, cast
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAnnotationReplyEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.graph_engine.entities.event import (
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class AdvancedChatAppRunner(AppRunner):
"""
AdvancedChat Application Runner
"""
def run(self, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message) -> Generator[AppQueueEvent, None, None]:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
:return:
"""
app_config = application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity
):
return
db.session.close()
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
},
call_depth=application_generate_entity.call_depth
)
generator = workflow_entry.run(
callbacks=workflow_callbacks,
)
for event in generator:
if isinstance(event, GraphRunStartedEvent):
queue_manager.publish(
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, GraphRunSucceededEvent):
queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, GraphRunFailedEvent):
queue_manager.publish(
QueueWorkflowFailedEvent(error=event.error),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, NodeRunStartedEvent):
queue_manager.publish(
QueueNodeStartedEvent(
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
predecessor_node_id=event.predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, NodeRunSucceededEvent):
queue_manager.publish(
QueueNodeSucceededEvent(
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result else {},
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, NodeRunFailedEvent):
queue_manager.publish(
QueueNodeFailedEvent(
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result
and event.route_node_state.node_run_result.error
else "Unknown error"
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, NodeRunStreamChunkEvent):
queue_manager.publish(
QueueTextChunkEvent(
text=event.chunk_content
), PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, NodeRunRetrieverResourceEvent):
queue_manager.publish(
QueueRetrieverResourcesEvent(
retriever_resources=event.retriever_resources
), PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, ParallelBranchRunStartedEvent):
queue_manager.publish(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
queue_manager.publish(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, ParallelBranchRunFailedEvent):
queue_manager.publish(
QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
error=event.error
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, IterationRunStartedEvent):
queue_manager.publish(
QueueIterationStartEvent(
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, IterationRunNextEvent):
queue_manager.publish(
QueueIterationNextEvent(
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
queue_manager.publish(
QueueIterationCompletedEvent(
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None
),
PublishFrom.APPLICATION_MANAGER
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id
).first()
# return workflow
return workflow
def handle_input_moderation(
self, queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str
) -> bool:
"""
Handle input moderation
:param queue_manager: application queue manager
:param app_record: app record
:param app_generate_entity: application generate entity
:param inputs: inputs
:param query: query
:param message_id: message id
:return:
"""
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
app_id=app_record.id,
tenant_id=app_generate_entity.app_config.tenant_id,
app_generate_entity=app_generate_entity,
inputs=inputs,
query=query,
message_id=message_id,
)
except ModerationException as e:
self._stream_output(
queue_manager=queue_manager,
text=str(e),
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
)
return True
return False
def handle_annotation_reply(self, app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
"""
Handle annotation reply
:param app_record: app record
:param message: message
:param query: query
:param queue_manager: application queue manager
:param app_generate_entity: application generate entity
"""
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
message=message,
query=query,
user_id=app_generate_entity.user_id,
invoke_from=app_generate_entity.invoke_from
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER
)
self._stream_output(
queue_manager=queue_manager,
text=annotation_reply.content,
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
)
return True
return False
def _stream_output(self, queue_manager: AppQueueManager,
text: str,
stream: bool,
stopped_by: QueueStopEvent.StopBy) -> None:
"""
Direct output
:param queue_manager: application queue manager
:param text: text
:param stream: stream
:return:
"""
if stream:
index = 0
for token in text:
queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.APPLICATION_MANAGER
)
index += 1
time.sleep(0.01)
else:
queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(
QueueStopEvent(stopped_by=stopped_by),
PublishFrom.APPLICATION_MANAGER
)

View File

@ -0,0 +1,450 @@
import json
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
)
from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
AdvancedChatTaskState,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.graph_engine.entities.event import GraphRunStartedEvent, NodeRunStartedEvent
from events.message_event import message_was_created
from extensions.ext_database import db
from models.account import Account
from models.model import Conversation, EndUser, Message
from models.workflow import (
Workflow,
WorkflowRunStatus,
)
logger = logging.getLogger(__name__)
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: AdvancedChatTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariable, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(
self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param conversation: conversation
:param message: message
:param user: user
:param stream: stream
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
if isinstance(self._user, EndUser):
user_id = self._user.session_id
else:
user_id = self._user.id
self._workflow = workflow
self._conversation = conversation
self._message = message
self._workflow_system_variables = {
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
}
self._task_state = AdvancedChatTaskState(
usage=LLMUsage.empty_usage()
)
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._stream_generate_routes = self._get_stream_generate_routes()
self._conversation_name_generate_thread = None
def process(self):
"""
Process generate task pipeline.
:return:
"""
db.session.refresh(self._workflow)
db.session.refresh(self._user)
db.session.close()
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation,
self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
"""
Process blocking response.
:return:
"""
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
extras['metadata'] = stream_response.metadata
return ChatbotAppBlockingResponse(
task_id=stream_response.task_id,
data=ChatbotAppBlockingResponse.Data(
id=self._message.id,
mode=self._conversation.mode,
conversation_id=self._conversation.id,
message_id=self._message.id,
answer=self._task_state.answer,
created_at=int(self._message.created_at.timestamp()),
**extras
)
)
else:
continue
raise Exception('Queue listening stopped unexpectedly.')
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
"""
To stream response.
:return:
"""
for stream_response in generator:
yield ChatbotAppStreamResponse(
conversation_id=self._conversation.id,
message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response
)
def _listenAudioMsg(self, publisher, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
break
audio_trunk = publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
break
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if publisher:
publisher.publish(message=message)
event = message.event
if isinstance(event, QueueErrorEvent):
err = self._handle_error(event, self._message)
yield self._error_to_stream_response(err)
break
elif isinstance(event, GraphRunStartedEvent):
workflow_run = self._handle_workflow_start()
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._message.workflow_run_id = workflow_run.id
db.session.commit()
db.session.refresh(self._message)
db.session.close()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
elif isinstance(event, NodeRunStartedEvent):
workflow_node_execution = self._handle_node_start(event)
yield self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
yield self._workflow_node_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(
event, conversation_id=self._conversation.id, trace_manager=trace_manager
)
if workflow_run:
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
if workflow_run.status == WorkflowRunStatus.FAILED.value:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
if isinstance(event, QueueStopEvent):
# Save message
self._save_message()
yield self._message_end_to_stream_response()
break
else:
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message()
yield self._message_end_to_stream_response()
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
continue
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
if should_direct_answer:
continue
self._task_state.answer += delta_text
yield self._message_to_stream_response(delta_text, self._message.id)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
else:
continue
if publisher:
publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self) -> None:
"""
Save message.
:return:
"""
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
if self._task_state.metadata and self._task_state.metadata.get('usage'):
usage = LLMUsage(**self._task_state.metadata['usage'])
self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
self._message.answer_tokens = usage.completion_tokens
self._message.answer_unit_price = usage.completion_unit_price
self._message.answer_price_unit = usage.completion_price_unit
self._message.total_price = usage.total_price
self._message.currency = usage.currency
db.session.commit()
message_was_created.send(
self._message,
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""
Message end to stream response.
:return:
"""
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message.id,
**extras
)
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}
def _handle_output_moderation_chunk(self, text: str) -> bool:
"""
Handle output moderation chunk.
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish(
QueueTextChunkEvent(
text=self._task_state.answer
), PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
)
return True
else:
self._output_moderation_handler.append_new_token(text)
return False

View File

@ -46,7 +46,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
elif isinstance(event, GraphRunSucceededEvent): elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[on_workflow_run_succeeded]", color='green') self.print_text("\n[on_workflow_run_succeeded]", color='green')
elif isinstance(event, GraphRunFailedEvent): elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[on_workflow_run_failed] reason: {event.reason}", color='red') self.print_text(f"\n[on_workflow_run_failed] reason: {event.error}", color='red')
elif isinstance(event, NodeRunStartedEvent): elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started( self.on_workflow_node_execute_started(
graph=graph, graph=graph,

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional
@ -5,7 +6,7 @@ from pydantic import BaseModel, field_validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
class QueueEvent(str, Enum): class QueueEvent(str, Enum):
@ -31,6 +32,9 @@ class QueueEvent(str, Enum):
ANNOTATION_REPLY = "annotation_reply" ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought" AGENT_THOUGHT = "agent_thought"
MESSAGE_FILE = "message_file" MESSAGE_FILE = "message_file"
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
ERROR = "error" ERROR = "error"
PING = "ping" PING = "ping"
STOP = "stop" STOP = "stop"
@ -38,7 +42,7 @@ class QueueEvent(str, Enum):
class AppQueueEvent(BaseModel): class AppQueueEvent(BaseModel):
""" """
QueueEvent entity QueueEvent abstract entity
""" """
event: QueueEvent event: QueueEvent
@ -46,6 +50,7 @@ class AppQueueEvent(BaseModel):
class QueueLLMChunkEvent(AppQueueEvent): class QueueLLMChunkEvent(AppQueueEvent):
""" """
QueueLLMChunkEvent entity QueueLLMChunkEvent entity
Only for basic mode apps
""" """
event: QueueEvent = QueueEvent.LLM_CHUNK event: QueueEvent = QueueEvent.LLM_CHUNK
chunk: LLMResultChunk chunk: LLMResultChunk
@ -58,11 +63,15 @@ class QueueIterationStartEvent(AppQueueEvent):
node_id: str node_id: str
node_type: NodeType node_type: NodeType
node_data: BaseNodeData node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
node_run_index: int node_run_index: int
inputs: dict = None inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
metadata: Optional[dict] = None metadata: Optional[Mapping[str, Any]] = None
class QueueIterationNextEvent(AppQueueEvent): class QueueIterationNextEvent(AppQueueEvent):
""" """
@ -73,6 +82,10 @@ class QueueIterationNextEvent(AppQueueEvent):
index: int index: int
node_id: str node_id: str
node_type: NodeType node_type: NodeType
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
node_run_index: int node_run_index: int
output: Optional[Any] = None # output for the current iteration output: Optional[Any] = None # output for the current iteration
@ -93,13 +106,23 @@ class QueueIterationCompletedEvent(AppQueueEvent):
""" """
QueueIterationCompletedEvent entity QueueIterationCompletedEvent entity
""" """
event:QueueEvent = QueueEvent.ITERATION_COMPLETED event: QueueEvent = QueueEvent.ITERATION_COMPLETED
node_id: str node_id: str
node_type: NodeType node_type: NodeType
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
node_run_index: int node_run_index: int
outputs: dict inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
error: Optional[str] = None
class QueueTextChunkEvent(AppQueueEvent): class QueueTextChunkEvent(AppQueueEvent):
""" """
@ -190,6 +213,10 @@ class QueueNodeStartedEvent(AppQueueEvent):
node_data: BaseNodeData node_data: BaseNodeData
node_run_index: int = 1 node_run_index: int = 1
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
class QueueNodeSucceededEvent(AppQueueEvent): class QueueNodeSucceededEvent(AppQueueEvent):
@ -201,11 +228,15 @@ class QueueNodeSucceededEvent(AppQueueEvent):
node_id: str node_id: str
node_type: NodeType node_type: NodeType
node_data: BaseNodeData node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
inputs: Optional[dict] = None inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[dict] = None process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[dict] = None outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[dict] = None execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
error: Optional[str] = None error: Optional[str] = None
@ -219,10 +250,14 @@ class QueueNodeFailedEvent(AppQueueEvent):
node_id: str node_id: str
node_type: NodeType node_type: NodeType
node_data: BaseNodeData node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
inputs: Optional[dict] = None inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[dict] = None process_data: Optional[Mapping[str, Any]] = None
process_data: Optional[dict] = None outputs: Optional[Mapping[str, Any]] = None
error: str error: str
@ -277,7 +312,7 @@ class QueueStopEvent(AppQueueEvent):
class QueueMessage(BaseModel): class QueueMessage(BaseModel):
""" """
QueueMessage entity QueueMessage abstract entity
""" """
task_id: str task_id: str
app_mode: str app_mode: str
@ -297,3 +332,34 @@ class WorkflowQueueMessage(QueueMessage):
WorkflowQueueMessage entity WorkflowQueueMessage entity
""" """
pass pass
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
"""
QueueParallelBranchRunStartedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
parallel_id: str
parallel_start_node_id: str
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
"""
QueueParallelBranchRunSucceededEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
parallel_id: str
parallel_start_node_id: str
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
"""
QueueParallelBranchRunFailedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
parallel_id: str
parallel_start_node_id: str
error: str

View File

@ -84,9 +84,9 @@ class NodeRunResult(BaseModel):
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[Mapping[str, Any]] = None # node inputs inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict] = None # process data process_data: Optional[Mapping[str, Any]] = None # process data
outputs: Optional[Mapping[str, Any]] = None # node outputs outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches edge_source_handle: Optional[str] = None # source handle id of node with multiple branches

View File

@ -1,13 +1,17 @@
from collections.abc import Mapping
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class GraphEngineEvent(BaseModel): class GraphEngineEvent(BaseModel):
pass pass
########################################### ###########################################
# Graph Events # Graph Events
########################################### ###########################################
@ -26,7 +30,7 @@ class GraphRunSucceededEvent(BaseGraphEvent):
class GraphRunFailedEvent(BaseGraphEvent): class GraphRunFailedEvent(BaseGraphEvent):
reason: str = Field(..., description="failed reason") error: str = Field(..., description="failed reason")
########################################### ###########################################
@ -35,16 +39,20 @@ class GraphRunFailedEvent(BaseGraphEvent):
class BaseNodeEvent(GraphEngineEvent): class BaseNodeEvent(GraphEngineEvent):
node_id: str = Field(..., description="node id")
node_type: NodeType = Field(..., description="node type")
node_data: BaseNodeData = Field(..., description="node data")
route_node_state: RouteNodeState = Field(..., description="route node state") route_node_state: RouteNodeState = Field(..., description="route node state")
parallel_id: Optional[str] = None parallel_id: Optional[str] = None
"""parallel id if node is in parallel""" """parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel""" """parallel start node id if node is in parallel"""
# iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration") in_iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
class NodeRunStartedEvent(BaseNodeEvent): class NodeRunStartedEvent(BaseNodeEvent):
pass predecessor_node_id: Optional[str] = None
"""predecessor node id"""
class NodeRunStreamChunkEvent(BaseNodeEvent): class NodeRunStreamChunkEvent(BaseNodeEvent):
@ -63,7 +71,7 @@ class NodeRunSucceededEvent(BaseNodeEvent):
class NodeRunFailedEvent(BaseNodeEvent): class NodeRunFailedEvent(BaseNodeEvent):
pass error: str = Field(..., description="error")
########################################### ###########################################
@ -74,6 +82,7 @@ class NodeRunFailedEvent(BaseNodeEvent):
class BaseParallelBranchEvent(GraphEngineEvent): class BaseParallelBranchEvent(GraphEngineEvent):
parallel_id: str = Field(..., description="parallel id") parallel_id: str = Field(..., description="parallel id")
parallel_start_node_id: str = Field(..., description="parallel start node id") parallel_start_node_id: str = Field(..., description="parallel start node id")
in_iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent): class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
@ -85,7 +94,7 @@ class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent): class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
reason: str = Field(..., description="failed reason") error: str = Field(..., description="failed reason")
########################################### ###########################################
@ -94,11 +103,19 @@ class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
class BaseIterationEvent(GraphEngineEvent): class BaseIterationEvent(GraphEngineEvent):
iteration_id: str = Field(..., description="iteration id") iteration_node_id: str = Field(..., description="iteration node id")
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
iteration_node_data: BaseNodeData = Field(..., description="node data")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
class IterationRunStartedEvent(BaseIterationEvent): class IterationRunStartedEvent(BaseIterationEvent):
pass inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
class IterationRunNextEvent(BaseIterationEvent): class IterationRunNextEvent(BaseIterationEvent):
@ -107,11 +124,18 @@ class IterationRunNextEvent(BaseIterationEvent):
class IterationRunSucceededEvent(BaseIterationEvent): class IterationRunSucceededEvent(BaseIterationEvent):
pass inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
class IterationRunFailedEvent(BaseIterationEvent): class IterationRunFailedEvent(BaseIterationEvent):
reason: str = Field(..., description="failed reason") inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent

View File

@ -14,6 +14,7 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, U
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
BaseIterationEvent,
GraphEngineEvent, GraphEngineEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
@ -32,6 +33,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_classes from core.workflow.nodes.node_mapping import node_classes
@ -84,16 +86,16 @@ class GraphEngine:
yield GraphRunStartedEvent() yield GraphRunStartedEvent()
try: try:
stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor]
if self.init_params.workflow_type == WorkflowType.CHAT: if self.init_params.workflow_type == WorkflowType.CHAT:
stream_processor = AnswerStreamProcessor( stream_processor_cls = AnswerStreamProcessor
graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool
)
else: else:
stream_processor = EndStreamProcessor( stream_processor_cls = EndStreamProcessor
graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool stream_processor = stream_processor_cls(
) graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool
)
# run graph # run graph
generator = stream_processor.process( generator = stream_processor.process(
@ -104,21 +106,21 @@ class GraphEngine:
try: try:
yield item yield item
if isinstance(item, NodeRunFailedEvent): if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.') yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.')
return return
except Exception as e: except Exception as e:
logger.exception(f"Graph run failed: {str(e)}") logger.exception(f"Graph run failed: {str(e)}")
yield GraphRunFailedEvent(reason=str(e)) yield GraphRunFailedEvent(error=str(e))
return return
# trigger graph run success event # trigger graph run success event
yield GraphRunSucceededEvent() yield GraphRunSucceededEvent()
except GraphRunFailedError as e: except GraphRunFailedError as e:
yield GraphRunFailedEvent(reason=e.error) yield GraphRunFailedEvent(error=e.error)
return return
except Exception as e: except Exception as e:
logger.exception("Unknown Error when graph running") logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(reason=str(e)) yield GraphRunFailedEvent(error=str(e))
raise e raise e
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]: def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
@ -145,11 +147,34 @@ class GraphEngine:
node_id=next_node_id node_id=next_node_id
) )
# get node config
node_id = route_node_state.node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(f'Node {node_id} config not found.')
# convert to specific node
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
# init workflow run state
node_instance = node_cls( # type: ignore
config=node_config,
graph_init_params=self.init_params,
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
previous_node_id=previous_node_id
)
try: try:
# run node # run node
yield from self._run_node( yield from self._run_node(
node_instance=node_instance,
route_node_state=route_node_state, route_node_state=route_node_state,
previous_node_id=previous_route_node_state.node_id if previous_route_node_state else None,
parallel_id=in_parallel_id, parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id parallel_start_node_id=parallel_start_node_id
) )
@ -166,6 +191,10 @@ class GraphEngine:
route_node_state.status = RouteNodeState.Status.FAILED route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = str(e) route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent( yield NodeRunFailedEvent(
error=str(e),
node_id=next_node_id,
node_type=node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=in_parallel_id, parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id parallel_start_node_id=parallel_start_node_id
@ -241,7 +270,7 @@ class GraphEngine:
# new thread # new thread
for edge in edge_mappings: for edge in edge_mappings:
threading.Thread(target=self._run_parallel_node, kwargs={ threading.Thread(target=self._run_parallel_node, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'parallel_id': parallel_id, 'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id, 'parallel_start_node_id': edge.target_node_id,
'q': q 'q': q
@ -309,21 +338,21 @@ class GraphEngine:
q.put(ParallelBranchRunFailedEvent( q.put(ParallelBranchRunFailedEvent(
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
reason=e.error error=e.error
)) ))
except Exception as e: except Exception as e:
logger.exception("Unknown Error when generating in parallel") logger.exception("Unknown Error when generating in parallel")
q.put(ParallelBranchRunFailedEvent( q.put(ParallelBranchRunFailedEvent(
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
reason=str(e) error=str(e)
)) ))
finally: finally:
db.session.remove() db.session.remove()
def _run_node(self, def _run_node(self,
node_instance: BaseNode,
route_node_state: RouteNodeState, route_node_state: RouteNodeState,
previous_node_id: Optional[str] = None,
parallel_id: Optional[str] = None, parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]: parallel_start_node_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
""" """
@ -331,46 +360,15 @@ class GraphEngine:
""" """
# trigger node run start event # trigger node run start event
yield NodeRunStartedEvent( yield NodeRunStartedEvent(
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state, route_node_state=route_node_state,
predecessor_node_id=node_instance.previous_node_id,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id parallel_start_node_id=parallel_start_node_id
) )
# get node config
node_id = route_node_state.node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = f'Node {node_id} config not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)
return
# convert to specific node
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)
return
# init workflow run state
node_instance = node_cls( # type: ignore
config=node_config,
graph_init_params=self.init_params,
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
previous_node_id=previous_node_id
)
db.session.close() db.session.close()
self.graph_runtime_state.node_run_steps += 1 self.graph_runtime_state.node_run_steps += 1
@ -385,6 +383,10 @@ class GraphEngine:
if run_result.status == WorkflowNodeExecutionStatus.FAILED: if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent( yield NodeRunFailedEvent(
error=route_node_state.failed_reason,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id parallel_start_node_id=parallel_start_node_id
@ -401,12 +403,15 @@ class GraphEngine:
for variable_key, variable_value in run_result.outputs.items(): for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively # append variables to variable pool recursively
self._append_variables_recursively( self._append_variables_recursively(
node_id=node_id, node_id=node_instance.node_id,
variable_key_list=[variable_key], variable_key_list=[variable_key],
variable_value=variable_value variable_value=variable_value
) )
yield NodeRunSucceededEvent( yield NodeRunSucceededEvent(
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id parallel_start_node_id=parallel_start_node_id
@ -415,6 +420,9 @@ class GraphEngine:
break break
elif isinstance(item, RunStreamChunkEvent): elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent( yield NodeRunStreamChunkEvent(
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content, chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector, from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state, route_node_state=route_node_state,
@ -423,17 +431,28 @@ class GraphEngine:
) )
elif isinstance(item, RunRetrieverResourceEvent): elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent( yield NodeRunRetrieverResourceEvent(
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources, retriever_resources=item.retriever_resources,
context=item.context, context=item.context,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
) )
elif isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
except GenerateTaskStoppedException: except GenerateTaskStoppedException:
# trigger node run failed event # trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped." route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent( yield NodeRunFailedEvent(
error="Workflow stopped.",
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,

View File

@ -11,21 +11,20 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent, NodeRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AnswerStreamProcessor: class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph super().__init__(graph, variable_pool)
self.variable_pool = variable_pool
self.generate_routes = graph.answer_stream_generate_routes self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {} self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0 self.route_position[answer_node_id] = 0
self.rest_node_ids = graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self, def process(self,
@ -74,58 +73,6 @@ class AnswerStreamProcessor:
self.rest_node_ids = self.graph.node_ids.copy() self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {} self.current_stream_chunk_generating_node_ids = {}
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
def _generate_stream_outputs_when_node_finished(self, def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]: ) -> Generator[GraphEngineEvent, None, None]:
@ -138,8 +85,8 @@ class AnswerStreamProcessor:
# all depends on answer node id not in rest node ids # all depends on answer node id not in rest node ids
if (event.route_node_state.node_id != answer_node_id if (event.route_node_state.node_id != answer_node_id
and (answer_node_id not in self.rest_node_ids and (answer_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))): for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
continue continue
route_position = self.route_position[answer_node_id] route_position = self.route_position[answer_node_id]
@ -149,6 +96,9 @@ class AnswerStreamProcessor:
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT: if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
route_chunk = cast(TextGenerateRouteChunk, route_chunk) route_chunk = cast(TextGenerateRouteChunk, route_chunk)
yield NodeRunStreamChunkEvent( yield NodeRunStreamChunkEvent(
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=route_chunk.text, chunk_content=route_chunk.text,
route_node_state=event.route_node_state, route_node_state=event.route_node_state,
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
@ -171,6 +121,9 @@ class AnswerStreamProcessor:
if text: if text:
yield NodeRunStreamChunkEvent( yield NodeRunStreamChunkEvent(
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text, chunk_content=text,
from_variable_selector=value_selector, from_variable_selector=value_selector,
route_node_state=event.route_node_state, route_node_state=event.route_node_state,

View File

@ -0,0 +1,71 @@
from abc import ABC, abstractmethod
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
@abstractmethod
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)

View File

@ -4,6 +4,7 @@ from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
@ -42,14 +43,14 @@ class BaseNode(ABC):
@abstractmethod @abstractmethod
def _run(self) \ def _run(self) \
-> NodeRunResult | Generator[RunEvent, None, None]: -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
""" """
Run node Run node
:return: :return:
""" """
raise NotImplementedError raise NotImplementedError
def run(self) -> Generator[RunEvent, None, None]: def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
""" """
Run node entry Run node entry
:return: :return:

View File

@ -9,18 +9,17 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent, NodeRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EndStreamProcessor: class EndStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph super().__init__(graph, variable_pool)
self.variable_pool = variable_pool
self.stream_param = graph.end_stream_param self.stream_param = graph.end_stream_param
self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy() self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy()
self.rest_node_ids = graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self, def process(self,
@ -56,64 +55,10 @@ class EndStreamProcessor:
yield event yield event
def reset(self) -> None: def reset(self) -> None:
self.end_streamed_variable_selectors = {} self.end_streamed_variable_selectors = self.graph.end_stream_param.end_stream_variable_selector_mapping.copy()
self.end_streamed_variable_selectors: dict[str, list[str]] = {
end_node_id: [] for end_node_id in self.graph.end_stream_param.end_stream_variable_selector_mapping
}
self.rest_node_ids = self.graph.node_ids.copy() self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {} self.current_stream_chunk_generating_node_ids = {}
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
""" """
Is stream out support Is stream out support

View File

@ -1,13 +1,16 @@
import logging import logging
from collections.abc import Generator
from typing import Any, cast from typing import Any, cast
from configs import dify_config from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseIterationState from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
BaseGraphEvent, BaseGraphEvent,
BaseNodeEvent,
BaseParallelBranchEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
InNodeEvent,
IterationRunFailedEvent, IterationRunFailedEvent,
IterationRunNextEvent, IterationRunNextEvent,
IterationRunStartedEvent, IterationRunStartedEvent,
@ -17,7 +20,7 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iteration.entities import IterationNodeData from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.utils.condition.entities import Condition from core.workflow.utils.condition.entities import Condition
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -32,7 +35,7 @@ class IterationNode(BaseNode):
_node_data_cls = IterationNodeData _node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION _node_type = NodeType.ITERATION
def _run(self) -> BaseIterationState: def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
""" """
Run the node. Run the node.
""" """
@ -42,6 +45,10 @@ class IterationNode(BaseNode):
if not isinstance(iterator_list_value, list): if not isinstance(iterator_list_value, list):
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
inputs = {
"iterator_selector": iterator_list_value
}
root_node_id = self.node_data.start_node_id root_node_id = self.node_data.start_node_id
graph_config = self.graph_config graph_config = self.graph_config
@ -117,21 +124,42 @@ class IterationNode(BaseNode):
) )
yield IterationRunStartedEvent( yield IterationRunStartedEvent(
iteration_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
inputs=inputs,
metadata={
"iterator_length": 1
},
predecessor_node_id=self.previous_node_id
) )
yield IterationRunNextEvent( yield IterationRunNextEvent(
iteration_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=0, index=0,
output=None pre_iteration_output=None
) )
outputs: list[Any] = []
try: try:
# run workflow # run workflow
rst = graph_engine.run() rst = graph_engine.run()
outputs: list[Any] = []
for event in rst: for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if isinstance(event, NodeRunSucceededEvent): if isinstance(event, NodeRunSucceededEvent):
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
event.route_node_state.node_run_result.metadata = metadata
yield event yield event
# handle iteration run result # handle iteration run result
@ -158,22 +186,35 @@ class IterationNode(BaseNode):
) )
yield IterationRunNextEvent( yield IterationRunNextEvent(
iteration_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index, index=next_index,
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None pre_iteration_output=jsonable_encoder(
current_iteration_output) if current_iteration_output else None
) )
elif isinstance(event, BaseGraphEvent): elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent): if isinstance(event, GraphRunFailedEvent):
# iteration run failed # iteration run failed
yield IterationRunFailedEvent( yield IterationRunFailedEvent(
iteration_id=self.node_id, iteration_node_id=self.node_id,
reason=event.reason, iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
error=event.error,
) )
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
error=event.reason, error=event.error,
) )
) )
break break
@ -181,7 +222,17 @@ class IterationNode(BaseNode):
yield event yield event
yield IterationRunSucceededEvent( yield IterationRunSucceededEvent(
iteration_id=self.node_id, iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
}
) )
yield RunCompletedEvent( yield RunCompletedEvent(
@ -196,8 +247,18 @@ class IterationNode(BaseNode):
# iteration run failed # iteration run failed
logger.exception("Iteration run failed") logger.exception("Iteration run failed")
yield IterationRunFailedEvent( yield IterationRunFailedEvent(
iteration_id=self.node_id, iteration_node_id=self.node_id,
reason=str(e), iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
error=str(e),
) )
yield RunCompletedEvent( yield RunCompletedEvent(

View File

@ -27,6 +27,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import ( from core.workflow.nodes.llm.entities import (
@ -42,11 +43,19 @@ from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage
class LLMNode(BaseNode): class LLMNode(BaseNode):
_node_data_cls = LLMNodeData _node_data_cls = LLMNodeData
_node_type = NodeType.LLM _node_type = NodeType.LLM
def _run(self) -> Generator[RunEvent, None, None]: def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
""" """
Run node Run node
:return: :return:
@ -167,7 +176,7 @@ class LLMNode(BaseNode):
model_instance: ModelInstance, model_instance: ModelInstance,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None) \ stop: Optional[list[str]] = None) \
-> Generator[RunEvent, None, None]: -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
""" """
Invoke large language model Invoke large language model
:param node_data_model: node data model :param node_data_model: node data model
@ -201,7 +210,7 @@ class LLMNode(BaseNode):
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \ def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
-> Generator[RunEvent, None, None]: -> Generator[RunEvent | ModelInvokeCompleted, None, None]:
""" """
Handle invoke result Handle invoke result
:param invoke_result: invoke result :param invoke_result: invoke result
@ -762,11 +771,3 @@ class LLMNode(BaseNode):
} }
} }
} }
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage

View File

@ -26,24 +26,21 @@ logger = logging.getLogger(__name__)
class WorkflowEntry: class WorkflowEntry:
def run( def __init__(
self, self,
*,
workflow: Workflow, workflow: Workflow,
user_id: str, user_id: str,
user_from: UserFrom, user_from: UserFrom,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
user_inputs: Mapping[str, Any], user_inputs: Mapping[str, Any],
system_inputs: Mapping[SystemVariable, Any], system_inputs: Mapping[SystemVariable, Any],
callbacks: Sequence[WorkflowCallback],
call_depth: int = 0 call_depth: int = 0
) -> Generator[GraphEngineEvent, None, None]: ) -> None:
""" """
:param workflow: Workflow instance :param workflow: Workflow instance
:param user_id: user id :param user_id: user id
:param user_from: user from :param user_from: user from
:param invoke_from: invoke from service-api, web-app, debugger, explore :param invoke_from: invoke from service-api, web-app, debugger, explore
:param callbacks: workflow callbacks
:param user_inputs: user variables inputs :param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files :param system_inputs: system inputs, like: query, files
:param call_depth: call depth :param call_depth: call depth
@ -82,7 +79,7 @@ class WorkflowEntry:
) )
# init workflow run state # init workflow run state
graph_engine = GraphEngine( self.graph_engine = GraphEngine(
tenant_id=workflow.tenant_id, tenant_id=workflow.tenant_id,
app_id=workflow.app_id, app_id=workflow.app_id,
workflow_type=WorkflowType.value_of(workflow.type), workflow_type=WorkflowType.value_of(workflow.type),
@ -98,6 +95,16 @@ class WorkflowEntry:
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
) )
def run(
self,
*,
callbacks: Sequence[WorkflowCallback],
) -> Generator[GraphEngineEvent, None, None]:
"""
:param callbacks: workflow callbacks
"""
graph_engine = self.graph_engine
try: try:
# run workflow # run workflow
generator = graph_engine.run() generator = graph_engine.run()
@ -105,7 +112,7 @@ class WorkflowEntry:
if callbacks: if callbacks:
for callback in callbacks: for callback in callbacks:
callback.on_event( callback.on_event(
graph=graph, graph=self.graph_engine.graph,
graph_init_params=graph_engine.init_params, graph_init_params=graph_engine.init_params,
graph_runtime_state=graph_engine.graph_runtime_state, graph_runtime_state=graph_engine.graph_runtime_state,
event=event event=event
@ -117,11 +124,11 @@ class WorkflowEntry:
if callbacks: if callbacks:
for callback in callbacks: for callback in callbacks:
callback.on_event( callback.on_event(
graph=graph, graph=self.graph_engine.graph,
graph_init_params=graph_engine.init_params, graph_init_params=graph_engine.init_params,
graph_runtime_state=graph_engine.graph_runtime_state, graph_runtime_state=graph_engine.graph_runtime_state,
event=GraphRunFailedEvent( event=GraphRunFailedEvent(
reason=str(e) error=str(e)
) )
) )
return return
@ -161,6 +168,9 @@ class WorkflowEntry:
node_type = NodeType.value_of(node_config.get('data', {}).get('type')) node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type) node_cls = node_classes.get(node_type)
if not node_cls:
raise ValueError(f'Node class not found for node type {node_type}')
# init workflow run state # init workflow run state
node_instance = node_cls( node_instance = node_cls(
tenant_id=workflow.tenant_id, tenant_id=workflow.tenant_id,
@ -195,7 +205,7 @@ class WorkflowEntry:
node_instance=node_instance node_instance=node_instance
) )
# run node # run node TODO
node_run_result = node_instance.run( node_run_result = node_instance.run(
variable_pool=variable_pool variable_pool=variable_pool
) )

View File

@ -1,7 +1,7 @@
from collections.abc import Generator from collections.abc import Generator
from datetime import datetime, timezone from datetime import datetime, timezone
from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
GraphEngineEvent, GraphEngineEvent,
@ -12,6 +12,7 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.start.entities import StartNodeData
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
@ -37,7 +38,14 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
parallel = graph.parallel_mapping.get(parallel_id) parallel = graph.parallel_mapping.get(parallel_id)
parallel_start_node_id = parallel.start_from_node_id if parallel else None parallel_start_node_id = parallel.start_from_node_id if parallel else None
node_config = graph.node_id_config_mapping[next_node_id]
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
yield NodeRunStartedEvent( yield NodeRunStartedEvent(
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=graph.node_parallel_mapping.get(next_node_id), parallel_id=graph.node_parallel_mapping.get(next_node_id),
parallel_start_node_id=parallel_start_node_id parallel_start_node_id=parallel_start_node_id
@ -47,6 +55,9 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
length = int(next_node_id[-1]) length = int(next_node_id[-1])
for i in range(0, length): for i in range(0, length):
yield NodeRunStreamChunkEvent( yield NodeRunStreamChunkEvent(
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
chunk_content=str(i), chunk_content=str(i),
route_node_state=route_node_state, route_node_state=route_node_state,
from_variable_selector=[next_node_id, "text"], from_variable_selector=[next_node_id, "text"],
@ -57,6 +68,9 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
route_node_state.status = RouteNodeState.Status.SUCCESS route_node_state.status = RouteNodeState.Status.SUCCESS
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
yield NodeRunSucceededEvent( yield NodeRunSucceededEvent(
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id parallel_start_node_id=parallel_start_node_id