mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 12:59:18 +08:00
Merge 116e5872d6 into 271019006e
This commit is contained in:
commit
3f1cdad149
40
api/core/app/app_context.py
Normal file
40
api/core/app/app_context.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""Request-scoped app context for model plugin dispatch.
|
||||
|
||||
Stores the current app_id in a ContextVar so that PluginModelRuntime
|
||||
can include it in the HTTP payload to the plugin daemon, without
|
||||
requiring changes to the graphon call chain.
|
||||
|
||||
Usage in app runners::
|
||||
|
||||
from core.app.app_context import set_current_app_id
|
||||
|
||||
token = set_current_app_id(app_id)
|
||||
try:
|
||||
invoke_result = model_instance.invoke_llm(...)
|
||||
finally:
|
||||
reset_current_app_id(token)
|
||||
|
||||
The value is read by PluginModelRuntime.invoke_llm() to tag the
|
||||
request with the originating Dify app.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar, Token
|
||||
|
||||
_current_app_id: ContextVar[str | None] = ContextVar("_current_app_id", default=None)
|
||||
|
||||
|
||||
def get_current_app_id() -> str | None:
|
||||
"""Return the app_id for the current execution context, or None."""
|
||||
return _current_app_id.get()
|
||||
|
||||
|
||||
def set_current_app_id(app_id: str | None) -> Token[str | None]:
|
||||
"""Set the app_id for the current execution context. Returns a token for reset."""
|
||||
return _current_app_id.set(app_id)
|
||||
|
||||
|
||||
def reset_current_app_id(token: Token[str | None]) -> None:
|
||||
"""Reset the app_id to its previous value."""
|
||||
_current_app_id.reset(token)
|
||||
@ -92,9 +92,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
from core.app.app_context import reset_current_app_id, set_current_app_id
|
||||
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
_app_id_token = set_current_app_id(app_config.app_id)
|
||||
|
||||
system_inputs = build_system_variables(
|
||||
query=self.application_generate_entity.query,
|
||||
files=self.application_generate_entity.files,
|
||||
@ -251,6 +255,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
reset_current_app_id(_app_id_token)
|
||||
|
||||
def handle_input_moderation(
|
||||
self,
|
||||
app_record: App,
|
||||
|
||||
@ -224,6 +224,10 @@ class AgentChatAppRunner(AppRunner):
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
from core.app.app_context import reset_current_app_id, set_current_app_id
|
||||
|
||||
_app_id_token = set_current_app_id(app_config.app_id)
|
||||
|
||||
invoke_result = runner.run(
|
||||
message=message,
|
||||
query=query,
|
||||
@ -240,3 +244,5 @@ class AgentChatAppRunner(AppRunner):
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
)
|
||||
|
||||
reset_current_app_id(_app_id_token)
|
||||
|
||||
@ -218,12 +218,18 @@ class ChatAppRunner(AppRunner):
|
||||
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=application_generate_entity.model_conf.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
from core.app.app_context import reset_current_app_id, set_current_app_id
|
||||
|
||||
_app_id_token = set_current_app_id(app_config.app_id)
|
||||
try:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=application_generate_entity.model_conf.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
finally:
|
||||
reset_current_app_id(_app_id_token)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
|
||||
@ -176,12 +176,18 @@ class CompletionAppRunner(AppRunner):
|
||||
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=application_generate_entity.model_conf.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
from core.app.app_context import reset_current_app_id, set_current_app_id
|
||||
|
||||
_app_id_token = set_current_app_id(app_config.app_id)
|
||||
try:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=application_generate_entity.model_conf.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
finally:
|
||||
reset_current_app_id(_app_id_token)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
|
||||
@ -64,8 +64,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
Run application
|
||||
"""
|
||||
from core.app.app_context import reset_current_app_id, set_current_app_id
|
||||
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
_app_id_token = set_current_app_id(app_config.app_id)
|
||||
invoke_from = self.application_generate_entity.invoke_from
|
||||
# if only single iteration or single loop run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
@ -173,3 +177,5 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
reset_current_app_id(_app_id_token)
|
||||
|
||||
@ -23,10 +23,12 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
class PluginModelClient(BasePluginClient):
|
||||
@staticmethod
|
||||
def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]:
|
||||
def _dispatch_payload(*, user_id: str | None, data: dict[str, Any], app_id: str | None = None) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {"data": data}
|
||||
if user_id is not None:
|
||||
payload["user_id"] = user_id
|
||||
if app_id is not None:
|
||||
payload["app_id"] = app_id
|
||||
return payload
|
||||
|
||||
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
|
||||
@ -162,6 +164,7 @@ class PluginModelClient(BasePluginClient):
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
app_id: str | None = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Invoke llm
|
||||
@ -173,6 +176,7 @@ class PluginModelClient(BasePluginClient):
|
||||
data=jsonable_encoder(
|
||||
self._dispatch_payload(
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
data={
|
||||
"provider": provider,
|
||||
"model_type": "llm",
|
||||
|
||||
@ -207,7 +207,10 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
from core.app.app_context import get_current_app_id
|
||||
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
app_id = get_current_app_id()
|
||||
return self.client.invoke_llm(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
@ -220,6 +223,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
tools=tools,
|
||||
stop=list(stop) if stop else None,
|
||||
stream=stream,
|
||||
app_id=app_id,
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(
|
||||
|
||||
@ -160,6 +160,58 @@ class TestPluginModelClient:
|
||||
assert call_kwargs["data"]["data"]["stream"] is False
|
||||
assert call_kwargs["data"]["data"]["model_parameters"] == {"temperature": 0.1}
|
||||
|
||||
def test_invoke_llm_with_app_id(self, mocker):
|
||||
client = PluginModelClient()
|
||||
stream_mock = mocker.patch.object(
|
||||
client, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk-1"])
|
||||
)
|
||||
|
||||
list(
|
||||
client.invoke_llm(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model="gpt-test",
|
||||
credentials={"api_key": "key"},
|
||||
prompt_messages=[],
|
||||
app_id="app-123",
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = stream_mock.call_args.kwargs
|
||||
assert call_kwargs["data"]["app_id"] == "app-123"
|
||||
|
||||
def test_invoke_llm_without_app_id_omits_field(self, mocker):
|
||||
client = PluginModelClient()
|
||||
stream_mock = mocker.patch.object(
|
||||
client, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk-1"])
|
||||
)
|
||||
|
||||
list(
|
||||
client.invoke_llm(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
plugin_id="org/plugin:1",
|
||||
provider="provider-a",
|
||||
model="gpt-test",
|
||||
credentials={"api_key": "key"},
|
||||
prompt_messages=[],
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = stream_mock.call_args.kwargs
|
||||
assert "app_id" not in call_kwargs["data"]
|
||||
|
||||
def test_dispatch_payload_includes_app_id_when_provided(self):
|
||||
payload = PluginModelClient._dispatch_payload(user_id="u1", data={"k": "v"}, app_id="app-456")
|
||||
assert payload["app_id"] == "app-456"
|
||||
assert payload["user_id"] == "u1"
|
||||
|
||||
def test_dispatch_payload_omits_app_id_when_none(self):
|
||||
payload = PluginModelClient._dispatch_payload(user_id="u1", data={"k": "v"})
|
||||
assert "app_id" not in payload
|
||||
|
||||
def test_invoke_llm_wraps_plugin_daemon_inner_error(self, mocker):
|
||||
client = PluginModelClient()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user