From 289b59208a5763a5355aec1042d5d0ff6955e8c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yanli=20=E7=9B=90=E7=B2=92?= Date: Sun, 25 Jan 2026 23:00:13 +0800 Subject: [PATCH] make the model_runtime support reading and parsing the opaque_body from plugin LLM call (and fix the tool call parsing in streaming mode) --- .../__base/large_language_model.py.md | 20 +++ .../__base/test_llm_invoke_opaque_body.py.md | 12 ++ .../__base/large_language_model.py | 16 +- .../__base/test_llm_invoke_opaque_body.py | 145 ++++++++++++++++++ 4 files changed, 192 insertions(+), 1 deletion(-) create mode 100644 api/agent-notes/core/model_runtime/model_providers/__base/large_language_model.py.md create mode 100644 api/agent-notes/tests/unit_tests/core/model_runtime/__base/test_llm_invoke_opaque_body.py.md create mode 100644 api/tests/unit_tests/core/model_runtime/__base/test_llm_invoke_opaque_body.py diff --git a/api/agent-notes/core/model_runtime/model_providers/__base/large_language_model.py.md b/api/agent-notes/core/model_runtime/model_providers/__base/large_language_model.py.md new file mode 100644 index 0000000000..1a05d4ae69 --- /dev/null +++ b/api/agent-notes/core/model_runtime/model_providers/__base/large_language_model.py.md @@ -0,0 +1,20 @@ +## Purpose + +`core/model_runtime/model_providers/__base/large_language_model.py` defines the base `LargeLanguageModel` interface used +by model providers, including plugin-backed providers via `PluginModelClient`. + +## Plugin invocation flow + +- For plugin-based providers, `invoke()` delegates to `PluginModelClient.invoke_llm(...)`, which streams + `LLMResultChunk` objects from the plugin daemon. +- Dify yields chunks to callers and also aggregates chunks to fire `after_invoke` callbacks (and to construct a + blocking `LLMResult` when `stream=False`). + +## Key invariants / edge cases + +- When aggregating chunks into an `LLMResult`, preserve provider-specific fields on the assistant message: + - `AssistantPromptMessage.opaque_body` (pass-through, uninterpreted JSON). + - Incremental `tool_calls` (merge deltas via `_increase_tool_call`). +- Chunk `.prompt_messages` may be empty for plugin responses (compat layer for the plugin daemon); Dify re-attaches the + original request `prompt_messages` for downstream consumers. + diff --git a/api/agent-notes/tests/unit_tests/core/model_runtime/__base/test_llm_invoke_opaque_body.py.md b/api/agent-notes/tests/unit_tests/core/model_runtime/__base/test_llm_invoke_opaque_body.py.md new file mode 100644 index 0000000000..b5ba0ac7a0 --- /dev/null +++ b/api/agent-notes/tests/unit_tests/core/model_runtime/__base/test_llm_invoke_opaque_body.py.md @@ -0,0 +1,12 @@ +## Purpose + +Unit tests for plugin-backed `LargeLanguageModel.invoke()` behavior around preserving provider pass-through data. + +## What it covers + +- `AssistantPromptMessage.opaque_body` from plugin `LLMResultChunk` deltas is preserved: + - On the returned `LLMResult` in blocking (`stream=False`) mode. + - On the aggregated `LLMResult` passed to `on_after_invoke` callbacks in streaming mode. +- Streaming mode also verifies that `chunk.prompt_messages` is re-attached to the original request prompt messages. +- Streaming aggregation merges incremental `tool_calls` across chunks. + diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index c0f4c504d9..5c65b97cdb 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -164,6 +164,7 @@ class LargeLanguageModel(AIModel): usage = LLMUsage.empty_usage() system_fingerprint = None tools_calls: list[AssistantPromptMessage.ToolCall] = [] + assistant_opaque_body = None for chunk in result: if isinstance(chunk.delta.message.content, str): @@ -172,6 +173,8 @@ class LargeLanguageModel(AIModel): content_list.extend(chunk.delta.message.content) if chunk.delta.message.tool_calls: _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) + if assistant_opaque_body is None and chunk.delta.message.opaque_body is not None: + assistant_opaque_body = chunk.delta.message.opaque_body usage = chunk.delta.usage or LLMUsage.empty_usage() system_fingerprint = chunk.system_fingerprint @@ -183,6 +186,7 @@ class LargeLanguageModel(AIModel): message=AssistantPromptMessage( content=content or content_list, tool_calls=tools_calls, + opaque_body=assistant_opaque_body, ), usage=usage, system_fingerprint=system_fingerprint, @@ -261,6 +265,8 @@ class LargeLanguageModel(AIModel): usage = None system_fingerprint = None real_model = model + assistant_opaque_body = None + tools_calls: list[AssistantPromptMessage.ToolCall] = [] def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None): if not content: @@ -294,6 +300,10 @@ class LargeLanguageModel(AIModel): ) _update_message_content(chunk.delta.message.content) + if chunk.delta.message.tool_calls: + _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) + if assistant_opaque_body is None and chunk.delta.message.opaque_body is not None: + assistant_opaque_body = chunk.delta.message.opaque_body real_model = chunk.model if chunk.delta.usage: @@ -304,7 +314,11 @@ class LargeLanguageModel(AIModel): except Exception as e: raise self._transform_invoke_error(e) - assistant_message = AssistantPromptMessage(content=message_content) + assistant_message = AssistantPromptMessage( + content=message_content, + tool_calls=tools_calls, + opaque_body=assistant_opaque_body, + ) self._trigger_after_invoke_callbacks( model=model, result=LLMResult( diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_llm_invoke_opaque_body.py b/api/tests/unit_tests/core/model_runtime/__base/test_llm_invoke_opaque_body.py new file mode 100644 index 0000000000..4b593ceddd --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/__base/test_llm_invoke_opaque_body.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any +from unittest.mock import patch + +from core.model_runtime.callbacks.base_callback import Callback +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity + + +class _CaptureAfterInvokeCallback(Callback): + after_result: LLMResult | None + + def __init__(self) -> None: + self.after_result = None + + def on_before_invoke(self, **kwargs: Any) -> None: # noqa: ANN401 + return None + + def on_new_chunk(self, **kwargs: Any) -> None: # noqa: ANN401 + return None + + def on_after_invoke(self, result: LLMResult, **kwargs: Any) -> None: # noqa: ANN401 + self.after_result = result + + def on_invoke_error(self, **kwargs: Any) -> None: # noqa: ANN401 + return None + + +def _build_llm_instance() -> LargeLanguageModel: + declaration = ProviderEntity( + provider="test", + label=I18nObject(en_US="test"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + plugin_model_provider = PluginModelProviderEntity( + id="pmp_1", + created_at=datetime.now(), + updated_at=datetime.now(), + provider="test", + tenant_id="tenant_1", + plugin_unique_identifier="test/plugin", + plugin_id="test/plugin", + declaration=declaration, + ) + + return LargeLanguageModel( + tenant_id="tenant_1", + plugin_id="test/plugin", + provider_name="test", + plugin_model_provider=plugin_model_provider, + ) + + +def test_invoke_non_stream_preserves_assistant_opaque_body() -> None: + llm = _build_llm_instance() + prompt_messages: list[PromptMessage] = [UserPromptMessage(content="hi")] + + chunk = LLMResultChunk( + model="gpt-test", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content="hello", opaque_body={"provider_message_id": "msg_123"}), + ), + ) + + def _mock_invoke_llm(self, **kwargs: Any): # noqa: ANN001, ANN401 + yield chunk + + with patch("core.plugin.impl.model.PluginModelClient.invoke_llm", new=_mock_invoke_llm): + result = llm.invoke( + model="gpt-test", + credentials={}, + prompt_messages=prompt_messages, + model_parameters={}, + stream=False, + ) + + assert isinstance(result, LLMResult) + assert result.message.opaque_body == {"provider_message_id": "msg_123"} + assert list(result.prompt_messages) == prompt_messages + + +def test_invoke_stream_preserves_assistant_opaque_body_in_after_callback() -> None: + llm = _build_llm_instance() + prompt_messages: list[PromptMessage] = [UserPromptMessage(content="hi")] + callback = _CaptureAfterInvokeCallback() + + tool_call_1 = AssistantPromptMessage.ToolCall( + id="1", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": '), + ) + tool_call_2 = AssistantPromptMessage.ToolCall( + id="", + type="", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments='"value"}'), + ) + + chunk1 = LLMResultChunk( + model="gpt-test", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content="h", tool_calls=[tool_call_1], opaque_body={"provider_message_id": "msg_123"}), + ), + ) + chunk2 = LLMResultChunk( + model="gpt-test", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content="i", tool_calls=[tool_call_2]), + ), + ) + + def _mock_invoke_llm(self, **kwargs: Any): # noqa: ANN001, ANN401 + yield chunk1 + yield chunk2 + + with patch("core.plugin.impl.model.PluginModelClient.invoke_llm", new=_mock_invoke_llm): + gen = llm.invoke( + model="gpt-test", + credentials={}, + prompt_messages=prompt_messages, + model_parameters={}, + stream=True, + callbacks=[callback], + ) + chunks = list(gen) + + assert chunks[0].prompt_messages == prompt_messages + assert callback.after_result is not None + assert callback.after_result.message.opaque_body == {"provider_message_id": "msg_123"} + assert len(callback.after_result.message.tool_calls) == 1 + assert callback.after_result.message.tool_calls[0].function.arguments == '{"arg1": "value"}' +