mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 12:59:18 +08:00
Adapt the backend Graphon integration to the v0.3.0 breaking changes. Migrate provider factory and runtime usage, switch workflow node construction to the new data payload API, and refresh backend tests for the updated VariablePool and node behaviors.
319 lines
12 KiB
Python
319 lines
12 KiB
Python
import json
|
|
import time
|
|
import uuid
|
|
from collections.abc import Generator
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
|
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
|
from core.model_manager import ModelInstance
|
|
from core.workflow.system_variables import build_system_variables
|
|
from extensions.ext_database import db
|
|
from graphon.enums import WorkflowNodeExecutionStatus
|
|
from graphon.node_events import StreamCompletedEvent
|
|
from graphon.nodes.llm.entities import LLMNodeData
|
|
from graphon.nodes.llm.file_saver import LLMFileSaver
|
|
from graphon.nodes.llm.node import LLMNode
|
|
from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
|
from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol
|
|
from graphon.nodes.protocols import HttpClientProtocol
|
|
from graphon.runtime import GraphRuntimeState, VariablePool
|
|
from tests.workflow_test_utils import build_test_graph_init_params
|
|
|
|
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
|
|
|
|
|
def init_llm_node(config: dict) -> LLMNode:
|
|
graph_config = {
|
|
"edges": [
|
|
{
|
|
"id": "start-source-next-target",
|
|
"source": "start",
|
|
"target": "llm",
|
|
},
|
|
],
|
|
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
|
}
|
|
|
|
# Use proper UUIDs for database compatibility
|
|
tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
|
app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
|
|
workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
|
|
user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
|
|
|
|
init_params = build_test_graph_init_params(
|
|
workflow_id=workflow_id,
|
|
graph_config=graph_config,
|
|
tenant_id=tenant_id,
|
|
app_id=app_id,
|
|
user_id=user_id,
|
|
user_from=UserFrom.ACCOUNT,
|
|
invoke_from=InvokeFrom.DEBUGGER,
|
|
call_depth=0,
|
|
)
|
|
|
|
# construct variable pool
|
|
variable_pool = VariablePool(
|
|
system_variables=build_system_variables(
|
|
user_id="aaa",
|
|
app_id=app_id,
|
|
workflow_id=workflow_id,
|
|
files=[],
|
|
query="what's the weather today?",
|
|
conversation_id="abababa",
|
|
),
|
|
user_inputs={},
|
|
environment_variables=[],
|
|
conversation_variables=[],
|
|
)
|
|
variable_pool.add(["abc", "output"], "sunny")
|
|
|
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
|
prompt_message_serializer = MagicMock(spec=PromptMessageSerializerProtocol)
|
|
prompt_message_serializer.serialize.side_effect = lambda *, model_mode, prompt_messages: [
|
|
message.model_dump(mode="json") for message in prompt_messages
|
|
]
|
|
llm_file_saver = MagicMock(spec=LLMFileSaver)
|
|
|
|
node = LLMNode(
|
|
node_id=str(uuid.uuid4()),
|
|
data=LLMNodeData.model_validate(config["data"]),
|
|
graph_init_params=init_params,
|
|
graph_runtime_state=graph_runtime_state,
|
|
credentials_provider=MagicMock(spec=CredentialsProvider),
|
|
model_factory=MagicMock(spec=ModelFactory),
|
|
model_instance=MagicMock(spec=ModelInstance),
|
|
llm_file_saver=llm_file_saver,
|
|
prompt_message_serializer=prompt_message_serializer,
|
|
http_client=MagicMock(spec=HttpClientProtocol),
|
|
)
|
|
|
|
return node
|
|
|
|
|
|
def test_execute_llm():
|
|
node = init_llm_node(
|
|
config={
|
|
"id": "llm",
|
|
"data": {
|
|
"title": "123",
|
|
"type": "llm",
|
|
"model": {
|
|
"provider": "openai",
|
|
"name": "gpt-3.5-turbo",
|
|
"mode": "chat",
|
|
"completion_params": {},
|
|
},
|
|
"prompt_template": [
|
|
{
|
|
"role": "system",
|
|
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
|
|
},
|
|
{"role": "user", "text": "{{#sys.query#}}"},
|
|
],
|
|
"memory": None,
|
|
"context": {"enabled": False},
|
|
"vision": {"enabled": False},
|
|
},
|
|
},
|
|
)
|
|
|
|
db.session.close = MagicMock()
|
|
|
|
def build_mock_model_instance() -> MagicMock:
|
|
from decimal import Decimal
|
|
from unittest.mock import MagicMock
|
|
|
|
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
|
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
|
|
|
|
# Create mock model instance
|
|
mock_model_instance = MagicMock(spec=ModelInstance)
|
|
mock_model_instance.provider = "openai"
|
|
mock_model_instance.model_name = "gpt-3.5-turbo"
|
|
mock_model_instance.credentials = {}
|
|
mock_model_instance.parameters = {}
|
|
mock_model_instance.stop = []
|
|
mock_model_instance.model_type_instance = MagicMock()
|
|
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
|
|
model_properties={},
|
|
parameter_rules=[],
|
|
features=[],
|
|
)
|
|
mock_model_instance.provider_model_bundle = MagicMock()
|
|
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
|
|
mock_usage = LLMUsage(
|
|
prompt_tokens=30,
|
|
prompt_unit_price=Decimal("0.001"),
|
|
prompt_price_unit=Decimal(1000),
|
|
prompt_price=Decimal("0.00003"),
|
|
completion_tokens=20,
|
|
completion_unit_price=Decimal("0.002"),
|
|
completion_price_unit=Decimal(1000),
|
|
completion_price=Decimal("0.00004"),
|
|
total_tokens=50,
|
|
total_price=Decimal("0.00007"),
|
|
currency="USD",
|
|
latency=0.5,
|
|
)
|
|
mock_message = AssistantPromptMessage(content="Test response from mock")
|
|
mock_llm_result = LLMResult(
|
|
model="gpt-3.5-turbo",
|
|
prompt_messages=[],
|
|
message=mock_message,
|
|
usage=mock_usage,
|
|
)
|
|
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
|
|
|
return mock_model_instance
|
|
|
|
# Mock fetch_prompt_messages to avoid database calls
|
|
def mock_fetch_prompt_messages_1(**_kwargs):
|
|
from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
|
|
|
return [
|
|
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
|
|
UserPromptMessage(content="what's the weather today?"),
|
|
], []
|
|
|
|
node._model_instance = build_mock_model_instance()
|
|
|
|
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1):
|
|
# execute node
|
|
result = node._run()
|
|
assert isinstance(result, Generator)
|
|
|
|
for item in result:
|
|
if isinstance(item, StreamCompletedEvent):
|
|
if item.node_run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
|
|
print(f"Error: {item.node_run_result.error}")
|
|
print(f"Error type: {item.node_run_result.error_type}")
|
|
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
|
assert item.node_run_result.process_data is not None
|
|
assert item.node_run_result.outputs is not None
|
|
assert item.node_run_result.outputs.get("text") is not None
|
|
assert item.node_run_result.outputs.get("usage", {})["total_tokens"] > 0
|
|
|
|
|
|
def test_execute_llm_with_jinja2():
|
|
"""
|
|
Test execute LLM node with jinja2
|
|
"""
|
|
node = init_llm_node(
|
|
config={
|
|
"id": "llm",
|
|
"data": {
|
|
"title": "123",
|
|
"type": "llm",
|
|
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
|
"prompt_config": {
|
|
"jinja2_variables": [
|
|
{"variable": "sys_query", "value_selector": ["sys", "query"]},
|
|
{"variable": "output", "value_selector": ["abc", "output"]},
|
|
]
|
|
},
|
|
"prompt_template": [
|
|
{
|
|
"role": "system",
|
|
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
|
|
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
|
|
"edition_type": "jinja2",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"text": "{{#sys.query#}}",
|
|
"jinja2_text": "{{sys_query}}",
|
|
"edition_type": "basic",
|
|
},
|
|
],
|
|
"memory": None,
|
|
"context": {"enabled": False},
|
|
"vision": {"enabled": False},
|
|
},
|
|
},
|
|
)
|
|
|
|
# Mock db.session.close()
|
|
db.session.close = MagicMock()
|
|
|
|
def build_mock_model_instance() -> MagicMock:
|
|
from decimal import Decimal
|
|
from unittest.mock import MagicMock
|
|
|
|
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
|
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
|
|
|
|
# Create mock model instance
|
|
mock_model_instance = MagicMock(spec=ModelInstance)
|
|
mock_model_instance.provider = "openai"
|
|
mock_model_instance.model_name = "gpt-3.5-turbo"
|
|
mock_model_instance.credentials = {}
|
|
mock_model_instance.parameters = {}
|
|
mock_model_instance.stop = []
|
|
mock_model_instance.model_type_instance = MagicMock()
|
|
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
|
|
model_properties={},
|
|
parameter_rules=[],
|
|
features=[],
|
|
)
|
|
mock_model_instance.provider_model_bundle = MagicMock()
|
|
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
|
|
mock_usage = LLMUsage(
|
|
prompt_tokens=30,
|
|
prompt_unit_price=Decimal("0.001"),
|
|
prompt_price_unit=Decimal(1000),
|
|
prompt_price=Decimal("0.00003"),
|
|
completion_tokens=20,
|
|
completion_unit_price=Decimal("0.002"),
|
|
completion_price_unit=Decimal(1000),
|
|
completion_price=Decimal("0.00004"),
|
|
total_tokens=50,
|
|
total_price=Decimal("0.00007"),
|
|
currency="USD",
|
|
latency=0.5,
|
|
)
|
|
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
|
|
mock_llm_result = LLMResult(
|
|
model="gpt-3.5-turbo",
|
|
prompt_messages=[],
|
|
message=mock_message,
|
|
usage=mock_usage,
|
|
)
|
|
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
|
|
|
return mock_model_instance
|
|
|
|
# Mock fetch_prompt_messages to avoid database calls
|
|
def mock_fetch_prompt_messages_2(**_kwargs):
|
|
from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
|
|
|
return [
|
|
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
|
|
UserPromptMessage(content="what's the weather today?"),
|
|
], []
|
|
|
|
node._model_instance = build_mock_model_instance()
|
|
|
|
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2):
|
|
# execute node
|
|
result = node._run()
|
|
|
|
for item in result:
|
|
if isinstance(item, StreamCompletedEvent):
|
|
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
|
assert item.node_run_result.process_data is not None
|
|
assert "sunny" in json.dumps(item.node_run_result.process_data)
|
|
assert "what's the weather today?" in json.dumps(item.node_run_result.process_data)
|
|
|
|
|
|
def test_extract_json():
|
|
llm_texts = [
|
|
'<think>\n\n</think>{"name": "test", "age": 123', # resoning model (deepseek-r1)
|
|
'{"name":"test","age":123}', # json schema model (gpt-4o)
|
|
'{\n "name": "test",\n "age": 123\n}', # small model (llama-3.2-1b)
|
|
'```json\n{"name": "test", "age": 123}\n```', # json markdown (deepseek-chat)
|
|
'{"name":"test",age:123}', # without quotes (qwen-2.5-0.5b)
|
|
]
|
|
result = {"name": "test", "age": 123}
|
|
assert all(_parse_structured_output(item) == result for item in llm_texts)
|