mirror of https://github.com/langgenius/dify.git
refactor pipeline and remove node run run_args
This commit is contained in:
parent
80b4db08dc
commit
2db67c4101
|
|
@ -55,6 +55,19 @@ class TaskState(BaseModel):
|
|||
"""
|
||||
TaskState entity
|
||||
"""
|
||||
class NodeExecutionInfo(BaseModel):
|
||||
"""
|
||||
NodeExecutionInfo entity
|
||||
"""
|
||||
workflow_node_execution: WorkflowNodeExecution
|
||||
start_at: float
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
answer: str = ""
|
||||
metadata: dict = {}
|
||||
usage: LLMUsage
|
||||
|
|
@ -64,8 +77,8 @@ class TaskState(BaseModel):
|
|||
total_tokens: int = 0
|
||||
total_steps: int = 0
|
||||
|
||||
current_node_execution: Optional[WorkflowNodeExecution] = None
|
||||
current_node_execution_start_at: Optional[float] = None
|
||||
running_node_execution_infos: dict[str, NodeExecutionInfo] = {}
|
||||
latest_node_execution_info: Optional[NodeExecutionInfo] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
|
@ -218,7 +231,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
self._on_node_start(event)
|
||||
workflow_node_execution = self._task_state.current_node_execution
|
||||
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
|
||||
|
||||
response = {
|
||||
'event': 'node_started',
|
||||
|
|
@ -237,7 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
self._on_node_finished(event)
|
||||
workflow_node_execution = self._task_state.current_node_execution
|
||||
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
|
||||
|
||||
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
|
||||
if workflow_node_execution.node_type == NodeType.LLM.value:
|
||||
|
|
@ -447,15 +460,21 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
predecessor_node_id=event.predecessor_node_id
|
||||
)
|
||||
|
||||
self._task_state.current_node_execution = workflow_node_execution
|
||||
self._task_state.current_node_execution_start_at = time.perf_counter()
|
||||
latest_node_execution_info = TaskState.NodeExecutionInfo(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
self._task_state.total_steps += 1
|
||||
|
||||
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None:
|
||||
current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._workflow_node_execution_success(
|
||||
workflow_node_execution=self._task_state.current_node_execution,
|
||||
start_at=self._task_state.current_node_execution_start_at,
|
||||
workflow_node_execution=current_node_execution.workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
outputs=event.outputs,
|
||||
|
|
@ -472,12 +491,14 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
self._task_state.metadata['usage'] = usage_dict
|
||||
else:
|
||||
workflow_node_execution = self._workflow_node_execution_failed(
|
||||
workflow_node_execution=self._task_state.current_node_execution,
|
||||
start_at=self._task_state.current_node_execution_start_at,
|
||||
workflow_node_execution=current_node_execution.workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
error=event.error
|
||||
)
|
||||
|
||||
self._task_state.current_node_execution = workflow_node_execution
|
||||
# remove running node execution info
|
||||
del self._task_state.running_node_execution_infos[event.node_id]
|
||||
self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution
|
||||
|
||||
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None:
|
||||
if isinstance(event, QueueStopEvent):
|
||||
|
|
@ -504,8 +525,8 @@ class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
start_at=self._task_state.start_at,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
outputs=self._task_state.current_node_execution.outputs
|
||||
if self._task_state.current_node_execution else None
|
||||
outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs
|
||||
if self._task_state.latest_node_execution_info else None
|
||||
)
|
||||
|
||||
self._task_state.workflow_run = workflow_run
|
||||
|
|
|
|||
|
|
@ -41,6 +41,19 @@ class TaskState(BaseModel):
|
|||
"""
|
||||
TaskState entity
|
||||
"""
|
||||
class NodeExecutionInfo(BaseModel):
|
||||
"""
|
||||
NodeExecutionInfo entity
|
||||
"""
|
||||
workflow_node_execution: WorkflowNodeExecution
|
||||
start_at: float
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
answer: str = ""
|
||||
metadata: dict = {}
|
||||
|
||||
|
|
@ -49,8 +62,8 @@ class TaskState(BaseModel):
|
|||
total_tokens: int = 0
|
||||
total_steps: int = 0
|
||||
|
||||
current_node_execution: Optional[WorkflowNodeExecution] = None
|
||||
current_node_execution_start_at: Optional[float] = None
|
||||
running_node_execution_infos: dict[str, NodeExecutionInfo] = {}
|
||||
latest_node_execution_info: Optional[NodeExecutionInfo] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
|
@ -179,7 +192,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
self._on_node_start(event)
|
||||
workflow_node_execution = self._task_state.current_node_execution
|
||||
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
|
||||
|
||||
response = {
|
||||
'event': 'node_started',
|
||||
|
|
@ -198,7 +211,7 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
self._on_node_finished(event)
|
||||
workflow_node_execution = self._task_state.current_node_execution
|
||||
workflow_node_execution = self._task_state.latest_node_execution_info.workflow_node_execution
|
||||
|
||||
response = {
|
||||
'event': 'node_finished',
|
||||
|
|
@ -339,15 +352,22 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
predecessor_node_id=event.predecessor_node_id
|
||||
)
|
||||
|
||||
self._task_state.current_node_execution = workflow_node_execution
|
||||
self._task_state.current_node_execution_start_at = time.perf_counter()
|
||||
latest_node_execution_info = TaskState.NodeExecutionInfo(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
self._task_state.total_steps += 1
|
||||
|
||||
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> None:
|
||||
current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._workflow_node_execution_success(
|
||||
workflow_node_execution=self._task_state.current_node_execution,
|
||||
start_at=self._task_state.current_node_execution_start_at,
|
||||
workflow_node_execution=current_node_execution.workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
outputs=event.outputs,
|
||||
|
|
@ -359,12 +379,14 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
|
||||
else:
|
||||
workflow_node_execution = self._workflow_node_execution_failed(
|
||||
workflow_node_execution=self._task_state.current_node_execution,
|
||||
start_at=self._task_state.current_node_execution_start_at,
|
||||
workflow_node_execution=current_node_execution.workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
error=event.error
|
||||
)
|
||||
|
||||
self._task_state.current_node_execution = workflow_node_execution
|
||||
# remove running node execution info
|
||||
del self._task_state.running_node_execution_infos[event.node_id]
|
||||
self._task_state.latest_node_execution_info.workflow_node_execution = workflow_node_execution
|
||||
|
||||
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) -> None:
|
||||
if isinstance(event, QueueStopEvent):
|
||||
|
|
@ -391,8 +413,8 @@ class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
|
|||
start_at=self._task_state.start_at,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
outputs=self._task_state.current_node_execution.outputs
|
||||
if self._task_state.current_node_execution else None
|
||||
outputs=self._task_state.latest_node_execution_info.workflow_node_execution.outputs
|
||||
if self._task_state.latest_node_execution_info else None
|
||||
)
|
||||
|
||||
self._task_state.workflow_run = workflow_run
|
||||
|
|
|
|||
|
|
@ -19,14 +19,17 @@ class ValueType(Enum):
|
|||
|
||||
class VariablePool:
|
||||
variables_mapping = {}
|
||||
user_inputs: dict
|
||||
|
||||
def __init__(self, system_variables: dict[SystemVariable, Any]) -> None:
|
||||
def __init__(self, system_variables: dict[SystemVariable, Any],
|
||||
user_inputs: dict) -> None:
|
||||
# system variables
|
||||
# for example:
|
||||
# {
|
||||
# 'query': 'abc',
|
||||
# 'files': []
|
||||
# }
|
||||
self.user_inputs = user_inputs
|
||||
for system_variable, value in system_variables.items():
|
||||
self.append_variable('sys', [system_variable.value], value)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,15 +18,13 @@ class WorkflowNodeAndResult:
|
|||
class WorkflowRunState:
|
||||
workflow: Workflow
|
||||
start_at: float
|
||||
user_inputs: dict
|
||||
variable_pool: VariablePool
|
||||
|
||||
total_tokens: int = 0
|
||||
|
||||
workflow_nodes_and_results: list[WorkflowNodeAndResult] = []
|
||||
|
||||
def __init__(self, workflow: Workflow, start_at: float, user_inputs: dict, variable_pool: VariablePool):
|
||||
def __init__(self, workflow: Workflow, start_at: float, variable_pool: VariablePool):
|
||||
self.workflow = workflow
|
||||
self.start_at = start_at
|
||||
self.user_inputs = user_inputs
|
||||
self.variable_pool = variable_pool
|
||||
|
|
|
|||
|
|
@ -28,31 +28,23 @@ class BaseNode(ABC):
|
|||
self.callbacks = callbacks or []
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||
def run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node entry
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
if variable_pool is None and run_args is None:
|
||||
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
|
||||
|
||||
try:
|
||||
result = self._run(
|
||||
variable_pool=variable_pool,
|
||||
run_args=run_args
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
except Exception as e:
|
||||
# process unhandled exception
|
||||
|
|
@ -77,6 +69,26 @@ class BaseNode(ABC):
|
|||
text=text
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = cls._node_data_cls(**config.get("data", {}))
|
||||
return cls._extract_variable_selector_to_variable_mapping(node_data)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional, Union, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
|
|
@ -15,6 +16,7 @@ MAX_STRING_LENGTH = 1000
|
|||
MAX_STRING_ARRAY_LENGTH = 30
|
||||
MAX_NUMBER_ARRAY_LENGTH = 1000
|
||||
|
||||
|
||||
class CodeNode(BaseNode):
|
||||
_node_data_cls = CodeNodeData
|
||||
node_type = NodeType.CODE
|
||||
|
|
@ -78,21 +80,15 @@ class CodeNode(BaseNode):
|
|||
}
|
||||
}
|
||||
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run code
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data: CodeNodeData = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
|
||||
# SINGLE DEBUG NOT IMPLEMENTED YET
|
||||
if variable_pool is None and run_args:
|
||||
raise ValueError("Not support single step debug.")
|
||||
|
||||
# Get code language
|
||||
code_language = node_data.code_language
|
||||
code = node_data.code
|
||||
|
|
@ -134,7 +130,6 @@ class CodeNode(BaseNode):
|
|||
Check string
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:param max_length: max length
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
|
|
@ -142,9 +137,9 @@ class CodeNode(BaseNode):
|
|||
|
||||
if len(value) > MAX_STRING_LENGTH:
|
||||
raise ValueError(f'{variable} in input form must be less than {MAX_STRING_LENGTH} characters')
|
||||
|
||||
|
||||
return value.replace('\x00', '')
|
||||
|
||||
|
||||
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
|
||||
"""
|
||||
Check number
|
||||
|
|
@ -157,13 +152,13 @@ class CodeNode(BaseNode):
|
|||
|
||||
if value > MAX_NUMBER or value < MIN_NUMBER:
|
||||
raise ValueError(f'{variable} in input form is out of range.')
|
||||
|
||||
|
||||
if isinstance(value, float):
|
||||
value = round(value, MAX_PRECISION)
|
||||
|
||||
return value
|
||||
|
||||
def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output],
|
||||
def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output],
|
||||
prefix: str = '',
|
||||
depth: int = 1) -> dict:
|
||||
"""
|
||||
|
|
@ -174,7 +169,7 @@ class CodeNode(BaseNode):
|
|||
"""
|
||||
if depth > MAX_DEPTH:
|
||||
raise ValueError("Depth limit reached, object too deep.")
|
||||
|
||||
|
||||
transformed_result = {}
|
||||
for output_name, output_config in output_schema.items():
|
||||
if output_config.type == 'object':
|
||||
|
|
@ -183,7 +178,7 @@ class CodeNode(BaseNode):
|
|||
raise ValueError(
|
||||
f'Output {prefix}.{output_name} is not an object, got {type(result.get(output_name))} instead.'
|
||||
)
|
||||
|
||||
|
||||
transformed_result[output_name] = self._transform_result(
|
||||
result=result[output_name],
|
||||
output_schema=output_config.children,
|
||||
|
|
@ -208,7 +203,7 @@ class CodeNode(BaseNode):
|
|||
raise ValueError(
|
||||
f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.'
|
||||
)
|
||||
|
||||
|
||||
if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'{prefix}.{output_name} in input form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters'
|
||||
|
|
@ -227,12 +222,12 @@ class CodeNode(BaseNode):
|
|||
raise ValueError(
|
||||
f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.'
|
||||
)
|
||||
|
||||
|
||||
if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
f'{prefix}.{output_name} in input form must be less than {MAX_STRING_ARRAY_LENGTH} characters'
|
||||
)
|
||||
|
||||
|
||||
transformed_result[output_name] = [
|
||||
self._check_string(
|
||||
value=value,
|
||||
|
|
@ -242,5 +237,15 @@ class CodeNode(BaseNode):
|
|||
]
|
||||
else:
|
||||
raise ValueError(f'Output type {output_config.type} is not supported.')
|
||||
|
||||
return transformed_result
|
||||
|
||||
return transformed_result
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# TODO extract variable selector to variable mapping for single step debugging
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import time
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import ValueType, VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
|
|
@ -13,20 +14,15 @@ class DirectAnswerNode(BaseNode):
|
|||
_node_data_cls = DirectAnswerNodeData
|
||||
node_type = NodeType.DIRECT_ANSWER
|
||||
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
|
||||
if variable_pool is None and run_args:
|
||||
raise ValueError("Not support single step debug.")
|
||||
|
||||
variable_values = {}
|
||||
for variable_selector in node_data.variables:
|
||||
value = variable_pool.get_variable_value(
|
||||
|
|
@ -43,7 +39,7 @@ class DirectAnswerNode(BaseNode):
|
|||
# publish answer as stream
|
||||
for word in answer:
|
||||
self.publish_text_chunk(word)
|
||||
time.sleep(0.01) # todo sleep 0.01
|
||||
time.sleep(0.01)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
|
@ -52,3 +48,12 @@ class DirectAnswerNode(BaseNode):
|
|||
"answer": answer
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import ValueType, VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
|
|
@ -11,50 +12,54 @@ class EndNode(BaseNode):
|
|||
_node_data_cls = EndNodeData
|
||||
node_type = NodeType.END
|
||||
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
outputs_config = node_data.outputs
|
||||
|
||||
if variable_pool is not None:
|
||||
outputs = None
|
||||
if outputs_config:
|
||||
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT:
|
||||
plain_text_selector = outputs_config.plain_text_selector
|
||||
if plain_text_selector:
|
||||
outputs = {
|
||||
'text': variable_pool.get_variable_value(
|
||||
variable_selector=plain_text_selector,
|
||||
target_value_type=ValueType.STRING
|
||||
)
|
||||
}
|
||||
else:
|
||||
outputs = {
|
||||
'text': ''
|
||||
}
|
||||
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
|
||||
structured_variables = outputs_config.structured_variables
|
||||
if structured_variables:
|
||||
outputs = {}
|
||||
for variable_selector in structured_variables:
|
||||
variable_value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
outputs[variable_selector.variable] = variable_value
|
||||
else:
|
||||
outputs = {}
|
||||
else:
|
||||
raise ValueError("Not support single step debug.")
|
||||
outputs = None
|
||||
if outputs_config:
|
||||
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT:
|
||||
plain_text_selector = outputs_config.plain_text_selector
|
||||
if plain_text_selector:
|
||||
outputs = {
|
||||
'text': variable_pool.get_variable_value(
|
||||
variable_selector=plain_text_selector,
|
||||
target_value_type=ValueType.STRING
|
||||
)
|
||||
}
|
||||
else:
|
||||
outputs = {
|
||||
'text': ''
|
||||
}
|
||||
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
|
||||
structured_variables = outputs_config.structured_variables
|
||||
if structured_variables:
|
||||
outputs = {}
|
||||
for variable_selector in structured_variables:
|
||||
variable_value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
outputs[variable_selector.variable] = variable_value
|
||||
else:
|
||||
outputs = {}
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=outputs,
|
||||
outputs=outputs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional, cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
|
|
@ -10,12 +11,10 @@ class LLMNode(BaseNode):
|
|||
_node_data_cls = LLMNodeData
|
||||
node_type = NodeType.LLM
|
||||
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
|
|
@ -23,6 +22,17 @@ class LLMNode(BaseNode):
|
|||
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# TODO extract variable selector to variable mapping for single step debugging
|
||||
return {}
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
|
|
@ -12,12 +13,10 @@ class StartNode(BaseNode):
|
|||
_node_data_cls = StartNodeData
|
||||
node_type = NodeType.START
|
||||
|
||||
def _run(self, variable_pool: Optional[VariablePool] = None,
|
||||
run_args: Optional[dict] = None) -> NodeRunResult:
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:param run_args: run args
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
|
|
@ -25,7 +24,7 @@ class StartNode(BaseNode):
|
|||
variables = node_data.variables
|
||||
|
||||
# Get cleaned inputs
|
||||
cleaned_inputs = self._get_cleaned_inputs(variables, run_args)
|
||||
cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
|
@ -68,3 +67,12 @@ class StartNode(BaseNode):
|
|||
filtered_inputs[variable] = value.replace('\x00', '') if value else None
|
||||
|
||||
return filtered_inputs
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -109,9 +109,9 @@ class WorkflowEngineManager:
|
|||
workflow_run_state = WorkflowRunState(
|
||||
workflow=workflow,
|
||||
start_at=time.perf_counter(),
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=user_inputs
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -292,9 +292,7 @@ class WorkflowEngineManager:
|
|||
|
||||
# run node, result must have inputs, process_data, outputs, execution_metadata
|
||||
node_run_result = node.run(
|
||||
variable_pool=workflow_run_state.variable_pool,
|
||||
run_args=workflow_run_state.user_inputs
|
||||
if (not predecessor_node and node.node_type == NodeType.START) else None # only on start node
|
||||
variable_pool=workflow_run_state.variable_pool
|
||||
)
|
||||
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
|
|
|
|||
Loading…
Reference in New Issue