diff --git a/api/core/app/app_context.py b/api/core/app/app_context.py new file mode 100644 index 0000000000..98c408a84c --- /dev/null +++ b/api/core/app/app_context.py @@ -0,0 +1,40 @@ +"""Request-scoped app context for model plugin dispatch. + +Stores the current app_id in a ContextVar so that PluginModelRuntime +can include it in the HTTP payload to the plugin daemon, without +requiring changes to the graphon call chain. + +Usage in app runners:: + + from core.app.app_context import set_current_app_id + + token = set_current_app_id(app_id) + try: + invoke_result = model_instance.invoke_llm(...) + finally: + reset_current_app_id(token) + +The value is read by PluginModelRuntime.invoke_llm() to tag the +request with the originating Dify app. +""" + +from __future__ import annotations + +from contextvars import ContextVar, Token + +_current_app_id: ContextVar[str | None] = ContextVar("_current_app_id", default=None) + + +def get_current_app_id() -> str | None: + """Return the app_id for the current execution context, or None.""" + return _current_app_id.get() + + +def set_current_app_id(app_id: str | None) -> Token[str | None]: + """Set the app_id for the current execution context. Returns a token for reset.""" + return _current_app_id.set(app_id) + + +def reset_current_app_id(token: Token[str | None]) -> None: + """Reset the app_id to its previous value.""" + _current_app_id.reset(token) diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 4e57b4dedc..044cbd465c 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -92,9 +92,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): @trace_span(WorkflowAppRunnerHandler) def run(self): + from core.app.app_context import reset_current_app_id, set_current_app_id + app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) + _app_id_token = set_current_app_id(app_config.app_id) + system_inputs = build_system_variables( query=self.application_generate_entity.query, files=self.application_generate_entity.files, @@ -251,6 +255,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): for event in generator: self._handle_event(workflow_entry, event) + reset_current_app_id(_app_id_token) + def handle_input_moderation( self, app_record: App, diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index cae0eee0df..5a1c6e5c34 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -224,6 +224,10 @@ class AgentChatAppRunner(AppRunner): model_instance=model_instance, ) + from core.app.app_context import reset_current_app_id, set_current_app_id + + _app_id_token = set_current_app_id(app_config.app_id) + invoke_result = runner.run( message=message, query=query, @@ -240,3 +244,5 @@ class AgentChatAppRunner(AppRunner): user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id, ) + + reset_current_app_id(_app_id_token) diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 077c5239f3..1f48ee2d7b 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -218,12 +218,18 @@ class ChatAppRunner(AppRunner): db.session.close() - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=application_generate_entity.model_conf.parameters, - stop=stop, - stream=application_generate_entity.stream, - ) + from core.app.app_context import reset_current_app_id, set_current_app_id + + _app_id_token = set_current_app_id(app_config.app_id) + try: + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=application_generate_entity.model_conf.parameters, + stop=stop, + stream=application_generate_entity.stream, + ) + finally: + reset_current_app_id(_app_id_token) # handle invoke result self._handle_invoke_result( diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 6bb1ecdcb1..e49de95738 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -176,12 +176,18 @@ class CompletionAppRunner(AppRunner): db.session.close() - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=application_generate_entity.model_conf.parameters, - stop=stop, - stream=application_generate_entity.stream, - ) + from core.app.app_context import reset_current_app_id, set_current_app_id + + _app_id_token = set_current_app_id(app_config.app_id) + try: + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=application_generate_entity.model_conf.parameters, + stop=stop, + stream=application_generate_entity.stream, + ) + finally: + reset_current_app_id(_app_id_token) # handle invoke result self._handle_invoke_result( diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index cfb9208486..2ce3c5af14 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -64,8 +64,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): """ Run application """ + from core.app.app_context import reset_current_app_id, set_current_app_id + app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) + + _app_id_token = set_current_app_id(app_config.app_id) invoke_from = self.application_generate_entity.invoke_from # if only single iteration or single loop run is requested if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run: @@ -173,3 +177,5 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): for event in generator: self._handle_event(workflow_entry, event) + + reset_current_app_id(_app_id_token) diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 47608bdfa6..f923cde5da 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -23,10 +23,14 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): @staticmethod - def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]: + def _dispatch_payload( + *, user_id: str | None, data: dict[str, Any], app_id: str | None = None + ) -> dict[str, Any]: payload: dict[str, Any] = {"data": data} if user_id is not None: payload["user_id"] = user_id + if app_id is not None: + payload["app_id"] = app_id return payload def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: @@ -162,6 +166,7 @@ class PluginModelClient(BasePluginClient): tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, + app_id: str | None = None, ) -> Generator[LLMResultChunk, None, None]: """ Invoke llm @@ -173,6 +178,7 @@ class PluginModelClient(BasePluginClient): data=jsonable_encoder( self._dispatch_payload( user_id=user_id, + app_id=app_id, data={ "provider": provider, "model_type": "llm", diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index 4e66d58b5e..06617f201d 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -207,7 +207,10 @@ class PluginModelRuntime(ModelRuntime): stop: Sequence[str] | None, stream: bool, ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: + from core.app.app_context import get_current_app_id + plugin_id, provider_name = self._split_provider(provider) + app_id = get_current_app_id() return self.client.invoke_llm( tenant_id=self.tenant_id, user_id=self.user_id, @@ -220,6 +223,7 @@ class PluginModelRuntime(ModelRuntime): tools=tools, stop=list(stop) if stop else None, stream=stream, + app_id=app_id, ) def get_llm_num_tokens( diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_client.py b/api/tests/unit_tests/core/plugin/impl/test_model_client.py index bcbebbb38b..262c3cf456 100644 --- a/api/tests/unit_tests/core/plugin/impl/test_model_client.py +++ b/api/tests/unit_tests/core/plugin/impl/test_model_client.py @@ -160,6 +160,58 @@ class TestPluginModelClient: assert call_kwargs["data"]["data"]["stream"] is False assert call_kwargs["data"]["data"]["model_parameters"] == {"temperature": 0.1} + def test_invoke_llm_with_app_id(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk-1"]) + ) + + list( + client.invoke_llm( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="gpt-test", + credentials={"api_key": "key"}, + prompt_messages=[], + app_id="app-123", + ) + ) + + call_kwargs = stream_mock.call_args.kwargs + assert call_kwargs["data"]["app_id"] == "app-123" + + def test_invoke_llm_without_app_id_omits_field(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk-1"]) + ) + + list( + client.invoke_llm( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="gpt-test", + credentials={"api_key": "key"}, + prompt_messages=[], + ) + ) + + call_kwargs = stream_mock.call_args.kwargs + assert "app_id" not in call_kwargs["data"] + + def test_dispatch_payload_includes_app_id_when_provided(self): + payload = PluginModelClient._dispatch_payload(user_id="u1", data={"k": "v"}, app_id="app-456") + assert payload["app_id"] == "app-456" + assert payload["user_id"] == "u1" + + def test_dispatch_payload_omits_app_id_when_none(self): + payload = PluginModelClient._dispatch_payload(user_id="u1", data={"k": "v"}) + assert "app_id" not in payload + def test_invoke_llm_wraps_plugin_daemon_inner_error(self, mocker): client = PluginModelClient()