dify/api/tests/integration_tests/workflow/nodes/test_llm.py
-LAN- 05aba26091 chore(api): upgrade graphon to v0.3.0
Adapt the backend Graphon integration to the v0.3.0 breaking changes.

Migrate provider factory and runtime usage, switch workflow node construction to the new data payload API, and refresh backend tests for the updated VariablePool and node behaviors.
2026-05-08 13:36:25 +08:00

319 lines
12 KiB
Python

import json
import time
import uuid
from collections.abc import Generator
from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.llm_generator.output_parser.structured_output import _parse_structured_output
from core.model_manager import ModelInstance
from core.workflow.system_variables import build_system_variables
from extensions.ext_database import db
from graphon.enums import WorkflowNodeExecutionStatus
from graphon.node_events import StreamCompletedEvent
from graphon.nodes.llm.entities import LLMNodeData
from graphon.nodes.llm.file_saver import LLMFileSaver
from graphon.nodes.llm.node import LLMNode
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol
from graphon.nodes.protocols import HttpClientProtocol
from graphon.runtime import GraphRuntimeState, VariablePool
from tests.workflow_test_utils import build_test_graph_init_params
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
def init_llm_node(config: dict) -> LLMNode:
graph_config = {
"edges": [
{
"id": "start-source-next-target",
"source": "start",
"target": "llm",
},
],
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
}
# Use proper UUIDs for database compatibility
tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
init_params = build_test_graph_init_params(
workflow_id=workflow_id,
graph_config=graph_config,
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
variable_pool = VariablePool(
system_variables=build_system_variables(
user_id="aaa",
app_id=app_id,
workflow_id=workflow_id,
files=[],
query="what's the weather today?",
conversation_id="abababa",
),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
variable_pool.add(["abc", "output"], "sunny")
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
prompt_message_serializer = MagicMock(spec=PromptMessageSerializerProtocol)
prompt_message_serializer.serialize.side_effect = lambda *, model_mode, prompt_messages: [
message.model_dump(mode="json") for message in prompt_messages
]
llm_file_saver = MagicMock(spec=LLMFileSaver)
node = LLMNode(
node_id=str(uuid.uuid4()),
data=LLMNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
llm_file_saver=llm_file_saver,
prompt_message_serializer=prompt_message_serializer,
http_client=MagicMock(spec=HttpClientProtocol),
)
return node
def test_execute_llm():
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
},
},
)
db.session.close = MagicMock()
def build_mock_model_instance() -> MagicMock:
from decimal import Decimal
from unittest.mock import MagicMock
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock(spec=ModelInstance)
mock_model_instance.provider = "openai"
mock_model_instance.model_name = "gpt-3.5-turbo"
mock_model_instance.credentials = {}
mock_model_instance.parameters = {}
mock_model_instance.stop = []
mock_model_instance.model_type_instance = MagicMock()
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
model_properties={},
parameter_rules=[],
features=[],
)
mock_model_instance.provider_model_bundle = MagicMock()
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response from mock")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_1(**_kwargs):
from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
UserPromptMessage(content="what's the weather today?"),
], []
node._model_instance = build_mock_model_instance()
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1):
# execute node
result = node._run()
assert isinstance(result, Generator)
for item in result:
if isinstance(item, StreamCompletedEvent):
if item.node_run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
print(f"Error: {item.node_run_result.error}")
print(f"Error type: {item.node_run_result.error_type}")
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.node_run_result.process_data is not None
assert item.node_run_result.outputs is not None
assert item.node_run_result.outputs.get("text") is not None
assert item.node_run_result.outputs.get("usage", {})["total_tokens"] > 0
def test_execute_llm_with_jinja2():
"""
Test execute LLM node with jinja2
"""
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_config": {
"jinja2_variables": [
{"variable": "sys_query", "value_selector": ["sys", "query"]},
{"variable": "output", "value_selector": ["abc", "output"]},
]
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
"edition_type": "jinja2",
},
{
"role": "user",
"text": "{{#sys.query#}}",
"jinja2_text": "{{sys_query}}",
"edition_type": "basic",
},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
},
},
)
# Mock db.session.close()
db.session.close = MagicMock()
def build_mock_model_instance() -> MagicMock:
from decimal import Decimal
from unittest.mock import MagicMock
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock(spec=ModelInstance)
mock_model_instance.provider = "openai"
mock_model_instance.model_name = "gpt-3.5-turbo"
mock_model_instance.credentials = {}
mock_model_instance.parameters = {}
mock_model_instance.stop = []
mock_model_instance.model_type_instance = MagicMock()
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
model_properties={},
parameter_rules=[],
features=[],
)
mock_model_instance.provider_model_bundle = MagicMock()
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_2(**_kwargs):
from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
UserPromptMessage(content="what's the weather today?"),
], []
node._model_instance = build_mock_model_instance()
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2):
# execute node
result = node._run()
for item in result:
if isinstance(item, StreamCompletedEvent):
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.node_run_result.process_data is not None
assert "sunny" in json.dumps(item.node_run_result.process_data)
assert "what's the weather today?" in json.dumps(item.node_run_result.process_data)
def test_extract_json():
llm_texts = [
'<think>\n\n</think>{"name": "test", "age": 123', # resoning model (deepseek-r1)
'{"name":"test","age":123}', # json schema model (gpt-4o)
'{\n "name": "test",\n "age": 123\n}', # small model (llama-3.2-1b)
'```json\n{"name": "test", "age": 123}\n```', # json markdown (deepseek-chat)
'{"name":"test",age:123}', # without quotes (qwen-2.5-0.5b)
]
result = {"name": "test", "age": 123}
assert all(_parse_structured_output(item) == result for item in llm_texts)