mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 11:56:55 +08:00
add answer output parse
This commit is contained in:
parent
5a67c09b48
commit
44c4d5be72
@ -5,7 +5,6 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueTextChunkEvent,
|
|
||||||
QueueWorkflowFailedEvent,
|
QueueWorkflowFailedEvent,
|
||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
@ -20,7 +19,6 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||||||
|
|
||||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||||
self._queue_manager = queue_manager
|
self._queue_manager = queue_manager
|
||||||
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
|
|
||||||
|
|
||||||
def on_workflow_run_started(self) -> None:
|
def on_workflow_run_started(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -118,31 +116,4 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||||||
"""
|
"""
|
||||||
Publish text chunk
|
Publish text chunk
|
||||||
"""
|
"""
|
||||||
if node_id in self._streamable_node_ids:
|
pass
|
||||||
self._queue_manager.publish(
|
|
||||||
QueueTextChunkEvent(
|
|
||||||
text=text
|
|
||||||
), PublishFrom.APPLICATION_MANAGER
|
|
||||||
)
|
|
||||||
|
|
||||||
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
|
|
||||||
"""
|
|
||||||
Fetch streamable node ids
|
|
||||||
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
|
|
||||||
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
|
|
||||||
|
|
||||||
:param graph: workflow graph
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
streamable_node_ids = []
|
|
||||||
end_node_ids = []
|
|
||||||
for node_config in graph.get('nodes'):
|
|
||||||
if node_config.get('data', {}).get('type') == NodeType.END.value:
|
|
||||||
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
|
|
||||||
end_node_ids.append(node_config.get('id'))
|
|
||||||
|
|
||||||
for edge_config in graph.get('edges'):
|
|
||||||
if edge_config.get('target') in end_node_ids:
|
|
||||||
streamable_node_ids.append(edge_config.get('source'))
|
|
||||||
|
|
||||||
return streamable_node_ids
|
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
import time
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
@ -32,14 +31,49 @@ class AnswerNode(BaseNode):
|
|||||||
|
|
||||||
variable_values[variable_selector.variable] = value
|
variable_values[variable_selector.variable] = value
|
||||||
|
|
||||||
|
variable_keys = list(variable_values.keys())
|
||||||
|
|
||||||
# format answer template
|
# format answer template
|
||||||
template_parser = PromptTemplateParser(node_data.answer)
|
template_parser = PromptTemplateParser(node_data.answer)
|
||||||
answer = template_parser.format(variable_values)
|
template_variable_keys = template_parser.variable_keys
|
||||||
|
|
||||||
# publish answer as stream
|
# Take the intersection of variable_keys and template_variable_keys
|
||||||
for word in answer:
|
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||||
self.publish_text_chunk(word)
|
|
||||||
time.sleep(10) # TODO for debug
|
template = node_data.answer
|
||||||
|
for var in variable_keys:
|
||||||
|
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||||
|
|
||||||
|
split_template = [
|
||||||
|
{
|
||||||
|
"type": "var" if self._is_variable(part, variable_keys) else "text",
|
||||||
|
"value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part
|
||||||
|
}
|
||||||
|
for part in template.split('Ω') if part
|
||||||
|
]
|
||||||
|
|
||||||
|
answer = []
|
||||||
|
for part in split_template:
|
||||||
|
if part["type"] == "var":
|
||||||
|
value = variable_values.get(part["value"].replace('{{', '').replace('}}', ''))
|
||||||
|
answer_part = {
|
||||||
|
"type": "text",
|
||||||
|
"text": value
|
||||||
|
}
|
||||||
|
# TODO File
|
||||||
|
else:
|
||||||
|
answer_part = {
|
||||||
|
"type": "text",
|
||||||
|
"text": part["value"]
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
|
||||||
|
answer[-1]["text"] += answer_part["text"]
|
||||||
|
else:
|
||||||
|
answer.append(answer_part)
|
||||||
|
|
||||||
|
if len(answer) == 1 and answer[0]["type"] == "text":
|
||||||
|
answer = answer[0]["text"]
|
||||||
|
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
@ -49,6 +83,10 @@ class AnswerNode(BaseNode):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _is_variable(self, part, variable_keys):
|
||||||
|
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||||
|
return part.startswith('{{') and cleaned_part in variable_keys
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
|||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
|
||||||
|
|
||||||
|
|
||||||
class UserFrom(Enum):
|
class UserFrom(Enum):
|
||||||
@ -80,16 +79,9 @@ class BaseNode(ABC):
|
|||||||
:param variable_pool: variable pool
|
:param variable_pool: variable pool
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
try:
|
result = self._run(
|
||||||
result = self._run(
|
variable_pool=variable_pool
|
||||||
variable_pool=variable_pool
|
)
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# process unhandled exception
|
|
||||||
result = NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
|
||||||
error=str(e)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.node_run_result = result
|
self.node_run_result = result
|
||||||
return result
|
return result
|
||||||
|
|||||||
@ -2,9 +2,9 @@ from typing import cast
|
|||||||
|
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import ValueType, VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.nodes.base_node import BaseNode
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs
|
from core.workflow.nodes.end.entities import EndNodeData
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
@ -20,34 +20,14 @@ class EndNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
node_data = self.node_data
|
node_data = self.node_data
|
||||||
node_data = cast(self._node_data_cls, node_data)
|
node_data = cast(self._node_data_cls, node_data)
|
||||||
outputs_config = node_data.outputs
|
output_variables = node_data.outputs
|
||||||
|
|
||||||
outputs = None
|
outputs = {}
|
||||||
if outputs_config:
|
for variable_selector in output_variables:
|
||||||
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT:
|
variable_value = variable_pool.get_variable_value(
|
||||||
plain_text_selector = outputs_config.plain_text_selector
|
variable_selector=variable_selector.value_selector
|
||||||
if plain_text_selector:
|
)
|
||||||
outputs = {
|
outputs[variable_selector.variable] = variable_value
|
||||||
'text': variable_pool.get_variable_value(
|
|
||||||
variable_selector=plain_text_selector,
|
|
||||||
target_value_type=ValueType.STRING
|
|
||||||
)
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
outputs = {
|
|
||||||
'text': ''
|
|
||||||
}
|
|
||||||
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
|
|
||||||
structured_variables = outputs_config.structured_variables
|
|
||||||
if structured_variables:
|
|
||||||
outputs = {}
|
|
||||||
for variable_selector in structured_variables:
|
|
||||||
variable_value = variable_pool.get_variable_value(
|
|
||||||
variable_selector=variable_selector.value_selector
|
|
||||||
)
|
|
||||||
outputs[variable_selector.variable] = variable_value
|
|
||||||
else:
|
|
||||||
outputs = {}
|
|
||||||
|
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
|||||||
@ -1,68 +1,9 @@
|
|||||||
from enum import Enum
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.variable_entities import VariableSelector
|
from core.workflow.entities.variable_entities import VariableSelector
|
||||||
|
|
||||||
|
|
||||||
class EndNodeOutputType(Enum):
|
|
||||||
"""
|
|
||||||
END Node Output Types.
|
|
||||||
|
|
||||||
none, plain-text, structured
|
|
||||||
"""
|
|
||||||
NONE = 'none'
|
|
||||||
PLAIN_TEXT = 'plain-text'
|
|
||||||
STRUCTURED = 'structured'
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_of(cls, value: str) -> 'OutputType':
|
|
||||||
"""
|
|
||||||
Get value of given output type.
|
|
||||||
|
|
||||||
:param value: output type value
|
|
||||||
:return: output type
|
|
||||||
"""
|
|
||||||
for output_type in cls:
|
|
||||||
if output_type.value == value:
|
|
||||||
return output_type
|
|
||||||
raise ValueError(f'invalid output type value {value}')
|
|
||||||
|
|
||||||
|
|
||||||
class EndNodeDataOutputs(BaseModel):
|
|
||||||
"""
|
|
||||||
END Node Data Outputs.
|
|
||||||
"""
|
|
||||||
class OutputType(Enum):
|
|
||||||
"""
|
|
||||||
Output Types.
|
|
||||||
"""
|
|
||||||
NONE = 'none'
|
|
||||||
PLAIN_TEXT = 'plain-text'
|
|
||||||
STRUCTURED = 'structured'
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_of(cls, value: str) -> 'OutputType':
|
|
||||||
"""
|
|
||||||
Get value of given output type.
|
|
||||||
|
|
||||||
:param value: output type value
|
|
||||||
:return: output type
|
|
||||||
"""
|
|
||||||
for output_type in cls:
|
|
||||||
if output_type.value == value:
|
|
||||||
return output_type
|
|
||||||
raise ValueError(f'invalid output type value {value}')
|
|
||||||
|
|
||||||
type: OutputType = OutputType.NONE
|
|
||||||
plain_text_selector: Optional[list[str]] = None
|
|
||||||
structured_variables: Optional[list[VariableSelector]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class EndNodeData(BaseNodeData):
|
class EndNodeData(BaseNodeData):
|
||||||
"""
|
"""
|
||||||
END Node Data.
|
END Node Data.
|
||||||
"""
|
"""
|
||||||
outputs: Optional[EndNodeDataOutputs] = None
|
outputs: list[VariableSelector]
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -41,6 +42,8 @@ node_classes = {
|
|||||||
NodeType.VARIABLE_ASSIGNER: VariableAssignerNode,
|
NodeType.VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowEngineManager:
|
class WorkflowEngineManager:
|
||||||
def get_default_configs(self) -> list[dict]:
|
def get_default_configs(self) -> list[dict]:
|
||||||
@ -407,6 +410,7 @@ class WorkflowEngineManager:
|
|||||||
variable_pool=workflow_run_state.variable_pool
|
variable_pool=workflow_run_state.variable_pool
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
|
||||||
node_run_result = NodeRunResult(
|
node_run_result = NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
error=str(e)
|
error=str(e)
|
||||||
|
|||||||
@ -531,10 +531,10 @@ class WorkflowConverter:
|
|||||||
"data": {
|
"data": {
|
||||||
"title": "END",
|
"title": "END",
|
||||||
"type": NodeType.END.value,
|
"type": NodeType.END.value,
|
||||||
"outputs": {
|
"outputs": [{
|
||||||
"variable": "result",
|
"variable": "result",
|
||||||
"value_selector": ["llm", "text"]
|
"value_selector": ["llm", "text"]
|
||||||
}
|
}]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
56
api/tests/unit_tests/core/workflow/nodes/test_answer.py
Normal file
56
api/tests/unit_tests/core/workflow/nodes/test_answer.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from core.workflow.entities.node_entities import SystemVariable
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||||
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
|
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_answer():
|
||||||
|
node = AnswerNode(
|
||||||
|
tenant_id='1',
|
||||||
|
app_id='1',
|
||||||
|
workflow_id='1',
|
||||||
|
user_id='1',
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
config={
|
||||||
|
'id': 'answer',
|
||||||
|
'data': {
|
||||||
|
'title': '123',
|
||||||
|
'type': 'answer',
|
||||||
|
'variables': [
|
||||||
|
{
|
||||||
|
'value_selector': ['llm', 'text'],
|
||||||
|
'variable': 'text'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'value_selector': ['start', 'weather'],
|
||||||
|
'variable': 'weather'
|
||||||
|
},
|
||||||
|
],
|
||||||
|
'answer': 'Today\'s weather is {{weather}}\n{{text}}\n{{img}}\nFin.'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# construct variable pool
|
||||||
|
pool = VariablePool(system_variables={
|
||||||
|
SystemVariable.FILES: [],
|
||||||
|
}, user_inputs={})
|
||||||
|
pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny')
|
||||||
|
pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.')
|
||||||
|
|
||||||
|
# Mock db.session.close()
|
||||||
|
db.session.close = MagicMock()
|
||||||
|
|
||||||
|
# execute node
|
||||||
|
result = node._run(pool)
|
||||||
|
|
||||||
|
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
|
assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."
|
||||||
|
|
||||||
|
|
||||||
|
# TODO test files
|
||||||
Loading…
Reference in New Issue
Block a user