mirror of https://github.com/langgenius/dify.git
add chatflow app event convert
This commit is contained in:
parent
0818b7b078
commit
917aacbf7f
|
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
|
|
@ -342,7 +342,7 @@ class AppRunner:
|
|||
self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> tuple[bool, dict, str]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,101 @@
|
|||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
|
||||
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
|
||||
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
|
||||
from core.app.app_config.features.suggested_questions_after_answer.manager import (
|
||||
SuggestedQuestionsAfterAnswerConfigManager,
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
|
||||
from models.model import App, AppMode
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
|
||||
"""
|
||||
Advanced Chatbot App Config Entity.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
workflow: Workflow) -> AdvancedChatAppConfig:
|
||||
features_dict = workflow.features_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = AdvancedChatAppConfig(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
app_mode=app_mode,
|
||||
workflow_id=workflow.id,
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=features_dict
|
||||
),
|
||||
variables=WorkflowVariablesConfigManager.convert(
|
||||
workflow=workflow
|
||||
),
|
||||
additional_features=cls.convert_features(features_dict, app_mode)
|
||||
)
|
||||
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||
"""
|
||||
Validate for advanced chat app model config
|
||||
|
||||
:param tenant_id: tenant id
|
||||
:param config: app model config args
|
||||
:param only_structure_validate: if True, only structure validation will be performed
|
||||
"""
|
||||
related_config_keys = []
|
||||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
|
||||
config=config,
|
||||
is_vision=False
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# opening_statement
|
||||
config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# suggested_questions_after_answer
|
||||
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# speech_to_text
|
||||
config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# text_to_speech
|
||||
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# return retriever resource
|
||||
config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
only_structure_validate=only_structure_validate
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.apps.chatflow.app_runner import AdvancedChatAppRunner
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not args.get('query'):
|
||||
raise ValueError('query is required')
|
||||
|
||||
query = args['query']
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', False)
|
||||
}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get('conversation_id'):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
def _generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Conversation = None,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
is_first_conversation = False
|
||||
if not conversation:
|
||||
is_first_conversation = True
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
conversation.override_model_configs = workflow.features
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
)
|
||||
|
||||
try:
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner()
|
||||
response = runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
raise
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
raise e
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
else:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
logger.exception(e)
|
||||
raise e
|
||||
except InvokeError as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
logger.exception("Error when generating")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
|
@ -0,0 +1,422 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppRunner(AppRunner):
|
||||
"""
|
||||
AdvancedChat Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> Generator[AppQueueEvent, None, None]:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
queue_manager=queue_manager,
|
||||
app_record=app_record,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
):
|
||||
return
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity
|
||||
):
|
||||
return
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
workflow=workflow,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
},
|
||||
call_depth=application_generate_entity.call_depth
|
||||
)
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
queue_manager.publish(
|
||||
QueueWorkflowSucceededEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
queue_manager.publish(
|
||||
QueueWorkflowFailedEvent(error=event.error),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
queue_manager.publish(
|
||||
QueueNodeStartedEvent(
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
queue_manager.publish(
|
||||
QueueNodeSucceededEvent(
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result else {},
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
queue_manager.publish(
|
||||
QueueNodeFailedEvent(
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result
|
||||
and event.route_node_state.node_run_result.error
|
||||
else "Unknown error"
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk_content
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=event.retriever_resources
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
queue_manager.publish(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
queue_manager.publish(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
queue_manager.publish(
|
||||
QueueParallelBranchRunFailedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
error=event.error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
queue_manager.publish(
|
||||
QueueIterationStartEvent(
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
metadata=event.metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
queue_manager.publish(
|
||||
QueueIterationNextEvent(
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
queue_manager.publish(
|
||||
QueueIterationCompletedEvent(
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == workflow_id
|
||||
).first()
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def handle_input_moderation(
|
||||
self, queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
:param queue_manager: application queue manager
|
||||
:param app_record: app record
|
||||
:param app_generate_entity: application generate entity
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
app_id=app_record.id,
|
||||
tenant_id=app_generate_entity.app_config.tenant_id,
|
||||
app_generate_entity=app_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
text=str(e),
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def handle_annotation_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
queue_manager: AppQueueManager,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
|
||||
"""
|
||||
Handle annotation reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param queue_manager: application queue manager
|
||||
:param app_generate_entity: application generate entity
|
||||
"""
|
||||
# annotation reply
|
||||
annotation_reply = self.query_app_annotations_to_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=app_generate_entity.user_id,
|
||||
invoke_from=app_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
text=annotation_reply.content,
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _stream_output(self, queue_manager: AppQueueManager,
|
||||
text: str,
|
||||
stream: bool,
|
||||
stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
:param text: text
|
||||
:param stream: stream
|
||||
:return:
|
||||
"""
|
||||
if stream:
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=token
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=stopped_by),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
|
@ -0,0 +1,450 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||
from core.workflow.graph_engine.entities.event import GraphRunStartedEvent, NodeRunStartedEvent
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import Conversation, EndUser, Message
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
_task_state: AdvancedChatTaskState
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(
|
||||
self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param workflow: workflow
|
||||
:param queue_manager: queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:param user: user
|
||||
:param stream: stream
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
|
||||
if isinstance(self._user, EndUser):
|
||||
user_id = self._user.session_id
|
||||
else:
|
||||
user_id = self._user.id
|
||||
|
||||
self._workflow = workflow
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.QUERY: message.query,
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = AdvancedChatTaskState(
|
||||
usage=LLMUsage.empty_usage()
|
||||
)
|
||||
|
||||
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
|
||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
db.session.refresh(self._workflow)
|
||||
db.session.refresh(self._user)
|
||||
db.session.close()
|
||||
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation,
|
||||
self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
"""
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {}
|
||||
if stream_response.metadata:
|
||||
extras['metadata'] = stream_response.metadata
|
||||
|
||||
return ChatbotAppBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id=self._message.id,
|
||||
mode=self._conversation.mode,
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.answer,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras
|
||||
)
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
"""
|
||||
for stream_response in generator:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
break
|
||||
yield response
|
||||
|
||||
start_listener_time = time.time()
|
||||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||
continue
|
||||
if audio_trunk.status == "finish":
|
||||
break
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event, self._message)
|
||||
yield self._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, GraphRunStartedEvent):
|
||||
workflow_run = self._handle_workflow_start()
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._message.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
workflow_node_execution = self._handle_node_start(event)
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_node_finished(event)
|
||||
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeFailedEvent):
|
||||
yield from self._handle_iteration_exception(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
error=f'Child node failed: {event.error}'
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
|
||||
if isinstance(event, QueueIterationNextEvent):
|
||||
# clear ran node execution infos of current iteration
|
||||
iteration_relations = self._iteration_nested_relations.get(event.node_id)
|
||||
if iteration_relations:
|
||||
for node_id in iteration_relations:
|
||||
self._task_state.ran_node_execution_infos.pop(node_id, None)
|
||||
|
||||
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||
self._handle_iteration_operation(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
workflow_run = self._handle_workflow_finished(
|
||||
event, conversation_id=self._conversation.id, trace_manager=trace_manager
|
||||
)
|
||||
if workflow_run:
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.FAILED.value:
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
||||
break
|
||||
|
||||
if isinstance(event, QueueStopEvent):
|
||||
# Save message
|
||||
self._save_message()
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
else:
|
||||
self._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
|
||||
# Save message
|
||||
self._save_message()
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._handle_annotation_reply(event)
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(delta_text, self._message.id)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
"""
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
|
||||
self._message.answer = self._task_state.answer
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
if self._task_state.metadata and self._task_state.metadata.get('usage'):
|
||||
usage = LLMUsage(**self._task_state.metadata['usage'])
|
||||
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
self._message.answer_tokens = usage.completion_tokens
|
||||
self._message.answer_unit_price = usage.completion_unit_price
|
||||
self._message.answer_price_unit = usage.completion_price_unit
|
||||
self._message.total_price = usage.total_price
|
||||
self._message.currency = usage.currency
|
||||
|
||||
db.session.commit()
|
||||
|
||||
message_was_created.send(
|
||||
self._message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras
|
||||
)
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
"""
|
||||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
**extras
|
||||
)
|
||||
|
||||
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get iteration nested relations.
|
||||
:param graph: graph
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
]]
|
||||
|
||||
return {
|
||||
iteration_id: [
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
"""
|
||||
Handle output moderation chunk.
|
||||
:param text: text
|
||||
:return: True if output moderation should direct output, otherwise False
|
||||
"""
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=self._task_state.answer
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
|
|
@ -46,7 +46,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.print_text("\n[on_workflow_run_succeeded]", color='green')
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.print_text(f"\n[on_workflow_run_failed] reason: {event.reason}", color='red')
|
||||
self.print_text(f"\n[on_workflow_run_failed] reason: {event.error}", color='red')
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.on_workflow_node_execute_started(
|
||||
graph=graph,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
|
|
@ -5,7 +6,7 @@ from pydantic import BaseModel, field_validator
|
|||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||
|
||||
|
||||
class QueueEvent(str, Enum):
|
||||
|
|
@ -31,6 +32,9 @@ class QueueEvent(str, Enum):
|
|||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
MESSAGE_FILE = "message_file"
|
||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
|
|
@ -38,7 +42,7 @@ class QueueEvent(str, Enum):
|
|||
|
||||
class AppQueueEvent(BaseModel):
|
||||
"""
|
||||
QueueEvent entity
|
||||
QueueEvent abstract entity
|
||||
"""
|
||||
event: QueueEvent
|
||||
|
||||
|
|
@ -46,6 +50,7 @@ class AppQueueEvent(BaseModel):
|
|||
class QueueLLMChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueLLMChunkEvent entity
|
||||
Only for basic mode apps
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.LLM_CHUNK
|
||||
chunk: LLMResultChunk
|
||||
|
|
@ -58,11 +63,15 @@ class QueueIterationStartEvent(AppQueueEvent):
|
|||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
|
||||
node_run_index: int
|
||||
inputs: dict = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
class QueueIterationNextEvent(AppQueueEvent):
|
||||
"""
|
||||
|
|
@ -73,6 +82,10 @@ class QueueIterationNextEvent(AppQueueEvent):
|
|||
index: int
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
|
|
@ -93,13 +106,23 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueIterationCompletedEvent entity
|
||||
"""
|
||||
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
|
||||
event: QueueEvent = QueueEvent.ITERATION_COMPLETED
|
||||
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
|
||||
node_run_index: int
|
||||
outputs: dict
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
|
|
@ -190,6 +213,10 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
|||
node_data: BaseNodeData
|
||||
node_run_index: int = 1
|
||||
predecessor_node_id: Optional[str] = None
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
|
|
@ -201,11 +228,15 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
|
||||
inputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
execution_metadata: Optional[dict] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
|
||||
|
|
@ -219,10 +250,14 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
|
||||
inputs: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
@ -277,7 +312,7 @@ class QueueStopEvent(AppQueueEvent):
|
|||
|
||||
class QueueMessage(BaseModel):
|
||||
"""
|
||||
QueueMessage entity
|
||||
QueueMessage abstract entity
|
||||
"""
|
||||
task_id: str
|
||||
app_mode: str
|
||||
|
|
@ -297,3 +332,34 @@ class WorkflowQueueMessage(QueueMessage):
|
|||
WorkflowQueueMessage entity
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunStartedEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
|
||||
|
||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunSucceededEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
|
||||
|
||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunFailedEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
error: str
|
||||
|
|
|
|||
|
|
@ -84,9 +84,9 @@ class NodeRunResult(BaseModel):
|
|||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict] = None # process data
|
||||
process_data: Optional[Mapping[str, Any]] = None # process data
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,17 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
###########################################
|
||||
# Graph Events
|
||||
###########################################
|
||||
|
|
@ -26,7 +30,7 @@ class GraphRunSucceededEvent(BaseGraphEvent):
|
|||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
reason: str = Field(..., description="failed reason")
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
|
|
@ -35,16 +39,20 @@ class GraphRunFailedEvent(BaseGraphEvent):
|
|||
|
||||
|
||||
class BaseNodeEvent(GraphEngineEvent):
|
||||
node_id: str = Field(..., description="node id")
|
||||
node_type: NodeType = Field(..., description="node type")
|
||||
node_data: BaseNodeData = Field(..., description="node data")
|
||||
route_node_state: RouteNodeState = Field(..., description="route node state")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
# iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
|
||||
in_iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
pass
|
||||
predecessor_node_id: Optional[str] = None
|
||||
"""predecessor node id"""
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
|
|
@ -63,7 +71,7 @@ class NodeRunSucceededEvent(BaseNodeEvent):
|
|||
|
||||
|
||||
class NodeRunFailedEvent(BaseNodeEvent):
|
||||
pass
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
###########################################
|
||||
|
|
@ -74,6 +82,7 @@ class NodeRunFailedEvent(BaseNodeEvent):
|
|||
class BaseParallelBranchEvent(GraphEngineEvent):
|
||||
parallel_id: str = Field(..., description="parallel id")
|
||||
parallel_start_node_id: str = Field(..., description="parallel start node id")
|
||||
in_iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
|
||||
|
||||
|
||||
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
|
||||
|
|
@ -85,7 +94,7 @@ class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
|
|||
|
||||
|
||||
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
|
||||
reason: str = Field(..., description="failed reason")
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
|
|
@ -94,11 +103,19 @@ class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
|
|||
|
||||
|
||||
class BaseIterationEvent(GraphEngineEvent):
|
||||
iteration_id: str = Field(..., description="iteration id")
|
||||
iteration_node_id: str = Field(..., description="iteration node id")
|
||||
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
|
||||
iteration_node_data: BaseNodeData = Field(..., description="node data")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
pass
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class IterationRunNextEvent(BaseIterationEvent):
|
||||
|
|
@ -107,11 +124,18 @@ class IterationRunNextEvent(BaseIterationEvent):
|
|||
|
||||
|
||||
class IterationRunSucceededEvent(BaseIterationEvent):
|
||||
pass
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class IterationRunFailedEvent(BaseIterationEvent):
|
||||
reason: str = Field(..., description="failed reason")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, U
|
|||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseIterationEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
|
|
@ -32,6 +33,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
|||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
|
|
@ -84,16 +86,16 @@ class GraphEngine:
|
|||
yield GraphRunStartedEvent()
|
||||
|
||||
try:
|
||||
stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor]
|
||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||
stream_processor = AnswerStreamProcessor(
|
||||
graph=self.graph,
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
stream_processor_cls = AnswerStreamProcessor
|
||||
else:
|
||||
stream_processor = EndStreamProcessor(
|
||||
graph=self.graph,
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
stream_processor_cls = EndStreamProcessor
|
||||
|
||||
stream_processor = stream_processor_cls(
|
||||
graph=self.graph,
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
# run graph
|
||||
generator = stream_processor.process(
|
||||
|
|
@ -104,21 +106,21 @@ class GraphEngine:
|
|||
try:
|
||||
yield item
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
yield GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.')
|
||||
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.')
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Graph run failed: {str(e)}")
|
||||
yield GraphRunFailedEvent(reason=str(e))
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
return
|
||||
|
||||
# trigger graph run success event
|
||||
yield GraphRunSucceededEvent()
|
||||
except GraphRunFailedError as e:
|
||||
yield GraphRunFailedEvent(reason=e.error)
|
||||
yield GraphRunFailedEvent(error=e.error)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when graph running")
|
||||
yield GraphRunFailedEvent(reason=str(e))
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
raise e
|
||||
|
||||
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
||||
|
|
@ -145,11 +147,34 @@ class GraphEngine:
|
|||
node_id=next_node_id
|
||||
)
|
||||
|
||||
# get node config
|
||||
node_id = route_node_state.node_id
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError(f'Node {node_id} config not found.')
|
||||
|
||||
# convert to specific node
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
||||
|
||||
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls( # type: ignore
|
||||
config=node_config,
|
||||
graph_init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_node_id=previous_node_id
|
||||
)
|
||||
|
||||
try:
|
||||
# run node
|
||||
yield from self._run_node(
|
||||
node_instance=node_instance,
|
||||
route_node_state=route_node_state,
|
||||
previous_node_id=previous_route_node_state.node_id if previous_route_node_state else None,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
|
@ -166,6 +191,10 @@ class GraphEngine:
|
|||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = str(e)
|
||||
yield NodeRunFailedEvent(
|
||||
error=str(e),
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
|
|
@ -241,7 +270,7 @@ class GraphEngine:
|
|||
# new thread
|
||||
for edge in edge_mappings:
|
||||
threading.Thread(target=self._run_parallel_node, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
'parallel_id': parallel_id,
|
||||
'parallel_start_node_id': edge.target_node_id,
|
||||
'q': q
|
||||
|
|
@ -309,21 +338,21 @@ class GraphEngine:
|
|||
q.put(ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
reason=e.error
|
||||
error=e.error
|
||||
))
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating in parallel")
|
||||
q.put(ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
reason=str(e)
|
||||
error=str(e)
|
||||
))
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _run_node(self,
|
||||
node_instance: BaseNode,
|
||||
route_node_state: RouteNodeState,
|
||||
previous_node_id: Optional[str] = None,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
|
|
@ -331,46 +360,15 @@ class GraphEngine:
|
|||
"""
|
||||
# trigger node run start event
|
||||
yield NodeRunStartedEvent(
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
# get node config
|
||||
node_id = route_node_state.node_id
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = f'Node {node_id} config not found.'
|
||||
yield NodeRunFailedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
return
|
||||
|
||||
# convert to specific node
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.'
|
||||
yield NodeRunFailedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
return
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls( # type: ignore
|
||||
config=node_config,
|
||||
graph_init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_node_id=previous_node_id
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
self.graph_runtime_state.node_run_steps += 1
|
||||
|
|
@ -385,6 +383,10 @@ class GraphEngine:
|
|||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
|
|
@ -401,12 +403,15 @@ class GraphEngine:
|
|||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_id,
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value
|
||||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
|
|
@ -415,6 +420,9 @@ class GraphEngine:
|
|||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
|
|
@ -423,17 +431,28 @@ class GraphEngine:
|
|||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
)
|
||||
elif isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
except GenerateTaskStoppedException:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
|
|
|
|||
|
|
@ -11,21 +11,20 @@ from core.workflow.graph_engine.entities.event import (
|
|||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnswerStreamProcessor:
|
||||
class AnswerStreamProcessor(StreamProcessor):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
self.graph = graph
|
||||
self.variable_pool = variable_pool
|
||||
super().__init__(graph, variable_pool)
|
||||
self.generate_routes = graph.answer_stream_generate_routes
|
||||
self.route_position = {}
|
||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.rest_node_ids = graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
|
||||
def process(self,
|
||||
|
|
@ -74,58 +73,6 @@ class AnswerStreamProcessor:
|
|||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
|
||||
if finished_node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
# remove finished node id
|
||||
self.rest_node_ids.remove(finished_node_id)
|
||||
|
||||
run_result = event.route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return
|
||||
|
||||
if run_result.edge_source_handle:
|
||||
reachable_node_ids = []
|
||||
unreachable_first_node_ids = []
|
||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||
if (edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify):
|
||||
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
continue
|
||||
else:
|
||||
unreachable_first_node_ids.append(edge.target_node_id)
|
||||
|
||||
for node_id in unreachable_first_node_ids:
|
||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
||||
node_ids = []
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id == self.graph.root_node_id:
|
||||
continue
|
||||
|
||||
node_ids.append(edge.target_node_id)
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
return node_ids
|
||||
|
||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
||||
"""
|
||||
remove target node ids until merge
|
||||
"""
|
||||
if node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
self.rest_node_ids.remove(node_id)
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id in reachable_node_ids:
|
||||
continue
|
||||
|
||||
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
|
|
@ -138,8 +85,8 @@ class AnswerStreamProcessor:
|
|||
# all depends on answer node id not in rest node ids
|
||||
if (event.route_node_state.node_id != answer_node_id
|
||||
and (answer_node_id not in self.rest_node_ids
|
||||
or not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
|
||||
or not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[answer_node_id]
|
||||
|
|
@ -149,6 +96,9 @@ class AnswerStreamProcessor:
|
|||
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
yield NodeRunStreamChunkEvent(
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=route_chunk.text,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
|
|
@ -171,6 +121,9 @@ class AnswerStreamProcessor:
|
|||
|
||||
if text:
|
||||
yield NodeRunStreamChunkEvent(
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=text,
|
||||
from_variable_selector=value_selector,
|
||||
route_node_state=event.route_node_state,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
|
||||
|
||||
class StreamProcessor(ABC):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
self.graph = graph
|
||||
self.variable_pool = variable_pool
|
||||
self.rest_node_ids = graph.node_ids.copy()
|
||||
|
||||
@abstractmethod
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
if finished_node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
# remove finished node id
|
||||
self.rest_node_ids.remove(finished_node_id)
|
||||
|
||||
run_result = event.route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return
|
||||
|
||||
if run_result.edge_source_handle:
|
||||
reachable_node_ids = []
|
||||
unreachable_first_node_ids = []
|
||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||
if (edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify):
|
||||
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
continue
|
||||
else:
|
||||
unreachable_first_node_ids.append(edge.target_node_id)
|
||||
|
||||
for node_id in unreachable_first_node_ids:
|
||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
||||
node_ids = []
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id == self.graph.root_node_id:
|
||||
continue
|
||||
|
||||
node_ids.append(edge.target_node_id)
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
return node_ids
|
||||
|
||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
||||
"""
|
||||
remove target node ids until merge
|
||||
"""
|
||||
if node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
self.rest_node_ids.remove(node_id)
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id in reachable_node_ids:
|
||||
continue
|
||||
|
||||
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
||||
|
|
@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
|
@ -42,14 +43,14 @@ class BaseNode(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def _run(self) \
|
||||
-> NodeRunResult | Generator[RunEvent, None, None]:
|
||||
-> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self) -> Generator[RunEvent, None, None]:
|
||||
def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run node entry
|
||||
:return:
|
||||
|
|
|
|||
|
|
@ -9,18 +9,17 @@ from core.workflow.graph_engine.entities.event import (
|
|||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndStreamProcessor:
|
||||
class EndStreamProcessor(StreamProcessor):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
self.graph = graph
|
||||
self.variable_pool = variable_pool
|
||||
super().__init__(graph, variable_pool)
|
||||
self.stream_param = graph.end_stream_param
|
||||
self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy()
|
||||
self.rest_node_ids = graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
|
||||
def process(self,
|
||||
|
|
@ -56,64 +55,10 @@ class EndStreamProcessor:
|
|||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
self.end_streamed_variable_selectors = {}
|
||||
self.end_streamed_variable_selectors: dict[str, list[str]] = {
|
||||
end_node_id: [] for end_node_id in self.graph.end_stream_param.end_stream_variable_selector_mapping
|
||||
}
|
||||
self.end_streamed_variable_selectors = self.graph.end_stream_param.end_stream_variable_selector_mapping.copy()
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
if finished_node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
# remove finished node id
|
||||
self.rest_node_ids.remove(finished_node_id)
|
||||
|
||||
run_result = event.route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return
|
||||
|
||||
if run_result.edge_source_handle:
|
||||
reachable_node_ids = []
|
||||
unreachable_first_node_ids = []
|
||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||
if (edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify):
|
||||
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
continue
|
||||
else:
|
||||
unreachable_first_node_ids.append(edge.target_node_id)
|
||||
|
||||
for node_id in unreachable_first_node_ids:
|
||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
||||
node_ids = []
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id == self.graph.root_node_id:
|
||||
continue
|
||||
|
||||
node_ids.append(edge.target_node_id)
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
return node_ids
|
||||
|
||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
||||
"""
|
||||
remove target node ids until merge
|
||||
"""
|
||||
if node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
self.rest_node_ids.remove(node_id)
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id in reachable_node_ids:
|
||||
continue
|
||||
|
||||
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
||||
|
||||
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||
"""
|
||||
Is stream out support
|
||||
|
|
|
|||
|
|
@ -1,13 +1,16 @@
|
|||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
BaseParallelBranchEvent,
|
||||
GraphRunFailedEvent,
|
||||
InNodeEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
|
|
@ -17,7 +20,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
|
@ -32,7 +35,7 @@ class IterationNode(BaseNode):
|
|||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
def _run(self) -> BaseIterationState:
|
||||
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
|
|
@ -42,6 +45,10 @@ class IterationNode(BaseNode):
|
|||
if not isinstance(iterator_list_value, list):
|
||||
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
|
||||
inputs = {
|
||||
"iterator_selector": iterator_list_value
|
||||
}
|
||||
|
||||
root_node_id = self.node_data.start_node_id
|
||||
graph_config = self.graph_config
|
||||
|
||||
|
|
@ -117,21 +124,42 @@ class IterationNode(BaseNode):
|
|||
)
|
||||
|
||||
yield IterationRunStartedEvent(
|
||||
iteration_id=self.node_id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
inputs=inputs,
|
||||
metadata={
|
||||
"iterator_length": 1
|
||||
},
|
||||
predecessor_node_id=self.previous_node_id
|
||||
)
|
||||
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.node_id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=0,
|
||||
output=None
|
||||
pre_iteration_output=None
|
||||
)
|
||||
|
||||
outputs: list[Any] = []
|
||||
try:
|
||||
# run workflow
|
||||
rst = graph_engine.run()
|
||||
outputs: list[Any] = []
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||
event.in_iteration_id = self.node_id
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
|
||||
yield event
|
||||
|
||||
# handle iteration run result
|
||||
|
|
@ -158,22 +186,35 @@ class IterationNode(BaseNode):
|
|||
)
|
||||
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.node_id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None
|
||||
pre_iteration_output=jsonable_encoder(
|
||||
current_iteration_output) if current_iteration_output else None
|
||||
)
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# iteration run failed
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.node_id,
|
||||
reason=event.reason,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.reason,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
|
@ -181,7 +222,17 @@ class IterationNode(BaseNode):
|
|||
yield event
|
||||
|
||||
yield IterationRunSucceededEvent(
|
||||
iteration_id=self.node_id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
}
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
|
|
@ -196,8 +247,18 @@ class IterationNode(BaseNode):
|
|||
# iteration run failed
|
||||
logger.exception("Iteration run failed")
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.node_id,
|
||||
reason=str(e),
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
|||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
|
|
@ -42,11 +43,19 @@ from models.provider import Provider, ProviderType
|
|||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class ModelInvokeCompleted(BaseModel):
|
||||
"""
|
||||
Model invoke completed
|
||||
"""
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
_node_data_cls = LLMNodeData
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
def _run(self) -> Generator[RunEvent, None, None]:
|
||||
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
:return:
|
||||
|
|
@ -167,7 +176,7 @@ class LLMNode(BaseNode):
|
|||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None) \
|
||||
-> Generator[RunEvent, None, None]:
|
||||
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data_model: node data model
|
||||
|
|
@ -201,7 +210,7 @@ class LLMNode(BaseNode):
|
|||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
|
||||
-> Generator[RunEvent, None, None]:
|
||||
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
|
|
@ -762,11 +771,3 @@ class LLMNode(BaseNode):
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ModelInvokeCompleted(BaseModel):
|
||||
"""
|
||||
Model invoke completed
|
||||
"""
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
|
|
|
|||
|
|
@ -26,24 +26,21 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class WorkflowEntry:
|
||||
def run(
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workflow: Workflow,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
user_inputs: Mapping[str, Any],
|
||||
system_inputs: Mapping[SystemVariable, Any],
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
call_depth: int = 0
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
) -> None:
|
||||
"""
|
||||
:param workflow: Workflow instance
|
||||
:param user_id: user id
|
||||
:param user_from: user from
|
||||
:param invoke_from: invoke from service-api, web-app, debugger, explore
|
||||
:param callbacks: workflow callbacks
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:param call_depth: call depth
|
||||
|
|
@ -82,7 +79,7 @@ class WorkflowEntry:
|
|||
)
|
||||
|
||||
# init workflow run state
|
||||
graph_engine = GraphEngine(
|
||||
self.graph_engine = GraphEngine(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
|
|
@ -98,6 +95,16 @@ class WorkflowEntry:
|
|||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
:param callbacks: workflow callbacks
|
||||
"""
|
||||
graph_engine = self.graph_engine
|
||||
|
||||
try:
|
||||
# run workflow
|
||||
generator = graph_engine.run()
|
||||
|
|
@ -105,7 +112,7 @@ class WorkflowEntry:
|
|||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_event(
|
||||
graph=graph,
|
||||
graph=self.graph_engine.graph,
|
||||
graph_init_params=graph_engine.init_params,
|
||||
graph_runtime_state=graph_engine.graph_runtime_state,
|
||||
event=event
|
||||
|
|
@ -117,11 +124,11 @@ class WorkflowEntry:
|
|||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_event(
|
||||
graph=graph,
|
||||
graph=self.graph_engine.graph,
|
||||
graph_init_params=graph_engine.init_params,
|
||||
graph_runtime_state=graph_engine.graph_runtime_state,
|
||||
event=GraphRunFailedEvent(
|
||||
reason=str(e)
|
||||
error=str(e)
|
||||
)
|
||||
)
|
||||
return
|
||||
|
|
@ -161,6 +168,9 @@ class WorkflowEntry:
|
|||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
|
||||
if not node_cls:
|
||||
raise ValueError(f'Node class not found for node type {node_type}')
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls(
|
||||
tenant_id=workflow.tenant_id,
|
||||
|
|
@ -195,7 +205,7 @@ class WorkflowEntry:
|
|||
node_instance=node_instance
|
||||
)
|
||||
|
||||
# run node
|
||||
# run node TODO
|
||||
node_run_result = node_instance.run(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
|
|
@ -12,6 +12,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
|
||||
|
||||
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
||||
|
|
@ -37,7 +38,14 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
|||
parallel = graph.parallel_mapping.get(parallel_id)
|
||||
parallel_start_node_id = parallel.start_from_node_id if parallel else None
|
||||
|
||||
node_config = graph.node_id_config_mapping[next_node_id]
|
||||
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
|
||||
mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
|
||||
|
||||
yield NodeRunStartedEvent(
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=mock_node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=graph.node_parallel_mapping.get(next_node_id),
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
|
|
@ -47,6 +55,9 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
|||
length = int(next_node_id[-1])
|
||||
for i in range(0, length):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=mock_node_data,
|
||||
chunk_content=str(i),
|
||||
route_node_state=route_node_state,
|
||||
from_variable_selector=[next_node_id, "text"],
|
||||
|
|
@ -57,6 +68,9 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
|||
route_node_state.status = RouteNodeState.Status.SUCCESS
|
||||
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
yield NodeRunSucceededEvent(
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=mock_node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
|
|
|
|||
Loading…
Reference in New Issue