mirror of https://github.com/langgenius/dify.git
add start, end, direct answer node
This commit is contained in:
parent
46296d777c
commit
ea883b5e48
|
|
@ -5,7 +5,5 @@ from pydantic import BaseModel
|
|||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
type: str
|
||||
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -46,6 +46,15 @@ class SystemVariable(Enum):
|
|||
CONVERSATION = 'conversation'
|
||||
|
||||
|
||||
class NodeRunMetadataKey(Enum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
"""
|
||||
TOTAL_TOKENS = 'total_tokens'
|
||||
TOTAL_PRICE = 'total_price'
|
||||
CURRENCY = 'currency'
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
"""
|
||||
Node Run Result.
|
||||
|
|
@ -55,7 +64,7 @@ class NodeRunResult(BaseModel):
|
|||
inputs: Optional[dict] = None # node inputs
|
||||
process_data: Optional[dict] = None # process data
|
||||
outputs: Optional[dict] = None # node outputs
|
||||
metadata: Optional[dict] = None # node metadata
|
||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
|
||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class VariableSelector(BaseModel):
|
||||
"""
|
||||
Variable Selector.
|
||||
"""
|
||||
variable: str
|
||||
value_selector: list[str]
|
||||
|
|
@ -5,13 +5,18 @@ from models.workflow import WorkflowNodeExecution, WorkflowRun
|
|||
class WorkflowRunState:
|
||||
workflow_run: WorkflowRun
|
||||
start_at: float
|
||||
user_inputs: dict
|
||||
variable_pool: VariablePool
|
||||
|
||||
total_tokens: int = 0
|
||||
|
||||
workflow_node_executions: list[WorkflowNodeExecution] = []
|
||||
|
||||
def __init__(self, workflow_run: WorkflowRun, start_at: float, variable_pool: VariablePool) -> None:
|
||||
def __init__(self, workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
user_inputs: dict,
|
||||
variable_pool: VariablePool) -> None:
|
||||
self.workflow_run = workflow_run
|
||||
self.start_at = start_at
|
||||
self.user_inputs = user_inputs
|
||||
self.variable_pool = variable_pool
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
|
|
@ -8,7 +8,7 @@ from core.workflow.entities.variable_pool import VariablePool
|
|||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class BaseNode:
|
||||
class BaseNode(ABC):
|
||||
_node_data_cls: type[BaseNodeData]
|
||||
_node_type: NodeType
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,54 @@
|
|||
import time
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import ValueType, VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.direct_answer.entities import DirectAnswerNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class DirectAnswerNode(BaseNode):
|
||||
pass
|
||||
_node_data_cls = DirectAnswerNodeData
|
||||
node_type = NodeType.DIRECT_ANSWER
|
||||
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
|
||||
if variable_pool is None and run_args:
|
||||
raise ValueError("Not support single step debug.")
|
||||
|
||||
variable_values = {}
|
||||
for variable_selector in node_data.variables:
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector,
|
||||
target_value_type=ValueType.STRING
|
||||
)
|
||||
|
||||
variable_values[variable_selector.variable] = value
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(node_data.answer)
|
||||
answer = template_parser.format(variable_values)
|
||||
|
||||
# publish answer as stream
|
||||
for word in answer:
|
||||
self.publish_text_chunk(word)
|
||||
time.sleep(0.01)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variable_values,
|
||||
output={
|
||||
"answer": answer
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class DirectAnswerNodeData(BaseNodeData):
|
||||
"""
|
||||
DirectAnswer Node Data.
|
||||
"""
|
||||
variables: list[VariableSelector] = []
|
||||
answer: str
|
||||
|
|
@ -1,5 +1,60 @@
|
|||
from typing import Optional, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import ValueType, VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class EndNode(BaseNode):
|
||||
pass
|
||||
_node_data_cls = EndNodeData
|
||||
node_type = NodeType.END
|
||||
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
outputs_config = node_data.outputs
|
||||
|
||||
if variable_pool is not None:
|
||||
outputs = None
|
||||
if outputs_config:
|
||||
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT:
|
||||
plain_text_selector = outputs_config.plain_text_selector
|
||||
if plain_text_selector:
|
||||
outputs = {
|
||||
'text': variable_pool.get_variable_value(
|
||||
variable_selector=plain_text_selector,
|
||||
target_value_type=ValueType.STRING
|
||||
)
|
||||
}
|
||||
else:
|
||||
outputs = {
|
||||
'text': ''
|
||||
}
|
||||
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
|
||||
structured_variables = outputs_config.structured_variables
|
||||
if structured_variables:
|
||||
outputs = {}
|
||||
for variable_selector in structured_variables:
|
||||
variable_value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
outputs[variable_selector.variable] = variable_value
|
||||
else:
|
||||
outputs = {}
|
||||
else:
|
||||
raise ValueError("Not support single step debug.")
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=outputs,
|
||||
outputs=outputs
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,10 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class EndNodeOutputType(Enum):
|
||||
|
|
@ -23,3 +29,40 @@ class EndNodeOutputType(Enum):
|
|||
if output_type.value == value:
|
||||
return output_type
|
||||
raise ValueError(f'invalid output type value {value}')
|
||||
|
||||
|
||||
class EndNodeDataOutputs(BaseModel):
|
||||
"""
|
||||
END Node Data Outputs.
|
||||
"""
|
||||
class OutputType(Enum):
|
||||
"""
|
||||
Output Types.
|
||||
"""
|
||||
NONE = 'none'
|
||||
PLAIN_TEXT = 'plain-text'
|
||||
STRUCTURED = 'structured'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'OutputType':
|
||||
"""
|
||||
Get value of given output type.
|
||||
|
||||
:param value: output type value
|
||||
:return: output type
|
||||
"""
|
||||
for output_type in cls:
|
||||
if output_type.value == value:
|
||||
return output_type
|
||||
raise ValueError(f'invalid output type value {value}')
|
||||
|
||||
type: OutputType = OutputType.NONE
|
||||
plain_text_selector: Optional[list[str]] = None
|
||||
structured_variables: Optional[list[VariableSelector]] = None
|
||||
|
||||
|
||||
class EndNodeData(BaseNodeData):
|
||||
"""
|
||||
END Node Data.
|
||||
"""
|
||||
outputs: Optional[EndNodeDataOutputs] = None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
"""
|
||||
LLM Node Data.
|
||||
"""
|
||||
pass
|
||||
|
|
@ -1,9 +1,28 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
_node_data_cls = LLMNodeData
|
||||
node_type = NodeType.LLM
|
||||
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,23 +1,9 @@
|
|||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
|
||||
|
||||
class StartNodeData(BaseNodeData):
|
||||
"""
|
||||
- title (string) 节点标题
|
||||
- desc (string) optional 节点描述
|
||||
- type (string) 节点类型,固定为 start
|
||||
- variables (array[object]) 表单变量列表
|
||||
- type (string) 表单变量类型,text-input, paragraph, select, number, files(文件暂不支持自定义)
|
||||
- label (string) 控件展示标签名
|
||||
- variable (string) 变量 key
|
||||
- max_length (int) 最大长度,适用于 text-input 和 paragraph
|
||||
- default (string) optional 默认值
|
||||
- required (bool) optional是否必填,默认 false
|
||||
- hint (string) optional 提示信息
|
||||
- options (array[string]) 选项值(仅 select 可用)
|
||||
Start Node Data
|
||||
"""
|
||||
type: str = NodeType.START.value
|
||||
|
||||
variables: list[VariableEntity] = []
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class StartNode(BaseNode):
|
||||
|
|
@ -11,12 +13,58 @@ class StartNode(BaseNode):
|
|||
node_type = NodeType.START
|
||||
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> dict:
|
||||
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
variables = node_data.variables
|
||||
|
||||
# Get cleaned inputs
|
||||
cleaned_inputs = self._get_cleaned_inputs(variables, run_args)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=cleaned_inputs,
|
||||
outputs=cleaned_inputs
|
||||
)
|
||||
|
||||
def _get_cleaned_inputs(self, variables: list[VariableEntity], user_inputs: dict):
|
||||
if user_inputs is None:
|
||||
user_inputs = {}
|
||||
|
||||
filtered_inputs = {}
|
||||
|
||||
for variable_config in variables:
|
||||
variable = variable_config.variable
|
||||
|
||||
if variable not in user_inputs or not user_inputs[variable]:
|
||||
if variable_config.required:
|
||||
raise ValueError(f"Input form variable {variable} is required")
|
||||
else:
|
||||
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
|
||||
continue
|
||||
|
||||
value = user_inputs[variable]
|
||||
|
||||
if value:
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{variable} in input form must be a string")
|
||||
|
||||
if variable_config.type == VariableEntity.Type.SELECT:
|
||||
options = variable_config.options if variable_config.options is not None else []
|
||||
if value not in options:
|
||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
||||
else:
|
||||
if variable_config.max_length is not None:
|
||||
max_length = variable_config.max_length
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
||||
|
||||
filtered_inputs[variable] = value.replace('\x00', '') if value else None
|
||||
|
||||
return filtered_inputs
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import time
|
|||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
|
|
@ -141,6 +142,7 @@ class WorkflowEngineManager:
|
|||
workflow_run_state = WorkflowRunState(
|
||||
workflow_run=workflow_run,
|
||||
start_at=time.perf_counter(),
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=VariablePool(
|
||||
system_variables=system_inputs,
|
||||
)
|
||||
|
|
@ -399,7 +401,9 @@ class WorkflowEngineManager:
|
|||
|
||||
# run node, result must have inputs, process_data, outputs, execution_metadata
|
||||
node_run_result = node.run(
|
||||
variable_pool=workflow_run_state.variable_pool
|
||||
variable_pool=workflow_run_state.variable_pool,
|
||||
run_args=workflow_run_state.user_inputs
|
||||
if (not predecessor_node and node.node_type == NodeType.START) else None # only on start node
|
||||
)
|
||||
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
|
|
@ -492,7 +496,7 @@ class WorkflowEngineManager:
|
|||
workflow_node_execution.inputs = json.dumps(result.inputs)
|
||||
workflow_node_execution.process_data = json.dumps(result.process_data)
|
||||
workflow_node_execution.outputs = json.dumps(result.outputs)
|
||||
workflow_node_execution.execution_metadata = json.dumps(result.metadata)
|
||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata))
|
||||
workflow_node_execution.finished_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
|
|
|
|||
Loading…
Reference in New Issue