mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 04:26:30 +08:00
add chatflow app event convert
This commit is contained in:
parent
0818b7b078
commit
917aacbf7f
@ -1,6 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
@ -342,7 +342,7 @@ class AppRunner:
|
|||||||
self, app_id: str,
|
self, app_id: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
app_generate_entity: AppGenerateEntity,
|
app_generate_entity: AppGenerateEntity,
|
||||||
inputs: dict,
|
inputs: Mapping[str, Any],
|
||||||
query: str,
|
query: str,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
) -> tuple[bool, dict, str]:
|
) -> tuple[bool, dict, str]:
|
||||||
|
|||||||
0
api/core/app/apps/chatflow/__init__.py
Normal file
0
api/core/app/apps/chatflow/__init__.py
Normal file
101
api/core/app/apps/chatflow/app_config_manager.py
Normal file
101
api/core/app/apps/chatflow/app_config_manager.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
|
||||||
|
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||||
|
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||||
|
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||||
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
|
from core.app.app_config.features.opening_statement.manager import OpeningStatementConfigManager
|
||||||
|
from core.app.app_config.features.retrieval_resource.manager import RetrievalResourceConfigManager
|
||||||
|
from core.app.app_config.features.speech_to_text.manager import SpeechToTextConfigManager
|
||||||
|
from core.app.app_config.features.suggested_questions_after_answer.manager import (
|
||||||
|
SuggestedQuestionsAfterAnswerConfigManager,
|
||||||
|
)
|
||||||
|
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||||
|
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
|
||||||
|
from models.model import App, AppMode
|
||||||
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
|
||||||
|
"""
|
||||||
|
Advanced Chatbot App Config Entity.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||||
|
@classmethod
|
||||||
|
def get_app_config(cls, app_model: App,
|
||||||
|
workflow: Workflow) -> AdvancedChatAppConfig:
|
||||||
|
features_dict = workflow.features_dict
|
||||||
|
|
||||||
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
app_config = AdvancedChatAppConfig(
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
app_id=app_model.id,
|
||||||
|
app_mode=app_mode,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||||
|
config=features_dict
|
||||||
|
),
|
||||||
|
variables=WorkflowVariablesConfigManager.convert(
|
||||||
|
workflow=workflow
|
||||||
|
),
|
||||||
|
additional_features=cls.convert_features(features_dict, app_mode)
|
||||||
|
)
|
||||||
|
|
||||||
|
return app_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
Validate for advanced chat app model config
|
||||||
|
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param config: app model config args
|
||||||
|
:param only_structure_validate: if True, only structure validation will be performed
|
||||||
|
"""
|
||||||
|
related_config_keys = []
|
||||||
|
|
||||||
|
# file upload validation
|
||||||
|
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
|
||||||
|
config=config,
|
||||||
|
is_vision=False
|
||||||
|
)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# opening_statement
|
||||||
|
config, current_related_config_keys = OpeningStatementConfigManager.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# suggested_questions_after_answer
|
||||||
|
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||||
|
config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# speech_to_text
|
||||||
|
config, current_related_config_keys = SpeechToTextConfigManager.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# text_to_speech
|
||||||
|
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# return retriever resource
|
||||||
|
config, current_related_config_keys = RetrievalResourceConfigManager.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# moderation validation
|
||||||
|
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
config=config,
|
||||||
|
only_structure_validate=only_structure_validate
|
||||||
|
)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
related_config_keys = list(set(related_config_keys))
|
||||||
|
|
||||||
|
# Filter out extra parameters
|
||||||
|
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||||
|
|
||||||
|
return filtered_config
|
||||||
|
|
||||||
189
api/core/app/apps/chatflow/app_generator.py
Normal file
189
api/core/app/apps/chatflow/app_generator.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
import contexts
|
||||||
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||||
|
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||||
|
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||||
|
from core.app.apps.chatflow.app_runner import AdvancedChatAppRunner
|
||||||
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||||
|
from core.file.message_file_parser import MessageFileParser
|
||||||
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.account import Account
|
||||||
|
from models.model import App, Conversation, EndUser
|
||||||
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
|
def generate(
|
||||||
|
self, app_model: App,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: dict,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
stream: bool = True,
|
||||||
|
) -> Union[dict, Generator[dict, None, None]]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param app_model: App
|
||||||
|
:param workflow: Workflow
|
||||||
|
:param user: account or end user
|
||||||
|
:param args: request args
|
||||||
|
:param invoke_from: invoke from source
|
||||||
|
:param stream: is stream
|
||||||
|
"""
|
||||||
|
if not args.get('query'):
|
||||||
|
raise ValueError('query is required')
|
||||||
|
|
||||||
|
query = args['query']
|
||||||
|
if not isinstance(query, str):
|
||||||
|
raise ValueError('query must be a string')
|
||||||
|
|
||||||
|
query = query.replace('\x00', '')
|
||||||
|
inputs = args['inputs']
|
||||||
|
|
||||||
|
extras = {
|
||||||
|
"auto_generate_conversation_name": args.get('auto_generate_name', False)
|
||||||
|
}
|
||||||
|
|
||||||
|
# get conversation
|
||||||
|
conversation = None
|
||||||
|
if args.get('conversation_id'):
|
||||||
|
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
||||||
|
|
||||||
|
# parse files
|
||||||
|
files = args['files'] if args.get('files') else []
|
||||||
|
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||||
|
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||||
|
if file_extra_config:
|
||||||
|
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||||
|
files,
|
||||||
|
file_extra_config,
|
||||||
|
user
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
file_objs = []
|
||||||
|
|
||||||
|
# convert to app config
|
||||||
|
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||||
|
app_model=app_model,
|
||||||
|
workflow=workflow
|
||||||
|
)
|
||||||
|
|
||||||
|
# get tracing instance
|
||||||
|
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||||
|
|
||||||
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
|
# always enable retriever resource in debugger mode
|
||||||
|
app_config.additional_features.show_retrieve_source = True
|
||||||
|
|
||||||
|
# init application generate entity
|
||||||
|
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
app_config=app_config,
|
||||||
|
conversation_id=conversation.id if conversation else None,
|
||||||
|
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||||
|
query=query,
|
||||||
|
files=file_objs,
|
||||||
|
user_id=user.id,
|
||||||
|
stream=stream,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
extras=extras,
|
||||||
|
trace_manager=trace_manager
|
||||||
|
)
|
||||||
|
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||||
|
|
||||||
|
return self._generate(
|
||||||
|
app_model=app_model,
|
||||||
|
workflow=workflow,
|
||||||
|
user=user,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
conversation=conversation,
|
||||||
|
stream=stream
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate(self, app_model: App,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
|
conversation: Conversation = None,
|
||||||
|
stream: bool = True) \
|
||||||
|
-> Union[dict, Generator[dict, None, None]]:
|
||||||
|
is_first_conversation = False
|
||||||
|
if not conversation:
|
||||||
|
is_first_conversation = True
|
||||||
|
|
||||||
|
# init generate records
|
||||||
|
(
|
||||||
|
conversation,
|
||||||
|
message
|
||||||
|
) = self._init_generate_records(application_generate_entity, conversation)
|
||||||
|
|
||||||
|
if is_first_conversation:
|
||||||
|
# update conversation features
|
||||||
|
conversation.override_model_configs = workflow.features
|
||||||
|
db.session.commit()
|
||||||
|
db.session.refresh(conversation)
|
||||||
|
|
||||||
|
# init queue manager
|
||||||
|
queue_manager = MessageBasedAppQueueManager(
|
||||||
|
task_id=application_generate_entity.task_id,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
|
conversation_id=conversation.id,
|
||||||
|
app_mode=conversation.mode,
|
||||||
|
message_id=message.id
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# chatbot app
|
||||||
|
runner = AdvancedChatAppRunner()
|
||||||
|
response = runner.run(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
conversation=conversation,
|
||||||
|
message=message
|
||||||
|
)
|
||||||
|
except GenerateTaskStoppedException:
|
||||||
|
pass
|
||||||
|
except InvokeAuthorizationError:
|
||||||
|
raise
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.exception("Validation Error when generating")
|
||||||
|
raise e
|
||||||
|
except ValueError as e:
|
||||||
|
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||||
|
raise GenerateTaskStoppedException()
|
||||||
|
else:
|
||||||
|
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||||
|
logger.exception(e)
|
||||||
|
raise e
|
||||||
|
except InvokeError as e:
|
||||||
|
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||||
|
logger.exception("Error when generating")
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Unknown Error when generating")
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
|
return AdvancedChatAppGenerateResponseConverter.convert(
|
||||||
|
response=response,
|
||||||
|
invoke_from=invoke_from
|
||||||
|
)
|
||||||
422
api/core/app/apps/chatflow/app_runner.py
Normal file
422
api/core/app/apps/chatflow/app_runner.py
Normal file
@ -0,0 +1,422 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections.abc import Generator, Mapping
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
|
from core.app.apps.base_app_runner import AppRunner
|
||||||
|
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||||
|
from core.app.entities.app_invoke_entities import (
|
||||||
|
AdvancedChatAppGenerateEntity,
|
||||||
|
InvokeFrom,
|
||||||
|
)
|
||||||
|
from core.app.entities.queue_entities import (
|
||||||
|
AppQueueEvent,
|
||||||
|
QueueAnnotationReplyEvent,
|
||||||
|
QueueIterationCompletedEvent,
|
||||||
|
QueueIterationNextEvent,
|
||||||
|
QueueIterationStartEvent,
|
||||||
|
QueueNodeFailedEvent,
|
||||||
|
QueueNodeStartedEvent,
|
||||||
|
QueueNodeSucceededEvent,
|
||||||
|
QueueParallelBranchRunFailedEvent,
|
||||||
|
QueueParallelBranchRunStartedEvent,
|
||||||
|
QueueRetrieverResourcesEvent,
|
||||||
|
QueueStopEvent,
|
||||||
|
QueueTextChunkEvent,
|
||||||
|
QueueWorkflowFailedEvent,
|
||||||
|
QueueWorkflowStartedEvent,
|
||||||
|
QueueWorkflowSucceededEvent,
|
||||||
|
)
|
||||||
|
from core.moderation.base import ModerationException
|
||||||
|
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||||
|
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||||
|
from core.workflow.graph_engine.entities.event import (
|
||||||
|
GraphRunFailedEvent,
|
||||||
|
GraphRunStartedEvent,
|
||||||
|
GraphRunSucceededEvent,
|
||||||
|
IterationRunFailedEvent,
|
||||||
|
IterationRunNextEvent,
|
||||||
|
IterationRunStartedEvent,
|
||||||
|
IterationRunSucceededEvent,
|
||||||
|
NodeRunFailedEvent,
|
||||||
|
NodeRunRetrieverResourceEvent,
|
||||||
|
NodeRunStartedEvent,
|
||||||
|
NodeRunStreamChunkEvent,
|
||||||
|
NodeRunSucceededEvent,
|
||||||
|
ParallelBranchRunFailedEvent,
|
||||||
|
ParallelBranchRunStartedEvent,
|
||||||
|
ParallelBranchRunSucceededEvent,
|
||||||
|
)
|
||||||
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.model import App, Conversation, EndUser, Message
|
||||||
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedChatAppRunner(AppRunner):
|
||||||
|
"""
|
||||||
|
AdvancedChat Application Runner
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run(self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
conversation: Conversation,
|
||||||
|
message: Message) -> Generator[AppQueueEvent, None, None]:
|
||||||
|
"""
|
||||||
|
Run application
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:param conversation: conversation
|
||||||
|
:param message: message
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
app_config = application_generate_entity.app_config
|
||||||
|
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||||
|
|
||||||
|
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||||
|
if not app_record:
|
||||||
|
raise ValueError("App not found")
|
||||||
|
|
||||||
|
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||||
|
if not workflow:
|
||||||
|
raise ValueError("Workflow not initialized")
|
||||||
|
|
||||||
|
inputs = application_generate_entity.inputs
|
||||||
|
query = application_generate_entity.query
|
||||||
|
files = application_generate_entity.files
|
||||||
|
|
||||||
|
user_id = None
|
||||||
|
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||||
|
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||||
|
if end_user:
|
||||||
|
user_id = end_user.session_id
|
||||||
|
else:
|
||||||
|
user_id = application_generate_entity.user_id
|
||||||
|
|
||||||
|
# moderation
|
||||||
|
if self.handle_input_moderation(
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
app_record=app_record,
|
||||||
|
app_generate_entity=application_generate_entity,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query,
|
||||||
|
message_id=message.id
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# annotation reply
|
||||||
|
if self.handle_annotation_reply(
|
||||||
|
app_record=app_record,
|
||||||
|
message=message,
|
||||||
|
query=query,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
app_generate_entity=application_generate_entity
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
|
workflow_callbacks: list[WorkflowCallback] = []
|
||||||
|
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||||
|
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||||
|
|
||||||
|
# RUN WORKFLOW
|
||||||
|
workflow_entry = WorkflowEntry(
|
||||||
|
workflow=workflow,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
user_from=UserFrom.ACCOUNT
|
||||||
|
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||||
|
else UserFrom.END_USER,
|
||||||
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
|
user_inputs=inputs,
|
||||||
|
system_inputs={
|
||||||
|
SystemVariable.QUERY: query,
|
||||||
|
SystemVariable.FILES: files,
|
||||||
|
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||||
|
SystemVariable.USER_ID: user_id
|
||||||
|
},
|
||||||
|
call_depth=application_generate_entity.call_depth
|
||||||
|
)
|
||||||
|
generator = workflow_entry.run(
|
||||||
|
callbacks=workflow_callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
for event in generator:
|
||||||
|
if isinstance(event, GraphRunStartedEvent):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueWorkflowStartedEvent(),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
elif isinstance(event, GraphRunSucceededEvent):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueWorkflowSucceededEvent(),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
elif isinstance(event, GraphRunFailedEvent):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueWorkflowFailedEvent(error=event.error),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
elif isinstance(event, NodeRunStartedEvent):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueNodeStartedEvent(
|
||||||
|
node_id=event.node_id,
|
||||||
|
node_type=event.node_type,
|
||||||
|
node_data=event.node_data,
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
|
predecessor_node_id=event.predecessor_node_id
|
||||||
|
),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
elif isinstance(event, NodeRunSucceededEvent):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueNodeSucceededEvent(
|
||||||
|
node_id=event.node_id,
|
||||||
|
node_type=event.node_type,
|
||||||
|
node_data=event.node_data,
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
inputs=event.route_node_state.node_run_result.inputs
|
||||||
|
if event.route_node_state.node_run_result else {},
|
||||||
|
process_data=event.route_node_state.node_run_result.process_data
|
||||||
|
if event.route_node_state.node_run_result else {},
|
||||||
|
outputs=event.route_node_state.node_run_result.outputs
|
||||||
|
if event.route_node_state.node_run_result else {},
|
||||||
|
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||||
|
if event.route_node_state.node_run_result else {},
|
||||||
|
),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
elif isinstance(event, NodeRunFailedEvent):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueNodeFailedEvent(
|
||||||
|
node_id=event.node_id,
|
||||||
|
node_type=event.node_type,
|
||||||
|
node_data=event.node_data,
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
inputs=event.route_node_state.node_run_result.inputs
|
||||||
|
if event.route_node_state.node_run_result else {},
|
||||||
|
process_data=event.route_node_state.node_run_result.process_data
|
||||||
|
if event.route_node_state.node_run_result else {},
|
||||||
|
outputs=event.route_node_state.node_run_result.outputs
|
||||||
|
if event.route_node_state.node_run_result else {},
|
||||||
|
error=event.route_node_state.node_run_result.error
|
||||||
|
if event.route_node_state.node_run_result
|
||||||
|
and event.route_node_state.node_run_result.error
|
||||||
|
else "Unknown error"
|
||||||
|
),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueTextChunkEvent(
|
||||||
|
text=event.chunk_content
|
||||||
|
), PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueRetrieverResourcesEvent(
|
||||||
|
retriever_resources=event.retriever_resources
|
||||||
|
), PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueParallelBranchRunStartedEvent(
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_start_node_id=event.parallel_start_node_id
|
||||||
|
),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueParallelBranchRunStartedEvent(
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_start_node_id=event.parallel_start_node_id
|
||||||
|
),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueParallelBranchRunFailedEvent(
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
error=event.error
|
||||||
|
),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
elif isinstance(event, IterationRunStartedEvent):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueIterationStartEvent(
|
||||||
|
node_id=event.iteration_node_id,
|
||||||
|
node_type=event.iteration_node_type,
|
||||||
|
node_data=event.iteration_node_data,
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
|
inputs=event.inputs,
|
||||||
|
predecessor_node_id=event.predecessor_node_id,
|
||||||
|
metadata=event.metadata
|
||||||
|
),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
elif isinstance(event, IterationRunNextEvent):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueIterationNextEvent(
|
||||||
|
node_id=event.iteration_node_id,
|
||||||
|
node_type=event.iteration_node_type,
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
index=event.index,
|
||||||
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
|
output=event.pre_iteration_output,
|
||||||
|
),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueIterationCompletedEvent(
|
||||||
|
node_id=event.iteration_node_id,
|
||||||
|
node_type=event.iteration_node_type,
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||||
|
inputs=event.inputs,
|
||||||
|
outputs=event.outputs,
|
||||||
|
metadata=event.metadata,
|
||||||
|
steps=event.steps,
|
||||||
|
error=event.error if isinstance(event, IterationRunFailedEvent) else None
|
||||||
|
),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||||
|
"""
|
||||||
|
Get workflow
|
||||||
|
"""
|
||||||
|
# fetch workflow by workflow_id
|
||||||
|
workflow = db.session.query(Workflow).filter(
|
||||||
|
Workflow.tenant_id == app_model.tenant_id,
|
||||||
|
Workflow.app_id == app_model.id,
|
||||||
|
Workflow.id == workflow_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
# return workflow
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
def handle_input_moderation(
|
||||||
|
self, queue_manager: AppQueueManager,
|
||||||
|
app_record: App,
|
||||||
|
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
|
inputs: Mapping[str, Any],
|
||||||
|
query: str,
|
||||||
|
message_id: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Handle input moderation
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:param app_record: app record
|
||||||
|
:param app_generate_entity: application generate entity
|
||||||
|
:param inputs: inputs
|
||||||
|
:param query: query
|
||||||
|
:param message_id: message id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# process sensitive_word_avoidance
|
||||||
|
_, inputs, query = self.moderation_for_inputs(
|
||||||
|
app_id=app_record.id,
|
||||||
|
tenant_id=app_generate_entity.app_config.tenant_id,
|
||||||
|
app_generate_entity=app_generate_entity,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query,
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
except ModerationException as e:
|
||||||
|
self._stream_output(
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
text=str(e),
|
||||||
|
stream=app_generate_entity.stream,
|
||||||
|
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def handle_annotation_reply(self, app_record: App,
|
||||||
|
message: Message,
|
||||||
|
query: str,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
|
||||||
|
"""
|
||||||
|
Handle annotation reply
|
||||||
|
:param app_record: app record
|
||||||
|
:param message: message
|
||||||
|
:param query: query
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:param app_generate_entity: application generate entity
|
||||||
|
"""
|
||||||
|
# annotation reply
|
||||||
|
annotation_reply = self.query_app_annotations_to_reply(
|
||||||
|
app_record=app_record,
|
||||||
|
message=message,
|
||||||
|
query=query,
|
||||||
|
user_id=app_generate_entity.user_id,
|
||||||
|
invoke_from=app_generate_entity.invoke_from
|
||||||
|
)
|
||||||
|
|
||||||
|
if annotation_reply:
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
|
||||||
|
self._stream_output(
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
text=annotation_reply.content,
|
||||||
|
stream=app_generate_entity.stream,
|
||||||
|
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _stream_output(self, queue_manager: AppQueueManager,
|
||||||
|
text: str,
|
||||||
|
stream: bool,
|
||||||
|
stopped_by: QueueStopEvent.StopBy) -> None:
|
||||||
|
"""
|
||||||
|
Direct output
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:param text: text
|
||||||
|
:param stream: stream
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if stream:
|
||||||
|
index = 0
|
||||||
|
for token in text:
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueTextChunkEvent(
|
||||||
|
text=token
|
||||||
|
), PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
index += 1
|
||||||
|
time.sleep(0.01)
|
||||||
|
else:
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueTextChunkEvent(
|
||||||
|
text=text
|
||||||
|
), PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
|
||||||
|
queue_manager.publish(
|
||||||
|
QueueStopEvent(stopped_by=stopped_by),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
450
api/core/app/apps/chatflow/generate_task_pipeline.py
Normal file
450
api/core/app/apps/chatflow/generate_task_pipeline.py
Normal file
@ -0,0 +1,450 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||||
|
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
|
from core.app.entities.app_invoke_entities import (
|
||||||
|
AdvancedChatAppGenerateEntity,
|
||||||
|
)
|
||||||
|
from core.app.entities.queue_entities import (
|
||||||
|
QueueAdvancedChatMessageEndEvent,
|
||||||
|
QueueAnnotationReplyEvent,
|
||||||
|
QueueErrorEvent,
|
||||||
|
QueueIterationCompletedEvent,
|
||||||
|
QueueIterationNextEvent,
|
||||||
|
QueueIterationStartEvent,
|
||||||
|
QueueMessageReplaceEvent,
|
||||||
|
QueueNodeFailedEvent,
|
||||||
|
QueueNodeSucceededEvent,
|
||||||
|
QueuePingEvent,
|
||||||
|
QueueRetrieverResourcesEvent,
|
||||||
|
QueueStopEvent,
|
||||||
|
QueueTextChunkEvent,
|
||||||
|
QueueWorkflowFailedEvent,
|
||||||
|
QueueWorkflowSucceededEvent,
|
||||||
|
)
|
||||||
|
from core.app.entities.task_entities import (
|
||||||
|
AdvancedChatTaskState,
|
||||||
|
ChatbotAppBlockingResponse,
|
||||||
|
ChatbotAppStreamResponse,
|
||||||
|
ErrorStreamResponse,
|
||||||
|
MessageAudioEndStreamResponse,
|
||||||
|
MessageAudioStreamResponse,
|
||||||
|
MessageEndStreamResponse,
|
||||||
|
StreamResponse,
|
||||||
|
)
|
||||||
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
|
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||||
|
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||||
|
from core.workflow.graph_engine.entities.event import GraphRunStartedEvent, NodeRunStartedEvent
|
||||||
|
from events.message_event import message_was_created
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.account import Account
|
||||||
|
from models.model import Conversation, EndUser, Message
|
||||||
|
from models.workflow import (
|
||||||
|
Workflow,
|
||||||
|
WorkflowRunStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage):
|
||||||
|
"""
|
||||||
|
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
|
"""
|
||||||
|
_task_state: AdvancedChatTaskState
|
||||||
|
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||||
|
_workflow: Workflow
|
||||||
|
_user: Union[Account, EndUser]
|
||||||
|
_workflow_system_variables: dict[SystemVariable, Any]
|
||||||
|
_iteration_nested_relations: dict[str, list[str]]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||||
|
workflow: Workflow,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
conversation: Conversation,
|
||||||
|
message: Message,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
stream: bool
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param workflow: workflow
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param conversation: conversation
|
||||||
|
:param message: message
|
||||||
|
:param user: user
|
||||||
|
:param stream: stream
|
||||||
|
"""
|
||||||
|
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||||
|
|
||||||
|
if isinstance(self._user, EndUser):
|
||||||
|
user_id = self._user.session_id
|
||||||
|
else:
|
||||||
|
user_id = self._user.id
|
||||||
|
|
||||||
|
self._workflow = workflow
|
||||||
|
self._conversation = conversation
|
||||||
|
self._message = message
|
||||||
|
self._workflow_system_variables = {
|
||||||
|
SystemVariable.QUERY: message.query,
|
||||||
|
SystemVariable.FILES: application_generate_entity.files,
|
||||||
|
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||||
|
SystemVariable.USER_ID: user_id
|
||||||
|
}
|
||||||
|
|
||||||
|
self._task_state = AdvancedChatTaskState(
|
||||||
|
usage=LLMUsage.empty_usage()
|
||||||
|
)
|
||||||
|
|
||||||
|
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
|
||||||
|
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||||
|
self._conversation_name_generate_thread = None
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
"""
|
||||||
|
Process generate task pipeline.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
db.session.refresh(self._workflow)
|
||||||
|
db.session.refresh(self._user)
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
|
# start generate conversation name thread
|
||||||
|
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||||
|
self._conversation,
|
||||||
|
self._application_generate_entity.query
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = self._wrapper_process_stream_response(
|
||||||
|
trace_manager=self._application_generate_entity.trace_manager
|
||||||
|
)
|
||||||
|
if self._stream:
|
||||||
|
return self._to_stream_response(generator)
|
||||||
|
else:
|
||||||
|
return self._to_blocking_response(generator)
|
||||||
|
|
||||||
|
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
|
||||||
|
"""
|
||||||
|
Process blocking response.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for stream_response in generator:
|
||||||
|
if isinstance(stream_response, ErrorStreamResponse):
|
||||||
|
raise stream_response.err
|
||||||
|
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||||
|
extras = {}
|
||||||
|
if stream_response.metadata:
|
||||||
|
extras['metadata'] = stream_response.metadata
|
||||||
|
|
||||||
|
return ChatbotAppBlockingResponse(
|
||||||
|
task_id=stream_response.task_id,
|
||||||
|
data=ChatbotAppBlockingResponse.Data(
|
||||||
|
id=self._message.id,
|
||||||
|
mode=self._conversation.mode,
|
||||||
|
conversation_id=self._conversation.id,
|
||||||
|
message_id=self._message.id,
|
||||||
|
answer=self._task_state.answer,
|
||||||
|
created_at=int(self._message.created_at.timestamp()),
|
||||||
|
**extras
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise Exception('Queue listening stopped unexpectedly.')
|
||||||
|
|
||||||
|
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||||
|
"""
|
||||||
|
To stream response.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for stream_response in generator:
|
||||||
|
yield ChatbotAppStreamResponse(
|
||||||
|
conversation_id=self._conversation.id,
|
||||||
|
message_id=self._message.id,
|
||||||
|
created_at=int(self._message.created_at.timestamp()),
|
||||||
|
stream_response=stream_response
|
||||||
|
)
|
||||||
|
|
||||||
|
def _listenAudioMsg(self, publisher, task_id: str):
|
||||||
|
if not publisher:
|
||||||
|
return None
|
||||||
|
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||||
|
if audio_msg and audio_msg.status != "finish":
|
||||||
|
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||||
|
Generator[StreamResponse, None, None]:
|
||||||
|
|
||||||
|
publisher = None
|
||||||
|
task_id = self._application_generate_entity.task_id
|
||||||
|
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||||
|
features_dict = self._workflow.features_dict
|
||||||
|
|
||||||
|
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||||
|
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||||
|
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||||
|
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||||
|
while True:
|
||||||
|
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||||
|
if audio_response:
|
||||||
|
yield audio_response
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
yield response
|
||||||
|
|
||||||
|
start_listener_time = time.time()
|
||||||
|
# timeout
|
||||||
|
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||||
|
try:
|
||||||
|
if not publisher:
|
||||||
|
break
|
||||||
|
audio_trunk = publisher.checkAndGetAudio()
|
||||||
|
if audio_trunk is None:
|
||||||
|
# release cpu
|
||||||
|
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||||
|
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
|
||||||
|
continue
|
||||||
|
if audio_trunk.status == "finish":
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
start_listener_time = time.time()
|
||||||
|
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
break
|
||||||
|
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||||
|
|
||||||
|
def _process_stream_response(
|
||||||
|
self,
|
||||||
|
publisher: AppGeneratorTTSPublisher,
|
||||||
|
trace_manager: Optional[TraceQueueManager] = None
|
||||||
|
) -> Generator[StreamResponse, None, None]:
|
||||||
|
"""
|
||||||
|
Process stream response.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for message in self._queue_manager.listen():
|
||||||
|
if publisher:
|
||||||
|
publisher.publish(message=message)
|
||||||
|
event = message.event
|
||||||
|
|
||||||
|
if isinstance(event, QueueErrorEvent):
|
||||||
|
err = self._handle_error(event, self._message)
|
||||||
|
yield self._error_to_stream_response(err)
|
||||||
|
break
|
||||||
|
elif isinstance(event, GraphRunStartedEvent):
|
||||||
|
workflow_run = self._handle_workflow_start()
|
||||||
|
|
||||||
|
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||||
|
self._message.workflow_run_id = workflow_run.id
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
db.session.refresh(self._message)
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
|
yield self._workflow_start_to_stream_response(
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run
|
||||||
|
)
|
||||||
|
elif isinstance(event, NodeRunStartedEvent):
|
||||||
|
workflow_node_execution = self._handle_node_start(event)
|
||||||
|
|
||||||
|
yield self._workflow_node_start_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_node_execution=workflow_node_execution
|
||||||
|
)
|
||||||
|
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||||
|
workflow_node_execution = self._handle_node_finished(event)
|
||||||
|
|
||||||
|
yield self._workflow_node_finish_to_stream_response(
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_node_execution=workflow_node_execution
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(event, QueueNodeFailedEvent):
|
||||||
|
yield from self._handle_iteration_exception(
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
error=f'Child node failed: {event.error}'
|
||||||
|
)
|
||||||
|
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
|
||||||
|
if isinstance(event, QueueIterationNextEvent):
|
||||||
|
# clear ran node execution infos of current iteration
|
||||||
|
iteration_relations = self._iteration_nested_relations.get(event.node_id)
|
||||||
|
if iteration_relations:
|
||||||
|
for node_id in iteration_relations:
|
||||||
|
self._task_state.ran_node_execution_infos.pop(node_id, None)
|
||||||
|
|
||||||
|
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||||
|
self._handle_iteration_operation(event)
|
||||||
|
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||||
|
workflow_run = self._handle_workflow_finished(
|
||||||
|
event, conversation_id=self._conversation.id, trace_manager=trace_manager
|
||||||
|
)
|
||||||
|
if workflow_run:
|
||||||
|
yield self._workflow_finish_to_stream_response(
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_run=workflow_run
|
||||||
|
)
|
||||||
|
|
||||||
|
if workflow_run.status == WorkflowRunStatus.FAILED.value:
|
||||||
|
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||||
|
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
||||||
|
break
|
||||||
|
|
||||||
|
if isinstance(event, QueueStopEvent):
|
||||||
|
# Save message
|
||||||
|
self._save_message()
|
||||||
|
|
||||||
|
yield self._message_end_to_stream_response()
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self._queue_manager.publish(
|
||||||
|
QueueAdvancedChatMessageEndEvent(),
|
||||||
|
PublishFrom.TASK_PIPELINE
|
||||||
|
)
|
||||||
|
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||||
|
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||||
|
if output_moderation_answer:
|
||||||
|
self._task_state.answer = output_moderation_answer
|
||||||
|
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||||
|
|
||||||
|
# Save message
|
||||||
|
self._save_message()
|
||||||
|
|
||||||
|
yield self._message_end_to_stream_response()
|
||||||
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||||
|
self._handle_retriever_resources(event)
|
||||||
|
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||||
|
self._handle_annotation_reply(event)
|
||||||
|
elif isinstance(event, QueueTextChunkEvent):
|
||||||
|
delta_text = event.text
|
||||||
|
if delta_text is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# handle output moderation chunk
|
||||||
|
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
|
||||||
|
if should_direct_answer:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._task_state.answer += delta_text
|
||||||
|
yield self._message_to_stream_response(delta_text, self._message.id)
|
||||||
|
elif isinstance(event, QueueMessageReplaceEvent):
|
||||||
|
yield self._message_replace_to_stream_response(answer=event.text)
|
||||||
|
elif isinstance(event, QueuePingEvent):
|
||||||
|
yield self._ping_stream_response()
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if publisher:
|
||||||
|
publisher.publish(None)
|
||||||
|
if self._conversation_name_generate_thread:
|
||||||
|
self._conversation_name_generate_thread.join()
|
||||||
|
|
||||||
|
def _save_message(self) -> None:
|
||||||
|
"""
|
||||||
|
Save message.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||||
|
|
||||||
|
self._message.answer = self._task_state.answer
|
||||||
|
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||||
|
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||||
|
if self._task_state.metadata else None
|
||||||
|
|
||||||
|
if self._task_state.metadata and self._task_state.metadata.get('usage'):
|
||||||
|
usage = LLMUsage(**self._task_state.metadata['usage'])
|
||||||
|
|
||||||
|
self._message.message_tokens = usage.prompt_tokens
|
||||||
|
self._message.message_unit_price = usage.prompt_unit_price
|
||||||
|
self._message.message_price_unit = usage.prompt_price_unit
|
||||||
|
self._message.answer_tokens = usage.completion_tokens
|
||||||
|
self._message.answer_unit_price = usage.completion_unit_price
|
||||||
|
self._message.answer_price_unit = usage.completion_price_unit
|
||||||
|
self._message.total_price = usage.total_price
|
||||||
|
self._message.currency = usage.currency
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
message_was_created.send(
|
||||||
|
self._message,
|
||||||
|
application_generate_entity=self._application_generate_entity,
|
||||||
|
conversation=self._conversation,
|
||||||
|
is_first_message=self._application_generate_entity.conversation_id is None,
|
||||||
|
extras=self._application_generate_entity.extras
|
||||||
|
)
|
||||||
|
|
||||||
|
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||||
|
"""
|
||||||
|
Message end to stream response.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
extras = {}
|
||||||
|
if self._task_state.metadata:
|
||||||
|
extras['metadata'] = self._task_state.metadata
|
||||||
|
|
||||||
|
return MessageEndStreamResponse(
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
id=self._message.id,
|
||||||
|
**extras
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||||
|
"""
|
||||||
|
Get iteration nested relations.
|
||||||
|
:param graph: graph
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
nodes = graph.get('nodes')
|
||||||
|
|
||||||
|
iteration_ids = [node.get('id') for node in nodes
|
||||||
|
if node.get('data', {}).get('type') in [
|
||||||
|
NodeType.ITERATION.value,
|
||||||
|
NodeType.LOOP.value,
|
||||||
|
]]
|
||||||
|
|
||||||
|
return {
|
||||||
|
iteration_id: [
|
||||||
|
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||||
|
] for iteration_id in iteration_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||||
|
"""
|
||||||
|
Handle output moderation chunk.
|
||||||
|
:param text: text
|
||||||
|
:return: True if output moderation should direct output, otherwise False
|
||||||
|
"""
|
||||||
|
if self._output_moderation_handler:
|
||||||
|
if self._output_moderation_handler.should_direct_output():
|
||||||
|
# stop subscribe new token when output moderation should direct output
|
||||||
|
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||||
|
self._queue_manager.publish(
|
||||||
|
QueueTextChunkEvent(
|
||||||
|
text=self._task_state.answer
|
||||||
|
), PublishFrom.TASK_PIPELINE
|
||||||
|
)
|
||||||
|
|
||||||
|
self._queue_manager.publish(
|
||||||
|
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||||
|
PublishFrom.TASK_PIPELINE
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self._output_moderation_handler.append_new_token(text)
|
||||||
|
|
||||||
|
return False
|
||||||
@ -46,7 +46,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||||||
elif isinstance(event, GraphRunSucceededEvent):
|
elif isinstance(event, GraphRunSucceededEvent):
|
||||||
self.print_text("\n[on_workflow_run_succeeded]", color='green')
|
self.print_text("\n[on_workflow_run_succeeded]", color='green')
|
||||||
elif isinstance(event, GraphRunFailedEvent):
|
elif isinstance(event, GraphRunFailedEvent):
|
||||||
self.print_text(f"\n[on_workflow_run_failed] reason: {event.reason}", color='red')
|
self.print_text(f"\n[on_workflow_run_failed] reason: {event.error}", color='red')
|
||||||
elif isinstance(event, NodeRunStartedEvent):
|
elif isinstance(event, NodeRunStartedEvent):
|
||||||
self.on_workflow_node_execute_started(
|
self.on_workflow_node_execute_started(
|
||||||
graph=graph,
|
graph=graph,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@ -5,7 +6,7 @@ from pydantic import BaseModel, field_validator
|
|||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeType
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||||
|
|
||||||
|
|
||||||
class QueueEvent(str, Enum):
|
class QueueEvent(str, Enum):
|
||||||
@ -31,6 +32,9 @@ class QueueEvent(str, Enum):
|
|||||||
ANNOTATION_REPLY = "annotation_reply"
|
ANNOTATION_REPLY = "annotation_reply"
|
||||||
AGENT_THOUGHT = "agent_thought"
|
AGENT_THOUGHT = "agent_thought"
|
||||||
MESSAGE_FILE = "message_file"
|
MESSAGE_FILE = "message_file"
|
||||||
|
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||||
|
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||||
|
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
PING = "ping"
|
PING = "ping"
|
||||||
STOP = "stop"
|
STOP = "stop"
|
||||||
@ -38,7 +42,7 @@ class QueueEvent(str, Enum):
|
|||||||
|
|
||||||
class AppQueueEvent(BaseModel):
|
class AppQueueEvent(BaseModel):
|
||||||
"""
|
"""
|
||||||
QueueEvent entity
|
QueueEvent abstract entity
|
||||||
"""
|
"""
|
||||||
event: QueueEvent
|
event: QueueEvent
|
||||||
|
|
||||||
@ -46,6 +50,7 @@ class AppQueueEvent(BaseModel):
|
|||||||
class QueueLLMChunkEvent(AppQueueEvent):
|
class QueueLLMChunkEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
QueueLLMChunkEvent entity
|
QueueLLMChunkEvent entity
|
||||||
|
Only for basic mode apps
|
||||||
"""
|
"""
|
||||||
event: QueueEvent = QueueEvent.LLM_CHUNK
|
event: QueueEvent = QueueEvent.LLM_CHUNK
|
||||||
chunk: LLMResultChunk
|
chunk: LLMResultChunk
|
||||||
@ -58,11 +63,15 @@ class QueueIterationStartEvent(AppQueueEvent):
|
|||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_data: BaseNodeData
|
||||||
|
parallel_id: Optional[str] = None
|
||||||
|
"""parallel id if node is in parallel"""
|
||||||
|
parallel_start_node_id: Optional[str] = None
|
||||||
|
"""parallel start node id if node is in parallel"""
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
inputs: dict = None
|
inputs: Optional[Mapping[str, Any]] = None
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: Optional[str] = None
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[Mapping[str, Any]] = None
|
||||||
|
|
||||||
class QueueIterationNextEvent(AppQueueEvent):
|
class QueueIterationNextEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
@ -73,6 +82,10 @@ class QueueIterationNextEvent(AppQueueEvent):
|
|||||||
index: int
|
index: int
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
|
parallel_id: Optional[str] = None
|
||||||
|
"""parallel id if node is in parallel"""
|
||||||
|
parallel_start_node_id: Optional[str] = None
|
||||||
|
"""parallel start node id if node is in parallel"""
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
output: Optional[Any] = None # output for the current iteration
|
output: Optional[Any] = None # output for the current iteration
|
||||||
@ -93,13 +106,23 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
|||||||
"""
|
"""
|
||||||
QueueIterationCompletedEvent entity
|
QueueIterationCompletedEvent entity
|
||||||
"""
|
"""
|
||||||
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
|
event: QueueEvent = QueueEvent.ITERATION_COMPLETED
|
||||||
|
|
||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
|
parallel_id: Optional[str] = None
|
||||||
|
"""parallel id if node is in parallel"""
|
||||||
|
parallel_start_node_id: Optional[str] = None
|
||||||
|
"""parallel start node id if node is in parallel"""
|
||||||
|
|
||||||
node_run_index: int
|
node_run_index: int
|
||||||
outputs: dict
|
inputs: Optional[Mapping[str, Any]] = None
|
||||||
|
outputs: Optional[Mapping[str, Any]] = None
|
||||||
|
metadata: Optional[Mapping[str, Any]] = None
|
||||||
|
steps: int = 0
|
||||||
|
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class QueueTextChunkEvent(AppQueueEvent):
|
class QueueTextChunkEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
@ -190,6 +213,10 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
|||||||
node_data: BaseNodeData
|
node_data: BaseNodeData
|
||||||
node_run_index: int = 1
|
node_run_index: int = 1
|
||||||
predecessor_node_id: Optional[str] = None
|
predecessor_node_id: Optional[str] = None
|
||||||
|
parallel_id: Optional[str] = None
|
||||||
|
"""parallel id if node is in parallel"""
|
||||||
|
parallel_start_node_id: Optional[str] = None
|
||||||
|
"""parallel start node id if node is in parallel"""
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||||
@ -201,11 +228,15 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_data: BaseNodeData
|
||||||
|
parallel_id: Optional[str] = None
|
||||||
|
"""parallel id if node is in parallel"""
|
||||||
|
parallel_start_node_id: Optional[str] = None
|
||||||
|
"""parallel start node id if node is in parallel"""
|
||||||
|
|
||||||
inputs: Optional[dict] = None
|
inputs: Optional[Mapping[str, Any]] = None
|
||||||
process_data: Optional[dict] = None
|
process_data: Optional[Mapping[str, Any]] = None
|
||||||
outputs: Optional[dict] = None
|
outputs: Optional[Mapping[str, Any]] = None
|
||||||
execution_metadata: Optional[dict] = None
|
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||||
|
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
@ -219,10 +250,14 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||||||
node_id: str
|
node_id: str
|
||||||
node_type: NodeType
|
node_type: NodeType
|
||||||
node_data: BaseNodeData
|
node_data: BaseNodeData
|
||||||
|
parallel_id: Optional[str] = None
|
||||||
|
"""parallel id if node is in parallel"""
|
||||||
|
parallel_start_node_id: Optional[str] = None
|
||||||
|
"""parallel start node id if node is in parallel"""
|
||||||
|
|
||||||
inputs: Optional[dict] = None
|
inputs: Optional[Mapping[str, Any]] = None
|
||||||
outputs: Optional[dict] = None
|
process_data: Optional[Mapping[str, Any]] = None
|
||||||
process_data: Optional[dict] = None
|
outputs: Optional[Mapping[str, Any]] = None
|
||||||
|
|
||||||
error: str
|
error: str
|
||||||
|
|
||||||
@ -277,7 +312,7 @@ class QueueStopEvent(AppQueueEvent):
|
|||||||
|
|
||||||
class QueueMessage(BaseModel):
|
class QueueMessage(BaseModel):
|
||||||
"""
|
"""
|
||||||
QueueMessage entity
|
QueueMessage abstract entity
|
||||||
"""
|
"""
|
||||||
task_id: str
|
task_id: str
|
||||||
app_mode: str
|
app_mode: str
|
||||||
@ -297,3 +332,34 @@ class WorkflowQueueMessage(QueueMessage):
|
|||||||
WorkflowQueueMessage entity
|
WorkflowQueueMessage entity
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
QueueParallelBranchRunStartedEvent entity
|
||||||
|
"""
|
||||||
|
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
||||||
|
|
||||||
|
parallel_id: str
|
||||||
|
parallel_start_node_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
QueueParallelBranchRunSucceededEvent entity
|
||||||
|
"""
|
||||||
|
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
||||||
|
|
||||||
|
parallel_id: str
|
||||||
|
parallel_start_node_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
QueueParallelBranchRunFailedEvent entity
|
||||||
|
"""
|
||||||
|
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
||||||
|
|
||||||
|
parallel_id: str
|
||||||
|
parallel_start_node_id: str
|
||||||
|
error: str
|
||||||
|
|||||||
@ -84,9 +84,9 @@ class NodeRunResult(BaseModel):
|
|||||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
|
||||||
|
|
||||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||||
process_data: Optional[dict] = None # process data
|
process_data: Optional[Mapping[str, Any]] = None # process data
|
||||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata
|
||||||
|
|
||||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||||
|
|
||||||
|
|||||||
@ -1,13 +1,17 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
|
from core.workflow.entities.node_entities import NodeType
|
||||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
|
|
||||||
|
|
||||||
class GraphEngineEvent(BaseModel):
|
class GraphEngineEvent(BaseModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
###########################################
|
###########################################
|
||||||
# Graph Events
|
# Graph Events
|
||||||
###########################################
|
###########################################
|
||||||
@ -26,7 +30,7 @@ class GraphRunSucceededEvent(BaseGraphEvent):
|
|||||||
|
|
||||||
|
|
||||||
class GraphRunFailedEvent(BaseGraphEvent):
|
class GraphRunFailedEvent(BaseGraphEvent):
|
||||||
reason: str = Field(..., description="failed reason")
|
error: str = Field(..., description="failed reason")
|
||||||
|
|
||||||
|
|
||||||
###########################################
|
###########################################
|
||||||
@ -35,16 +39,20 @@ class GraphRunFailedEvent(BaseGraphEvent):
|
|||||||
|
|
||||||
|
|
||||||
class BaseNodeEvent(GraphEngineEvent):
|
class BaseNodeEvent(GraphEngineEvent):
|
||||||
|
node_id: str = Field(..., description="node id")
|
||||||
|
node_type: NodeType = Field(..., description="node type")
|
||||||
|
node_data: BaseNodeData = Field(..., description="node data")
|
||||||
route_node_state: RouteNodeState = Field(..., description="route node state")
|
route_node_state: RouteNodeState = Field(..., description="route node state")
|
||||||
parallel_id: Optional[str] = None
|
parallel_id: Optional[str] = None
|
||||||
"""parallel id if node is in parallel"""
|
"""parallel id if node is in parallel"""
|
||||||
parallel_start_node_id: Optional[str] = None
|
parallel_start_node_id: Optional[str] = None
|
||||||
"""parallel start node id if node is in parallel"""
|
"""parallel start node id if node is in parallel"""
|
||||||
# iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
|
in_iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
|
||||||
|
|
||||||
|
|
||||||
class NodeRunStartedEvent(BaseNodeEvent):
|
class NodeRunStartedEvent(BaseNodeEvent):
|
||||||
pass
|
predecessor_node_id: Optional[str] = None
|
||||||
|
"""predecessor node id"""
|
||||||
|
|
||||||
|
|
||||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||||
@ -63,7 +71,7 @@ class NodeRunSucceededEvent(BaseNodeEvent):
|
|||||||
|
|
||||||
|
|
||||||
class NodeRunFailedEvent(BaseNodeEvent):
|
class NodeRunFailedEvent(BaseNodeEvent):
|
||||||
pass
|
error: str = Field(..., description="error")
|
||||||
|
|
||||||
|
|
||||||
###########################################
|
###########################################
|
||||||
@ -74,6 +82,7 @@ class NodeRunFailedEvent(BaseNodeEvent):
|
|||||||
class BaseParallelBranchEvent(GraphEngineEvent):
|
class BaseParallelBranchEvent(GraphEngineEvent):
|
||||||
parallel_id: str = Field(..., description="parallel id")
|
parallel_id: str = Field(..., description="parallel id")
|
||||||
parallel_start_node_id: str = Field(..., description="parallel start node id")
|
parallel_start_node_id: str = Field(..., description="parallel start node id")
|
||||||
|
in_iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration")
|
||||||
|
|
||||||
|
|
||||||
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
|
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
|
||||||
@ -85,7 +94,7 @@ class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
|
|||||||
|
|
||||||
|
|
||||||
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
|
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
|
||||||
reason: str = Field(..., description="failed reason")
|
error: str = Field(..., description="failed reason")
|
||||||
|
|
||||||
|
|
||||||
###########################################
|
###########################################
|
||||||
@ -94,11 +103,19 @@ class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
|
|||||||
|
|
||||||
|
|
||||||
class BaseIterationEvent(GraphEngineEvent):
|
class BaseIterationEvent(GraphEngineEvent):
|
||||||
iteration_id: str = Field(..., description="iteration id")
|
iteration_node_id: str = Field(..., description="iteration node id")
|
||||||
|
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
|
||||||
|
iteration_node_data: BaseNodeData = Field(..., description="node data")
|
||||||
|
parallel_id: Optional[str] = None
|
||||||
|
"""parallel id if node is in parallel"""
|
||||||
|
parallel_start_node_id: Optional[str] = None
|
||||||
|
"""parallel start node id if node is in parallel"""
|
||||||
|
|
||||||
|
|
||||||
class IterationRunStartedEvent(BaseIterationEvent):
|
class IterationRunStartedEvent(BaseIterationEvent):
|
||||||
pass
|
inputs: Optional[Mapping[str, Any]] = None
|
||||||
|
metadata: Optional[Mapping[str, Any]] = None
|
||||||
|
predecessor_node_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class IterationRunNextEvent(BaseIterationEvent):
|
class IterationRunNextEvent(BaseIterationEvent):
|
||||||
@ -107,11 +124,18 @@ class IterationRunNextEvent(BaseIterationEvent):
|
|||||||
|
|
||||||
|
|
||||||
class IterationRunSucceededEvent(BaseIterationEvent):
|
class IterationRunSucceededEvent(BaseIterationEvent):
|
||||||
pass
|
inputs: Optional[Mapping[str, Any]] = None
|
||||||
|
outputs: Optional[Mapping[str, Any]] = None
|
||||||
|
metadata: Optional[Mapping[str, Any]] = None
|
||||||
|
steps: int = 0
|
||||||
|
|
||||||
|
|
||||||
class IterationRunFailedEvent(BaseIterationEvent):
|
class IterationRunFailedEvent(BaseIterationEvent):
|
||||||
reason: str = Field(..., description="failed reason")
|
inputs: Optional[Mapping[str, Any]] = None
|
||||||
|
outputs: Optional[Mapping[str, Any]] = None
|
||||||
|
metadata: Optional[Mapping[str, Any]] = None
|
||||||
|
steps: int = 0
|
||||||
|
error: str = Field(..., description="failed reason")
|
||||||
|
|
||||||
|
|
||||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent
|
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, U
|
|||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
|
BaseIterationEvent,
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
GraphRunStartedEvent,
|
GraphRunStartedEvent,
|
||||||
@ -32,6 +33,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
|||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||||
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||||
from core.workflow.nodes.node_mapping import node_classes
|
from core.workflow.nodes.node_mapping import node_classes
|
||||||
@ -84,16 +86,16 @@ class GraphEngine:
|
|||||||
yield GraphRunStartedEvent()
|
yield GraphRunStartedEvent()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor]
|
||||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||||
stream_processor = AnswerStreamProcessor(
|
stream_processor_cls = AnswerStreamProcessor
|
||||||
graph=self.graph,
|
|
||||||
variable_pool=self.graph_runtime_state.variable_pool
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
stream_processor = EndStreamProcessor(
|
stream_processor_cls = EndStreamProcessor
|
||||||
graph=self.graph,
|
|
||||||
variable_pool=self.graph_runtime_state.variable_pool
|
stream_processor = stream_processor_cls(
|
||||||
)
|
graph=self.graph,
|
||||||
|
variable_pool=self.graph_runtime_state.variable_pool
|
||||||
|
)
|
||||||
|
|
||||||
# run graph
|
# run graph
|
||||||
generator = stream_processor.process(
|
generator = stream_processor.process(
|
||||||
@ -104,21 +106,21 @@ class GraphEngine:
|
|||||||
try:
|
try:
|
||||||
yield item
|
yield item
|
||||||
if isinstance(item, NodeRunFailedEvent):
|
if isinstance(item, NodeRunFailedEvent):
|
||||||
yield GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.')
|
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.')
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Graph run failed: {str(e)}")
|
logger.exception(f"Graph run failed: {str(e)}")
|
||||||
yield GraphRunFailedEvent(reason=str(e))
|
yield GraphRunFailedEvent(error=str(e))
|
||||||
return
|
return
|
||||||
|
|
||||||
# trigger graph run success event
|
# trigger graph run success event
|
||||||
yield GraphRunSucceededEvent()
|
yield GraphRunSucceededEvent()
|
||||||
except GraphRunFailedError as e:
|
except GraphRunFailedError as e:
|
||||||
yield GraphRunFailedEvent(reason=e.error)
|
yield GraphRunFailedEvent(error=e.error)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Unknown Error when graph running")
|
logger.exception("Unknown Error when graph running")
|
||||||
yield GraphRunFailedEvent(reason=str(e))
|
yield GraphRunFailedEvent(error=str(e))
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
||||||
@ -145,11 +147,34 @@ class GraphEngine:
|
|||||||
node_id=next_node_id
|
node_id=next_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# get node config
|
||||||
|
node_id = route_node_state.node_id
|
||||||
|
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||||
|
if not node_config:
|
||||||
|
raise GraphRunFailedError(f'Node {node_id} config not found.')
|
||||||
|
|
||||||
|
# convert to specific node
|
||||||
|
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||||
|
node_cls = node_classes.get(node_type)
|
||||||
|
if not node_cls:
|
||||||
|
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
||||||
|
|
||||||
|
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
|
||||||
|
|
||||||
|
# init workflow run state
|
||||||
|
node_instance = node_cls( # type: ignore
|
||||||
|
config=node_config,
|
||||||
|
graph_init_params=self.init_params,
|
||||||
|
graph=self.graph,
|
||||||
|
graph_runtime_state=self.graph_runtime_state,
|
||||||
|
previous_node_id=previous_node_id
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# run node
|
# run node
|
||||||
yield from self._run_node(
|
yield from self._run_node(
|
||||||
|
node_instance=node_instance,
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
previous_node_id=previous_route_node_state.node_id if previous_route_node_state else None,
|
|
||||||
parallel_id=in_parallel_id,
|
parallel_id=in_parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id
|
parallel_start_node_id=parallel_start_node_id
|
||||||
)
|
)
|
||||||
@ -166,6 +191,10 @@ class GraphEngine:
|
|||||||
route_node_state.status = RouteNodeState.Status.FAILED
|
route_node_state.status = RouteNodeState.Status.FAILED
|
||||||
route_node_state.failed_reason = str(e)
|
route_node_state.failed_reason = str(e)
|
||||||
yield NodeRunFailedEvent(
|
yield NodeRunFailedEvent(
|
||||||
|
error=str(e),
|
||||||
|
node_id=next_node_id,
|
||||||
|
node_type=node_type,
|
||||||
|
node_data=node_instance.node_data,
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
parallel_id=in_parallel_id,
|
parallel_id=in_parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id
|
parallel_start_node_id=parallel_start_node_id
|
||||||
@ -241,7 +270,7 @@ class GraphEngine:
|
|||||||
# new thread
|
# new thread
|
||||||
for edge in edge_mappings:
|
for edge in edge_mappings:
|
||||||
threading.Thread(target=self._run_parallel_node, kwargs={
|
threading.Thread(target=self._run_parallel_node, kwargs={
|
||||||
'flask_app': current_app._get_current_object(),
|
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||||
'parallel_id': parallel_id,
|
'parallel_id': parallel_id,
|
||||||
'parallel_start_node_id': edge.target_node_id,
|
'parallel_start_node_id': edge.target_node_id,
|
||||||
'q': q
|
'q': q
|
||||||
@ -309,21 +338,21 @@ class GraphEngine:
|
|||||||
q.put(ParallelBranchRunFailedEvent(
|
q.put(ParallelBranchRunFailedEvent(
|
||||||
parallel_id=parallel_id,
|
parallel_id=parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
reason=e.error
|
error=e.error
|
||||||
))
|
))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Unknown Error when generating in parallel")
|
logger.exception("Unknown Error when generating in parallel")
|
||||||
q.put(ParallelBranchRunFailedEvent(
|
q.put(ParallelBranchRunFailedEvent(
|
||||||
parallel_id=parallel_id,
|
parallel_id=parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
reason=str(e)
|
error=str(e)
|
||||||
))
|
))
|
||||||
finally:
|
finally:
|
||||||
db.session.remove()
|
db.session.remove()
|
||||||
|
|
||||||
def _run_node(self,
|
def _run_node(self,
|
||||||
|
node_instance: BaseNode,
|
||||||
route_node_state: RouteNodeState,
|
route_node_state: RouteNodeState,
|
||||||
previous_node_id: Optional[str] = None,
|
|
||||||
parallel_id: Optional[str] = None,
|
parallel_id: Optional[str] = None,
|
||||||
parallel_start_node_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
parallel_start_node_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
||||||
"""
|
"""
|
||||||
@ -331,46 +360,15 @@ class GraphEngine:
|
|||||||
"""
|
"""
|
||||||
# trigger node run start event
|
# trigger node run start event
|
||||||
yield NodeRunStartedEvent(
|
yield NodeRunStartedEvent(
|
||||||
|
node_id=node_instance.node_id,
|
||||||
|
node_type=node_instance.node_type,
|
||||||
|
node_data=node_instance.node_data,
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
|
predecessor_node_id=node_instance.previous_node_id,
|
||||||
parallel_id=parallel_id,
|
parallel_id=parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id
|
parallel_start_node_id=parallel_start_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# get node config
|
|
||||||
node_id = route_node_state.node_id
|
|
||||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
|
||||||
if not node_config:
|
|
||||||
route_node_state.status = RouteNodeState.Status.FAILED
|
|
||||||
route_node_state.failed_reason = f'Node {node_id} config not found.'
|
|
||||||
yield NodeRunFailedEvent(
|
|
||||||
route_node_state=route_node_state,
|
|
||||||
parallel_id=parallel_id,
|
|
||||||
parallel_start_node_id=parallel_start_node_id
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# convert to specific node
|
|
||||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
|
||||||
node_cls = node_classes.get(node_type)
|
|
||||||
if not node_cls:
|
|
||||||
route_node_state.status = RouteNodeState.Status.FAILED
|
|
||||||
route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.'
|
|
||||||
yield NodeRunFailedEvent(
|
|
||||||
route_node_state=route_node_state,
|
|
||||||
parallel_id=parallel_id,
|
|
||||||
parallel_start_node_id=parallel_start_node_id
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# init workflow run state
|
|
||||||
node_instance = node_cls( # type: ignore
|
|
||||||
config=node_config,
|
|
||||||
graph_init_params=self.init_params,
|
|
||||||
graph=self.graph,
|
|
||||||
graph_runtime_state=self.graph_runtime_state,
|
|
||||||
previous_node_id=previous_node_id
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
self.graph_runtime_state.node_run_steps += 1
|
self.graph_runtime_state.node_run_steps += 1
|
||||||
@ -385,6 +383,10 @@ class GraphEngine:
|
|||||||
|
|
||||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||||
yield NodeRunFailedEvent(
|
yield NodeRunFailedEvent(
|
||||||
|
error=route_node_state.failed_reason,
|
||||||
|
node_id=node_instance.node_id,
|
||||||
|
node_type=node_instance.node_type,
|
||||||
|
node_data=node_instance.node_data,
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
parallel_id=parallel_id,
|
parallel_id=parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id
|
parallel_start_node_id=parallel_start_node_id
|
||||||
@ -401,12 +403,15 @@ class GraphEngine:
|
|||||||
for variable_key, variable_value in run_result.outputs.items():
|
for variable_key, variable_value in run_result.outputs.items():
|
||||||
# append variables to variable pool recursively
|
# append variables to variable pool recursively
|
||||||
self._append_variables_recursively(
|
self._append_variables_recursively(
|
||||||
node_id=node_id,
|
node_id=node_instance.node_id,
|
||||||
variable_key_list=[variable_key],
|
variable_key_list=[variable_key],
|
||||||
variable_value=variable_value
|
variable_value=variable_value
|
||||||
)
|
)
|
||||||
|
|
||||||
yield NodeRunSucceededEvent(
|
yield NodeRunSucceededEvent(
|
||||||
|
node_id=node_instance.node_id,
|
||||||
|
node_type=node_instance.node_type,
|
||||||
|
node_data=node_instance.node_data,
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
parallel_id=parallel_id,
|
parallel_id=parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id
|
parallel_start_node_id=parallel_start_node_id
|
||||||
@ -415,6 +420,9 @@ class GraphEngine:
|
|||||||
break
|
break
|
||||||
elif isinstance(item, RunStreamChunkEvent):
|
elif isinstance(item, RunStreamChunkEvent):
|
||||||
yield NodeRunStreamChunkEvent(
|
yield NodeRunStreamChunkEvent(
|
||||||
|
node_id=node_instance.node_id,
|
||||||
|
node_type=node_instance.node_type,
|
||||||
|
node_data=node_instance.node_data,
|
||||||
chunk_content=item.chunk_content,
|
chunk_content=item.chunk_content,
|
||||||
from_variable_selector=item.from_variable_selector,
|
from_variable_selector=item.from_variable_selector,
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
@ -423,17 +431,28 @@ class GraphEngine:
|
|||||||
)
|
)
|
||||||
elif isinstance(item, RunRetrieverResourceEvent):
|
elif isinstance(item, RunRetrieverResourceEvent):
|
||||||
yield NodeRunRetrieverResourceEvent(
|
yield NodeRunRetrieverResourceEvent(
|
||||||
|
node_id=node_instance.node_id,
|
||||||
|
node_type=node_instance.node_type,
|
||||||
|
node_data=node_instance.node_data,
|
||||||
retriever_resources=item.retriever_resources,
|
retriever_resources=item.retriever_resources,
|
||||||
context=item.context,
|
context=item.context,
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
parallel_id=parallel_id,
|
parallel_id=parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
)
|
)
|
||||||
|
elif isinstance(item, BaseIterationEvent):
|
||||||
|
# add parallel info to iteration event
|
||||||
|
item.parallel_id = parallel_id
|
||||||
|
item.parallel_start_node_id = parallel_start_node_id
|
||||||
except GenerateTaskStoppedException:
|
except GenerateTaskStoppedException:
|
||||||
# trigger node run failed event
|
# trigger node run failed event
|
||||||
route_node_state.status = RouteNodeState.Status.FAILED
|
route_node_state.status = RouteNodeState.Status.FAILED
|
||||||
route_node_state.failed_reason = "Workflow stopped."
|
route_node_state.failed_reason = "Workflow stopped."
|
||||||
yield NodeRunFailedEvent(
|
yield NodeRunFailedEvent(
|
||||||
|
error="Workflow stopped.",
|
||||||
|
node_id=node_instance.node_id,
|
||||||
|
node_type=node_instance.node_type,
|
||||||
|
node_data=node_instance.node_data,
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
parallel_id=parallel_id,
|
parallel_id=parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
|
|||||||
@ -11,21 +11,20 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
NodeRunSucceededEvent,
|
NodeRunSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
|
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AnswerStreamProcessor:
|
class AnswerStreamProcessor(StreamProcessor):
|
||||||
|
|
||||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||||
self.graph = graph
|
super().__init__(graph, variable_pool)
|
||||||
self.variable_pool = variable_pool
|
|
||||||
self.generate_routes = graph.answer_stream_generate_routes
|
self.generate_routes = graph.answer_stream_generate_routes
|
||||||
self.route_position = {}
|
self.route_position = {}
|
||||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||||
self.route_position[answer_node_id] = 0
|
self.route_position[answer_node_id] = 0
|
||||||
self.rest_node_ids = graph.node_ids.copy()
|
|
||||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||||
|
|
||||||
def process(self,
|
def process(self,
|
||||||
@ -74,58 +73,6 @@ class AnswerStreamProcessor:
|
|||||||
self.rest_node_ids = self.graph.node_ids.copy()
|
self.rest_node_ids = self.graph.node_ids.copy()
|
||||||
self.current_stream_chunk_generating_node_ids = {}
|
self.current_stream_chunk_generating_node_ids = {}
|
||||||
|
|
||||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
|
||||||
finished_node_id = event.route_node_state.node_id
|
|
||||||
|
|
||||||
if finished_node_id not in self.rest_node_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
# remove finished node id
|
|
||||||
self.rest_node_ids.remove(finished_node_id)
|
|
||||||
|
|
||||||
run_result = event.route_node_state.node_run_result
|
|
||||||
if not run_result:
|
|
||||||
return
|
|
||||||
|
|
||||||
if run_result.edge_source_handle:
|
|
||||||
reachable_node_ids = []
|
|
||||||
unreachable_first_node_ids = []
|
|
||||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
|
||||||
if (edge.run_condition
|
|
||||||
and edge.run_condition.branch_identify
|
|
||||||
and run_result.edge_source_handle == edge.run_condition.branch_identify):
|
|
||||||
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
unreachable_first_node_ids.append(edge.target_node_id)
|
|
||||||
|
|
||||||
for node_id in unreachable_first_node_ids:
|
|
||||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
|
||||||
|
|
||||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
|
||||||
node_ids = []
|
|
||||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
|
||||||
if edge.target_node_id == self.graph.root_node_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
node_ids.append(edge.target_node_id)
|
|
||||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
|
||||||
return node_ids
|
|
||||||
|
|
||||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
|
||||||
"""
|
|
||||||
remove target node ids until merge
|
|
||||||
"""
|
|
||||||
if node_id not in self.rest_node_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.rest_node_ids.remove(node_id)
|
|
||||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
|
||||||
if edge.target_node_id in reachable_node_ids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
|
||||||
|
|
||||||
def _generate_stream_outputs_when_node_finished(self,
|
def _generate_stream_outputs_when_node_finished(self,
|
||||||
event: NodeRunSucceededEvent
|
event: NodeRunSucceededEvent
|
||||||
) -> Generator[GraphEngineEvent, None, None]:
|
) -> Generator[GraphEngineEvent, None, None]:
|
||||||
@ -138,8 +85,8 @@ class AnswerStreamProcessor:
|
|||||||
# all depends on answer node id not in rest node ids
|
# all depends on answer node id not in rest node ids
|
||||||
if (event.route_node_state.node_id != answer_node_id
|
if (event.route_node_state.node_id != answer_node_id
|
||||||
and (answer_node_id not in self.rest_node_ids
|
and (answer_node_id not in self.rest_node_ids
|
||||||
or not all(dep_id not in self.rest_node_ids
|
or not all(dep_id not in self.rest_node_ids
|
||||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
|
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
route_position = self.route_position[answer_node_id]
|
route_position = self.route_position[answer_node_id]
|
||||||
@ -149,6 +96,9 @@ class AnswerStreamProcessor:
|
|||||||
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
|
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
|
||||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||||
yield NodeRunStreamChunkEvent(
|
yield NodeRunStreamChunkEvent(
|
||||||
|
node_id=event.node_id,
|
||||||
|
node_type=event.node_type,
|
||||||
|
node_data=event.node_data,
|
||||||
chunk_content=route_chunk.text,
|
chunk_content=route_chunk.text,
|
||||||
route_node_state=event.route_node_state,
|
route_node_state=event.route_node_state,
|
||||||
parallel_id=event.parallel_id,
|
parallel_id=event.parallel_id,
|
||||||
@ -171,6 +121,9 @@ class AnswerStreamProcessor:
|
|||||||
|
|
||||||
if text:
|
if text:
|
||||||
yield NodeRunStreamChunkEvent(
|
yield NodeRunStreamChunkEvent(
|
||||||
|
node_id=event.node_id,
|
||||||
|
node_type=event.node_type,
|
||||||
|
node_data=event.node_data,
|
||||||
chunk_content=text,
|
chunk_content=text,
|
||||||
from_variable_selector=value_selector,
|
from_variable_selector=value_selector,
|
||||||
route_node_state=event.route_node_state,
|
route_node_state=event.route_node_state,
|
||||||
|
|||||||
71
api/core/workflow/nodes/answer/base_stream_processor.py
Normal file
71
api/core/workflow/nodes/answer/base_stream_processor.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
|
||||||
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
|
||||||
|
|
||||||
|
class StreamProcessor(ABC):
|
||||||
|
|
||||||
|
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||||
|
self.graph = graph
|
||||||
|
self.variable_pool = variable_pool
|
||||||
|
self.rest_node_ids = graph.node_ids.copy()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process(self,
|
||||||
|
generator: Generator[GraphEngineEvent, None, None]
|
||||||
|
) -> Generator[GraphEngineEvent, None, None]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||||
|
finished_node_id = event.route_node_state.node_id
|
||||||
|
if finished_node_id not in self.rest_node_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
# remove finished node id
|
||||||
|
self.rest_node_ids.remove(finished_node_id)
|
||||||
|
|
||||||
|
run_result = event.route_node_state.node_run_result
|
||||||
|
if not run_result:
|
||||||
|
return
|
||||||
|
|
||||||
|
if run_result.edge_source_handle:
|
||||||
|
reachable_node_ids = []
|
||||||
|
unreachable_first_node_ids = []
|
||||||
|
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||||
|
if (edge.run_condition
|
||||||
|
and edge.run_condition.branch_identify
|
||||||
|
and run_result.edge_source_handle == edge.run_condition.branch_identify):
|
||||||
|
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
unreachable_first_node_ids.append(edge.target_node_id)
|
||||||
|
|
||||||
|
for node_id in unreachable_first_node_ids:
|
||||||
|
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||||
|
|
||||||
|
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
||||||
|
node_ids = []
|
||||||
|
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||||
|
if edge.target_node_id == self.graph.root_node_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
node_ids.append(edge.target_node_id)
|
||||||
|
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||||
|
return node_ids
|
||||||
|
|
||||||
|
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
||||||
|
"""
|
||||||
|
remove target node ids until merge
|
||||||
|
"""
|
||||||
|
if node_id not in self.rest_node_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.rest_node_ids.remove(node_id)
|
||||||
|
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||||
|
if edge.target_node_id in reachable_node_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
||||||
@ -4,6 +4,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
|
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
@ -42,14 +43,14 @@ class BaseNode(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _run(self) \
|
def _run(self) \
|
||||||
-> NodeRunResult | Generator[RunEvent, None, None]:
|
-> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
|
||||||
"""
|
"""
|
||||||
Run node
|
Run node
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def run(self) -> Generator[RunEvent, None, None]:
|
def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||||
"""
|
"""
|
||||||
Run node entry
|
Run node entry
|
||||||
:return:
|
:return:
|
||||||
|
|||||||
@ -9,18 +9,17 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
NodeRunSucceededEvent,
|
NodeRunSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EndStreamProcessor:
|
class EndStreamProcessor(StreamProcessor):
|
||||||
|
|
||||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||||
self.graph = graph
|
super().__init__(graph, variable_pool)
|
||||||
self.variable_pool = variable_pool
|
|
||||||
self.stream_param = graph.end_stream_param
|
self.stream_param = graph.end_stream_param
|
||||||
self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy()
|
self.end_streamed_variable_selectors = graph.end_stream_param.end_stream_variable_selector_mapping.copy()
|
||||||
self.rest_node_ids = graph.node_ids.copy()
|
|
||||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||||
|
|
||||||
def process(self,
|
def process(self,
|
||||||
@ -56,64 +55,10 @@ class EndStreamProcessor:
|
|||||||
yield event
|
yield event
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.end_streamed_variable_selectors = {}
|
self.end_streamed_variable_selectors = self.graph.end_stream_param.end_stream_variable_selector_mapping.copy()
|
||||||
self.end_streamed_variable_selectors: dict[str, list[str]] = {
|
|
||||||
end_node_id: [] for end_node_id in self.graph.end_stream_param.end_stream_variable_selector_mapping
|
|
||||||
}
|
|
||||||
self.rest_node_ids = self.graph.node_ids.copy()
|
self.rest_node_ids = self.graph.node_ids.copy()
|
||||||
self.current_stream_chunk_generating_node_ids = {}
|
self.current_stream_chunk_generating_node_ids = {}
|
||||||
|
|
||||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
|
||||||
finished_node_id = event.route_node_state.node_id
|
|
||||||
if finished_node_id not in self.rest_node_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
# remove finished node id
|
|
||||||
self.rest_node_ids.remove(finished_node_id)
|
|
||||||
|
|
||||||
run_result = event.route_node_state.node_run_result
|
|
||||||
if not run_result:
|
|
||||||
return
|
|
||||||
|
|
||||||
if run_result.edge_source_handle:
|
|
||||||
reachable_node_ids = []
|
|
||||||
unreachable_first_node_ids = []
|
|
||||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
|
||||||
if (edge.run_condition
|
|
||||||
and edge.run_condition.branch_identify
|
|
||||||
and run_result.edge_source_handle == edge.run_condition.branch_identify):
|
|
||||||
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
unreachable_first_node_ids.append(edge.target_node_id)
|
|
||||||
|
|
||||||
for node_id in unreachable_first_node_ids:
|
|
||||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
|
||||||
|
|
||||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
|
||||||
node_ids = []
|
|
||||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
|
||||||
if edge.target_node_id == self.graph.root_node_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
node_ids.append(edge.target_node_id)
|
|
||||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
|
||||||
return node_ids
|
|
||||||
|
|
||||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
|
||||||
"""
|
|
||||||
remove target node ids until merge
|
|
||||||
"""
|
|
||||||
if node_id not in self.rest_node_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.rest_node_ids.remove(node_id)
|
|
||||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
|
||||||
if edge.target_node_id in reachable_node_ids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
|
||||||
|
|
||||||
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Is stream out support
|
Is stream out support
|
||||||
|
|||||||
@ -1,13 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
BaseGraphEvent,
|
BaseGraphEvent,
|
||||||
|
BaseNodeEvent,
|
||||||
|
BaseParallelBranchEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
|
InNodeEvent,
|
||||||
IterationRunFailedEvent,
|
IterationRunFailedEvent,
|
||||||
IterationRunNextEvent,
|
IterationRunNextEvent,
|
||||||
IterationRunStartedEvent,
|
IterationRunStartedEvent,
|
||||||
@ -17,7 +20,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.event import RunCompletedEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||||
from core.workflow.utils.condition.entities import Condition
|
from core.workflow.utils.condition.entities import Condition
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
@ -32,7 +35,7 @@ class IterationNode(BaseNode):
|
|||||||
_node_data_cls = IterationNodeData
|
_node_data_cls = IterationNodeData
|
||||||
_node_type = NodeType.ITERATION
|
_node_type = NodeType.ITERATION
|
||||||
|
|
||||||
def _run(self) -> BaseIterationState:
|
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||||
"""
|
"""
|
||||||
Run the node.
|
Run the node.
|
||||||
"""
|
"""
|
||||||
@ -42,6 +45,10 @@ class IterationNode(BaseNode):
|
|||||||
if not isinstance(iterator_list_value, list):
|
if not isinstance(iterator_list_value, list):
|
||||||
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
"iterator_selector": iterator_list_value
|
||||||
|
}
|
||||||
|
|
||||||
root_node_id = self.node_data.start_node_id
|
root_node_id = self.node_data.start_node_id
|
||||||
graph_config = self.graph_config
|
graph_config = self.graph_config
|
||||||
|
|
||||||
@ -117,21 +124,42 @@ class IterationNode(BaseNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
yield IterationRunStartedEvent(
|
yield IterationRunStartedEvent(
|
||||||
iteration_id=self.node_id,
|
iteration_node_id=self.node_id,
|
||||||
|
iteration_node_type=self.node_type,
|
||||||
|
iteration_node_data=self.node_data,
|
||||||
|
inputs=inputs,
|
||||||
|
metadata={
|
||||||
|
"iterator_length": 1
|
||||||
|
},
|
||||||
|
predecessor_node_id=self.previous_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
yield IterationRunNextEvent(
|
yield IterationRunNextEvent(
|
||||||
iteration_id=self.node_id,
|
iteration_node_id=self.node_id,
|
||||||
|
iteration_node_type=self.node_type,
|
||||||
|
iteration_node_data=self.node_data,
|
||||||
index=0,
|
index=0,
|
||||||
output=None
|
pre_iteration_output=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
outputs: list[Any] = []
|
||||||
try:
|
try:
|
||||||
# run workflow
|
# run workflow
|
||||||
rst = graph_engine.run()
|
rst = graph_engine.run()
|
||||||
outputs: list[Any] = []
|
|
||||||
for event in rst:
|
for event in rst:
|
||||||
|
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||||
|
event.in_iteration_id = self.node_id
|
||||||
|
|
||||||
if isinstance(event, NodeRunSucceededEvent):
|
if isinstance(event, NodeRunSucceededEvent):
|
||||||
|
metadata = event.route_node_state.node_run_result.metadata
|
||||||
|
if not metadata:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||||
|
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||||
|
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
|
||||||
|
event.route_node_state.node_run_result.metadata = metadata
|
||||||
|
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
# handle iteration run result
|
# handle iteration run result
|
||||||
@ -158,22 +186,35 @@ class IterationNode(BaseNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
yield IterationRunNextEvent(
|
yield IterationRunNextEvent(
|
||||||
iteration_id=self.node_id,
|
iteration_node_id=self.node_id,
|
||||||
|
iteration_node_type=self.node_type,
|
||||||
|
iteration_node_data=self.node_data,
|
||||||
index=next_index,
|
index=next_index,
|
||||||
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None
|
pre_iteration_output=jsonable_encoder(
|
||||||
|
current_iteration_output) if current_iteration_output else None
|
||||||
)
|
)
|
||||||
elif isinstance(event, BaseGraphEvent):
|
elif isinstance(event, BaseGraphEvent):
|
||||||
if isinstance(event, GraphRunFailedEvent):
|
if isinstance(event, GraphRunFailedEvent):
|
||||||
# iteration run failed
|
# iteration run failed
|
||||||
yield IterationRunFailedEvent(
|
yield IterationRunFailedEvent(
|
||||||
iteration_id=self.node_id,
|
iteration_node_id=self.node_id,
|
||||||
reason=event.reason,
|
iteration_node_type=self.node_type,
|
||||||
|
iteration_node_data=self.node_data,
|
||||||
|
inputs=inputs,
|
||||||
|
outputs={
|
||||||
|
"output": jsonable_encoder(outputs)
|
||||||
|
},
|
||||||
|
steps=len(iterator_list_value),
|
||||||
|
metadata={
|
||||||
|
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||||
|
},
|
||||||
|
error=event.error,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
error=event.reason,
|
error=event.error,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
@ -181,7 +222,17 @@ class IterationNode(BaseNode):
|
|||||||
yield event
|
yield event
|
||||||
|
|
||||||
yield IterationRunSucceededEvent(
|
yield IterationRunSucceededEvent(
|
||||||
iteration_id=self.node_id,
|
iteration_node_id=self.node_id,
|
||||||
|
iteration_node_type=self.node_type,
|
||||||
|
iteration_node_data=self.node_data,
|
||||||
|
inputs=inputs,
|
||||||
|
outputs={
|
||||||
|
"output": jsonable_encoder(outputs)
|
||||||
|
},
|
||||||
|
steps=len(iterator_list_value),
|
||||||
|
metadata={
|
||||||
|
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
@ -196,8 +247,18 @@ class IterationNode(BaseNode):
|
|||||||
# iteration run failed
|
# iteration run failed
|
||||||
logger.exception("Iteration run failed")
|
logger.exception("Iteration run failed")
|
||||||
yield IterationRunFailedEvent(
|
yield IterationRunFailedEvent(
|
||||||
iteration_id=self.node_id,
|
iteration_node_id=self.node_id,
|
||||||
reason=str(e),
|
iteration_node_type=self.node_type,
|
||||||
|
iteration_node_data=self.node_data,
|
||||||
|
inputs=inputs,
|
||||||
|
outputs={
|
||||||
|
"output": jsonable_encoder(outputs)
|
||||||
|
},
|
||||||
|
steps=len(iterator_list_value),
|
||||||
|
metadata={
|
||||||
|
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||||
|
},
|
||||||
|
error=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
|||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||||
from core.workflow.nodes.llm.entities import (
|
from core.workflow.nodes.llm.entities import (
|
||||||
@ -42,11 +43,19 @@ from models.provider import Provider, ProviderType
|
|||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInvokeCompleted(BaseModel):
|
||||||
|
"""
|
||||||
|
Model invoke completed
|
||||||
|
"""
|
||||||
|
text: str
|
||||||
|
usage: LLMUsage
|
||||||
|
|
||||||
|
|
||||||
class LLMNode(BaseNode):
|
class LLMNode(BaseNode):
|
||||||
_node_data_cls = LLMNodeData
|
_node_data_cls = LLMNodeData
|
||||||
_node_type = NodeType.LLM
|
_node_type = NodeType.LLM
|
||||||
|
|
||||||
def _run(self) -> Generator[RunEvent, None, None]:
|
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||||
"""
|
"""
|
||||||
Run node
|
Run node
|
||||||
:return:
|
:return:
|
||||||
@ -167,7 +176,7 @@ class LLMNode(BaseNode):
|
|||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
stop: Optional[list[str]] = None) \
|
stop: Optional[list[str]] = None) \
|
||||||
-> Generator[RunEvent, None, None]:
|
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
:param node_data_model: node data model
|
:param node_data_model: node data model
|
||||||
@ -201,7 +210,7 @@ class LLMNode(BaseNode):
|
|||||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||||
|
|
||||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
|
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
|
||||||
-> Generator[RunEvent, None, None]:
|
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
|
||||||
"""
|
"""
|
||||||
Handle invoke result
|
Handle invoke result
|
||||||
:param invoke_result: invoke result
|
:param invoke_result: invoke result
|
||||||
@ -762,11 +771,3 @@ class LLMNode(BaseNode):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelInvokeCompleted(BaseModel):
|
|
||||||
"""
|
|
||||||
Model invoke completed
|
|
||||||
"""
|
|
||||||
text: str
|
|
||||||
usage: LLMUsage
|
|
||||||
|
|||||||
@ -26,24 +26,21 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class WorkflowEntry:
|
class WorkflowEntry:
|
||||||
def run(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
user_from: UserFrom,
|
user_from: UserFrom,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
user_inputs: Mapping[str, Any],
|
user_inputs: Mapping[str, Any],
|
||||||
system_inputs: Mapping[SystemVariable, Any],
|
system_inputs: Mapping[SystemVariable, Any],
|
||||||
callbacks: Sequence[WorkflowCallback],
|
|
||||||
call_depth: int = 0
|
call_depth: int = 0
|
||||||
) -> Generator[GraphEngineEvent, None, None]:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
:param workflow: Workflow instance
|
:param workflow: Workflow instance
|
||||||
:param user_id: user id
|
:param user_id: user id
|
||||||
:param user_from: user from
|
:param user_from: user from
|
||||||
:param invoke_from: invoke from service-api, web-app, debugger, explore
|
:param invoke_from: invoke from service-api, web-app, debugger, explore
|
||||||
:param callbacks: workflow callbacks
|
|
||||||
:param user_inputs: user variables inputs
|
:param user_inputs: user variables inputs
|
||||||
:param system_inputs: system inputs, like: query, files
|
:param system_inputs: system inputs, like: query, files
|
||||||
:param call_depth: call depth
|
:param call_depth: call depth
|
||||||
@ -82,7 +79,7 @@ class WorkflowEntry:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# init workflow run state
|
# init workflow run state
|
||||||
graph_engine = GraphEngine(
|
self.graph_engine = GraphEngine(
|
||||||
tenant_id=workflow.tenant_id,
|
tenant_id=workflow.tenant_id,
|
||||||
app_id=workflow.app_id,
|
app_id=workflow.app_id,
|
||||||
workflow_type=WorkflowType.value_of(workflow.type),
|
workflow_type=WorkflowType.value_of(workflow.type),
|
||||||
@ -98,6 +95,16 @@ class WorkflowEntry:
|
|||||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
callbacks: Sequence[WorkflowCallback],
|
||||||
|
) -> Generator[GraphEngineEvent, None, None]:
|
||||||
|
"""
|
||||||
|
:param callbacks: workflow callbacks
|
||||||
|
"""
|
||||||
|
graph_engine = self.graph_engine
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# run workflow
|
# run workflow
|
||||||
generator = graph_engine.run()
|
generator = graph_engine.run()
|
||||||
@ -105,7 +112,7 @@ class WorkflowEntry:
|
|||||||
if callbacks:
|
if callbacks:
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_event(
|
callback.on_event(
|
||||||
graph=graph,
|
graph=self.graph_engine.graph,
|
||||||
graph_init_params=graph_engine.init_params,
|
graph_init_params=graph_engine.init_params,
|
||||||
graph_runtime_state=graph_engine.graph_runtime_state,
|
graph_runtime_state=graph_engine.graph_runtime_state,
|
||||||
event=event
|
event=event
|
||||||
@ -117,11 +124,11 @@ class WorkflowEntry:
|
|||||||
if callbacks:
|
if callbacks:
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
callback.on_event(
|
callback.on_event(
|
||||||
graph=graph,
|
graph=self.graph_engine.graph,
|
||||||
graph_init_params=graph_engine.init_params,
|
graph_init_params=graph_engine.init_params,
|
||||||
graph_runtime_state=graph_engine.graph_runtime_state,
|
graph_runtime_state=graph_engine.graph_runtime_state,
|
||||||
event=GraphRunFailedEvent(
|
event=GraphRunFailedEvent(
|
||||||
reason=str(e)
|
error=str(e)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -161,6 +168,9 @@ class WorkflowEntry:
|
|||||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||||
node_cls = node_classes.get(node_type)
|
node_cls = node_classes.get(node_type)
|
||||||
|
|
||||||
|
if not node_cls:
|
||||||
|
raise ValueError(f'Node class not found for node type {node_type}')
|
||||||
|
|
||||||
# init workflow run state
|
# init workflow run state
|
||||||
node_instance = node_cls(
|
node_instance = node_cls(
|
||||||
tenant_id=workflow.tenant_id,
|
tenant_id=workflow.tenant_id,
|
||||||
@ -195,7 +205,7 @@ class WorkflowEntry:
|
|||||||
node_instance=node_instance
|
node_instance=node_instance
|
||||||
)
|
)
|
||||||
|
|
||||||
# run node
|
# run node TODO
|
||||||
node_run_result = node_instance.run(
|
node_run_result = node_instance.run(
|
||||||
variable_pool=variable_pool
|
variable_pool=variable_pool
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
@ -12,6 +12,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||||
|
from core.workflow.nodes.start.entities import StartNodeData
|
||||||
|
|
||||||
|
|
||||||
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
|
||||||
@ -37,7 +38,14 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
|||||||
parallel = graph.parallel_mapping.get(parallel_id)
|
parallel = graph.parallel_mapping.get(parallel_id)
|
||||||
parallel_start_node_id = parallel.start_from_node_id if parallel else None
|
parallel_start_node_id = parallel.start_from_node_id if parallel else None
|
||||||
|
|
||||||
|
node_config = graph.node_id_config_mapping[next_node_id]
|
||||||
|
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
|
||||||
|
mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
|
||||||
|
|
||||||
yield NodeRunStartedEvent(
|
yield NodeRunStartedEvent(
|
||||||
|
node_id=next_node_id,
|
||||||
|
node_type=node_type,
|
||||||
|
node_data=mock_node_data,
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
parallel_id=graph.node_parallel_mapping.get(next_node_id),
|
parallel_id=graph.node_parallel_mapping.get(next_node_id),
|
||||||
parallel_start_node_id=parallel_start_node_id
|
parallel_start_node_id=parallel_start_node_id
|
||||||
@ -47,6 +55,9 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
|||||||
length = int(next_node_id[-1])
|
length = int(next_node_id[-1])
|
||||||
for i in range(0, length):
|
for i in range(0, length):
|
||||||
yield NodeRunStreamChunkEvent(
|
yield NodeRunStreamChunkEvent(
|
||||||
|
node_id=next_node_id,
|
||||||
|
node_type=node_type,
|
||||||
|
node_data=mock_node_data,
|
||||||
chunk_content=str(i),
|
chunk_content=str(i),
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
from_variable_selector=[next_node_id, "text"],
|
from_variable_selector=[next_node_id, "text"],
|
||||||
@ -57,6 +68,9 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
|||||||
route_node_state.status = RouteNodeState.Status.SUCCESS
|
route_node_state.status = RouteNodeState.Status.SUCCESS
|
||||||
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
yield NodeRunSucceededEvent(
|
yield NodeRunSucceededEvent(
|
||||||
|
node_id=next_node_id,
|
||||||
|
node_type=node_type,
|
||||||
|
node_data=mock_node_data,
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
parallel_id=parallel_id,
|
parallel_id=parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id
|
parallel_start_node_id=parallel_start_node_id
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user