mirror of https://github.com/langgenius/dify.git
refactor workflow generate pipeline
This commit is contained in:
parent
7d28fe8ea5
commit
a1bc6b50c5
|
|
@ -21,7 +21,7 @@ from controllers.console.app.error import (
|
|||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.app.app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from controllers.console.app.error import (
|
|||
)
|
||||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from controllers.service_api.app.error import (
|
|||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from controllers.web.error import (
|
|||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from mimetypes import guess_extension
|
|||
from typing import Optional, Union, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.app.app_queue_manager import AppQueueManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ from typing import Literal, Union
|
|||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentPromptEntity, AgentScratchpadUnit
|
||||
from core.app.app_queue_manager import PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
|
|
@ -121,7 +122,9 @@ class CotAgentRunner(BaseAgentRunner):
|
|||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt messages
|
||||
prompt_messages = self._organize_cot_prompt_messages(
|
||||
|
|
@ -163,7 +166,9 @@ class CotAgentRunner(BaseAgentRunner):
|
|||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, dict):
|
||||
|
|
@ -225,7 +230,9 @@ class CotAgentRunner(BaseAgentRunner):
|
|||
llm_usage=usage_dict['usage'])
|
||||
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
|
|
@ -255,7 +262,9 @@ class CotAgentRunner(BaseAgentRunner):
|
|||
observation=answer,
|
||||
answer=answer,
|
||||
messages_ids=[])
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
# invoke tool
|
||||
error_response = None
|
||||
|
|
@ -282,7 +291,9 @@ class CotAgentRunner(BaseAgentRunner):
|
|||
self.variables_pool.set_file(tool_name=tool_call_name,
|
||||
value=message_file.id,
|
||||
name=save_as)
|
||||
self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
message_file_ids = [message_file.id for message_file, _ in message_files]
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
|
|
@ -318,7 +329,9 @@ class CotAgentRunner(BaseAgentRunner):
|
|||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
|
|
@ -352,7 +365,7 @@ class CotAgentRunner(BaseAgentRunner):
|
|||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish_message_end(LLMResult(
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
|
|
@ -360,7 +373,7 @@ class CotAgentRunner(BaseAgentRunner):
|
|||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \
|
||||
-> Generator[Union[str, dict], None, None]:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ from collections.abc import Generator
|
|||
from typing import Any, Union
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.app.app_queue_manager import PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
|
|
@ -135,7 +136,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
is_first_chunk = False
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
|
|
@ -195,7 +198,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
if not result.message.content:
|
||||
result.message.content = ''
|
||||
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
|
|
@ -233,8 +238,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
messages_ids=[],
|
||||
llm_usage=current_llm_usage
|
||||
)
|
||||
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
final_answer += response + '\n'
|
||||
|
||||
|
|
@ -275,7 +281,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
|
|
@ -331,7 +339,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
answer=None,
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
|
|
@ -341,15 +351,15 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish_message_end(LLMResult(
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer,
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,335 +0,0 @@
|
|||
import queue
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentMessageEvent,
|
||||
QueueAgentThoughtEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueMessage,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFinishedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFinishedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import MessageAgentThought, MessageFile
|
||||
|
||||
|
||||
class PublishFrom(Enum):
|
||||
APPLICATION_MANAGER = 1
|
||||
TASK_PIPELINE = 2
|
||||
|
||||
|
||||
class AppQueueManager:
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
conversation_id: str,
|
||||
app_mode: str,
|
||||
message_id: str) -> None:
|
||||
if not user_id:
|
||||
raise ValueError("user is required")
|
||||
|
||||
self._task_id = task_id
|
||||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
self._conversation_id = str(conversation_id)
|
||||
self._app_mode = app_mode
|
||||
self._message_id = str(message_id)
|
||||
|
||||
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
|
||||
|
||||
q = queue.Queue()
|
||||
|
||||
self._q = q
|
||||
|
||||
def listen(self) -> Generator:
|
||||
"""
|
||||
Listen to queue
|
||||
:return:
|
||||
"""
|
||||
# wait for 10 minutes to stop listen
|
||||
listen_timeout = 600
|
||||
start_time = time.time()
|
||||
last_ping_time = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
if message is None:
|
||||
break
|
||||
|
||||
yield message
|
||||
except queue.Empty:
|
||||
continue
|
||||
finally:
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time >= listen_timeout or self._is_stopped():
|
||||
# publish two messages to make sure the client can receive the stop signal
|
||||
# and stop listening after the stop signal processed
|
||||
self.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
self.stop_listen()
|
||||
|
||||
if elapsed_time // 10 > last_ping_time:
|
||||
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
|
||||
last_ping_time = elapsed_time // 10
|
||||
|
||||
def stop_listen(self) -> None:
|
||||
"""
|
||||
Stop listen to queue
|
||||
:return:
|
||||
"""
|
||||
self._q.put(None)
|
||||
|
||||
def publish_llm_chunk(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish llm chunk to channel
|
||||
|
||||
:param chunk: llm chunk
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueLLMChunkEvent(
|
||||
chunk=chunk
|
||||
), pub_from)
|
||||
|
||||
def publish_text_chunk(self, text: str, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish text chunk to channel
|
||||
|
||||
:param text: text
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueTextChunkEvent(
|
||||
text=text
|
||||
), pub_from)
|
||||
|
||||
def publish_agent_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish agent chunk message to channel
|
||||
|
||||
:param chunk: chunk
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueAgentMessageEvent(
|
||||
chunk=chunk
|
||||
), pub_from)
|
||||
|
||||
def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish message replace
|
||||
:param text: text
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageReplaceEvent(
|
||||
text=text
|
||||
), pub_from)
|
||||
|
||||
def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish retriever resources
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from)
|
||||
|
||||
def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish annotation reply
|
||||
:param message_annotation_id: message annotation id
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueAnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from)
|
||||
|
||||
def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish message end
|
||||
:param llm_result: llm result
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from)
|
||||
self.stop_listen()
|
||||
|
||||
def publish_workflow_started(self, workflow_run_id: str, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish workflow started
|
||||
:param workflow_run_id: workflow run id
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueWorkflowStartedEvent(workflow_run_id=workflow_run_id), pub_from)
|
||||
|
||||
def publish_workflow_finished(self, workflow_run_id: str, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish workflow finished
|
||||
:param workflow_run_id: workflow run id
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueWorkflowFinishedEvent(workflow_run_id=workflow_run_id), pub_from)
|
||||
|
||||
def publish_node_started(self, workflow_node_execution_id: str, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish node started
|
||||
:param workflow_node_execution_id: workflow node execution id
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution_id), pub_from)
|
||||
|
||||
def publish_node_finished(self, workflow_node_execution_id: str, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish node finished
|
||||
:param workflow_node_execution_id: workflow node execution id
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution_id), pub_from)
|
||||
|
||||
def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish agent thought
|
||||
:param message_agent_thought: message agent thought
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=message_agent_thought.id
|
||||
), pub_from)
|
||||
|
||||
def publish_message_file(self, message_file: MessageFile, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish agent thought
|
||||
:param message_file: message file
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), pub_from)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish error
|
||||
:param e: error
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueErrorEvent(
|
||||
error=e
|
||||
), pub_from)
|
||||
self.stop_listen()
|
||||
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
self._check_for_sqlalchemy_models(event.dict())
|
||||
|
||||
message = QueueMessage(
|
||||
task_id=self._task_id,
|
||||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
@classmethod
|
||||
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
|
||||
"""
|
||||
Set task stop flag
|
||||
:return:
|
||||
"""
|
||||
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
if result is None:
|
||||
return
|
||||
|
||||
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
if result.decode('utf-8') != f"{user_prefix}-{user_id}":
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
def _is_stopped(self) -> bool:
|
||||
"""
|
||||
Check if task is stopped
|
||||
:return:
|
||||
"""
|
||||
stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id)
|
||||
result = redis_client.get(stopped_cache_key)
|
||||
if result is not None:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _generate_task_belong_cache_key(cls, task_id: str) -> str:
|
||||
"""
|
||||
Generate task belong cache key
|
||||
:param task_id: task id
|
||||
:return:
|
||||
"""
|
||||
return f"generate_task_belong:{task_id}"
|
||||
|
||||
@classmethod
|
||||
def _generate_stopped_cache_key(cls, task_id: str) -> str:
|
||||
"""
|
||||
Generate stopped cache key
|
||||
:param task_id: task id
|
||||
:return:
|
||||
"""
|
||||
return f"generate_task_stopped:{task_id}"
|
||||
|
||||
def _check_for_sqlalchemy_models(self, data: Any):
|
||||
# from entity to dict or list
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
self._check_for_sqlalchemy_models(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
self._check_for_sqlalchemy_models(item)
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
|
||||
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed.")
|
||||
|
||||
|
||||
class ConversationTaskStoppedException(Exception):
|
||||
pass
|
||||
|
|
@ -8,11 +8,12 @@ from flask import Flask, current_app
|
|||
from pydantic import ValidationError
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
|
|
@ -101,7 +102,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = AppQueueManager(
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
|
|
|
|||
|
|
@ -2,14 +2,14 @@ import logging
|
|||
import time
|
||||
from typing import cast
|
||||
|
||||
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueStopEvent
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
|
|
@ -93,7 +93,7 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION: conversation.id,
|
||||
},
|
||||
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)],
|
||||
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
|
||||
)
|
||||
|
||||
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
||||
|
|
@ -153,9 +153,9 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish_annotation_reply(
|
||||
message_annotation_id=annotation_reply.id,
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
self._stream_output(
|
||||
|
|
@ -182,7 +182,11 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
if stream:
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish_text_chunk(token, PublishFrom.APPLICATION_MANAGER)
|
||||
queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=token
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
|
|
@ -190,4 +194,3 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
QueueStopEvent(stopped_by=stopped_by),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
queue_manager.stop_listen()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Optional, Union
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
|
|
@ -46,6 +46,7 @@ class TaskState(BaseModel):
|
|||
"""
|
||||
answer: str = ""
|
||||
metadata: dict = {}
|
||||
usage: LLMUsage
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
|
@ -349,7 +350,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
if self._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish_text_chunk(self._task_state.answer, PublishFrom.TASK_PIPELINE)
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=self._task_state.answer
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
|
|
@ -558,5 +564,5 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
type=sensitive_word_avoidance.type,
|
||||
config=sensitive_word_avoidance.config
|
||||
),
|
||||
on_message_replace_func=self._queue_manager.publish_message_replace
|
||||
queue_manager=self._queue_manager
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,10 +9,11 @@ from pydantic import ValidationError
|
|||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
|
|
@ -119,7 +120,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = AppQueueManager(
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
|
|
|
|||
|
|
@ -4,10 +4,11 @@ from typing import cast
|
|||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
|
@ -120,10 +121,11 @@ class AgentChatAppRunner(AppRunner):
|
|||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish_annotation_reply(
|
||||
message_annotation_id=annotation_reply.id,
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,181 @@
|
|||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessage,
|
||||
QueueMessageEndEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class PublishFrom(Enum):
|
||||
APPLICATION_MANAGER = 1
|
||||
TASK_PIPELINE = 2
|
||||
|
||||
|
||||
class AppQueueManager:
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> None:
|
||||
if not user_id:
|
||||
raise ValueError("user is required")
|
||||
|
||||
self._task_id = task_id
|
||||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
|
||||
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}")
|
||||
|
||||
q = queue.Queue()
|
||||
|
||||
self._q = q
|
||||
|
||||
def listen(self) -> Generator:
|
||||
"""
|
||||
Listen to queue
|
||||
:return:
|
||||
"""
|
||||
# wait for 10 minutes to stop listen
|
||||
listen_timeout = 600
|
||||
start_time = time.time()
|
||||
last_ping_time = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = self._q.get(timeout=1)
|
||||
if message is None:
|
||||
break
|
||||
|
||||
yield message
|
||||
except queue.Empty:
|
||||
continue
|
||||
finally:
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time >= listen_timeout or self._is_stopped():
|
||||
# publish two messages to make sure the client can receive the stop signal
|
||||
# and stop listening after the stop signal processed
|
||||
self.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
if elapsed_time // 10 > last_ping_time:
|
||||
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
|
||||
last_ping_time = elapsed_time // 10
|
||||
|
||||
def stop_listen(self) -> None:
|
||||
"""
|
||||
Stop listen to queue
|
||||
:return:
|
||||
"""
|
||||
self._q.put(None)
|
||||
|
||||
def publish_error(self, e, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish error
|
||||
:param e: error
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueErrorEvent(
|
||||
error=e
|
||||
), pub_from)
|
||||
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
self._check_for_sqlalchemy_models(event.dict())
|
||||
|
||||
message = self.construct_queue_message(event)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
@abstractmethod
|
||||
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
|
||||
"""
|
||||
Set task stop flag
|
||||
:return:
|
||||
"""
|
||||
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
if result is None:
|
||||
return
|
||||
|
||||
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
if result.decode('utf-8') != f"{user_prefix}-{user_id}":
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
def _is_stopped(self) -> bool:
|
||||
"""
|
||||
Check if task is stopped
|
||||
:return:
|
||||
"""
|
||||
stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id)
|
||||
result = redis_client.get(stopped_cache_key)
|
||||
if result is not None:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _generate_task_belong_cache_key(cls, task_id: str) -> str:
|
||||
"""
|
||||
Generate task belong cache key
|
||||
:param task_id: task id
|
||||
:return:
|
||||
"""
|
||||
return f"generate_task_belong:{task_id}"
|
||||
|
||||
@classmethod
|
||||
def _generate_stopped_cache_key(cls, task_id: str) -> str:
|
||||
"""
|
||||
Generate stopped cache key
|
||||
:param task_id: task id
|
||||
:return:
|
||||
"""
|
||||
return f"generate_task_stopped:{task_id}"
|
||||
|
||||
def _check_for_sqlalchemy_models(self, data: Any):
|
||||
# from entity to dict or list
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
self._check_for_sqlalchemy_models(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
self._check_for_sqlalchemy_models(item)
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
|
||||
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed.")
|
||||
|
||||
|
||||
class ConversationTaskStoppedException(Exception):
|
||||
pass
|
||||
|
|
@ -3,13 +3,14 @@ from collections.abc import Generator
|
|||
from typing import Optional, Union, cast
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AppGenerateEntity,
|
||||
EasyUIBasedAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
|
|
@ -187,25 +188,32 @@ class AppRunner:
|
|||
if stream:
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish_llm_chunk(LLMResultChunk(
|
||||
chunk = LLMResultChunk(
|
||||
model=app_generate_entity.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=token)
|
||||
)
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=chunk
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=LLMResult(
|
||||
model=app_generate_entity.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage if usage else LLMUsage.empty_usage()
|
||||
),
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=app_generate_entity.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage if usage else LLMUsage.empty_usage()
|
||||
),
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
||||
|
|
@ -241,9 +249,10 @@ class AppRunner:
|
|||
:param queue_manager: application queue manager
|
||||
:return:
|
||||
"""
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=invoke_result,
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=invoke_result,
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
||||
|
|
@ -261,9 +270,17 @@ class AppRunner:
|
|||
usage = None
|
||||
for result in invoke_result:
|
||||
if not agent:
|
||||
queue_manager.publish_llm_chunk(result, PublishFrom.APPLICATION_MANAGER)
|
||||
queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=result
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
else:
|
||||
queue_manager.publish_agent_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
|
||||
queue_manager.publish(
|
||||
QueueAgentMessageEvent(
|
||||
chunk=result
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
text += result.delta.message.content
|
||||
|
||||
|
|
@ -286,9 +303,10 @@ class AppRunner:
|
|||
usage=usage
|
||||
)
|
||||
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=llm_result,
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=llm_result,
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def moderation_for_inputs(self, app_id: str,
|
||||
|
|
@ -311,7 +329,7 @@ class AppRunner:
|
|||
tenant_id=tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
query=query if query else ''
|
||||
)
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
|
|
|
|||
|
|
@ -9,10 +9,11 @@ from pydantic import ValidationError
|
|||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.chat.app_runner import ChatAppRunner
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
|
|
@ -119,7 +120,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = AppQueueManager(
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfig
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
ChatAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
|
|
@ -117,10 +118,11 @@ class ChatAppRunner(AppRunner):
|
|||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish_annotation_reply(
|
||||
message_annotation_id=annotation_reply.id,
|
||||
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
|
|
|
|||
|
|
@ -9,10 +9,11 @@ from pydantic import ValidationError
|
|||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from core.app.apps.completion.app_runner import CompletionAppRunner
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
|
|
@ -112,7 +113,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = AppQueueManager(
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
|
|
@ -263,7 +264,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = AppQueueManager(
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from core.app.app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfig
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Optional, Union, cast
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
|
|
@ -385,14 +385,19 @@ class EasyUIBasedGenerateTaskPipeline:
|
|||
if self._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish_llm_chunk(LLMResultChunk(
|
||||
model=self._task_state.llm_result.model,
|
||||
prompt_messages=self._task_state.llm_result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
|
||||
)
|
||||
), PublishFrom.TASK_PIPELINE)
|
||||
self._queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=LLMResultChunk(
|
||||
model=self._task_state.llm_result.model,
|
||||
prompt_messages=self._task_state.llm_result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
|
||||
)
|
||||
)
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
|
|
@ -664,5 +669,5 @@ class EasyUIBasedGenerateTaskPipeline:
|
|||
type=sensitive_word_avoidance.type,
|
||||
config=sensitive_word_avoidance.config
|
||||
),
|
||||
on_message_replace_func=self._queue_manager.publish_message_replace
|
||||
queue_manager=self._queue_manager
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from typing import Optional, Union
|
|||
from sqlalchemy import and_
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.app_queue_manager import AppQueueManager, ConversationTaskStoppedException
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException
|
||||
from core.app.apps.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueMessage,
|
||||
)
|
||||
|
||||
|
||||
class MessageBasedAppQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
conversation_id: str,
|
||||
app_mode: str,
|
||||
message_id: str) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._conversation_id = str(conversation_id)
|
||||
self._app_mode = app_mode
|
||||
self._message_id = str(message_id)
|
||||
|
||||
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
|
||||
return QueueMessage(
|
||||
task_id=self._task_id,
|
||||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
)
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
def generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
inputs = args['inputs']
|
||||
|
||||
# parse files
|
||||
files = args['files'] if 'files' in args and args['files'] else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_upload_entity = FileUploadConfigManager.convert(workflow.features_dict)
|
||||
if file_upload_entity:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_upload_entity,
|
||||
user
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
inputs=self._get_cleaned_inputs(inputs, app_config),
|
||||
files=file_objs,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=app_model.mode
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
return self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:return:
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# workflow app
|
||||
runner = WorkflowAppRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool = False) -> Union[dict, Generator]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param stream: is stream
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
try:
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise ConversationTaskStoppedException()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueMessage,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
app_mode: str) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
|
||||
return QueueMessage(
|
||||
task_id=self._task_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
)
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AppGenerateEntity,
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent
|
||||
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import WorkflowRunTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppRunner:
|
||||
"""
|
||||
Workflow Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
files = application_generate_entity.files
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
queue_manager=queue_manager,
|
||||
app_record=app_record,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs
|
||||
):
|
||||
return
|
||||
|
||||
# fetch user
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]:
|
||||
user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first()
|
||||
else:
|
||||
user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow=workflow,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
|
||||
user=user,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.FILES: files
|
||||
},
|
||||
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
|
||||
)
|
||||
|
||||
def handle_input_moderation(self, queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: WorkflowAppGenerateEntity,
|
||||
inputs: dict) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
:param queue_manager: application queue manager
|
||||
:param app_record: app record
|
||||
:param app_generate_entity: application generate entity
|
||||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
moderation_feature = InputModeration()
|
||||
_, inputs, query = moderation_feature.check(
|
||||
app_id=app_record.id,
|
||||
tenant_id=app_generate_entity.app_config.tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
query=''
|
||||
)
|
||||
except ModerationException as e:
|
||||
if app_generate_entity.stream:
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
text=str(e),
|
||||
)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _stream_output(self, queue_manager: AppQueueManager,
|
||||
text: str) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=token
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
def moderation_for_inputs(self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: dict) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
:param tenant_id: tenant id
|
||||
:param app_generate_entity: app generate entity
|
||||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
moderation_feature = InputModeration()
|
||||
return moderation_feature.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
query=''
|
||||
)
|
||||
|
|
@ -0,0 +1,408 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueErrorEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFinishedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFinishedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
"""
|
||||
TaskState entity
|
||||
"""
|
||||
answer: str = ""
|
||||
metadata: dict = {}
|
||||
workflow_run_id: Optional[str] = None
|
||||
|
||||
|
||||
class WorkflowAppGenerateTaskPipeline:
|
||||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self._task_state = TaskState()
|
||||
self._start_at = time.perf_counter()
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
self._stream = stream
|
||||
|
||||
def process(self) -> Union[dict, Generator]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
if self._stream:
|
||||
return self._process_stream_response()
|
||||
else:
|
||||
return self._process_blocking_response()
|
||||
|
||||
def _process_blocking_response(self) -> dict:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
"""
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
raise self._handle_error(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._get_workflow_run(self._task_state.workflow_run_id)
|
||||
else:
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||
outputs = workflow_run.outputs
|
||||
self._task_state.answer = outputs.get('text', '')
|
||||
else:
|
||||
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
self._task_state.answer = self._output_moderation_handler.moderation_completion(
|
||||
completion=self._task_state.answer,
|
||||
public_event=False
|
||||
)
|
||||
|
||||
response = {
|
||||
'event': 'workflow_finished',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': event.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_run.id,
|
||||
'workflow_id': workflow_run.workflow_id,
|
||||
'status': workflow_run.status,
|
||||
'outputs': workflow_run.outputs_dict,
|
||||
'error': workflow_run.error,
|
||||
'elapsed_time': workflow_run.elapsed_time,
|
||||
'total_tokens': workflow_run.total_tokens,
|
||||
'total_steps': workflow_run.total_steps,
|
||||
'created_at': int(workflow_run.created_at.timestamp()),
|
||||
'finished_at': int(workflow_run.finished_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
def _process_stream_response(self) -> Generator:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
data = self._error_to_stream_response_data(self._handle_error(event))
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
self._task_state.workflow_run_id = event.workflow_run_id
|
||||
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
response = {
|
||||
'event': 'workflow_started',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': event.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_run.id,
|
||||
'workflow_id': workflow_run.workflow_id,
|
||||
'created_at': int(workflow_run.created_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||
response = {
|
||||
'event': 'node_started',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': workflow_node_execution.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_node_execution.id,
|
||||
'node_id': workflow_node_execution.node_id,
|
||||
'index': workflow_node_execution.index,
|
||||
'predecessor_node_id': workflow_node_execution.predecessor_node_id,
|
||||
'inputs': workflow_node_execution.inputs_dict,
|
||||
'created_at': int(workflow_node_execution.created_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeFinishedEvent):
|
||||
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
|
||||
response = {
|
||||
'event': 'node_finished',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': workflow_node_execution.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_node_execution.id,
|
||||
'node_id': workflow_node_execution.node_id,
|
||||
'index': workflow_node_execution.index,
|
||||
'predecessor_node_id': workflow_node_execution.predecessor_node_id,
|
||||
'inputs': workflow_node_execution.inputs_dict,
|
||||
'process_data': workflow_node_execution.process_data_dict,
|
||||
'outputs': workflow_node_execution.outputs_dict,
|
||||
'status': workflow_node_execution.status,
|
||||
'error': workflow_node_execution.error,
|
||||
'elapsed_time': workflow_node_execution.elapsed_time,
|
||||
'execution_metadata': workflow_node_execution.execution_metadata_dict,
|
||||
'created_at': int(workflow_node_execution.created_at.timestamp()),
|
||||
'finished_at': int(workflow_node_execution.finished_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._get_workflow_run(self._task_state.workflow_run_id)
|
||||
else:
|
||||
workflow_run = self._get_workflow_run(event.workflow_run_id)
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
|
||||
outputs = workflow_run.outputs
|
||||
self._task_state.answer = outputs.get('text', '')
|
||||
else:
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
data = self._error_to_stream_response_data(self._handle_error(err_event))
|
||||
yield self._yield_response(data)
|
||||
break
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
self._task_state.answer = self._output_moderation_handler.moderation_completion(
|
||||
completion=self._task_state.answer,
|
||||
public_event=False
|
||||
)
|
||||
|
||||
self._output_moderation_handler = None
|
||||
|
||||
replace_response = {
|
||||
'event': 'text_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': self._task_state.workflow_run_id,
|
||||
'data': {
|
||||
'text': self._task_state.answer
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(replace_response)
|
||||
|
||||
workflow_run_response = {
|
||||
'event': 'workflow_finished',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': event.workflow_run_id,
|
||||
'data': {
|
||||
'id': workflow_run.id,
|
||||
'workflow_id': workflow_run.workflow_id,
|
||||
'status': workflow_run.status,
|
||||
'outputs': workflow_run.outputs_dict,
|
||||
'error': workflow_run.error,
|
||||
'elapsed_time': workflow_run.elapsed_time,
|
||||
'total_tokens': workflow_run.total_tokens,
|
||||
'total_steps': workflow_run.total_steps,
|
||||
'created_at': int(workflow_run.created_at.timestamp()),
|
||||
'finished_at': int(workflow_run.finished_at.timestamp())
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(workflow_run_response)
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.chunk_text
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=self._task_state.answer
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
continue
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(delta_text)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
response = self._handle_chunk(delta_text)
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
response = {
|
||||
'event': 'text_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': self._task_state.workflow_run_id,
|
||||
'data': {
|
||||
'text': event.text
|
||||
}
|
||||
}
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield "event: ping\n\n"
|
||||
else:
|
||||
continue
|
||||
|
||||
def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Get workflow run.
|
||||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
|
||||
|
||||
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Get workflow node execution.
|
||||
:param workflow_node_execution_id: workflow node execution id
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).first()
|
||||
|
||||
def _handle_chunk(self, text: str) -> dict:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'text_chunk',
|
||||
'workflow_run_id': self._task_state.workflow_run_id,
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'data': {
|
||||
'text': text
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
def _handle_error(self, event: QueueErrorEvent) -> Exception:
|
||||
"""
|
||||
Handle error event.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
return InvokeAuthorizationError('Incorrect API key provided')
|
||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||
return e
|
||||
else:
|
||||
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
|
||||
def _error_to_stream_response_data(self, e: Exception) -> dict:
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
error_responses = {
|
||||
ValueError: {'code': 'invalid_param', 'status': 400},
|
||||
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
|
||||
QuotaExceededError: {
|
||||
'code': 'provider_quota_exceeded',
|
||||
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
'status': 400
|
||||
},
|
||||
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
|
||||
InvokeError: {'code': 'completion_request_error', 'status': 400}
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
data = None
|
||||
for k, v in error_responses.items():
|
||||
if isinstance(e, k):
|
||||
data = v
|
||||
|
||||
if data:
|
||||
data.setdefault('message', getattr(e, 'description', str(e)))
|
||||
else:
|
||||
logging.error(e)
|
||||
data = {
|
||||
'code': 'internal_server_error',
|
||||
'message': 'Internal Server Error, please contact support.',
|
||||
'status': 500
|
||||
}
|
||||
|
||||
return {
|
||||
'event': 'error',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'workflow_run_id': self._task_state.workflow_run_id,
|
||||
**data
|
||||
}
|
||||
|
||||
def _yield_response(self, response: dict) -> str:
|
||||
"""
|
||||
Yield response.
|
||||
:param response: response
|
||||
:return:
|
||||
"""
|
||||
return "data: " + json.dumps(response) + "\n\n"
|
||||
|
||||
def _init_output_moderation(self) -> Optional[OutputModeration]:
|
||||
"""
|
||||
Init output moderation.
|
||||
:return:
|
||||
"""
|
||||
app_config = self._application_generate_entity.app_config
|
||||
sensitive_word_avoidance = app_config.sensitive_word_avoidance
|
||||
|
||||
if sensitive_word_avoidance:
|
||||
return OutputModeration(
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_id=app_config.app_id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance.type,
|
||||
config=sensitive_word_avoidance.config
|
||||
),
|
||||
queue_manager=self._queue_manager
|
||||
)
|
||||
|
|
@ -127,9 +127,9 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
|||
query: Optional[str] = None
|
||||
|
||||
|
||||
class WorkflowUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
Workflow UI Based Application Generate Entity.
|
||||
Workflow Application Generate Entity.
|
||||
"""
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetQuery, DocumentSegment
|
||||
|
|
@ -82,4 +83,7 @@ class DatasetIndexToolCallbackHandler:
|
|||
db.session.add(dataset_retriever_resource)
|
||||
db.session.commit()
|
||||
|
||||
self._queue_manager.publish_retriever_resources(resource, PublishFrom.APPLICATION_MANAGER)
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,11 @@
|
|||
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueNodeFinishedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFinishedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowRun
|
||||
|
||||
|
|
@ -12,43 +19,45 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish_workflow_started(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pub_from=PublishFrom.TASK_PIPELINE
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Workflow run finished
|
||||
"""
|
||||
self._queue_manager.publish_workflow_finished(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pub_from=PublishFrom.TASK_PIPELINE
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish_node_started(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
pub_from=PublishFrom.TASK_PIPELINE
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Workflow node execute finished
|
||||
"""
|
||||
self._queue_manager.publish_node_finished(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
pub_from=PublishFrom.TASK_PIPELINE
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
|
||||
def on_text_chunk(self, text: str) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
self._queue_manager.publish_text_chunk(
|
||||
text=text,
|
||||
pub_from=PublishFrom.TASK_PIPELINE
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ from typing import Any, Optional
|
|||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.app_queue_manager import PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import QueueMessageReplaceEvent
|
||||
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
|
@ -25,7 +26,7 @@ class OutputModeration(BaseModel):
|
|||
app_id: str
|
||||
|
||||
rule: ModerationRule
|
||||
on_message_replace_func: Any
|
||||
queue_manager: AppQueueManager
|
||||
|
||||
thread: Optional[threading.Thread] = None
|
||||
thread_running: bool = True
|
||||
|
|
@ -67,7 +68,12 @@ class OutputModeration(BaseModel):
|
|||
final_output = result.text
|
||||
|
||||
if public_event:
|
||||
self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
|
||||
self.queue_manager.publish(
|
||||
QueueMessageReplaceEvent(
|
||||
text=final_output
|
||||
),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
return final_output
|
||||
|
||||
|
|
@ -117,7 +123,12 @@ class OutputModeration(BaseModel):
|
|||
|
||||
# trigger replace event
|
||||
if self.thread_running:
|
||||
self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
|
||||
self.queue_manager.publish(
|
||||
QueueMessageReplaceEvent(
|
||||
text=final_output
|
||||
),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Optional, Union
|
|||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
|
|
@ -175,8 +176,24 @@ class WorkflowService:
|
|||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom) -> Union[dict, Generator]:
|
||||
# TODO
|
||||
pass
|
||||
# fetch draft workflow by app_model
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not draft_workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
# run draft workflow
|
||||
app_generator = WorkflowAppGenerator()
|
||||
response = app_generator.generate(
|
||||
app_model=app_model,
|
||||
workflow=draft_workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=True
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def convert_to_workflow(self, app_model: App, account: Account) -> App:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue