mirror of https://github.com/langgenius/dify.git
add answer output parse
This commit is contained in:
parent
5a67c09b48
commit
44c4d5be72
|
|
@ -5,7 +5,6 @@ from core.app.entities.queue_entities import (
|
|||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
|
|
@ -20,7 +19,6 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
|
|
@ -118,31 +116,4 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
|||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
if node_id in self._streamable_node_ids:
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
|
||||
"""
|
||||
Fetch streamable node ids
|
||||
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
|
||||
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
|
||||
|
||||
:param graph: workflow graph
|
||||
:return:
|
||||
"""
|
||||
streamable_node_ids = []
|
||||
end_node_ids = []
|
||||
for node_config in graph.get('nodes'):
|
||||
if node_config.get('data', {}).get('type') == NodeType.END.value:
|
||||
if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
|
||||
end_node_ids.append(node_config.get('id'))
|
||||
|
||||
for edge_config in graph.get('edges'):
|
||||
if edge_config.get('target') in end_node_ids:
|
||||
streamable_node_ids.append(edge_config.get('source'))
|
||||
|
||||
return streamable_node_ids
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import time
|
||||
from typing import cast
|
||||
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
|
|
@ -32,14 +31,49 @@ class AnswerNode(BaseNode):
|
|||
|
||||
variable_values[variable_selector.variable] = value
|
||||
|
||||
variable_keys = list(variable_values.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(node_data.answer)
|
||||
answer = template_parser.format(variable_values)
|
||||
template_variable_keys = template_parser.variable_keys
|
||||
|
||||
# publish answer as stream
|
||||
for word in answer:
|
||||
self.publish_text_chunk(word)
|
||||
time.sleep(10) # TODO for debug
|
||||
# Take the intersection of variable_keys and template_variable_keys
|
||||
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
|
||||
split_template = [
|
||||
{
|
||||
"type": "var" if self._is_variable(part, variable_keys) else "text",
|
||||
"value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part
|
||||
}
|
||||
for part in template.split('Ω') if part
|
||||
]
|
||||
|
||||
answer = []
|
||||
for part in split_template:
|
||||
if part["type"] == "var":
|
||||
value = variable_values.get(part["value"].replace('{{', '').replace('}}', ''))
|
||||
answer_part = {
|
||||
"type": "text",
|
||||
"text": value
|
||||
}
|
||||
# TODO File
|
||||
else:
|
||||
answer_part = {
|
||||
"type": "text",
|
||||
"text": part["value"]
|
||||
}
|
||||
|
||||
if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
|
||||
answer[-1]["text"] += answer_part["text"]
|
||||
else:
|
||||
answer.append(answer_part)
|
||||
|
||||
if len(answer) == 1 and answer[0]["type"] == "text":
|
||||
answer = answer[0]["text"]
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
|
@ -49,6 +83,10 @@ class AnswerNode(BaseNode):
|
|||
}
|
||||
)
|
||||
|
||||
def _is_variable(self, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
|||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class UserFrom(Enum):
|
||||
|
|
@ -80,16 +79,9 @@ class BaseNode(ABC):
|
|||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
result = self._run(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
except Exception as e:
|
||||
# process unhandled exception
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e)
|
||||
)
|
||||
result = self._run(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
|
||||
self.node_run_result = result
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ from typing import cast
|
|||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import ValueType, VariablePool
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData, EndNodeDataOutputs
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
|
|
@ -20,34 +20,14 @@ class EndNode(BaseNode):
|
|||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
outputs_config = node_data.outputs
|
||||
output_variables = node_data.outputs
|
||||
|
||||
outputs = None
|
||||
if outputs_config:
|
||||
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT:
|
||||
plain_text_selector = outputs_config.plain_text_selector
|
||||
if plain_text_selector:
|
||||
outputs = {
|
||||
'text': variable_pool.get_variable_value(
|
||||
variable_selector=plain_text_selector,
|
||||
target_value_type=ValueType.STRING
|
||||
)
|
||||
}
|
||||
else:
|
||||
outputs = {
|
||||
'text': ''
|
||||
}
|
||||
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
|
||||
structured_variables = outputs_config.structured_variables
|
||||
if structured_variables:
|
||||
outputs = {}
|
||||
for variable_selector in structured_variables:
|
||||
variable_value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
outputs[variable_selector.variable] = variable_value
|
||||
else:
|
||||
outputs = {}
|
||||
outputs = {}
|
||||
for variable_selector in output_variables:
|
||||
variable_value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
)
|
||||
outputs[variable_selector.variable] = variable_value
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
|
|
|||
|
|
@ -1,68 +1,9 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
|
||||
class EndNodeOutputType(Enum):
|
||||
"""
|
||||
END Node Output Types.
|
||||
|
||||
none, plain-text, structured
|
||||
"""
|
||||
NONE = 'none'
|
||||
PLAIN_TEXT = 'plain-text'
|
||||
STRUCTURED = 'structured'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'OutputType':
|
||||
"""
|
||||
Get value of given output type.
|
||||
|
||||
:param value: output type value
|
||||
:return: output type
|
||||
"""
|
||||
for output_type in cls:
|
||||
if output_type.value == value:
|
||||
return output_type
|
||||
raise ValueError(f'invalid output type value {value}')
|
||||
|
||||
|
||||
class EndNodeDataOutputs(BaseModel):
|
||||
"""
|
||||
END Node Data Outputs.
|
||||
"""
|
||||
class OutputType(Enum):
|
||||
"""
|
||||
Output Types.
|
||||
"""
|
||||
NONE = 'none'
|
||||
PLAIN_TEXT = 'plain-text'
|
||||
STRUCTURED = 'structured'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'OutputType':
|
||||
"""
|
||||
Get value of given output type.
|
||||
|
||||
:param value: output type value
|
||||
:return: output type
|
||||
"""
|
||||
for output_type in cls:
|
||||
if output_type.value == value:
|
||||
return output_type
|
||||
raise ValueError(f'invalid output type value {value}')
|
||||
|
||||
type: OutputType = OutputType.NONE
|
||||
plain_text_selector: Optional[list[str]] = None
|
||||
structured_variables: Optional[list[VariableSelector]] = None
|
||||
|
||||
|
||||
class EndNodeData(BaseNodeData):
|
||||
"""
|
||||
END Node Data.
|
||||
"""
|
||||
outputs: Optional[EndNodeDataOutputs] = None
|
||||
outputs: list[VariableSelector]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -41,6 +42,8 @@ node_classes = {
|
|||
NodeType.VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowEngineManager:
|
||||
def get_default_configs(self) -> list[dict]:
|
||||
|
|
@ -407,6 +410,7 @@ class WorkflowEngineManager:
|
|||
variable_pool=workflow_run_state.variable_pool
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
|
||||
node_run_result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e)
|
||||
|
|
|
|||
|
|
@ -531,10 +531,10 @@ class WorkflowConverter:
|
|||
"data": {
|
||||
"title": "END",
|
||||
"type": NodeType.END.value,
|
||||
"outputs": {
|
||||
"outputs": [{
|
||||
"variable": "result",
|
||||
"value_selector": ["llm", "text"]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,56 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
def test_execute_answer():
|
||||
node = AnswerNode(
|
||||
tenant_id='1',
|
||||
app_id='1',
|
||||
workflow_id='1',
|
||||
user_id='1',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
config={
|
||||
'id': 'answer',
|
||||
'data': {
|
||||
'title': '123',
|
||||
'type': 'answer',
|
||||
'variables': [
|
||||
{
|
||||
'value_selector': ['llm', 'text'],
|
||||
'variable': 'text'
|
||||
},
|
||||
{
|
||||
'value_selector': ['start', 'weather'],
|
||||
'variable': 'weather'
|
||||
},
|
||||
],
|
||||
'answer': 'Today\'s weather is {{weather}}\n{{text}}\n{{img}}\nFin.'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='start', variable_key_list=['weather'], value='sunny')
|
||||
pool.append_variable(node_id='llm', variable_key_list=['text'], value='You are a helpful AI.')
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run(pool)
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."
|
||||
|
||||
|
||||
# TODO test files
|
||||
Loading…
Reference in New Issue