feat(api): pass app_id to model plugins for provider-side cost attribution

This commit is contained in:
Ryuta KOBAYASHI 2026-05-06 22:10:03 +09:00 committed by ryuta-kobayashi-15
parent 3708e3eef1
commit 46b1f5c0a0
9 changed files with 145 additions and 13 deletions

View File

@ -0,0 +1,40 @@
"""Request-scoped app context for model plugin dispatch.
Stores the current app_id in a ContextVar so that PluginModelRuntime
can include it in the HTTP payload to the plugin daemon, without
requiring changes to the graphon call chain.
Usage in app runners::
from core.app.app_context import set_current_app_id
token = set_current_app_id(app_id)
try:
invoke_result = model_instance.invoke_llm(...)
finally:
reset_current_app_id(token)
The value is read by PluginModelRuntime.invoke_llm() to tag the
request with the originating Dify app.
"""
from __future__ import annotations
from contextvars import ContextVar, Token
_current_app_id: ContextVar[str | None] = ContextVar("_current_app_id", default=None)
def get_current_app_id() -> str | None:
"""Return the app_id for the current execution context, or None."""
return _current_app_id.get()
def set_current_app_id(app_id: str | None) -> Token[str | None]:
"""Set the app_id for the current execution context. Returns a token for reset."""
return _current_app_id.set(app_id)
def reset_current_app_id(token: Token[str | None]) -> None:
"""Reset the app_id to its previous value."""
_current_app_id.reset(token)

View File

@ -92,9 +92,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
@trace_span(WorkflowAppRunnerHandler)
def run(self):
from core.app.app_context import reset_current_app_id, set_current_app_id
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
_app_id_token = set_current_app_id(app_config.app_id)
system_inputs = build_system_variables(
query=self.application_generate_entity.query,
files=self.application_generate_entity.files,
@ -251,6 +255,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
for event in generator:
self._handle_event(workflow_entry, event)
reset_current_app_id(_app_id_token)
def handle_input_moderation(
self,
app_record: App,

View File

@ -224,6 +224,10 @@ class AgentChatAppRunner(AppRunner):
model_instance=model_instance,
)
from core.app.app_context import reset_current_app_id, set_current_app_id
_app_id_token = set_current_app_id(app_config.app_id)
invoke_result = runner.run(
message=message,
query=query,
@ -240,3 +244,5 @@ class AgentChatAppRunner(AppRunner):
user_id=application_generate_entity.user_id,
tenant_id=app_config.tenant_id,
)
reset_current_app_id(_app_id_token)

View File

@ -218,12 +218,18 @@ class ChatAppRunner(AppRunner):
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=application_generate_entity.model_conf.parameters,
stop=stop,
stream=application_generate_entity.stream,
)
from core.app.app_context import reset_current_app_id, set_current_app_id
_app_id_token = set_current_app_id(app_config.app_id)
try:
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=application_generate_entity.model_conf.parameters,
stop=stop,
stream=application_generate_entity.stream,
)
finally:
reset_current_app_id(_app_id_token)
# handle invoke result
self._handle_invoke_result(

View File

@ -176,12 +176,18 @@ class CompletionAppRunner(AppRunner):
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=application_generate_entity.model_conf.parameters,
stop=stop,
stream=application_generate_entity.stream,
)
from core.app.app_context import reset_current_app_id, set_current_app_id
_app_id_token = set_current_app_id(app_config.app_id)
try:
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=application_generate_entity.model_conf.parameters,
stop=stop,
stream=application_generate_entity.stream,
)
finally:
reset_current_app_id(_app_id_token)
# handle invoke result
self._handle_invoke_result(

View File

@ -64,8 +64,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
"""
Run application
"""
from core.app.app_context import reset_current_app_id, set_current_app_id
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
_app_id_token = set_current_app_id(app_config.app_id)
invoke_from = self.application_generate_entity.invoke_from
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
@ -173,3 +177,5 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
for event in generator:
self._handle_event(workflow_entry, event)
reset_current_app_id(_app_id_token)

View File

@ -23,10 +23,14 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
class PluginModelClient(BasePluginClient):
@staticmethod
def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]:
def _dispatch_payload(
*, user_id: str | None, data: dict[str, Any], app_id: str | None = None
) -> dict[str, Any]:
payload: dict[str, Any] = {"data": data}
if user_id is not None:
payload["user_id"] = user_id
if app_id is not None:
payload["app_id"] = app_id
return payload
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
@ -162,6 +166,7 @@ class PluginModelClient(BasePluginClient):
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
app_id: str | None = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Invoke llm
@ -173,6 +178,7 @@ class PluginModelClient(BasePluginClient):
data=jsonable_encoder(
self._dispatch_payload(
user_id=user_id,
app_id=app_id,
data={
"provider": provider,
"model_type": "llm",

View File

@ -207,7 +207,10 @@ class PluginModelRuntime(ModelRuntime):
stop: Sequence[str] | None,
stream: bool,
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
from core.app.app_context import get_current_app_id
plugin_id, provider_name = self._split_provider(provider)
app_id = get_current_app_id()
return self.client.invoke_llm(
tenant_id=self.tenant_id,
user_id=self.user_id,
@ -220,6 +223,7 @@ class PluginModelRuntime(ModelRuntime):
tools=tools,
stop=list(stop) if stop else None,
stream=stream,
app_id=app_id,
)
def get_llm_num_tokens(

View File

@ -160,6 +160,58 @@ class TestPluginModelClient:
assert call_kwargs["data"]["data"]["stream"] is False
assert call_kwargs["data"]["data"]["model_parameters"] == {"temperature": 0.1}
def test_invoke_llm_with_app_id(self, mocker):
client = PluginModelClient()
stream_mock = mocker.patch.object(
client, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk-1"])
)
list(
client.invoke_llm(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="gpt-test",
credentials={"api_key": "key"},
prompt_messages=[],
app_id="app-123",
)
)
call_kwargs = stream_mock.call_args.kwargs
assert call_kwargs["data"]["app_id"] == "app-123"
def test_invoke_llm_without_app_id_omits_field(self, mocker):
client = PluginModelClient()
stream_mock = mocker.patch.object(
client, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk-1"])
)
list(
client.invoke_llm(
tenant_id="tenant-1",
user_id="user-1",
plugin_id="org/plugin:1",
provider="provider-a",
model="gpt-test",
credentials={"api_key": "key"},
prompt_messages=[],
)
)
call_kwargs = stream_mock.call_args.kwargs
assert "app_id" not in call_kwargs["data"]
def test_dispatch_payload_includes_app_id_when_provided(self):
payload = PluginModelClient._dispatch_payload(user_id="u1", data={"k": "v"}, app_id="app-456")
assert payload["app_id"] == "app-456"
assert payload["user_id"] == "u1"
def test_dispatch_payload_omits_app_id_when_none(self):
payload = PluginModelClient._dispatch_payload(user_id="u1", data={"k": "v"})
assert "app_id" not in payload
def test_invoke_llm_wraps_plugin_daemon_inner_error(self, mocker):
client = PluginModelClient()