From 6ca07c4a4e272f868a877e08738c96bae70a7ac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Fri, 24 Apr 2026 22:48:07 +0800 Subject: [PATCH] feat(agent): add dify llm adapter --- dify-agent/examples/run_pydantic_ai_agent.py | 78 ++ dify-agent/src/dify_agent/__init__.py | 5 + .../src/dify_agent/adapters/__init__.py | 1 + .../src/dify_agent/adapters/llm/__init__.py | 6 + .../src/dify_agent/adapters/llm/model.py | 798 ++++++++++++++++++ .../src/dify_agent/adapters/llm/provider.py | 252 ++++++ .../unit/dify_agent/adapters/__init__.py | 0 .../unit/dify_agent/adapters/llm/__init__.py | 0 .../dify_agent/adapters/llm/_test_support.py | 91 ++ .../dify_agent/adapters/llm/test_model.py | 374 ++++++++ 10 files changed, 1605 insertions(+) create mode 100644 dify-agent/examples/run_pydantic_ai_agent.py create mode 100644 dify-agent/src/dify_agent/__init__.py create mode 100644 dify-agent/src/dify_agent/adapters/__init__.py create mode 100644 dify-agent/src/dify_agent/adapters/llm/__init__.py create mode 100644 dify-agent/src/dify_agent/adapters/llm/model.py create mode 100644 dify-agent/src/dify_agent/adapters/llm/provider.py create mode 100644 dify-agent/tests/unit/dify_agent/adapters/__init__.py create mode 100644 dify-agent/tests/unit/dify_agent/adapters/llm/__init__.py create mode 100644 dify-agent/tests/unit/dify_agent/adapters/llm/_test_support.py create mode 100644 dify-agent/tests/unit/dify_agent/adapters/llm/test_model.py diff --git a/dify-agent/examples/run_pydantic_ai_agent.py b/dify-agent/examples/run_pydantic_ai_agent.py new file mode 100644 index 0000000000..0fdec33fd8 --- /dev/null +++ b/dify-agent/examples/run_pydantic_ai_agent.py @@ -0,0 +1,78 @@ +"""Run a Pydantic AI agent through the Dify plugin-daemon adapter. + +Prerequisites: +- Start the plugin daemon from `dify-aio/dify/docker/docker-compose.middleware.yaml`. +- Run the Dify API with `dify-aio/dify/api/.env` so the daemon can resolve tenants/plugins. +- Fill `dify-agent/.env` with a real tenant, plugin, provider, model, and provider credentials. + +Example: + uv run --project dify-agent python examples/run_pydantic_ai_agent.py +""" + +from __future__ import annotations + +import asyncio +import json +import os +from pathlib import Path +from typing import Any + +from pydantic_ai import Agent + +from dify_agent import DifyLLMAdapterModel, DifyPluginDaemonProvider + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] + + +def load_env_file(path: Path) -> None: + """Load simple KEY=VALUE lines without adding a dotenv dependency.""" + if not path.exists(): + return + + for raw_line in path.read_text().splitlines(): + line = raw_line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, value = line.split("=", 1) + os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) + + +def required_env(name: str) -> str: + value = os.environ.get(name) + if value: + return value + raise RuntimeError(f"Missing required environment variable: {name}") + + +def load_credentials() -> dict[str, Any]: + raw_credentials = required_env("DIFY_AGENT_MODEL_CREDENTIALS_JSON") + credentials = json.loads(raw_credentials) + if not isinstance(credentials, dict): + raise RuntimeError("DIFY_AGENT_MODEL_CREDENTIALS_JSON must be a JSON object") + return credentials + + +async def main() -> None: + load_env_file(PROJECT_ROOT / ".env") + + model = DifyLLMAdapterModel( + required_env("DIFY_AGENT_MODEL_NAME"), + DifyPluginDaemonProvider( + tenant_id=required_env("DIFY_AGENT_TENANT_ID"), + plugin_id=required_env("DIFY_AGENT_PLUGIN_ID"), + plugin_provider=required_env("DIFY_AGENT_PROVIDER"), + plugin_daemon_url=required_env("PLUGIN_DAEMON_URL"), + plugin_daemon_api_key=required_env("PLUGIN_DAEMON_KEY"), + ), + credentials=load_credentials(), + ) + agent = Agent(model=model) + async with agent.run_stream("Explain the theory of relativity") as run: + async for piece in run.stream_output(): + print(piece, end="", flush=True) + print(run.usage()) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/dify-agent/src/dify_agent/__init__.py b/dify-agent/src/dify_agent/__init__.py new file mode 100644 index 0000000000..bb11ce70c8 --- /dev/null +++ b/dify-agent/src/dify_agent/__init__.py @@ -0,0 +1,5 @@ +"""Adapters for using Dify components inside the local agent package.""" + +from .adapters.llm import DifyLLMAdapterModel, DifyPluginDaemonProvider + +__all__ = ["DifyLLMAdapterModel", "DifyPluginDaemonProvider"] diff --git a/dify-agent/src/dify_agent/adapters/__init__.py b/dify-agent/src/dify_agent/adapters/__init__.py new file mode 100644 index 0000000000..ac7a1ae47a --- /dev/null +++ b/dify-agent/src/dify_agent/adapters/__init__.py @@ -0,0 +1 @@ +"""Adapter integrations for Dify agent components.""" diff --git a/dify-agent/src/dify_agent/adapters/llm/__init__.py b/dify-agent/src/dify_agent/adapters/llm/__init__.py new file mode 100644 index 0000000000..b63771ed61 --- /dev/null +++ b/dify-agent/src/dify_agent/adapters/llm/__init__.py @@ -0,0 +1,6 @@ +"""LLM adapters for Dify plugin-daemon integrations.""" + +from .model import DifyLLMAdapterModel +from .provider import DifyPluginDaemonProvider + +__all__ = ["DifyLLMAdapterModel", "DifyPluginDaemonProvider"] diff --git a/dify-agent/src/dify_agent/adapters/llm/model.py b/dify-agent/src/dify_agent/adapters/llm/model.py new file mode 100644 index 0000000000..5723253177 --- /dev/null +++ b/dify-agent/src/dify_agent/adapters/llm/model.py @@ -0,0 +1,798 @@ +"""Bridge Dify plugin-daemon LLM invocations into Pydantic AI's model interface. + +The API and agent layers are clients of the plugin daemon, not direct hosts of provider SDK +implementations. This adapter therefore targets the plugin-daemon dispatch protocol and maps +Pydantic AI messages into the daemon's Graphon-compatible request and stream response schema. +""" + +from __future__ import annotations + +import base64 +import re +from collections.abc import AsyncGenerator, AsyncIterator, Mapping, Sequence +from contextlib import asynccontextmanager +from dataclasses import KW_ONLY, InitVar, dataclass, field +from datetime import datetime, timezone +from typing import cast + +from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + AudioPromptMessageContent, + DocumentPromptMessageContent, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, + VideoPromptMessageContent, +) +from typing_extensions import assert_never, override + +from pydantic_ai._parts_manager import ModelResponsePartsManager +from pydantic_ai.exceptions import UnexpectedModelBehavior +from pydantic_ai.messages import ( + AudioUrl, + BinaryContent, + BuiltinToolCallPart, + BuiltinToolReturnPart, + CachePoint, + CompactionPart, + DocumentUrl, + FilePart, + FinishReason, + ImageUrl, + ModelMessage, + ModelRequest, + ModelResponse, + ModelResponsePart, + ModelResponseStreamEvent, + MultiModalContent, + RetryPromptPart, + SystemPromptPart, + TextContent, + TextPart, + ThinkingPart, + ToolCallPart, + ToolReturnPart, + UploadedFile, + UserContent, + UserPromptPart, + VideoUrl, +) +from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse +from pydantic_ai.profiles import ModelProfileSpec +from pydantic_ai.settings import ModelSettings +from pydantic_ai.usage import RequestUsage + +from .provider import DifyPluginDaemonLLMClient, DifyPluginDaemonProvider + +_THINK_START = "\n" +_THINK_END = "\n" +_THINK_OPEN_TAG = "" +_THINK_CLOSE_TAG = "" +_THINK_TAG_PATTERN = re.compile(r"(.*?)", re.DOTALL) +_DETAIL_HIGH = "high" + + +@dataclass(slots=True) +class _DifyRequestInput: + credentials: dict[str, object] + prompt_messages: list[PromptMessage] + model_parameters: dict[str, object] + tools: list[PromptMessageTool] | None + stop_sequences: list[str] | None + +@dataclass(slots=True) +class DifyLLMAdapterModel(Model[DifyPluginDaemonLLMClient]): + """Use a Dify plugin-daemon LLM provider as a Pydantic AI model.""" + + model: str + daemon_provider: DifyPluginDaemonProvider + _: KW_ONLY + credentials: dict[str, object] = field(default_factory=dict, repr=False) + model_profile: InitVar[ModelProfileSpec | None] = None + model_settings: InitVar[ModelSettings | None] = None + + def __post_init__( + self, + model_profile: ModelProfileSpec | None, + model_settings: ModelSettings | None, + ) -> None: + Model.__init__( + self, + settings=model_settings, + profile=model_profile or self.daemon_provider.model_profile(self.model), + ) + + @property + @override + def provider(self) -> DifyPluginDaemonProvider: + return self.daemon_provider + + @property + @override + def model_name(self) -> str: + return self.model + + @property + @override + def system(self) -> str: + return self.daemon_provider.name + + @override + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + prepared_settings, prepared_params = self.prepare_request( + model_settings, model_request_parameters + ) + request_input = self._build_request_input( + messages, prepared_settings, prepared_params + ) + + response = DifyStreamedResponse( + model_request_parameters=prepared_params, + chunks=self.daemon_provider.client.iter_llm_result_chunks( + model=self.model_name, + credentials=request_input.credentials, + prompt_messages=request_input.prompt_messages, + model_parameters=request_input.model_parameters, + tools=request_input.tools, + stop=request_input.stop_sequences, + stream=False, + ), + response_model_name=self.model_name, + provider_name_value=self.system, + ) + async for _event in response: + pass + return response.get() + + @asynccontextmanager + @override + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + run_context: object | None = None, + ) -> AsyncGenerator[StreamedResponse, None]: + del run_context + prepared_settings, prepared_params = self.prepare_request( + model_settings, model_request_parameters + ) + request_input = self._build_request_input( + messages, prepared_settings, prepared_params + ) + + yield DifyStreamedResponse( + model_request_parameters=prepared_params, + chunks=self.daemon_provider.client.iter_llm_result_chunks( + model=self.model_name, + credentials=request_input.credentials, + prompt_messages=request_input.prompt_messages, + model_parameters=request_input.model_parameters, + tools=request_input.tools, + stop=request_input.stop_sequences, + stream=True, + ), + response_model_name=self.model_name, + provider_name_value=self.system, + ) + + def _build_request_input( + self, + messages: Sequence[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> _DifyRequestInput: + return _DifyRequestInput( + credentials=dict(self.credentials), + prompt_messages=_map_messages_to_prompt_messages( + messages, model_request_parameters + ), + model_parameters=_map_model_settings_to_parameters(model_settings), + tools=_map_tool_definitions_to_prompt_tools(model_request_parameters), + stop_sequences=_get_stop_sequences(model_settings), + ) + + +@dataclass +class DifyStreamedResponse(StreamedResponse): + chunks: AsyncIterator[LLMResultChunk] + response_model_name: str + provider_name_value: str + _timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + _embedded_thinking_parser: "_EmbeddedThinkingParser" = field( + default_factory=lambda: _EmbeddedThinkingParser() + ) + + @override + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async for chunk in self.chunks: + if chunk.delta.usage is not None: + self._usage: RequestUsage = _map_usage(chunk.delta.usage) + if chunk.delta.finish_reason is not None: + self.finish_reason: FinishReason | None = _normalize_finish_reason( + chunk.delta.finish_reason + ) + + for event in _chunk_to_stream_events( + self._parts_manager, + chunk, + self.provider_name_value, + self._embedded_thinking_parser, + ): + yield event + + for event in self._embedded_thinking_parser.flush( + self._parts_manager, self.provider_name_value + ): + yield event + + @property + @override + def model_name(self) -> str: + return self.response_model_name + + @property + @override + def provider_name(self) -> str: + return self.provider_name_value + + @property + @override + def provider_url(self) -> None: + return None + + @property + @override + def timestamp(self) -> datetime: + return self._timestamp + + +def _map_messages_to_prompt_messages( + messages: Sequence[ModelMessage], + model_request_parameters: ModelRequestParameters, +) -> list[PromptMessage]: + prompt_messages: list[PromptMessage] = [] + + for message in messages: + if isinstance(message, ModelRequest): + prompt_messages.extend(_map_model_request_to_prompt_messages(message)) + elif isinstance(message, ModelResponse): + assistant_message = _map_model_response_to_prompt_message(message) + if assistant_message is not None: + prompt_messages.append(assistant_message) + else: + assert_never(message) + + instruction_messages = [ + SystemPromptMessage(content=part.content) + for part in ( + Model._get_instruction_parts(messages, model_request_parameters) or [] + ) + if part.content.strip() + ] + if instruction_messages: + insert_at = next( + ( + index + for index, message in enumerate(prompt_messages) + if not isinstance(message, SystemPromptMessage) + ), + len(prompt_messages), + ) + prompt_messages[insert_at:insert_at] = instruction_messages + + return prompt_messages + + +def _map_model_request_to_prompt_messages(message: ModelRequest) -> list[PromptMessage]: + prompt_messages: list[PromptMessage] = [] + + for part in message.parts: + if isinstance(part, SystemPromptPart): + prompt_messages.append(SystemPromptMessage(content=part.content)) + elif isinstance(part, UserPromptPart): + prompt_messages.append( + UserPromptMessage(content=_map_user_prompt_content(part.content)) + ) + elif isinstance(part, ToolReturnPart): + prompt_messages.append(_map_tool_return_part_to_prompt_message(part)) + elif isinstance(part, RetryPromptPart): + if part.tool_name is None: + prompt_messages.append(UserPromptMessage(content=part.model_response())) + else: + prompt_messages.append( + ToolPromptMessage( + content=part.model_response(), + tool_call_id=part.tool_call_id, + name=part.tool_name, + ) + ) + else: + assert_never(part) + + return prompt_messages + + +def _map_tool_return_part_to_prompt_message(part: ToolReturnPart) -> ToolPromptMessage: + items = part.content_items(mode="str") + if len(items) == 1 and isinstance(items[0], str): + content: str | list[PromptMessageContentUnionTypes] | None = items[0] + else: + content_items: list[PromptMessageContentUnionTypes] = [] + for item in items: + if isinstance(item, str): + content_items.append(TextPromptMessageContent(data=item)) + elif isinstance(item, CachePoint): + continue + elif _is_multi_modal_content(item): + content_items.append(_map_multi_modal_user_content(item)) + else: + raise UnexpectedModelBehavior( + f"Unsupported daemon tool message content: {type(item).__name__}" + ) + content = content_items or None + + return ToolPromptMessage( + content=content, tool_call_id=part.tool_call_id, name=part.tool_name + ) + + +def _map_model_response_to_prompt_message( + message: ModelResponse, +) -> AssistantPromptMessage | None: + content_parts: list[PromptMessageContentUnionTypes] = [] + tool_calls: list[AssistantPromptMessage.ToolCall] = [] + + for part in message.parts: + if isinstance(part, TextPart): + if part.content: + content_parts.append(TextPromptMessageContent(data=part.content)) + elif isinstance(part, ThinkingPart): + if part.content: + content_parts.append( + TextPromptMessageContent( + data=f"{_THINK_START}{part.content}{_THINK_END}" + ) + ) + elif isinstance(part, FilePart): + content_parts.append(_map_binary_content_to_prompt_content(part.content)) + elif isinstance(part, ToolCallPart): + tool_calls.append( + AssistantPromptMessage.ToolCall( + id=part.tool_call_id or f"tool-call-{part.tool_name}", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=part.tool_name, + arguments=part.args_as_json_str(), + ), + ) + ) + elif isinstance( + part, BuiltinToolCallPart | BuiltinToolReturnPart | CompactionPart + ): + raise UnexpectedModelBehavior( + f"Unsupported response part for daemon adapter: {type(part).__name__}" + ) + else: + assert_never(part) + + content = _normalize_prompt_content(content_parts) + if content is None and not tool_calls: + return None + + return AssistantPromptMessage(content=content, tool_calls=tool_calls) + + +def _map_user_prompt_content( + content: str | Sequence[UserContent], +) -> str | list[PromptMessageContentUnionTypes] | None: + if isinstance(content, str): + return content + + prompt_content: list[PromptMessageContentUnionTypes] = [] + for item in content: + if isinstance(item, CachePoint): + continue + if isinstance(item, str): + prompt_content.append(TextPromptMessageContent(data=item)) + elif isinstance(item, TextContent): + prompt_content.append(TextPromptMessageContent(data=item.content)) + elif _is_multi_modal_content(item): + prompt_content.append(_map_multi_modal_user_content(item)) + else: + raise UnexpectedModelBehavior(f"Unsupported user prompt content: {type(item).__name__}") + return _normalize_prompt_content(prompt_content) + + +def _is_multi_modal_content(item: object) -> bool: + return isinstance( + item, + ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent | UploadedFile, + ) + + +def _map_multi_modal_user_content( + item: MultiModalContent, +) -> PromptMessageContentUnionTypes: + if isinstance(item, ImageUrl): + detail = ( + ImagePromptMessageContent.DETAIL.HIGH + if _get_detail(item) == _DETAIL_HIGH + else ImagePromptMessageContent.DETAIL.LOW + ) + return ImagePromptMessageContent( + url=item.url, + mime_type=item.media_type, + format=item.format, + filename=f"{item.identifier}.{item.format}", + detail=detail, + ) + if isinstance(item, AudioUrl): + return AudioPromptMessageContent( + url=item.url, + mime_type=item.media_type, + format=item.format, + filename=f"{item.identifier}.{item.format}", + ) + if isinstance(item, VideoUrl): + return VideoPromptMessageContent( + url=item.url, + mime_type=item.media_type, + format=item.format, + filename=f"{item.identifier}.{item.format}", + ) + if isinstance(item, DocumentUrl): + return DocumentPromptMessageContent( + url=item.url, + mime_type=item.media_type, + format=item.format, + filename=f"{item.identifier}.{item.format}", + ) + if isinstance(item, BinaryContent): + return _map_binary_content_to_prompt_content(item) + if isinstance(item, UploadedFile): + raise UnexpectedModelBehavior( + "UploadedFile content is not supported by the daemon adapter" + ) + assert_never(item) + + +def _map_binary_content_to_prompt_content( + item: BinaryContent, +) -> PromptMessageContentUnionTypes: + filename = f"{item.identifier}.{item.format}" + if item.is_image: + detail = ( + ImagePromptMessageContent.DETAIL.HIGH + if _get_detail(item) == _DETAIL_HIGH + else ImagePromptMessageContent.DETAIL.LOW + ) + return ImagePromptMessageContent( + base64_data=item.base64, + mime_type=item.media_type, + format=item.format, + filename=filename, + detail=detail, + ) + if item.is_audio: + return AudioPromptMessageContent( + base64_data=item.base64, + mime_type=item.media_type, + format=item.format, + filename=filename, + ) + if item.is_video: + return VideoPromptMessageContent( + base64_data=item.base64, + mime_type=item.media_type, + format=item.format, + filename=filename, + ) + if item.is_document: + return DocumentPromptMessageContent( + base64_data=item.base64, + mime_type=item.media_type, + format=item.format, + filename=filename, + ) + raise UnexpectedModelBehavior( + f"Unsupported binary media type for daemon adapter: {item.media_type}" + ) + + +def _normalize_prompt_content( + content: list[PromptMessageContentUnionTypes], +) -> str | list[PromptMessageContentUnionTypes] | None: + if not content: + return None + if len(content) == 1 and isinstance(content[0], TextPromptMessageContent): + return content[0].data + return content + + +def _map_tool_definitions_to_prompt_tools( + model_request_parameters: ModelRequestParameters, +) -> list[PromptMessageTool] | None: + tool_definitions = [ + *model_request_parameters.function_tools, + *model_request_parameters.output_tools, + ] + if not tool_definitions: + return None + + return [ + PromptMessageTool( + name=tool_definition.name, + description=tool_definition.description or "", + parameters=cast(dict[str, object], tool_definition.parameters_json_schema), + ) + for tool_definition in tool_definitions + ] + + +def _map_model_settings_to_parameters(model_settings: ModelSettings | None) -> dict[str, object]: + if not model_settings: + return {} + + parameters: dict[str, object] = { + key: value + for key, value in model_settings.items() + if value is not None and key not in {"extra_body", "stop_sequences"} + } + + extra_body = model_settings.get("extra_body") + if isinstance(extra_body, Mapping): + parameters.update(cast(Mapping[str, object], extra_body)) + + return parameters + + +def _get_stop_sequences(model_settings: ModelSettings | None) -> list[str] | None: + if not model_settings: + return None + return list(model_settings.get("stop_sequences") or []) or None + + +def _map_usage(usage: LLMUsage) -> RequestUsage: + return RequestUsage( + input_tokens=usage.prompt_tokens, output_tokens=usage.completion_tokens + ) + + +def _normalize_finish_reason(finish_reason: str) -> FinishReason: + lowered = finish_reason.lower() + if lowered in {"stop", "length", "content_filter", "error", "tool_call"}: + return cast(FinishReason, lowered) + if lowered in {"tool_calls", "function_call", "function_calls"}: + return "tool_call" + return "error" + + +def _chunk_to_stream_events( + parts_manager: ModelResponsePartsManager, + chunk: LLMResultChunk, + provider_name: str, + embedded_thinking_parser: "_EmbeddedThinkingParser", +) -> list[ModelResponseStreamEvent]: + events: list[ModelResponseStreamEvent] = [] + message = chunk.delta.message + + if isinstance(message.content, str): + if message.content: + events.extend( + embedded_thinking_parser.parse( + parts_manager, message.content, provider_name + ) + ) + elif isinstance(message.content, list): + for part in _map_assistant_content_to_response_parts(message.content): + if isinstance(part, TextPart): + events.extend( + parts_manager.handle_text_delta( + vendor_part_id=None, + content=part.content, + provider_name=provider_name, + ) + ) + else: + events.append(parts_manager.handle_part(vendor_part_id=None, part=part)) + + for index, tool_call in enumerate(message.tool_calls): + vendor_id = tool_call.id or f"chunk-{chunk.delta.index}-tool-{index}" + events.append( + parts_manager.handle_tool_call_part( + vendor_part_id=vendor_id, + tool_name=tool_call.function.name, + args=tool_call.function.arguments, + tool_call_id=tool_call.id, + provider_name=provider_name, + ) + ) + + return events + + +def _map_assistant_content_to_response_parts( + content: Sequence[PromptMessageContentUnionTypes], +) -> list[ModelResponsePart]: + response_parts: list[ModelResponsePart] = [] + + for item in content: + if isinstance(item, TextPromptMessageContent): + if item.data: + response_parts.extend(_parse_assistant_text_parts(item.data)) + elif isinstance( + item, + ImagePromptMessageContent + | AudioPromptMessageContent + | VideoPromptMessageContent + | DocumentPromptMessageContent, + ): + if item.url: + raise UnexpectedModelBehavior( + "URL-based assistant multimodal output is not supported by the daemon adapter" + ) + if not item.base64_data: + continue + response_parts.append( + FilePart( + content=BinaryContent( + data=base64.b64decode(item.base64_data), + media_type=item.mime_type, + ), + provider_name=None, + ) + ) + else: + assert_never(item) + + return response_parts + + +def _get_detail(item: ImageUrl | BinaryContent) -> str | None: + metadata = item.vendor_metadata or {} + detail = metadata.get("detail") + return detail if isinstance(detail, str) else None + + +def _parse_assistant_text_parts(content: str) -> list[ModelResponsePart]: + response_parts: list[ModelResponsePart] = [] + cursor = 0 + + for match in _THINK_TAG_PATTERN.finditer(content): + if match.start() > cursor: + response_parts.append( + TextPart(content=content[cursor : match.start()], provider_name=None) + ) + + thinking_content = match.group(1).strip("\n") + if thinking_content: + response_parts.append( + ThinkingPart(content=thinking_content, provider_name=None) + ) + cursor = match.end() + + if cursor < len(content): + response_parts.append(TextPart(content=content[cursor:], provider_name=None)) + + if response_parts: + return response_parts + return [TextPart(content=content, provider_name=None)] + + +@dataclass(slots=True) +class _EmbeddedThinkingParser: + _pending: str = "" + _inside_thinking: bool = False + + def parse( + self, + parts_manager: ModelResponsePartsManager, + content: str, + provider_name: str, + ) -> list[ModelResponseStreamEvent]: + events: list[ModelResponseStreamEvent] = [] + buffer = self._pending + content + self._pending = "" + + while buffer: + if self._inside_thinking: + end_index = buffer.find(_THINK_CLOSE_TAG) + if end_index >= 0: + if end_index > 0: + events.extend( + parts_manager.handle_thinking_delta( + vendor_part_id=None, + content=buffer[:end_index], + provider_name=provider_name, + ) + ) + buffer = buffer[end_index + len(_THINK_CLOSE_TAG) :] + self._inside_thinking = False + continue + + safe_content, self._pending = _split_incomplete_tag_suffix( + buffer, _THINK_CLOSE_TAG + ) + if safe_content: + events.extend( + parts_manager.handle_thinking_delta( + vendor_part_id=None, + content=safe_content, + provider_name=provider_name, + ) + ) + break + + start_index = buffer.find(_THINK_OPEN_TAG) + if start_index >= 0: + if start_index > 0: + events.extend( + parts_manager.handle_text_delta( + vendor_part_id=None, + content=buffer[:start_index], + provider_name=provider_name, + ) + ) + buffer = buffer[start_index + len(_THINK_OPEN_TAG) :] + self._inside_thinking = True + continue + + safe_content, self._pending = _split_incomplete_tag_suffix( + buffer, _THINK_OPEN_TAG + ) + if safe_content: + events.extend( + parts_manager.handle_text_delta( + vendor_part_id=None, + content=safe_content, + provider_name=provider_name, + ) + ) + break + + return events + + def flush( + self, + parts_manager: ModelResponsePartsManager, + provider_name: str, + ) -> list[ModelResponseStreamEvent]: + if not self._pending: + return [] + + pending = self._pending + self._pending = "" + if self._inside_thinking: + return list( + parts_manager.handle_thinking_delta( + vendor_part_id=None, + content=pending, + provider_name=provider_name, + ) + ) + return list( + parts_manager.handle_text_delta( + vendor_part_id=None, + content=pending, + provider_name=provider_name, + ) + ) + + +def _split_incomplete_tag_suffix(content: str, tag: str) -> tuple[str, str]: + for suffix_length in range(len(tag) - 1, 0, -1): + if content.endswith(tag[:suffix_length]): + return content[:-suffix_length], content[-suffix_length:] + return content, "" diff --git a/dify-agent/src/dify_agent/adapters/llm/provider.py b/dify-agent/src/dify_agent/adapters/llm/provider.py new file mode 100644 index 0000000000..9b2bf35e73 --- /dev/null +++ b/dify-agent/src/dify_agent/adapters/llm/provider.py @@ -0,0 +1,252 @@ +"""Dify plugin-daemon provider for Pydantic AI LLM adapters.""" + +from __future__ import annotations + +import json +from collections.abc import AsyncIterator, Callable, Mapping +from dataclasses import dataclass, field +from typing import NoReturn + +import httpx +from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from pydantic import BaseModel +from typing_extensions import override + +from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UnexpectedModelBehavior, UserError +from pydantic_ai.providers import Provider + +_DEFAULT_DAEMON_TIMEOUT: float | httpx.Timeout | None = 600.0 + + +class PluginDaemonBasicResponse(BaseModel): + code: int + message: str + data: object | None = None + + +@dataclass(slots=True) +class DifyPluginDaemonLLMClient: + plugin_daemon_url: str + plugin_daemon_api_key: str + tenant_id: str + plugin_id: str + provider: str + user_id: str | None + http_client: httpx.AsyncClient = field(repr=False) + + def __post_init__(self) -> None: + self.plugin_daemon_url = self.plugin_daemon_url.rstrip("/") + + async def iter_llm_result_chunks( + self, + *, + model: str, + credentials: dict[str, object], + prompt_messages: list[PromptMessage], + model_parameters: dict[str, object], + tools: list[PromptMessageTool] | None, + stop: list[str] | None, + stream: bool, + ) -> AsyncIterator[LLMResultChunk]: + async for item in self._iter_stream_response( + model_name=model, + path=f"plugin/{self.tenant_id}/dispatch/llm/invoke", + request_data={ + "provider": self.provider, + "model_type": "llm", + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + }, + response_model=LLMResultChunk, + ): + yield item + + async def _iter_stream_response[T: BaseModel]( + self, + *, + model_name: str, + path: str, + request_data: Mapping[str, object], + response_model: type[T], + ) -> AsyncIterator[T]: + payload: dict[str, object] = {"data": _to_jsonable(request_data)} + if self.user_id is not None: + payload["user_id"] = self.user_id + + headers = { + "X-Api-Key": self.plugin_daemon_api_key, + "X-Plugin-ID": self.plugin_id, + "Content-Type": "application/json", + } + url = f"{self.plugin_daemon_url}/{path}" + + async with self.http_client.stream("POST", url, headers=headers, json=payload) as response: + if response.is_error: + body = (await response.aread()).decode("utf-8", errors="replace") + error = _decode_plugin_daemon_error_payload(body) + if error is not None: + _raise_plugin_daemon_error( + model_name=model_name, + error_type=error["error_type"], + message=error["message"], + status_code=response.status_code, + body=error, + ) + raise ModelHTTPError(response.status_code, model_name, body or None) + + async for raw_line in response.aiter_lines(): + line = raw_line.strip() + if not line: + continue + if line.startswith("data:"): + line = line[5:].strip() + + wrapped = PluginDaemonBasicResponse.model_validate_json(line) + if wrapped.code != 0: + error = _decode_plugin_daemon_error_payload(wrapped.message) + if error is not None: + _raise_plugin_daemon_error( + model_name=model_name, + error_type=error["error_type"], + message=error["message"], + body=error, + ) + raise ModelAPIError( + model_name, + f"Plugin daemon returned error code {wrapped.code}: {wrapped.message}", + ) + if wrapped.data is None: + raise UnexpectedModelBehavior("Plugin daemon returned an empty stream item") + yield response_model.model_validate(wrapped.data) + + +@dataclass(slots=True, kw_only=True) +class DifyPluginDaemonProvider(Provider[DifyPluginDaemonLLMClient]): + """Pydantic AI provider for Dify plugin-daemon dispatch requests.""" + + tenant_id: str + plugin_id: str + plugin_provider: str + plugin_daemon_url: str + plugin_daemon_api_key: str = field(repr=False) + user_id: str | None = None + timeout: float | httpx.Timeout | None = _DEFAULT_DAEMON_TIMEOUT + _client: DifyPluginDaemonLLMClient = field(init=False, repr=False) + _own_http_client: httpx.AsyncClient | None = field(init=False, default=None, repr=False) + _http_client_factory: Callable[[], httpx.AsyncClient] | None = field(init=False, default=None, repr=False) + + def __post_init__(self) -> None: + self.plugin_daemon_url = self.plugin_daemon_url.rstrip("/") + self._http_client_factory = self._make_http_client + http_client = self._make_http_client() + self._own_http_client = http_client + self._client = DifyPluginDaemonLLMClient( + plugin_daemon_url=self.plugin_daemon_url, + plugin_daemon_api_key=self.plugin_daemon_api_key, + tenant_id=self.tenant_id, + plugin_id=self.plugin_id, + provider=self.plugin_provider, + user_id=self.user_id, + http_client=http_client, + ) + + def _make_http_client(self) -> httpx.AsyncClient: + return httpx.AsyncClient(timeout=self.timeout, trust_env=False) + + @override + def _set_http_client(self, http_client: httpx.AsyncClient) -> None: + self._client.http_client = http_client + + @property + @override + def name(self) -> str: + return f"DifyPlugin/{self.plugin_provider}" + + @property + @override + def base_url(self) -> str: + return self.plugin_daemon_url + + @property + @override + def client(self) -> DifyPluginDaemonLLMClient: + return self._client + + +def _to_jsonable(value: object) -> object: + if isinstance(value, BaseModel): + return value.model_dump(mode="json") + if isinstance(value, dict): + return {key: _to_jsonable(item) for key, item in value.items()} + if isinstance(value, list | tuple): + return [_to_jsonable(item) for item in value] + return value + + +def _decode_plugin_daemon_error_payload(raw_message: str) -> dict[str, str] | None: + try: + parsed = json.loads(raw_message) + except json.JSONDecodeError: + return None + + if not isinstance(parsed, dict): + return None + + error_type = parsed.get("error_type") + message = parsed.get("message") + if not isinstance(error_type, str) or not isinstance(message, str): + return None + return {"error_type": error_type, "message": message} + + +def _raise_plugin_daemon_error( + *, + model_name: str, + error_type: str, + message: str, + status_code: int | None = None, + body: object | None = None, +) -> NoReturn: + http_error_body = body or {"error_type": error_type, "message": message} + + match error_type: + case "PluginInvokeError": + nested_error = _decode_plugin_daemon_error_payload(message) + if nested_error is not None: + _raise_plugin_daemon_error( + model_name=model_name, + error_type=nested_error["error_type"], + message=nested_error["message"], + status_code=status_code, + body=nested_error, + ) + raise ModelAPIError(model_name, message) + case "PluginDaemonUnauthorizedError" | "InvokeAuthorizationError": + raise ModelHTTPError(status_code or 401, model_name, http_error_body) + case "PluginPermissionDeniedError": + raise ModelHTTPError(status_code or 403, model_name, http_error_body) + case ( + "PluginDaemonBadRequestError" + | "InvokeBadRequestError" + | "CredentialsValidateFailedError" + | "PluginUniqueIdentifierError" + ): + raise ModelHTTPError(status_code or 400, model_name, http_error_body) + case "EndpointSetupFailedError" | "TriggerProviderCredentialValidationError": + raise UserError(message) + case "PluginDaemonNotFoundError" | "PluginNotFoundError": + raise ModelHTTPError(status_code or 404, model_name, http_error_body) + case "InvokeRateLimitError": + raise ModelHTTPError(status_code or 429, model_name, http_error_body) + case "PluginDaemonInternalServerError" | "PluginDaemonInnerError": + raise ModelHTTPError(status_code or 500, model_name, http_error_body) + case "InvokeConnectionError" | "InvokeServerUnavailableError": + raise ModelHTTPError(status_code or 503, model_name, http_error_body) + case _: + raise ModelAPIError(model_name, f"{error_type}: {message}") diff --git a/dify-agent/tests/unit/dify_agent/adapters/__init__.py b/dify-agent/tests/unit/dify_agent/adapters/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dify-agent/tests/unit/dify_agent/adapters/llm/__init__.py b/dify-agent/tests/unit/dify_agent/adapters/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dify-agent/tests/unit/dify_agent/adapters/llm/_test_support.py b/dify-agent/tests/unit/dify_agent/adapters/llm/_test_support.py new file mode 100644 index 0000000000..b37232c08c --- /dev/null +++ b/dify-agent/tests/unit/dify_agent/adapters/llm/_test_support.py @@ -0,0 +1,91 @@ +import json +from decimal import Decimal + +import httpx +from graphon.model_runtime.entities.llm_entities import ( + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, +) +from pydantic import BaseModel + + +def make_usage(prompt_tokens: int = 3, completion_tokens: int = 5) -> LLMUsage: + return LLMUsage( + prompt_tokens=prompt_tokens, + prompt_unit_price=Decimal("0"), + prompt_price_unit=Decimal("0"), + prompt_price=Decimal("0"), + completion_tokens=completion_tokens, + completion_unit_price=Decimal("0"), + completion_price_unit=Decimal("0"), + completion_price=Decimal("0"), + total_tokens=prompt_tokens + completion_tokens, + total_price=Decimal("0"), + currency="USD", + latency=0.1, + ) + + +def single_text_chunk( + text: str, + *, + prompt_tokens: int = 3, + completion_tokens: int = 5, +) -> list[LLMResultChunk]: + return [ + LLMResultChunk( + model="demo-model", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=text, tool_calls=[]), + usage=make_usage( + prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ), + ), + ) + ] + + +def wrap_plugin_daemon_stream_item(item: object) -> str: + if isinstance(item, BaseModel): + data = item.model_dump(mode="json") + else: + data = item + return f"data: {json.dumps({'code': 0, 'message': '', 'data': data})}\n\n" + + +def build_stream_response(*items: object, status_code: int = 200) -> httpx.Response: + body = "".join(wrap_plugin_daemon_stream_item(item) for item in items) + return httpx.Response( + status_code=status_code, + headers={"content-type": "text/event-stream"}, + content=body.encode("utf-8"), + ) + + +def build_error_response( + error_type: str, message: str, *, status_code: int +) -> httpx.Response: + return httpx.Response( + status_code=status_code, + headers={"content-type": "application/json"}, + content=json.dumps({"error_type": error_type, "message": message}).encode( + "utf-8" + ), + ) + + +def build_stream_error( + error_type: str, message: str, *, code: int = -500 +) -> httpx.Response: + return httpx.Response( + status_code=200, + headers={"content-type": "text/event-stream"}, + content=( + f"data: {json.dumps({'code': code, 'message': json.dumps({'error_type': error_type, 'message': message}), 'data': None})}\n\n" + ).encode("utf-8"), + ) diff --git a/dify-agent/tests/unit/dify_agent/adapters/llm/test_model.py b/dify-agent/tests/unit/dify_agent/adapters/llm/test_model.py new file mode 100644 index 0000000000..e0595139e8 --- /dev/null +++ b/dify-agent/tests/unit/dify_agent/adapters/llm/test_model.py @@ -0,0 +1,374 @@ +import json +import unittest +from contextlib import asynccontextmanager +from typing import cast +from unittest.mock import patch + +import httpx +from pydantic_ai.exceptions import ModelHTTPError, UserError +from pydantic_ai.messages import ( + InstructionPart, + ModelRequest, + ModelResponse, + RetryPromptPart, + SystemPromptPart, + TextPart, + ThinkingPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, +) +from pydantic_ai.models import ModelRequestParameters +from pydantic_ai.tools import ToolDefinition + +from dify_agent.adapters.llm import DifyLLMAdapterModel, DifyPluginDaemonProvider + +from ._test_support import ( + AssistantPromptMessage, + LLMResultChunk, + LLMResultChunkDelta, + build_error_response, + build_stream_error, + build_stream_response, + make_usage, + single_text_chunk, +) + + +class DifyLLMAdapterModelTests(unittest.IsolatedAsyncioTestCase): + def make_provider( + self, + *, + user_id: str | None = None, + ) -> DifyPluginDaemonProvider: + return DifyPluginDaemonProvider( + tenant_id="tenant-1", + plugin_id="langgenius/openai", + plugin_provider="openai", + plugin_daemon_url="http://plugin-daemon", + plugin_daemon_api_key="daemon-secret", + user_id=user_id, + ) + + @asynccontextmanager + async def mock_daemon_stream(self, handler: httpx.MockTransport): + @asynccontextmanager + async def mock_stream( + client: httpx.AsyncClient, + method: str, + url: str, + **kwargs: object, + ): + request = client.build_request( + method, + url, + headers=cast(dict[str, str] | None, kwargs.get("headers")), + json=kwargs.get("json"), + ) + yield handler.handle_request(request) + + with patch.object(httpx.AsyncClient, "stream", new=mock_stream): + yield + + async def test_request_uses_plugin_daemon_dispatch_contract(self) -> None: + messages = [ + ModelRequest( + parts=[ + SystemPromptPart("request system"), + UserPromptPart("hello"), + ToolReturnPart( + tool_name="lookup", + content={"city": "Paris"}, + tool_call_id="tool-1", + ), + RetryPromptPart( + content="try again", tool_name="lookup", tool_call_id="tool-1" + ), + ] + ), + ModelResponse( + parts=[ + TextPart(content="previous answer"), + ToolCallPart( + tool_name="lookup", + args='{"city":"Paris"}', + tool_call_id="tool-1", + ), + ] + ), + ] + request_parameters = ModelRequestParameters( + function_tools=[ + ToolDefinition( + name="weather", + description="Look up the weather", + parameters_json_schema={ + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + ) + ], + instruction_parts=[InstructionPart(content="be concise")], + ) + + def handler(request: httpx.Request) -> httpx.Response: + self.assertEqual(request.method, "POST") + self.assertEqual(request.url.path, "/plugin/tenant-1/dispatch/llm/invoke") + self.assertEqual(request.headers["X-Api-Key"], "daemon-secret") + self.assertEqual(request.headers["X-Plugin-ID"], "langgenius/openai") + + payload = json.loads(request.content.decode("utf-8")) + self.assertEqual(payload["user_id"], "user-123") + data = payload["data"] + self.assertEqual(data["provider"], "openai") + self.assertEqual(data["model_type"], "llm") + self.assertEqual(data["model"], "demo-model") + self.assertEqual(data["credentials"], {"api_key": "secret"}) + self.assertEqual( + data["model_parameters"], + {"temperature": 0.2, "max_tokens": 128, "logit_bias": {"1": 2}}, + ) + self.assertEqual(data["stop"], ["END"]) + self.assertFalse(data["stream"]) + self.assertEqual(data["tools"][0]["name"], "weather") + self.assertEqual(data["prompt_messages"][0]["role"], "system") + self.assertEqual(data["prompt_messages"][0]["content"], "request system") + self.assertEqual(data["prompt_messages"][1]["content"], "be concise") + self.assertEqual(data["prompt_messages"][2]["content"], "hello") + self.assertEqual(data["prompt_messages"][3]["role"], "tool") + self.assertEqual(data["prompt_messages"][4]["role"], "tool") + self.assertEqual(data["prompt_messages"][5]["role"], "assistant") + return build_stream_response( + LLMResultChunk( + model="demo-model", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content="adapter response", tool_calls=[] + ), + usage=make_usage(prompt_tokens=11, completion_tokens=7), + ), + ) + ) + + async with self.mock_daemon_stream(httpx.MockTransport(handler)): + adapter = DifyLLMAdapterModel( + "demo-model", + self.make_provider(user_id="user-123"), + credentials={"api_key": "secret"}, + model_settings={"temperature": 0.2, "stop_sequences": ["DEFAULT_STOP"]}, + ) + + response = await adapter.request( + messages, + model_settings={"max_tokens": 128, "logit_bias": {"1": 2}, "stop_sequences": ["END"]}, + model_request_parameters=request_parameters, + ) + + self.assertEqual(response.model_name, "demo-model") + self.assertEqual(response.provider_name, "DifyPlugin/openai") + self.assertEqual(response.usage.input_tokens, 11) + self.assertEqual(response.usage.output_tokens, 7) + self.assertEqual(response.parts[0].part_kind, "text") + self.assertEqual(cast(TextPart, response.parts[0]).content, "adapter response") + + async def test_request_returns_a_response(self) -> None: + def handler(_request: httpx.Request) -> httpx.Response: + return build_stream_response( + *single_text_chunk( + "adapter response", prompt_tokens=11, completion_tokens=7 + ) + ) + + async with self.mock_daemon_stream(httpx.MockTransport(handler)): + adapter = DifyLLMAdapterModel( + "demo-model", + self.make_provider(), + credentials={"api_key": "secret"}, + ) + + response = await adapter.request( + [ModelRequest(parts=[UserPromptPart("hello")])], + model_settings=None, + model_request_parameters=ModelRequestParameters(), + ) + + self.assertEqual(response.model_name, "demo-model") + self.assertEqual(response.parts[0].part_kind, "text") + self.assertEqual(cast(TextPart, response.parts[0]).content, "adapter response") + self.assertEqual(response.usage.input_tokens, 11) + self.assertEqual(response.usage.output_tokens, 7) + + async def test_request_stream_yields_response_parts_and_usage(self) -> None: + def handler(_request: httpx.Request) -> httpx.Response: + return build_stream_response( + LLMResultChunk( + model="demo-model", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content="hello ", tool_calls=[]), + ), + ), + LLMResultChunk( + model="demo-model", + delta=LLMResultChunkDelta( + index=1, + message=AssistantPromptMessage( + content="", + tool_calls=[ + AssistantPromptMessage.ToolCall( + id="call-1", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name="weather", + arguments='{"city":"Paris"}', + ), + ) + ], + ), + ), + ), + LLMResultChunk( + model="demo-model", + delta=LLMResultChunkDelta( + index=2, + message=AssistantPromptMessage(content="world", tool_calls=[]), + usage=make_usage(prompt_tokens=6, completion_tokens=4), + finish_reason="tool_calls", + ), + ), + ) + + async with self.mock_daemon_stream(httpx.MockTransport(handler)): + adapter = DifyLLMAdapterModel( + "demo-model", + self.make_provider(), + credentials={"api_key": "secret"}, + ) + + async with adapter.request_stream( + [ModelRequest(parts=[UserPromptPart("hello")])], + model_settings=None, + model_request_parameters=ModelRequestParameters(), + ) as stream: + events = [event async for event in stream] + response = stream.get() + + self.assertTrue(events) + self.assertEqual(response.usage.input_tokens, 6) + self.assertEqual(response.usage.output_tokens, 4) + self.assertEqual(response.finish_reason, "tool_call") + self.assertEqual(response.parts[0].part_kind, "text") + self.assertEqual(cast(TextPart, response.parts[0]).content, "hello ") + self.assertEqual(response.parts[1].part_kind, "tool-call") + self.assertEqual(cast(ToolCallPart, response.parts[1]).tool_name, "weather") + self.assertEqual(response.parts[2].part_kind, "text") + self.assertEqual(cast(TextPart, response.parts[2]).content, "world") + + async def test_request_splits_embedded_thinking_tags_into_parts(self) -> None: + def handler(_request: httpx.Request) -> httpx.Response: + return build_stream_response( + *single_text_chunk("beforereasoningafter") + ) + + async with self.mock_daemon_stream(httpx.MockTransport(handler)): + adapter = DifyLLMAdapterModel( + "demo-model", + self.make_provider(), + credentials={"api_key": "secret"}, + ) + + response = await adapter.request( + [ModelRequest(parts=[UserPromptPart("hello")])], + model_settings=None, + model_request_parameters=ModelRequestParameters(), + ) + + self.assertEqual( + [part.part_kind for part in response.parts], ["text", "thinking", "text"] + ) + self.assertEqual(cast(TextPart, response.parts[0]).content, "before") + self.assertEqual(cast(ThinkingPart, response.parts[1]).content, "reasoning") + self.assertEqual(cast(TextPart, response.parts[2]).content, "after") + + async def test_request_maps_stream_envelope_rate_limit_error_to_http_error( + self, + ) -> None: + def handler(_request: httpx.Request) -> httpx.Response: + return build_stream_error( + "PluginInvokeError", + json.dumps( + {"error_type": "InvokeRateLimitError", "message": "too many"} + ), + ) + + async with self.mock_daemon_stream(httpx.MockTransport(handler)): + adapter = DifyLLMAdapterModel( + "demo-model", + self.make_provider(), + credentials={"api_key": "secret"}, + ) + + with self.assertRaises(ModelHTTPError) as context: + await adapter.request( + [ModelRequest(parts=[UserPromptPart("hello")])], + model_settings=None, + model_request_parameters=ModelRequestParameters(), + ) + + self.assertEqual(context.exception.status_code, 429) + self.assertEqual( + context.exception.body, + {"error_type": "InvokeRateLimitError", "message": "too many"}, + ) + + async def test_request_maps_http_error_payload_to_http_error(self) -> None: + def handler(_request: httpx.Request) -> httpx.Response: + return build_error_response( + "PluginDaemonUnauthorizedError", "invalid api key", status_code=401 + ) + + async with self.mock_daemon_stream(httpx.MockTransport(handler)): + adapter = DifyLLMAdapterModel( + "demo-model", + self.make_provider(), + credentials={"api_key": "secret"}, + ) + + with self.assertRaises(ModelHTTPError) as context: + await adapter.request( + [ModelRequest(parts=[UserPromptPart("hello")])], + model_settings=None, + model_request_parameters=ModelRequestParameters(), + ) + + self.assertEqual(context.exception.status_code, 401) + self.assertEqual( + context.exception.body, + { + "error_type": "PluginDaemonUnauthorizedError", + "message": "invalid api key", + }, + ) + + async def test_request_maps_endpoint_setup_error_to_user_error(self) -> None: + def handler(_request: httpx.Request) -> httpx.Response: + return build_stream_error( + "EndpointSetupFailedError", "missing endpoint config" + ) + + async with self.mock_daemon_stream(httpx.MockTransport(handler)): + adapter = DifyLLMAdapterModel( + "demo-model", + self.make_provider(), + credentials={"api_key": "secret"}, + ) + + with self.assertRaises(UserError) as context: + await adapter.request( + [ModelRequest(parts=[UserPromptPart("hello")])], + model_settings=None, + model_request_parameters=ModelRequestParameters(), + ) + + self.assertEqual(str(context.exception), "missing endpoint config")