mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 08:57:28 +08:00
chore: fix the llm node memory issue
This commit is contained in:
parent
8154d0af53
commit
27de07e93d
@ -18,13 +18,14 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContentUnionTypes,
|
PromptMessageContentUnionTypes,
|
||||||
PromptMessageRole,
|
PromptMessageRole,
|
||||||
|
ToolPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
|
||||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.llm.entities import ModelConfig
|
from core.workflow.nodes.llm.entities import LLMGenerationData, ModelConfig
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
@ -214,21 +215,87 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
|||||||
def build_context(
|
def build_context(
|
||||||
prompt_messages: Sequence[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
assistant_response: str,
|
assistant_response: str,
|
||||||
|
generation_data: LLMGenerationData | None = None,
|
||||||
) -> list[PromptMessage]:
|
) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
Build context from prompt messages and assistant response.
|
Build context from prompt messages and assistant response.
|
||||||
Excludes system messages and includes the current LLM response.
|
Excludes system messages and includes the current LLM response.
|
||||||
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
|
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
|
||||||
|
|
||||||
|
For tool-enabled runs, reconstructs the full conversation including tool calls and results.
|
||||||
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
|
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_messages: Initial prompt messages (user query, etc.)
|
||||||
|
assistant_response: Final assistant response text
|
||||||
|
generation_data: Optional generation data containing trace for tool-enabled runs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
context_messages: list[PromptMessage] = [
|
context_messages: list[PromptMessage] = [
|
||||||
_truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
|
_truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
|
||||||
]
|
]
|
||||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
|
||||||
|
# For tool-enabled runs, reconstruct messages from trace
|
||||||
|
if generation_data and generation_data.trace:
|
||||||
|
context_messages.extend(_build_messages_from_trace(generation_data, assistant_response))
|
||||||
|
else:
|
||||||
|
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||||
|
|
||||||
return context_messages
|
return context_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _build_messages_from_trace(
|
||||||
|
generation_data: LLMGenerationData,
|
||||||
|
assistant_response: str,
|
||||||
|
) -> list[PromptMessage]:
|
||||||
|
"""
|
||||||
|
Build assistant and tool messages from trace segments.
|
||||||
|
|
||||||
|
Processes trace in order to reconstruct the conversation flow:
|
||||||
|
- Model segments with tool_calls -> AssistantPromptMessage with tool_calls
|
||||||
|
- Tool segments -> ToolPromptMessage with result
|
||||||
|
- Final response -> AssistantPromptMessage with assistant_response
|
||||||
|
"""
|
||||||
|
from core.workflow.nodes.llm.entities import ModelTraceSegment, ToolTraceSegment
|
||||||
|
|
||||||
|
messages: list[PromptMessage] = []
|
||||||
|
|
||||||
|
for segment in generation_data.trace:
|
||||||
|
if segment.type == "model" and isinstance(segment.output, ModelTraceSegment):
|
||||||
|
model_output = segment.output
|
||||||
|
segment_content = model_output.text or ""
|
||||||
|
|
||||||
|
if model_output.tool_calls:
|
||||||
|
# Build tool_calls for AssistantPromptMessage
|
||||||
|
tool_calls = [
|
||||||
|
AssistantPromptMessage.ToolCall(
|
||||||
|
id=tc.id or "",
|
||||||
|
type="function",
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=tc.name or "",
|
||||||
|
arguments=tc.arguments or "",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for tc in model_output.tool_calls
|
||||||
|
]
|
||||||
|
messages.append(AssistantPromptMessage(content=segment_content, tool_calls=tool_calls))
|
||||||
|
|
||||||
|
elif segment.type == "tool" and isinstance(segment.output, ToolTraceSegment):
|
||||||
|
tool_output = segment.output
|
||||||
|
messages.append(
|
||||||
|
ToolPromptMessage(
|
||||||
|
content=tool_output.output or "",
|
||||||
|
tool_call_id=tool_output.id or "",
|
||||||
|
name=tool_output.name or "",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add final assistant response as the authoritative text
|
||||||
|
messages.append(AssistantPromptMessage(content=assistant_response))
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
|
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
|
||||||
"""
|
"""
|
||||||
Truncate multi-modal content base64 data in a message to avoid storing large data.
|
Truncate multi-modal content base64 data in a message to avoid storing large data.
|
||||||
|
|||||||
@ -373,7 +373,7 @@ class LLMNode(Node[LLMNodeData]):
|
|||||||
"reasoning_content": reasoning_content,
|
"reasoning_content": reasoning_content,
|
||||||
"usage": jsonable_encoder(usage),
|
"usage": jsonable_encoder(usage),
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
"context": llm_utils.build_context(prompt_messages, clean_text),
|
"context": llm_utils.build_context(prompt_messages, clean_text, generation_data),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Build generation field
|
# Build generation field
|
||||||
|
|||||||
@ -171,9 +171,7 @@ class TestSandboxLayer:
|
|||||||
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
layer.on_event(GraphRunSucceededEvent(outputs={}))
|
||||||
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
|
layer.on_event(GraphRunFailedEvent(error="test error", exceptions_count=1))
|
||||||
|
|
||||||
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(
|
def test_on_graph_end_releases_sandbox_and_unregisters_from_manager(self, mock_sandbox_storage: MagicMock) -> None:
|
||||||
self, mock_sandbox_storage: MagicMock
|
|
||||||
) -> None:
|
|
||||||
sandbox_id = "test-exec-456"
|
sandbox_id = "test-exec-456"
|
||||||
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
|
layer = create_layer(sandbox_id=sandbox_id, sandbox_storage=mock_sandbox_storage)
|
||||||
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
mock_sandbox = MagicMock(spec=VirtualEnvironment)
|
||||||
|
|||||||
@ -123,6 +123,195 @@ class TestBuildContext:
|
|||||||
assert len(context) == 2
|
assert len(context) == 2
|
||||||
assert context[1].content == "The answer is 4."
|
assert context[1].content == "The answer is 4."
|
||||||
|
|
||||||
|
def test_builds_context_with_tool_calls_from_generation_data(self):
|
||||||
|
"""Should reconstruct full conversation including tool calls when generation_data is provided."""
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
|
)
|
||||||
|
from core.workflow.nodes.llm.entities import (
|
||||||
|
LLMGenerationData,
|
||||||
|
LLMTraceSegment,
|
||||||
|
ModelTraceSegment,
|
||||||
|
ToolCall,
|
||||||
|
ToolTraceSegment,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [UserPromptMessage(content="What's the weather in Beijing?")]
|
||||||
|
|
||||||
|
# Create trace with tool call and result
|
||||||
|
generation_data = LLMGenerationData(
|
||||||
|
text="The weather in Beijing is sunny, 25°C.",
|
||||||
|
reasoning_contents=[],
|
||||||
|
tool_calls=[],
|
||||||
|
sequence=[],
|
||||||
|
usage=LLMUsage.empty_usage(),
|
||||||
|
finish_reason="stop",
|
||||||
|
files=[],
|
||||||
|
trace=[
|
||||||
|
LLMTraceSegment(
|
||||||
|
type="model",
|
||||||
|
duration=0.5,
|
||||||
|
usage=None,
|
||||||
|
output=ModelTraceSegment(
|
||||||
|
text="Let me check the weather.",
|
||||||
|
reasoning=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
id="call_123",
|
||||||
|
name="get_weather",
|
||||||
|
arguments='{"city": "Beijing"}',
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
LLMTraceSegment(
|
||||||
|
type="tool",
|
||||||
|
duration=0.3,
|
||||||
|
usage=None,
|
||||||
|
output=ToolTraceSegment(
|
||||||
|
id="call_123",
|
||||||
|
name="get_weather",
|
||||||
|
arguments='{"city": "Beijing"}',
|
||||||
|
output="Sunny, 25°C",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
context = build_context(messages, "The weather in Beijing is sunny, 25°C.", generation_data)
|
||||||
|
|
||||||
|
# Should have: user message + assistant with tool_call + tool result + final assistant
|
||||||
|
assert len(context) == 4
|
||||||
|
assert context[0].content == "What's the weather in Beijing?"
|
||||||
|
assert isinstance(context[1], AssistantPromptMessage)
|
||||||
|
assert context[1].content == "Let me check the weather."
|
||||||
|
assert len(context[1].tool_calls) == 1
|
||||||
|
assert context[1].tool_calls[0].id == "call_123"
|
||||||
|
assert context[1].tool_calls[0].function.name == "get_weather"
|
||||||
|
assert isinstance(context[2], ToolPromptMessage)
|
||||||
|
assert context[2].content == "Sunny, 25°C"
|
||||||
|
assert context[2].tool_call_id == "call_123"
|
||||||
|
assert isinstance(context[3], AssistantPromptMessage)
|
||||||
|
assert context[3].content == "The weather in Beijing is sunny, 25°C."
|
||||||
|
|
||||||
|
def test_builds_context_with_multiple_tool_calls(self):
|
||||||
|
"""Should handle multiple tool calls in a single conversation."""
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ToolPromptMessage,
|
||||||
|
)
|
||||||
|
from core.workflow.nodes.llm.entities import (
|
||||||
|
LLMGenerationData,
|
||||||
|
LLMTraceSegment,
|
||||||
|
ModelTraceSegment,
|
||||||
|
ToolCall,
|
||||||
|
ToolTraceSegment,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [UserPromptMessage(content="Compare weather in Beijing and Shanghai")]
|
||||||
|
|
||||||
|
generation_data = LLMGenerationData(
|
||||||
|
text="Beijing is sunny at 25°C, Shanghai is cloudy at 22°C.",
|
||||||
|
reasoning_contents=[],
|
||||||
|
tool_calls=[],
|
||||||
|
sequence=[],
|
||||||
|
usage=LLMUsage.empty_usage(),
|
||||||
|
finish_reason="stop",
|
||||||
|
files=[],
|
||||||
|
trace=[
|
||||||
|
# First model call with two tool calls
|
||||||
|
LLMTraceSegment(
|
||||||
|
type="model",
|
||||||
|
duration=0.5,
|
||||||
|
usage=None,
|
||||||
|
output=ModelTraceSegment(
|
||||||
|
text="I'll check both cities.",
|
||||||
|
reasoning=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(id="call_1", name="get_weather", arguments='{"city": "Beijing"}'),
|
||||||
|
ToolCall(id="call_2", name="get_weather", arguments='{"city": "Shanghai"}'),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# First tool result
|
||||||
|
LLMTraceSegment(
|
||||||
|
type="tool",
|
||||||
|
duration=0.2,
|
||||||
|
usage=None,
|
||||||
|
output=ToolTraceSegment(
|
||||||
|
id="call_1",
|
||||||
|
name="get_weather",
|
||||||
|
arguments='{"city": "Beijing"}',
|
||||||
|
output="Sunny, 25°C",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
# Second tool result
|
||||||
|
LLMTraceSegment(
|
||||||
|
type="tool",
|
||||||
|
duration=0.2,
|
||||||
|
usage=None,
|
||||||
|
output=ToolTraceSegment(
|
||||||
|
id="call_2",
|
||||||
|
name="get_weather",
|
||||||
|
arguments='{"city": "Shanghai"}',
|
||||||
|
output="Cloudy, 22°C",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
context = build_context(messages, "Beijing is sunny at 25°C, Shanghai is cloudy at 22°C.", generation_data)
|
||||||
|
|
||||||
|
# Should have: user + assistant with 2 tool_calls + 2 tool results + final assistant
|
||||||
|
assert len(context) == 5
|
||||||
|
assert context[0].content == "Compare weather in Beijing and Shanghai"
|
||||||
|
assert isinstance(context[1], AssistantPromptMessage)
|
||||||
|
assert len(context[1].tool_calls) == 2
|
||||||
|
assert isinstance(context[2], ToolPromptMessage)
|
||||||
|
assert context[2].content == "Sunny, 25°C"
|
||||||
|
assert isinstance(context[3], ToolPromptMessage)
|
||||||
|
assert context[3].content == "Cloudy, 22°C"
|
||||||
|
assert isinstance(context[4], AssistantPromptMessage)
|
||||||
|
assert context[4].content == "Beijing is sunny at 25°C, Shanghai is cloudy at 22°C."
|
||||||
|
|
||||||
|
def test_builds_context_without_generation_data(self):
|
||||||
|
"""Should fallback to simple context when no generation_data is provided."""
|
||||||
|
messages = [UserPromptMessage(content="Hello!")]
|
||||||
|
|
||||||
|
context = build_context(messages, "Hi there!", generation_data=None)
|
||||||
|
|
||||||
|
assert len(context) == 2
|
||||||
|
assert context[0].content == "Hello!"
|
||||||
|
assert context[1].content == "Hi there!"
|
||||||
|
|
||||||
|
def test_builds_context_with_empty_trace(self):
|
||||||
|
"""Should fallback to simple context when trace is empty."""
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.workflow.nodes.llm.entities import LLMGenerationData
|
||||||
|
|
||||||
|
messages = [UserPromptMessage(content="Hello!")]
|
||||||
|
|
||||||
|
generation_data = LLMGenerationData(
|
||||||
|
text="Hi there!",
|
||||||
|
reasoning_contents=[],
|
||||||
|
tool_calls=[],
|
||||||
|
sequence=[],
|
||||||
|
usage=LLMUsage.empty_usage(),
|
||||||
|
finish_reason="stop",
|
||||||
|
files=[],
|
||||||
|
trace=[], # Empty trace
|
||||||
|
)
|
||||||
|
|
||||||
|
context = build_context(messages, "Hi there!", generation_data)
|
||||||
|
|
||||||
|
# Should fallback to simple context
|
||||||
|
assert len(context) == 2
|
||||||
|
assert context[0].content == "Hello!"
|
||||||
|
assert context[1].content == "Hi there!"
|
||||||
|
|
||||||
|
|
||||||
class TestRestoreMultimodalContentInMessages:
|
class TestRestoreMultimodalContentInMessages:
|
||||||
"""Tests for restore_multimodal_content_in_messages function."""
|
"""Tests for restore_multimodal_content_in_messages function."""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user