mirror of https://github.com/langgenius/dify.git
add few workflow run codes
This commit is contained in:
parent
1a86e79d4a
commit
75f1355d4c
|
|
@ -15,7 +15,7 @@ from libs.rsa import generate_key_pair
|
|||
from models.account import Tenant
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppMode, AppModelConfig, AppAnnotationSetting, Conversation, MessageAnnotation
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ class VariableEntity(BaseModel):
|
|||
max_length: Optional[int] = None
|
||||
options: Optional[list[str]] = None
|
||||
default: Optional[str] = None
|
||||
hint: Optional[str] = None
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
|
|
|
|||
|
|
@ -10,12 +10,14 @@ from core.app.entities.app_invoke_entities import (
|
|||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueStopEvent
|
||||
from core.callback_handler.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import WorkflowRunTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -83,13 +85,16 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
result_generator = workflow_engine_manager.run_workflow(
|
||||
app_model=app_record,
|
||||
workflow=workflow,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
|
||||
user=user,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION: conversation.id,
|
||||
}
|
||||
},
|
||||
callbacks=[WorkflowEventTriggerCallback(queue_manager=queue_manager)]
|
||||
)
|
||||
|
||||
for result in result_generator:
|
||||
|
|
|
|||
|
|
@ -1,157 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.input import print_text
|
||||
from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult
|
||||
|
||||
|
||||
class DifyStdOutCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
def __init__(self, color: Optional[str] = None) -> None:
|
||||
"""Initialize callback handler."""
|
||||
self.color = color
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[BaseMessage]],
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
print_text("\n[on_chat_model_start]\n", color='blue')
|
||||
for sub_messages in messages:
|
||||
for sub_message in sub_messages:
|
||||
print_text(str(sub_message) + "\n", color='blue')
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
print_text("\n[on_llm_start]\n", color='blue')
|
||||
print_text(prompts[0] + "\n", color='blue')
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_llm_end]\nOutput: " + str(response.generations[0][0].text) + "\nllm_output: " + str(
|
||||
response.llm_output) + "\n", color='blue')
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
chain_type = serialized['id'][-1]
|
||||
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
|
||||
|
||||
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_chain_error]\nError: " + str(error) + "\n", color='pink')
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_start] " + str(serialized), color='yellow')
|
||||
|
||||
def on_agent_action(
|
||||
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
try:
|
||||
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
|
||||
thought = action.log[:action_name_position].strip() if action.log else ''
|
||||
except ValueError:
|
||||
thought = ''
|
||||
|
||||
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
|
||||
print_text("\n[on_agent_action]\n" + log + "\n", color='green')
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
print_text("\n[on_tool_end]\n", color='yellow')
|
||||
if observation_prefix:
|
||||
print_text(f"\n{observation_prefix}")
|
||||
print_text(output, color='yellow')
|
||||
if llm_prefix:
|
||||
print_text(f"\n{llm_prefix}")
|
||||
print_text("\n")
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='yellow')
|
||||
|
||||
def on_text(
|
||||
self,
|
||||
text: str,
|
||||
color: Optional[str] = None,
|
||||
end: str = "",
|
||||
**kwargs: Optional[str],
|
||||
) -> None:
|
||||
"""Run when agent ends."""
|
||||
print_text("\n[on_text] " + text + "\n", color=color if color else self.color, end=end)
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
|
||||
|
||||
class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
|
||||
"""Callback handler for streaming. Only works with LLMs that support streaming."""
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
from core.app.app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
|
||||
from models.workflow import WorkflowRun, WorkflowNodeExecution
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish_workflow_started(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pub_from=PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Workflow run finished
|
||||
"""
|
||||
self._queue_manager.publish_workflow_finished(
|
||||
workflow_run_id=workflow_run.id,
|
||||
pub_from=PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish_node_started(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
pub_from=PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Workflow node execute finished
|
||||
"""
|
||||
self._queue_manager.publish_node_finished(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
pub_from=PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
from abc import abstractmethod
|
||||
|
||||
from models.workflow import WorkflowRun, WorkflowNodeExecution
|
||||
|
||||
|
||||
class BaseWorkflowCallback:
|
||||
@abstractmethod
|
||||
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
Workflow run finished
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Workflow node execute finished
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
from abc import ABC
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
pass
|
||||
|
|
@ -1,32 +1,21 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
class BaseNode:
|
||||
_node_type: NodeType
|
||||
_node_data_cls: Type[BaseNodeData]
|
||||
|
||||
def __int__(self, node_config: dict) -> None:
|
||||
self._node_config = node_config
|
||||
def __init__(self, config: dict) -> None:
|
||||
self._node_id = config.get("id")
|
||||
if not self._node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
|
||||
@abstractmethod
|
||||
def run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
if variable_pool is None and run_args is None:
|
||||
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
|
||||
|
||||
return self._run(
|
||||
variable_pool=variable_pool,
|
||||
run_args=run_args
|
||||
)
|
||||
self._node_data = self._node_data_cls(**config.get("data", {}))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
|
|
@ -39,6 +28,22 @@ class BaseNode:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Run node entry
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
if variable_pool is None and run_args is None:
|
||||
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
|
||||
|
||||
return self._run(
|
||||
variable_pool=variable_pool,
|
||||
run_args=run_args
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
|
||||
|
||||
class StartNodeData(BaseNodeData):
|
||||
"""
|
||||
- title (string) 节点标题
|
||||
- desc (string) optional 节点描述
|
||||
- type (string) 节点类型,固定为 start
|
||||
- variables (array[object]) 表单变量列表
|
||||
- type (string) 表单变量类型,text-input, paragraph, select, number, files(文件暂不支持自定义)
|
||||
- label (string) 控件展示标签名
|
||||
- variable (string) 变量 key
|
||||
- max_length (int) 最大长度,适用于 text-input 和 paragraph
|
||||
- default (string) optional 默认值
|
||||
- required (bool) optional是否必填,默认 false
|
||||
- hint (string) optional 提示信息
|
||||
- options (array[string]) 选项值(仅 select 可用)
|
||||
"""
|
||||
type: str = NodeType.START.value
|
||||
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
variables: list[VariableEntity] = []
|
||||
|
|
@ -1,5 +1,22 @@
|
|||
from typing import Type, Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
|
||||
|
||||
class StartNode(BaseNode):
|
||||
pass
|
||||
_node_type = NodeType.START
|
||||
_node_data_cls = StartNodeData
|
||||
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.workflow.callbacks.base_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode
|
||||
|
|
@ -17,7 +19,7 @@ from core.workflow.nodes.variable_assigner.variable_assigner_node import Variabl
|
|||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import Workflow, WorkflowRunTriggeredFrom, WorkflowRun, WorkflowRunStatus, CreatedByRole
|
||||
|
||||
node_classes = {
|
||||
NodeType.START: StartNode,
|
||||
|
|
@ -108,17 +110,103 @@ class WorkflowEngineManager:
|
|||
|
||||
def run_workflow(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
user: Union[Account, EndUser],
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None) -> Generator:
|
||||
system_inputs: Optional[dict] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> Generator:
|
||||
"""
|
||||
Run workflow
|
||||
:param app_model: App instance
|
||||
:param workflow: Workflow instance
|
||||
:param triggered_from: triggered from
|
||||
:param user: account or end user
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:param callbacks: workflow callbacks
|
||||
:return:
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph = workflow.graph_dict
|
||||
if not graph:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
# init workflow run
|
||||
workflow_run = self._init_workflow_run(
|
||||
workflow=workflow,
|
||||
triggered_from=triggered_from,
|
||||
user=user,
|
||||
user_inputs=user_inputs,
|
||||
system_inputs=system_inputs
|
||||
)
|
||||
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_run_started(workflow_run)
|
||||
|
||||
pass
|
||||
|
||||
def _init_workflow_run(self, workflow: Workflow,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
user: Union[Account, EndUser],
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None) -> WorkflowRun:
|
||||
"""
|
||||
Init workflow run
|
||||
:param workflow: Workflow instance
|
||||
:param triggered_from: triggered from
|
||||
:param user: account or end user
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:return:
|
||||
"""
|
||||
# TODO
|
||||
pass
|
||||
try:
|
||||
db.session.begin()
|
||||
|
||||
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
|
||||
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
|
||||
.filter(WorkflowRun.app_id == workflow.app_id) \
|
||||
.for_update() \
|
||||
.scalar() or 0
|
||||
new_sequence_number = max_sequence + 1
|
||||
|
||||
# init workflow run
|
||||
workflow_run = WorkflowRun(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
sequence_number=new_sequence_number,
|
||||
workflow_id=workflow.id,
|
||||
type=workflow.type,
|
||||
triggered_from=triggered_from.value,
|
||||
version=workflow.version,
|
||||
graph=workflow.graph,
|
||||
inputs=json.dumps({**user_inputs, **system_inputs}),
|
||||
status=WorkflowRunStatus.RUNNING.value,
|
||||
created_by_role=(CreatedByRole.ACCOUNT.value
|
||||
if isinstance(user, Account) else CreatedByRole.END_USER.value),
|
||||
created_by_id=user.id
|
||||
)
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
except:
|
||||
db.session.rollback()
|
||||
raise
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _get_entry_node(self, graph: dict) -> Optional[StartNode]:
|
||||
"""
|
||||
Get entry node
|
||||
:param graph: workflow graph
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
if not nodes:
|
||||
return None
|
||||
|
||||
for node_config in nodes.items():
|
||||
if node_config.get('type') == NodeType.START.value:
|
||||
return StartNode(config=node_config)
|
||||
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in New Issue