add chatflow app event convert

This commit is contained in:
takatost 2024-07-31 02:21:35 +08:00
parent 0818b7b078
commit 917aacbf7f
19 changed files with 1566 additions and 239 deletions

View File

@ -1,6 +1,6 @@
import time
from collections.abc import Generator
from typing import Optional, Union
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -342,7 +342,7 @@ class AppRunner:
self, app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: dict,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> tuple[bool, dict, str]:

View File

View File

@ -0,0 +1,101 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
from core.app.app_config.features.suggested_questions_after_answer.manager import (
SuggestedQuestionsAfterAnswerConfigManager,
)
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.model import App, AppMode
from models.workflow import Workflow
class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
"""
Advanced Chatbot App Config Entity.
"""
pass
class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App,
workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_mode=app_mode,
workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=features_dict
),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict, app_mode)
)
return app_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
"""
Validate for advanced chat app model config
:param tenant_id: tenant id
:param config: app model config args
:param only_structure_validate: if True, only structure validation will be performed
"""
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config,
is_vision=False
)
related_config_keys.extend(current_related_config_keys)
# opening_statement
config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# return retriever resource
config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id,
config=config,
only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@ -0,0 +1,189 @@
import logging
import os
import uuid
from collections.abc import Generator
from typing import Union
from pydantic import ValidationError
import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.apps.chatflow.app_runner import AdvancedChatAppRunner
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
def generate(
self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
"""
if not args.get('query'):
raise ValueError('query is required')
query = args['query']
if not isinstance(query, str):
raise ValueError('query must be a string')
query = query.replace('\x00', '')
inputs = args['inputs']
extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', False)
}
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
# parse files
files = args['files'] if args.get('files') else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
else:
file_objs = []
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
# get tracing instance
trace_manager = TraceQueueManager(app_id=app_model.id)
if invoke_from == InvokeFrom.DEBUGGER:
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
conversation=conversation,
stream=stream
)
def _generate(self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation = None,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
is_first_conversation = False
if not conversation:
is_first_conversation = True
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
db.session.refresh(conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
)
try:
# chatbot app
runner = AdvancedChatAppRunner()
response = runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
raise
except ValidationError as e:
logger.exception("Validation Error when generating")
raise e
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException()
else:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
logger.exception(e)
raise e
except InvokeError as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
logger.exception("Error when generating")
raise e
except Exception as e:
logger.exception("Unknown Error when generating")
raise e
finally:
db.session.close()
return AdvancedChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)

View File

@ -0,0 +1,422 @@
import logging
import os
import time
from collections.abc import Generator, Mapping
from typing import Any, Optional, cast
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAnnotationReplyEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.graph_engine.entities.event import (
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class AdvancedChatAppRunner(AppRunner):
"""
AdvancedChat Application Runner
"""
def run(self, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message) -> Generator[AppQueueEvent, None, None]:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
:return:
"""
app_config = application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity
):
return
db.session.close()
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.QUERY: query,
SystemVariable.FILES: files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
},
call_depth=application_generate_entity.call_depth
)
generator = workflow_entry.run(
callbacks=workflow_callbacks,
)
for event in generator:
if isinstance(event, GraphRunStartedEvent):
queue_manager.publish(
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, GraphRunSucceededEvent):
queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, GraphRunFailedEvent):
queue_manager.publish(
QueueWorkflowFailedEvent(error=event.error),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, NodeRunStartedEvent):
queue_manager.publish(
QueueNodeStartedEvent(
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
predecessor_node_id=event.predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, NodeRunSucceededEvent):
queue_manager.publish(
QueueNodeSucceededEvent(
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result else {},
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, NodeRunFailedEvent):
queue_manager.publish(
QueueNodeFailedEvent(
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result
and event.route_node_state.node_run_result.error
else "Unknown error"
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, NodeRunStreamChunkEvent):
queue_manager.publish(
QueueTextChunkEvent(
text=event.chunk_content
), PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, NodeRunRetrieverResourceEvent):
queue_manager.publish(
QueueRetrieverResourcesEvent(
retriever_resources=event.retriever_resources
), PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, ParallelBranchRunStartedEvent):
queue_manager.publish(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
queue_manager.publish(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, ParallelBranchRunFailedEvent):
queue_manager.publish(
QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
error=event.error
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, IterationRunStartedEvent):
queue_manager.publish(
QueueIterationStartEvent(
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, IterationRunNextEvent):
queue_manager.publish(
QueueIterationNextEvent(
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
),
PublishFrom.APPLICATION_MANAGER
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
queue_manager.publish(
QueueIterationCompletedEvent(
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None
),
PublishFrom.APPLICATION_MANAGER
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id
).first()
# return workflow
return workflow
def handle_input_moderation(
self, queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str
) -> bool:
"""
Handle input moderation
:param queue_manager: application queue manager
:param app_record: app record
:param app_generate_entity: application generate entity
:param inputs: inputs
:param query: query
:param message_id: message id
:return:
"""
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
app_id=app_record.id,
tenant_id=app_generate_entity.app_config.tenant_id,
app_generate_entity=app_generate_entity,
inputs=inputs,
query=query,
message_id=message_id,
)
except ModerationException as e:
self._stream_output(
queue_manager=queue_manager,
text=str(e),
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
)
return True
return False
def handle_annotation_reply(self, app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
"""
Handle annotation reply
:param app_record: app record
:param message: message
:param query: query
:param queue_manager: application queue manager
:param app_generate_entity: application generate entity
"""
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
message=message,
query=query,
user_id=app_generate_entity.user_id,
invoke_from=app_generate_entity.invoke_from
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER
)
self._stream_output(
queue_manager=queue_manager,
text=annotation_reply.content,
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
)
return True
return False
def _stream_output(self, queue_manager: AppQueueManager,
text: str,
stream: bool,
stopped_by: QueueStopEvent.StopBy) -> None:
"""
Direct output
:param queue_manager: application queue manager
:param text: text
:param stream: stream
:return:
"""
if stream:
index = 0
for token in text:
queue_manager.publish(
QueueTextChunkEvent(
text=token
), PublishFrom.APPLICATION_MANAGER
)
index += 1
time.sleep(0.01)
else:
queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(
QueueStopEvent(stopped_by=stopped_by),
PublishFrom.APPLICATION_MANAGER
)

View File

@ -0,0 +1,450 @@
import json
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, Union
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
)
from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
AdvancedChatTaskState,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.graph_engine.entities.event import GraphRunStartedEvent, NodeRunStartedEvent
from events.message_event import message_was_created
from extensions.ext_database import db
from models.account import Account
from models.model import Conversation, EndUser, Message
from models.workflow import (
Workflow,
WorkflowRunStatus,
)
logger = logging.getLogger(__name__)
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: AdvancedChatTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariable, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(
self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param conversation: conversation
:param message: message
:param user: user
:param stream: stream
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
if isinstance(self._user, EndUser):
user_id = self._user.session_id
else:
user_id = self._user.id
self._workflow = workflow
self._conversation = conversation
self._message = message
self._workflow_system_variables = {
SystemVariable.QUERY: message.query,
SystemVariable.FILES: application_generate_entity.files,
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
}
self._task_state = AdvancedChatTaskState(
usage=LLMUsage.empty_usage()
)
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._stream_generate_routes = self._get_stream_generate_routes()
self._conversation_name_generate_thread = None
def process(self):
"""
Process generate task pipeline.
:return:
"""
db.session.refresh(self._workflow)
db.session.refresh(self._user)
db.session.close()
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation,
self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
"""
Process blocking response.
:return:
"""
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
extras['metadata'] = stream_response.metadata
return ChatbotAppBlockingResponse(
task_id=stream_response.task_id,
data=ChatbotAppBlockingResponse.Data(
id=self._message.id,
mode=self._conversation.mode,
conversation_id=self._conversation.id,
message_id=self._message.id,
answer=self._task_state.answer,
created_at=int(self._message.created_at.timestamp()),
**extras
)
)
else:
continue
raise Exception('Queue listening stopped unexpectedly.')
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
"""
To stream response.
:return:
"""
for stream_response in generator:
yield ChatbotAppStreamResponse(
conversation_id=self._conversation.id,
message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response
)
def _listenAudioMsg(self, publisher, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
break
yield response
start_listener_time = time.time()
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
break
audio_trunk = publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
break
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if publisher:
publisher.publish(message=message)
event = message.event
if isinstance(event, QueueErrorEvent):
err = self._handle_error(event, self._message)
yield self._error_to_stream_response(err)
break
elif isinstance(event, GraphRunStartedEvent):
workflow_run = self._handle_workflow_start()
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._message.workflow_run_id = workflow_run.id
db.session.commit()
db.session.refresh(self._message)
db.session.close()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
elif isinstance(event, NodeRunStartedEvent):
workflow_node_execution = self._handle_node_start(event)
yield self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
yield self._workflow_node_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(
event, conversation_id=self._conversation.id, trace_manager=trace_manager
)
if workflow_run:
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
if workflow_run.status == WorkflowRunStatus.FAILED.value:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
if isinstance(event, QueueStopEvent):
# Save message
self._save_message()
yield self._message_end_to_stream_response()
break
else:
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message()
yield self._message_end_to_stream_response()
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
continue
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
if should_direct_answer:
continue
self._task_state.answer += delta_text
yield self._message_to_stream_response(delta_text, self._message.id)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
else:
continue
if publisher:
publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self) -> None:
"""
Save message.
:return:
"""
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
if self._task_state.metadata and self._task_state.metadata.get('usage'):
usage = LLMUsage(**self._task_state.metadata['usage'])
self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
self._message.answer_tokens = usage.completion_tokens
self._message.answer_unit_price = usage.completion_unit_price
self._message.answer_price_unit = usage.completion_price_unit
self._message.total_price = usage.total_price
self._message.currency = usage.currency
db.session.commit()
message_was_created.send(
self._message,
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""
Message end to stream response.
:return:
"""
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message.id,
**extras
)
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}
def _handle_output_moderation_chunk(self, text: str) -> bool:
"""
Handle output moderation chunk.
:param text: text
:return: True if output moderation should direct output, otherwise False
"""
if self._output_moderation_handler:
if self._output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish(
QueueTextChunkEvent(
text=self._task_state.answer
), PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
)
return True
else:
self._output_moderation_handler.append_new_token(text)
return False

View File

@ -46,7 +46,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[on_workflow_run_succeeded]", color='green')
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[on_workflow_run_failed] reason: {event.reason}", color='red')
self.print_text(f"\n[on_workflow_run_failed] reason: {event.error}", color='red')
elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started(
graph=graph,

View File

@ -1,3 +1,4 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
@ -5,7 +6,7 @@ from pydantic import BaseModel, field_validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
class QueueEvent(str, Enum):
@ -31,6 +32,9 @@ class QueueEvent(str, Enum):
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
MESSAGE_FILE = "message_file"
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
ERROR = "error"
PING = "ping"
STOP = "stop"
@ -38,7 +42,7 @@ class QueueEvent(str, Enum):
class AppQueueEvent(BaseModel):
"""
QueueEvent entity
QueueEvent abstract entity
"""
event: QueueEvent
@ -46,6 +50,7 @@ class AppQueueEvent(BaseModel):
class QueueLLMChunkEvent(AppQueueEvent):
"""
QueueLLMChunkEvent entity
Only for basic mode apps
"""
event: QueueEvent = QueueEvent.LLM_CHUNK
chunk: LLMResultChunk
@ -58,11 +63,15 @@ class QueueIterationStartEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
node_run_index: int
inputs: dict = None
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
metadata: Optional[dict] = None
metadata: Optional[Mapping[str, Any]] = None
class QueueIterationNextEvent(AppQueueEvent):
"""
@ -73,6 +82,10 @@ class QueueIterationNextEvent(AppQueueEvent):
index: int
node_id: str
node_type: NodeType
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration
@ -93,13 +106,23 @@ class QueueIterationCompletedEvent(AppQueueEvent):
"""
QueueIterationCompletedEvent entity
"""
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
event: QueueEvent = QueueEvent.ITERATION_COMPLETED
node_id: str
node_type: NodeType
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
node_run_index: int
outputs: dict
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
error: Optional[str] = None
class QueueTextChunkEvent(AppQueueEvent):
"""
@ -190,6 +213,10 @@ class QueueNodeStartedEvent(AppQueueEvent):
node_data: BaseNodeData
node_run_index: int = 1
predecessor_node_id: Optional[str] = None
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
class QueueNodeSucceededEvent(AppQueueEvent):
@ -201,11 +228,15 @@ class QueueNodeSucceededEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
execution_metadata: Optional[dict] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
error: Optional[str] = None
@ -219,10 +250,14 @@ class QueueNodeFailedEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
inputs: Optional[dict] = None
outputs: Optional[dict] = None
process_data: Optional[dict] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
error: str
@ -277,7 +312,7 @@ class QueueStopEvent(AppQueueEvent):
class QueueMessage(BaseModel):
"""
QueueMessage entity
QueueMessage abstract entity
"""
task_id: str
app_mode: str
@ -297,3 +332,34 @@ class WorkflowQueueMessage(QueueMessage):
WorkflowQueueMessage entity
"""
pass
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
"""
QueueParallelBranchRunStartedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
parallel_id: str
parallel_start_node_id: str
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
"""
QueueParallelBranchRunSucceededEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
parallel_id: str
parallel_start_node_id: str
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
"""
QueueParallelBranchRunFailedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
parallel_id: str
parallel_start_node_id: str
error: str

View File

@ -84,9 +84,9 @@ class NodeRunResult(BaseModel):
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict] = None # process data
process_data: Optional[Mapping[str, Any]] = None # process data
outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches

View File

@ -1,13 +1,17 @@
from collections.abc import Mapping
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class GraphEngineEvent(BaseModel):
pass
###########################################
# Graph Events
###########################################
@ -26,7 +30,7 @@ class GraphRunSucceededEvent(BaseGraphEvent):
class GraphRunFailedEvent(BaseGraphEvent):
reason: str = Field(..., description="failed reason")
error: str = Field(..., description="failed reason")
###########################################
@ -35,16 +39,20 @@ class GraphRunFailedEvent(BaseGraphEvent):
class BaseNodeEvent(GraphEngineEvent):
node_id: str = Field(..., description="node id")
node_type: NodeType = Field(..., description="node type")
node_data: BaseNodeData = Field(..., description="node data")
route_node_state: RouteNodeState = Field(..., description="route node state")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
# iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
in_iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
class NodeRunStartedEvent(BaseNodeEvent):
pass
predecessor_node_id: Optional[str] = None
"""predecessor node id"""
class NodeRunStreamChunkEvent(BaseNodeEvent):
@ -63,7 +71,7 @@ class NodeRunSucceededEvent(BaseNodeEvent):
class NodeRunFailedEvent(BaseNodeEvent):
pass
error: str = Field(..., description="error")
###########################################
@ -74,6 +82,7 @@ class NodeRunFailedEvent(BaseNodeEvent):
class BaseParallelBranchEvent(GraphEngineEvent):
parallel_id: str = Field(..., description="parallel id")
parallel_start_node_id: str = Field(..., description="parallel start node id")
in_iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
@ -85,7 +94,7 @@ class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
reason: str = Field(..., description="failed reason")
error: str = Field(..., description="failed reason")
###########################################
@ -94,11 +103,19 @@ class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
class BaseIterationEvent(GraphEngineEvent):
iteration_id: str = Field(..., description="iteration id")
iteration_node_id: str = Field(..., description="iteration node id")
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
iteration_node_data: BaseNodeData = Field(..., description="node data")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
class IterationRunStartedEvent(BaseIterationEvent):
pass
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
class IterationRunNextEvent(BaseIterationEvent):
@ -107,11 +124,18 @@ class IterationRunNextEvent(BaseIterationEvent):
class IterationRunSucceededEvent(BaseIterationEvent):
pass
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
class IterationRunFailedEvent(BaseIterationEvent):
reason: str = Field(..., description="failed reason")
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent

View File

@ -14,6 +14,7 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, U
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
BaseIterationEvent,
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
@ -32,6 +33,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_classes
@ -84,16 +86,16 @@ class GraphEngine:
yield GraphRunStartedEvent()
try:
stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor]
if self.init_params.workflow_type == WorkflowType.CHAT:
stream_processor = AnswerStreamProcessor(
graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool
)
stream_processor_cls = AnswerStreamProcessor
else:
stream_processor = EndStreamProcessor(
graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool
)
stream_processor_cls = EndStreamProcessor
stream_processor = stream_processor_cls(
graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool
)
# run graph
generator = stream_processor.process(
@ -104,21 +106,21 @@ class GraphEngine:
try:
yield item
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.')
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.')
return
except Exception as e:
logger.exception(f"Graph run failed: {str(e)}")
yield GraphRunFailedEvent(reason=str(e))
yield GraphRunFailedEvent(error=str(e))
return
# trigger graph run success event
yield GraphRunSucceededEvent()
except GraphRunFailedError as e:
yield GraphRunFailedEvent(reason=e.error)
yield GraphRunFailedEvent(error=e.error)
return
except Exception as e:
logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(reason=str(e))
yield GraphRunFailedEvent(error=str(e))
raise e
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
@ -145,11 +147,34 @@ class GraphEngine:
node_id=next_node_id
)
# get node config
node_id = route_node_state.node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(f'Node {node_id} config not found.')
# convert to specific node
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
# init workflow run state
node_instance = node_cls( # type: ignore
config=node_config,
graph_init_params=self.init_params,
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
previous_node_id=previous_node_id
)
try:
# run node
yield from self._run_node(
node_instance=node_instance,
route_node_state=route_node_state,
previous_node_id=previous_route_node_state.node_id if previous_route_node_state else None,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
@ -166,6 +191,10 @@ class GraphEngine:
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent(
error=str(e),
node_id=next_node_id,
node_type=node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
@ -241,7 +270,7 @@ class GraphEngine:
# new thread
for edge in edge_mappings:
threading.Thread(target=self._run_parallel_node, kwargs={
'flask_app': current_app._get_current_object(),
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id,
'q': q
@ -309,21 +338,21 @@ class GraphEngine:
q.put(ParallelBranchRunFailedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
reason=e.error
error=e.error
))
except Exception as e:
logger.exception("Unknown Error when generating in parallel")
q.put(ParallelBranchRunFailedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
reason=str(e)
error=str(e)
))
finally:
db.session.remove()
def _run_node(self,
node_instance: BaseNode,
route_node_state: RouteNodeState,
previous_node_id: Optional[str] = None,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
"""
@ -331,46 +360,15 @@ class GraphEngine:
"""
# trigger node run start event
yield NodeRunStartedEvent(
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
predecessor_node_id=node_instance.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)
# get node config
node_id = route_node_state.node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = f'Node {node_id} config not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)
return
# convert to specific node
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
)
return
# init workflow run state
node_instance = node_cls( # type: ignore
config=node_config,
graph_init_params=self.init_params,
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
previous_node_id=previous_node_id
)
db.session.close()
self.graph_runtime_state.node_run_steps += 1
@ -385,6 +383,10 @@ class GraphEngine:
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
@ -401,12 +403,15 @@ class GraphEngine:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_id,
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value
)
yield NodeRunSucceededEvent(
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id
@ -415,6 +420,9 @@ class GraphEngine:
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
@ -423,17 +431,28 @@ class GraphEngine:
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
)
elif isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
except GenerateTaskStoppedException:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,

View File

@ -11,21 +11,20 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
logger = logging.getLogger(__name__)
class AnswerStreamProcessor:
class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
super().__init__(graph, variable_pool)
self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self,
@ -74,58 +73,6 @@ class AnswerStreamProcessor:
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
@ -138,8 +85,8 @@ class AnswerStreamProcessor:
# all depends on answer node id not in rest node ids
if (event.route_node_state.node_id != answer_node_id
and (answer_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
continue
route_position = self.route_position[answer_node_id]
@ -149,6 +96,9 @@ class AnswerStreamProcessor:
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
yield NodeRunStreamChunkEvent(
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=route_chunk.text,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
@ -171,6 +121,9 @@ class AnswerStreamProcessor:
if text:
yield NodeRunStreamChunkEvent(
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=value_selector,
route_node_state=event.route_node_state,

View File

@ -0,0 +1,71 @@
from abc import ABC, abstractmethod
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
@abstractmethod
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)

View File

@ -4,6 +4,7 @@ from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
@ -42,14 +43,14 @@ class BaseNode(ABC):
@abstractmethod
def _run(self) \
-> NodeRunResult | Generator[RunEvent, None, None]:
-> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:return:
"""
raise NotImplementedError
def run(self) -> Generator[RunEvent, None, None]:
def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node entry
:return:

View File

@ -9,18 +9,17 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
class EndStreamProcessor:
class EndStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
super().__init__(graph, variable_pool)
self.stream_param = graph.end_stream_param
self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy()
self.rest_node_ids = graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self,
@ -56,64 +55,10 @@ class EndStreamProcessor:
yield event
def reset(self) -> None:
self.end_streamed_variable_selectors = {}
self.end_streamed_variable_selectors: dict[str, list[str]] = {
end_node_id: [] for end_node_id in self.graph.end_stream_param.end_stream_variable_selector_mapping
}
self.end_streamed_variable_selectors = self.graph.end_stream_param.end_stream_variable_selector_mapping.copy()
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support

View File

@ -1,13 +1,16 @@
import logging
from collections.abc import Generator
from typing import Any, cast
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
BaseParallelBranchEvent,
GraphRunFailedEvent,
InNodeEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
@ -17,7 +20,7 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.utils.condition.entities import Condition
from models.workflow import WorkflowNodeExecutionStatus
@ -32,7 +35,7 @@ class IterationNode(BaseNode):
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
def _run(self) -> BaseIterationState:
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run the node.
"""
@ -42,6 +45,10 @@ class IterationNode(BaseNode):
if not isinstance(iterator_list_value, list):
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
inputs = {
"iterator_selector": iterator_list_value
}
root_node_id = self.node_data.start_node_id
graph_config = self.graph_config
@ -117,21 +124,42 @@ class IterationNode(BaseNode):
)
yield IterationRunStartedEvent(
iteration_id=self.node_id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
inputs=inputs,
metadata={
"iterator_length": 1
},
predecessor_node_id=self.previous_node_id
)
yield IterationRunNextEvent(
iteration_id=self.node_id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=0,
output=None
pre_iteration_output=None
)
outputs: list[Any] = []
try:
# run workflow
rst = graph_engine.run()
outputs: list[Any] = []
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if isinstance(event, NodeRunSucceededEvent):
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
event.route_node_state.node_run_result.metadata = metadata
yield event
# handle iteration run result
@ -158,22 +186,35 @@ class IterationNode(BaseNode):
)
yield IterationRunNextEvent(
iteration_id=self.node_id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None
pre_iteration_output=jsonable_encoder(
current_iteration_output) if current_iteration_output else None
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.node_id,
reason=event.reason,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.reason,
error=event.error,
)
)
break
@ -181,7 +222,17 @@ class IterationNode(BaseNode):
yield event
yield IterationRunSucceededEvent(
iteration_id=self.node_id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
}
)
yield RunCompletedEvent(
@ -196,8 +247,18 @@ class IterationNode(BaseNode):
# iteration run failed
logger.exception("Iteration run failed")
yield IterationRunFailedEvent(
iteration_id=self.node_id,
reason=str(e),
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
error=str(e),
)
yield RunCompletedEvent(

View File

@ -27,6 +27,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import (
@ -42,11 +43,19 @@ from models.provider import Provider, ProviderType
from models.workflow import WorkflowNodeExecutionStatus
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage
class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
def _run(self) -> Generator[RunEvent, None, None]:
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:return:
@ -167,7 +176,7 @@ class LLMNode(BaseNode):
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None) \
-> Generator[RunEvent, None, None]:
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Invoke large language model
:param node_data_model: node data model
@ -201,7 +210,7 @@ class LLMNode(BaseNode):
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
-> Generator[RunEvent, None, None]:
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Handle invoke result
:param invoke_result: invoke result
@ -762,11 +771,3 @@ class LLMNode(BaseNode):
}
}
}
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage

View File

@ -26,24 +26,21 @@ logger = logging.getLogger(__name__)
class WorkflowEntry:
def run(
def __init__(
self,
*,
workflow: Workflow,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
user_inputs: Mapping[str, Any],
system_inputs: Mapping[SystemVariable, Any],
callbacks: Sequence[WorkflowCallback],
call_depth: int = 0
) -> Generator[GraphEngineEvent, None, None]:
) -> None:
"""
:param workflow: Workflow instance
:param user_id: user id
:param user_from: user from
:param invoke_from: invoke from service-api, web-app, debugger, explore
:param callbacks: workflow callbacks
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:param call_depth: call depth
@ -82,7 +79,7 @@ class WorkflowEntry:
)
# init workflow run state
graph_engine = GraphEngine(
self.graph_engine = GraphEngine(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_type=WorkflowType.value_of(workflow.type),
@ -98,6 +95,16 @@ class WorkflowEntry:
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
def run(
self,
*,
callbacks: Sequence[WorkflowCallback],
) -> Generator[GraphEngineEvent, None, None]:
"""
:param callbacks: workflow callbacks
"""
graph_engine = self.graph_engine
try:
# run workflow
generator = graph_engine.run()
@ -105,7 +112,7 @@ class WorkflowEntry:
if callbacks:
for callback in callbacks:
callback.on_event(
graph=graph,
graph=self.graph_engine.graph,
graph_init_params=graph_engine.init_params,
graph_runtime_state=graph_engine.graph_runtime_state,
event=event
@ -117,11 +124,11 @@ class WorkflowEntry:
if callbacks:
for callback in callbacks:
callback.on_event(
graph=graph,
graph=self.graph_engine.graph,
graph_init_params=graph_engine.init_params,
graph_runtime_state=graph_engine.graph_runtime_state,
event=GraphRunFailedEvent(
reason=str(e)
error=str(e)
)
)
return
@ -161,6 +168,9 @@ class WorkflowEntry:
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
raise ValueError(f'Node class not found for node type {node_type}')
# init workflow run state
node_instance = node_cls(
tenant_id=workflow.tenant_id,
@ -195,7 +205,7 @@ class WorkflowEntry:
node_instance=node_instance
)
# run node
# run node TODO
node_run_result = node_instance.run(
variable_pool=variable_pool
)

View File

@ -1,7 +1,7 @@
from collections.abc import Generator
from datetime import datetime, timezone
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
@ -12,6 +12,7 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.start.entities import StartNodeData
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
@ -37,7 +38,14 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
parallel = graph.parallel_mapping.get(parallel_id)
parallel_start_node_id = parallel.start_from_node_id if parallel else None
node_config = graph.node_id_config_mapping[next_node_id]
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
yield NodeRunStartedEvent(
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
route_node_state=route_node_state,
parallel_id=graph.node_parallel_mapping.get(next_node_id),
parallel_start_node_id=parallel_start_node_id
@ -47,6 +55,9 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
length = int(next_node_id[-1])
for i in range(0, length):
yield NodeRunStreamChunkEvent(
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
chunk_content=str(i),
route_node_state=route_node_state,
from_variable_selector=[next_node_id, "text"],
@ -57,6 +68,9 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
route_node_state.status = RouteNodeState.Status.SUCCESS
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
yield NodeRunSucceededEvent(
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id