mirror of https://github.com/langgenius/dify.git
Refactor workflow nodes to use generic node_data (#28782)
This commit is contained in:
parent
002d8769b0
commit
8b761319f6
|
|
@ -70,7 +70,6 @@ class AgentNode(Node[AgentNodeData]):
|
|||
"""
|
||||
|
||||
node_type = NodeType.AGENT
|
||||
_node_data: AgentNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
|
|
@ -82,8 +81,8 @@ class AgentNode(Node[AgentNodeData]):
|
|||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=self.tenant_id,
|
||||
agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self._node_data.agent_strategy_name,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
|
|
@ -101,13 +100,13 @@ class AgentNode(Node[AgentNodeData]):
|
|||
parameters = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
)
|
||||
parameters_for_log = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
|
|
@ -140,7 +139,7 @@ class AgentNode(Node[AgentNodeData]):
|
|||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": self._node_data.agent_strategy_name,
|
||||
"agent_strategy": self.node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
|
|
@ -387,7 +386,7 @@ class AgentNode(Node[AgentNodeData]):
|
|||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
|
|
|
|||
|
|
@ -14,14 +14,12 @@ class AnswerNode(Node[AnswerNodeData]):
|
|||
node_type = NodeType.ANSWER
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
_node_data: AnswerNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
segments = self.graph_runtime_state.variable_pool.convert_template(self._node_data.answer)
|
||||
segments = self.graph_runtime_state.variable_pool.convert_template(self.node_data.answer)
|
||||
files = self._extract_files_from_segments(segments.value)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
|
@ -71,4 +69,4 @@ class AnswerNode(Node[AnswerNodeData]):
|
|||
Returns:
|
||||
Template instance for this Answer node
|
||||
"""
|
||||
return Template.from_answer_template(self._node_data.answer)
|
||||
return Template.from_answer_template(self.node_data.answer)
|
||||
|
|
|
|||
|
|
@ -24,8 +24,6 @@ from .exc import (
|
|||
class CodeNode(Node[CodeNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
|
||||
_node_data: CodeNodeData
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
|
|
@ -48,12 +46,12 @@ class CodeNode(Node[CodeNodeData]):
|
|||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Get code language
|
||||
code_language = self._node_data.code_language
|
||||
code = self._node_data.code
|
||||
code_language = self.node_data.code_language
|
||||
code = self.node_data.code
|
||||
|
||||
# Get variables
|
||||
variables = {}
|
||||
for variable_selector in self._node_data.variables:
|
||||
for variable_selector in self.node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
|
|
@ -69,7 +67,7 @@ class CodeNode(Node[CodeNodeData]):
|
|||
)
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result=result, output_schema=self._node_data.outputs)
|
||||
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
|
|
@ -406,7 +404,7 @@ class CodeNode(Node[CodeNodeData]):
|
|||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
return self.node_data.retry_config.retry_enabled
|
||||
|
||||
@staticmethod
|
||||
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
|||
Datasource Node
|
||||
"""
|
||||
|
||||
_node_data: DatasourceNodeData
|
||||
node_type = NodeType.DATASOURCE
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
|
|
@ -51,7 +50,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
|||
Run the datasource node
|
||||
"""
|
||||
|
||||
node_data = self._node_data
|
||||
node_data = self.node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
if not datasource_type_segement:
|
||||
|
|
|
|||
|
|
@ -43,14 +43,12 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
|||
|
||||
node_type = NodeType.DOCUMENT_EXTRACTOR
|
||||
|
||||
_node_data: DocumentExtractorNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
variable_selector = self._node_data.variable_selector
|
||||
variable_selector = self.node_data.variable_selector
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
|
||||
|
||||
if variable is None:
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ class EndNode(Node[EndNodeData]):
|
|||
node_type = NodeType.END
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
_node_data: EndNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -22,7 +20,7 @@ class EndNode(Node[EndNodeData]):
|
|||
This method runs after streaming is complete (if streaming was enabled).
|
||||
It collects all output variables and returns them.
|
||||
"""
|
||||
output_variables = self._node_data.outputs
|
||||
output_variables = self.node_data.outputs
|
||||
|
||||
outputs = {}
|
||||
for variable_selector in output_variables:
|
||||
|
|
@ -44,6 +42,6 @@ class EndNode(Node[EndNodeData]):
|
|||
Template instance for this End node
|
||||
"""
|
||||
outputs_config = [
|
||||
{"variable": output.variable, "value_selector": output.value_selector} for output in self._node_data.outputs
|
||||
{"variable": output.variable, "value_selector": output.value_selector} for output in self.node_data.outputs
|
||||
]
|
||||
return Template.from_end_outputs(outputs_config)
|
||||
|
|
|
|||
|
|
@ -34,8 +34,6 @@ logger = logging.getLogger(__name__)
|
|||
class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
node_type = NodeType.HTTP_REQUEST
|
||||
|
||||
_node_data: HttpRequestNodeData
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
|
|
@ -69,8 +67,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
|||
process_data = {}
|
||||
try:
|
||||
http_executor = Executor(
|
||||
node_data=self._node_data,
|
||||
timeout=self._get_request_timeout(self._node_data),
|
||||
node_data=self.node_data,
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
max_retries=0,
|
||||
)
|
||||
|
|
@ -225,4 +223,4 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
|||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
return self.node_data.retry_config.retry_enabled
|
||||
|
|
|
|||
|
|
@ -25,8 +25,6 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
|||
"handle",
|
||||
)
|
||||
|
||||
_node_data: HumanInputNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -49,12 +47,12 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
|||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
||||
if not self._node_data.required_variables:
|
||||
if not self.node_data.required_variables:
|
||||
return False
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
for selector_str in self._node_data.required_variables:
|
||||
for selector_str in self.node_data.required_variables:
|
||||
parts = selector_str.split(".")
|
||||
if len(parts) != 2:
|
||||
return False
|
||||
|
|
@ -74,7 +72,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
|||
if handle:
|
||||
return handle
|
||||
|
||||
default_values = self._node_data.default_value_dict
|
||||
default_values = self.node_data.default_value_dict
|
||||
for key in self._BRANCH_SELECTION_KEYS:
|
||||
handle = self._normalize_branch_value(default_values.get(key))
|
||||
if handle:
|
||||
|
|
|
|||
|
|
@ -16,8 +16,6 @@ class IfElseNode(Node[IfElseNodeData]):
|
|||
node_type = NodeType.IF_ELSE
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
_node_data: IfElseNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -37,8 +35,8 @@ class IfElseNode(Node[IfElseNodeData]):
|
|||
condition_processor = ConditionProcessor()
|
||||
try:
|
||||
# Check if the new cases structure is used
|
||||
if self._node_data.cases:
|
||||
for case in self._node_data.cases:
|
||||
if self.node_data.cases:
|
||||
for case in self.node_data.cases:
|
||||
input_conditions, group_result, final_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=case.conditions,
|
||||
|
|
@ -64,8 +62,8 @@ class IfElseNode(Node[IfElseNodeData]):
|
|||
input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated]
|
||||
condition_processor=condition_processor,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=self._node_data.conditions or [],
|
||||
operator=self._node_data.logical_operator or "and",
|
||||
conditions=self.node_data.conditions or [],
|
||||
operator=self.node_data.logical_operator or "and",
|
||||
)
|
||||
|
||||
selected_case_id = "true" if final_result else "false"
|
||||
|
|
|
|||
|
|
@ -65,7 +65,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
|
||||
node_type = NodeType.ITERATION
|
||||
execution_type = NodeExecutionType.CONTAINER
|
||||
_node_data: IterationNodeData
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
|
|
@ -136,10 +135,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
)
|
||||
|
||||
def _get_iterator_variable(self) -> ArraySegment | NoneSegment:
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
||||
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
|
||||
|
||||
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
|
@ -174,7 +173,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
return cast(list[object], iterator_list_value)
|
||||
|
||||
def _validate_start_node(self) -> None:
|
||||
if not self._node_data.start_node_id:
|
||||
if not self.node_data.start_node_id:
|
||||
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found")
|
||||
|
||||
def _execute_iterations(
|
||||
|
|
@ -184,7 +183,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
iter_run_map: dict[str, float],
|
||||
usage_accumulator: list[LLMUsage],
|
||||
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
|
||||
if self._node_data.is_parallel:
|
||||
if self.node_data.is_parallel:
|
||||
# Parallel mode execution
|
||||
yield from self._execute_parallel_iterations(
|
||||
iterator_list_value=iterator_list_value,
|
||||
|
|
@ -231,7 +230,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
outputs.extend([None] * len(iterator_list_value))
|
||||
|
||||
# Determine the number of parallel workers
|
||||
max_workers = min(self._node_data.parallel_nums, len(iterator_list_value))
|
||||
max_workers = min(self.node_data.parallel_nums, len(iterator_list_value))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all iteration tasks
|
||||
|
|
@ -287,7 +286,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
|
||||
except Exception as e:
|
||||
# Handle errors based on error_handle_mode
|
||||
match self._node_data.error_handle_mode:
|
||||
match self.node_data.error_handle_mode:
|
||||
case ErrorHandleMode.TERMINATED:
|
||||
# Cancel remaining futures and re-raise
|
||||
for f in future_to_index:
|
||||
|
|
@ -300,7 +299,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
outputs[index] = None # Will be filtered later
|
||||
|
||||
# Remove None values if in REMOVE_ABNORMAL_OUTPUT mode
|
||||
if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
outputs[:] = [output for output in outputs if output is not None]
|
||||
|
||||
def _execute_single_iteration_parallel(
|
||||
|
|
@ -389,7 +388,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
If flatten_output is True (default), flattens the list if all elements are lists.
|
||||
"""
|
||||
# If flatten_output is disabled, return outputs as-is
|
||||
if not self._node_data.flatten_output:
|
||||
if not self.node_data.flatten_output:
|
||||
return outputs
|
||||
|
||||
if not outputs:
|
||||
|
|
@ -569,14 +568,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
self._append_iteration_info_to_event(event=event, iter_run_index=current_index)
|
||||
yield event
|
||||
elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)):
|
||||
result = variable_pool.get(self._node_data.output_selector)
|
||||
result = variable_pool.get(self.node_data.output_selector)
|
||||
if result is None:
|
||||
outputs.append(None)
|
||||
else:
|
||||
outputs.append(result.to_object())
|
||||
return
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
match self._node_data.error_handle_mode:
|
||||
match self.node_data.error_handle_mode:
|
||||
case ErrorHandleMode.TERMINATED:
|
||||
raise IterationNodeError(event.error)
|
||||
case ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
|
|
@ -627,7 +626,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
|||
|
||||
# Initialize the iteration graph with the new node factory
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self._node_data.start_node_id
|
||||
graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
|
|
|
|||
|
|
@ -11,8 +11,6 @@ class IterationStartNode(Node[IterationStartNodeData]):
|
|||
|
||||
node_type = NodeType.ITERATION_START
|
||||
|
||||
_node_data: IterationStartNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -35,12 +35,11 @@ default_retrieval_model = {
|
|||
|
||||
|
||||
class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
_node_data: KnowledgeIndexNodeData
|
||||
node_type = NodeType.KNOWLEDGE_INDEX
|
||||
execution_type = NodeExecutionType.RESPONSE
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = self._node_data
|
||||
node_data = self.node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||
if not dataset_id:
|
||||
|
|
|
|||
|
|
@ -83,8 +83,6 @@ default_retrieval_model = {
|
|||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
_node_data: KnowledgeRetrievalNodeData
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
|
|
@ -122,7 +120,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# extract variables
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector)
|
||||
if not isinstance(variable, StringSegment):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
|
@ -163,7 +161,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||
# retrieve knowledge
|
||||
usage = LLMUsage.empty_usage()
|
||||
try:
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
|
||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
|
@ -536,7 +534,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
|||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
|
|
|
|||
|
|
@ -37,8 +37,6 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
|
|||
class ListOperatorNode(Node[ListOperatorNodeData]):
|
||||
node_type = NodeType.LIST_OPERATOR
|
||||
|
||||
_node_data: ListOperatorNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -48,9 +46,9 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
|
|||
process_data: dict[str, Sequence[object]] = {}
|
||||
outputs: dict[str, Any] = {}
|
||||
|
||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
|
||||
variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
|
||||
if variable is None:
|
||||
error_message = f"Variable not found for selector: {self._node_data.variable}"
|
||||
error_message = f"Variable not found for selector: {self.node_data.variable}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
|
|
@ -69,7 +67,7 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
|
|||
outputs=outputs,
|
||||
)
|
||||
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
|
||||
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
|
||||
error_message = f"Variable {self.node_data.variable} is not an array type, actual type: {type(variable)}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
|
|
@ -83,19 +81,19 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
|
|||
|
||||
try:
|
||||
# Filter
|
||||
if self._node_data.filter_by.enabled:
|
||||
if self.node_data.filter_by.enabled:
|
||||
variable = self._apply_filter(variable)
|
||||
|
||||
# Extract
|
||||
if self._node_data.extract_by.enabled:
|
||||
if self.node_data.extract_by.enabled:
|
||||
variable = self._extract_slice(variable)
|
||||
|
||||
# Order
|
||||
if self._node_data.order_by.enabled:
|
||||
if self.node_data.order_by.enabled:
|
||||
variable = self._apply_order(variable)
|
||||
|
||||
# Slice
|
||||
if self._node_data.limit.enabled:
|
||||
if self.node_data.limit.enabled:
|
||||
variable = self._apply_slice(variable)
|
||||
|
||||
outputs = {
|
||||
|
|
@ -121,7 +119,7 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
|
|||
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
filter_func: Callable[[Any], bool]
|
||||
result: list[Any] = []
|
||||
for condition in self._node_data.filter_by.conditions:
|
||||
for condition in self.node_data.filter_by.conditions:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
if not isinstance(condition.value, str):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
|
|
@ -160,22 +158,22 @@ class ListOperatorNode(Node[ListOperatorNodeData]):
|
|||
|
||||
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
|
||||
result = sorted(variable.value, reverse=self._node_data.order_by.value == Order.DESC)
|
||||
result = sorted(variable.value, reverse=self.node_data.order_by.value == Order.DESC)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
result = _order_file(
|
||||
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
|
||||
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
|
||||
return variable
|
||||
|
||||
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
result = variable.value[: self._node_data.limit.size]
|
||||
result = variable.value[: self.node_data.limit.size]
|
||||
return variable.model_copy(update={"value": result})
|
||||
|
||||
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
|
||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
|
||||
if value < 1:
|
||||
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
|
||||
if value > len(variable.value):
|
||||
|
|
|
|||
|
|
@ -102,8 +102,6 @@ logger = logging.getLogger(__name__)
|
|||
class LLMNode(Node[LLMNodeData]):
|
||||
node_type = NodeType.LLM
|
||||
|
||||
_node_data: LLMNodeData
|
||||
|
||||
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
|
|
@ -154,13 +152,13 @@ class LLMNode(Node[LLMNodeData]):
|
|||
|
||||
try:
|
||||
# init messages template
|
||||
self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
|
||||
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
|
||||
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data=self._node_data)
|
||||
inputs = self._fetch_inputs(node_data=self.node_data)
|
||||
|
||||
# fetch jinja2 inputs
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
|
||||
jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
|
||||
|
||||
# merge inputs
|
||||
inputs.update(jinja_inputs)
|
||||
|
|
@ -169,9 +167,9 @@ class LLMNode(Node[LLMNodeData]):
|
|||
files = (
|
||||
llm_utils.fetch_files(
|
||||
variable_pool=variable_pool,
|
||||
selector=self._node_data.vision.configs.variable_selector,
|
||||
selector=self.node_data.vision.configs.variable_selector,
|
||||
)
|
||||
if self._node_data.vision.enabled
|
||||
if self.node_data.vision.enabled
|
||||
else []
|
||||
)
|
||||
|
||||
|
|
@ -179,7 +177,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
node_inputs["#files#"] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
generator = self._fetch_context(node_data=self._node_data)
|
||||
generator = self._fetch_context(node_data=self.node_data)
|
||||
context = None
|
||||
for event in generator:
|
||||
context = event.context
|
||||
|
|
@ -189,7 +187,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
|
||||
# fetch model config
|
||||
model_instance, model_config = LLMNode._fetch_model_config(
|
||||
node_data_model=self._node_data.model,
|
||||
node_data_model=self.node_data.model,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
|
|
@ -197,13 +195,13 @@ class LLMNode(Node[LLMNodeData]):
|
|||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
node_data_memory=self._node_data.memory,
|
||||
node_data_memory=self.node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
query: str | None = None
|
||||
if self._node_data.memory:
|
||||
query = self._node_data.memory.query_prompt_template
|
||||
if self.node_data.memory:
|
||||
query = self.node_data.memory.query_prompt_template
|
||||
if not query and (
|
||||
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
):
|
||||
|
|
@ -215,29 +213,29 @@ class LLMNode(Node[LLMNodeData]):
|
|||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=self._node_data.prompt_template,
|
||||
memory_config=self._node_data.memory,
|
||||
vision_enabled=self._node_data.vision.enabled,
|
||||
vision_detail=self._node_data.vision.configs.detail,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self._node_data.prompt_config.jinja2_variables,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=self._node_data.model,
|
||||
node_data_model=self.node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
structured_output=self._node_data.structured_output,
|
||||
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||
structured_output=self.node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self._node_data.reasoning_format,
|
||||
reasoning_format=self.node_data.reasoning_format,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
|
@ -253,12 +251,12 @@ class LLMNode(Node[LLMNodeData]):
|
|||
reasoning_content = event.reasoning_content or ""
|
||||
|
||||
# For downstream nodes, determine clean text based on reasoning_format
|
||||
if self._node_data.reasoning_format == "tagged":
|
||||
if self.node_data.reasoning_format == "tagged":
|
||||
# Keep <think> tags for backward compatibility
|
||||
clean_text = result_text
|
||||
else:
|
||||
# Extract clean text from <think> tags
|
||||
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
|
||||
clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format)
|
||||
|
||||
# Process structured output if available from the event.
|
||||
structured_output = (
|
||||
|
|
@ -1204,7 +1202,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
return self.node_data.retry_config.retry_enabled
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
|
|
|
|||
|
|
@ -11,8 +11,6 @@ class LoopEndNode(Node[LoopEndNodeData]):
|
|||
|
||||
node_type = NodeType.LOOP_END
|
||||
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -46,7 +46,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||
"""
|
||||
|
||||
node_type = NodeType.LOOP
|
||||
_node_data: LoopNodeData
|
||||
execution_type = NodeExecutionType.CONTAINER
|
||||
|
||||
@classmethod
|
||||
|
|
@ -56,27 +55,27 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||
def _run(self) -> Generator:
|
||||
"""Run the node."""
|
||||
# Get inputs
|
||||
loop_count = self._node_data.loop_count
|
||||
break_conditions = self._node_data.break_conditions
|
||||
logical_operator = self._node_data.logical_operator
|
||||
loop_count = self.node_data.loop_count
|
||||
break_conditions = self.node_data.break_conditions
|
||||
logical_operator = self.node_data.logical_operator
|
||||
|
||||
inputs = {"loop_count": loop_count}
|
||||
|
||||
if not self._node_data.start_node_id:
|
||||
if not self.node_data.start_node_id:
|
||||
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
|
||||
|
||||
root_node_id = self._node_data.start_node_id
|
||||
root_node_id = self.node_data.start_node_id
|
||||
|
||||
# Initialize loop variables in the original variable pool
|
||||
loop_variable_selectors = {}
|
||||
if self._node_data.loop_variables:
|
||||
if self.node_data.loop_variables:
|
||||
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
|
||||
if isinstance(var.value, list)
|
||||
else None,
|
||||
}
|
||||
for loop_variable in self._node_data.loop_variables:
|
||||
for loop_variable in self.node_data.loop_variables:
|
||||
if loop_variable.value_type not in value_processor:
|
||||
raise ValueError(
|
||||
f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}"
|
||||
|
|
@ -164,7 +163,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||
|
||||
yield LoopNextEvent(
|
||||
index=i + 1,
|
||||
pre_loop_output=self._node_data.outputs,
|
||||
pre_loop_output=self.node_data.outputs,
|
||||
)
|
||||
|
||||
self._accumulate_usage(loop_usage)
|
||||
|
|
@ -172,7 +171,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||
yield LoopSucceededEvent(
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs=self._node_data.outputs,
|
||||
outputs=self.node_data.outputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens,
|
||||
|
|
@ -194,7 +193,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
|
||||
},
|
||||
outputs=self._node_data.outputs,
|
||||
outputs=self.node_data.outputs,
|
||||
inputs=inputs,
|
||||
llm_usage=loop_usage,
|
||||
)
|
||||
|
|
@ -252,11 +251,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
|||
if isinstance(event, GraphRunFailedEvent):
|
||||
raise Exception(event.error)
|
||||
|
||||
for loop_var in self._node_data.loop_variables or []:
|
||||
for loop_var in self.node_data.loop_variables or []:
|
||||
key, sel = loop_var.label, [self._node_id, loop_var.label]
|
||||
segment = self.graph_runtime_state.variable_pool.get(sel)
|
||||
self._node_data.outputs[key] = segment.value if segment else None
|
||||
self._node_data.outputs["loop_round"] = current_index + 1
|
||||
self.node_data.outputs[key] = segment.value if segment else None
|
||||
self.node_data.outputs["loop_round"] = current_index + 1
|
||||
|
||||
return reach_break_node
|
||||
|
||||
|
|
|
|||
|
|
@ -11,8 +11,6 @@ class LoopStartNode(Node[LoopStartNodeData]):
|
|||
|
||||
node_type = NodeType.LOOP_START
|
||||
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -90,8 +90,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
|||
|
||||
node_type = NodeType.PARAMETER_EXTRACTOR
|
||||
|
||||
_node_data: ParameterExtractorNodeData
|
||||
|
||||
_model_instance: ModelInstance | None = None
|
||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||
|
||||
|
|
@ -116,7 +114,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
|||
"""
|
||||
Run the node.
|
||||
"""
|
||||
node_data = self._node_data
|
||||
node_data = self.node_data
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
|
||||
query = variable.text if variable else ""
|
||||
|
||||
|
|
|
|||
|
|
@ -47,8 +47,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
|||
node_type = NodeType.QUESTION_CLASSIFIER
|
||||
execution_type = NodeExecutionType.BRANCH
|
||||
|
||||
_node_data: QuestionClassifierNodeData
|
||||
|
||||
_file_outputs: list["File"]
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
|
|
@ -82,7 +80,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
|||
return "1"
|
||||
|
||||
def _run(self):
|
||||
node_data = self._node_data
|
||||
node_data = self.node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# extract variables
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ class StartNode(Node[StartNodeData]):
|
|||
node_type = NodeType.START
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
_node_data: StartNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@ MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
|||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
|
||||
_node_data: TemplateTransformNodeData
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
|
|
@ -35,14 +33,14 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
|||
def _run(self) -> NodeRunResult:
|
||||
# Get variables
|
||||
variables: dict[str, Any] = {}
|
||||
for variable_selector in self._node_data.variables:
|
||||
for variable_selector in self.node_data.variables:
|
||||
variable_name = variable_selector.variable
|
||||
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
variables[variable_name] = value.to_object() if value else None
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
|
||||
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
|
|
|||
|
|
@ -47,8 +47,6 @@ class ToolNode(Node[ToolNodeData]):
|
|||
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
_node_data: ToolNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -59,13 +57,11 @@ class ToolNode(Node[ToolNodeData]):
|
|||
"""
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
|
||||
node_data = self._node_data
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
"provider_type": node_data.provider_type.value,
|
||||
"provider_id": node_data.provider_id,
|
||||
"plugin_unique_identifier": node_data.plugin_unique_identifier,
|
||||
"provider_type": self.node_data.provider_type.value,
|
||||
"provider_id": self.node_data.provider_id,
|
||||
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
|
||||
}
|
||||
|
||||
# get tool runtime
|
||||
|
|
@ -77,10 +73,10 @@ class ToolNode(Node[ToolNodeData]):
|
|||
# But for backward compatibility with historical data
|
||||
# this version field judgment is still preserved here.
|
||||
variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
if self.node_data.version != "1" or self.node_data.tool_node_version is not None:
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool
|
||||
self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
|
|
@ -99,12 +95,12 @@ class ToolNode(Node[ToolNodeData]):
|
|||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
node_data=self.node_data,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
)
|
||||
# get conversation id
|
||||
|
|
@ -149,7 +145,7 @@ class ToolNode(Node[ToolNodeData]):
|
|||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool {node_data.provider_name}: {str(e)}",
|
||||
error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
|
@ -159,7 +155,7 @@ class ToolNode(Node[ToolNodeData]):
|
|||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=e.to_user_friendly_error(plugin_name=node_data.provider_name),
|
||||
error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
|
@ -495,4 +491,4 @@ class ToolNode(Node[ToolNodeData]):
|
|||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
return self.node_data.retry_config.retry_enabled
|
||||
|
|
|
|||
|
|
@ -43,9 +43,9 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
|
|||
# Get trigger data passed when workflow was triggered
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||
"provider_id": self._node_data.provider_id,
|
||||
"event_name": self._node_data.event_name,
|
||||
"plugin_unique_identifier": self._node_data.plugin_unique_identifier,
|
||||
"provider_id": self.node_data.provider_id,
|
||||
"event_name": self.node_data.event_name,
|
||||
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
|
||||
},
|
||||
}
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class TriggerWebhookNode(Node[WebhookData]):
|
|||
webhook_headers = webhook_data.get("headers", {})
|
||||
webhook_headers_lower = {k.lower(): v for k, v in webhook_headers.items()}
|
||||
|
||||
for header in self._node_data.headers:
|
||||
for header in self.node_data.headers:
|
||||
header_name = header.name
|
||||
value = _get_normalized(webhook_headers, header_name)
|
||||
if value is None:
|
||||
|
|
@ -93,20 +93,20 @@ class TriggerWebhookNode(Node[WebhookData]):
|
|||
outputs[sanitized_name] = value
|
||||
|
||||
# Extract configured query parameters
|
||||
for param in self._node_data.params:
|
||||
for param in self.node_data.params:
|
||||
param_name = param.name
|
||||
outputs[param_name] = webhook_data.get("query_params", {}).get(param_name)
|
||||
|
||||
# Extract configured body parameters
|
||||
for body_param in self._node_data.body:
|
||||
for body_param in self.node_data.body:
|
||||
param_name = body_param.name
|
||||
param_type = body_param.type
|
||||
|
||||
if self._node_data.content_type == ContentType.TEXT:
|
||||
if self.node_data.content_type == ContentType.TEXT:
|
||||
# For text/plain, the entire body is a single string parameter
|
||||
outputs[param_name] = str(webhook_data.get("body", {}).get("raw", ""))
|
||||
continue
|
||||
elif self._node_data.content_type == ContentType.BINARY:
|
||||
elif self.node_data.content_type == ContentType.BINARY:
|
||||
outputs[param_name] = webhook_data.get("body", {}).get("raw", b"")
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ from core.workflow.nodes.variable_aggregator.entities import VariableAggregatorN
|
|||
class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
|
||||
node_type = NodeType.VARIABLE_AGGREGATOR
|
||||
|
||||
_node_data: VariableAggregatorNodeData
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
|
@ -21,8 +19,8 @@ class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
|
|||
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
|
||||
inputs = {}
|
||||
|
||||
if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled:
|
||||
for selector in self._node_data.variables:
|
||||
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
|
||||
for selector in self.node_data.variables:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if variable is not None:
|
||||
outputs = {"output": variable}
|
||||
|
|
@ -30,7 +28,7 @@ class VariableAggregatorNode(Node[VariableAggregatorNodeData]):
|
|||
inputs = {".".join(selector[1:]): variable.to_object()}
|
||||
break
|
||||
else:
|
||||
for group in self._node_data.advanced_settings.groups:
|
||||
for group in self.node_data.advanced_settings.groups:
|
||||
for selector in group.variables:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,8 +25,6 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
|||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
|
||||
|
||||
_node_data: VariableAssignerData
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
|
|
@ -71,21 +69,21 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
|||
return mapping
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
assigned_variable_selector = self._node_data.assigned_variable_selector
|
||||
assigned_variable_selector = self.node_data.assigned_variable_selector
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
match self._node_data.write_mode:
|
||||
match self.node_data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
|
|
|
|||
|
|
@ -53,8 +53,6 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
|
|||
class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
|
||||
"""
|
||||
Check if this Variable Assigner node blocks the output of specific variables.
|
||||
|
|
@ -62,7 +60,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
|||
Returns True if this node updates any of the requested conversation variables.
|
||||
"""
|
||||
# Check each item in this Variable Assigner node
|
||||
for item in self._node_data.items:
|
||||
for item in self.node_data.items:
|
||||
# Convert the item's variable_selector to tuple for comparison
|
||||
item_selector_tuple = tuple(item.variable_selector)
|
||||
|
||||
|
|
@ -97,13 +95,13 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
|||
return var_mapping
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
inputs = self._node_data.model_dump()
|
||||
inputs = self.node_data.model_dump()
|
||||
process_data: dict[str, Any] = {}
|
||||
# NOTE: This node has no outputs
|
||||
updated_variable_selectors: list[Sequence[str]] = []
|
||||
|
||||
try:
|
||||
for item in self._node_data.items:
|
||||
for item in self.node_data.items:
|
||||
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
|
||||
|
||||
# ==================== Validation Part
|
||||
|
|
|
|||
Loading…
Reference in New Issue