feat(agent): add dify llm adapter

This commit is contained in:
盐粒 Yanli 2026-04-24 22:48:07 +08:00
parent 70bd5439d0
commit 6ca07c4a4e
10 changed files with 1605 additions and 0 deletions

View File

@ -0,0 +1,78 @@
"""Run a Pydantic AI agent through the Dify plugin-daemon adapter.
Prerequisites:
- Start the plugin daemon from `dify-aio/dify/docker/docker-compose.middleware.yaml`.
- Run the Dify API with `dify-aio/dify/api/.env` so the daemon can resolve tenants/plugins.
- Fill `dify-agent/.env` with a real tenant, plugin, provider, model, and provider credentials.
Example:
uv run --project dify-agent python examples/run_pydantic_ai_agent.py
"""
from __future__ import annotations
import asyncio
import json
import os
from pathlib import Path
from typing import Any
from pydantic_ai import Agent
from dify_agent import DifyLLMAdapterModel, DifyPluginDaemonProvider
PROJECT_ROOT = Path(__file__).resolve().parents[1]
def load_env_file(path: Path) -> None:
"""Load simple KEY=VALUE lines without adding a dotenv dependency."""
if not path.exists():
return
for raw_line in path.read_text().splitlines():
line = raw_line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'"))
def required_env(name: str) -> str:
value = os.environ.get(name)
if value:
return value
raise RuntimeError(f"Missing required environment variable: {name}")
def load_credentials() -> dict[str, Any]:
raw_credentials = required_env("DIFY_AGENT_MODEL_CREDENTIALS_JSON")
credentials = json.loads(raw_credentials)
if not isinstance(credentials, dict):
raise RuntimeError("DIFY_AGENT_MODEL_CREDENTIALS_JSON must be a JSON object")
return credentials
async def main() -> None:
load_env_file(PROJECT_ROOT / ".env")
model = DifyLLMAdapterModel(
required_env("DIFY_AGENT_MODEL_NAME"),
DifyPluginDaemonProvider(
tenant_id=required_env("DIFY_AGENT_TENANT_ID"),
plugin_id=required_env("DIFY_AGENT_PLUGIN_ID"),
plugin_provider=required_env("DIFY_AGENT_PROVIDER"),
plugin_daemon_url=required_env("PLUGIN_DAEMON_URL"),
plugin_daemon_api_key=required_env("PLUGIN_DAEMON_KEY"),
),
credentials=load_credentials(),
)
agent = Agent(model=model)
async with agent.run_stream("Explain the theory of relativity") as run:
async for piece in run.stream_output():
print(piece, end="", flush=True)
print(run.usage())
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,5 @@
"""Adapters for using Dify components inside the local agent package."""
from .adapters.llm import DifyLLMAdapterModel, DifyPluginDaemonProvider
__all__ = ["DifyLLMAdapterModel", "DifyPluginDaemonProvider"]

View File

@ -0,0 +1 @@
"""Adapter integrations for Dify agent components."""

View File

@ -0,0 +1,6 @@
"""LLM adapters for Dify plugin-daemon integrations."""
from .model import DifyLLMAdapterModel
from .provider import DifyPluginDaemonProvider
__all__ = ["DifyLLMAdapterModel", "DifyPluginDaemonProvider"]

View File

@ -0,0 +1,798 @@
"""Bridge Dify plugin-daemon LLM invocations into Pydantic AI's model interface.
The API and agent layers are clients of the plugin daemon, not direct hosts of provider SDK
implementations. This adapter therefore targets the plugin-daemon dispatch protocol and maps
Pydantic AI messages into the daemon's Graphon-compatible request and stream response schema.
"""
from __future__ import annotations
import base64
import re
from collections.abc import AsyncGenerator, AsyncIterator, Mapping, Sequence
from contextlib import asynccontextmanager
from dataclasses import KW_ONLY, InitVar, dataclass, field
from datetime import datetime, timezone
from typing import cast
from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMUsage
from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentUnionTypes,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
from typing_extensions import assert_never, override
from pydantic_ai._parts_manager import ModelResponsePartsManager
from pydantic_ai.exceptions import UnexpectedModelBehavior
from pydantic_ai.messages import (
AudioUrl,
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
CompactionPart,
DocumentUrl,
FilePart,
FinishReason,
ImageUrl,
ModelMessage,
ModelRequest,
ModelResponse,
ModelResponsePart,
ModelResponseStreamEvent,
MultiModalContent,
RetryPromptPart,
SystemPromptPart,
TextContent,
TextPart,
ThinkingPart,
ToolCallPart,
ToolReturnPart,
UploadedFile,
UserContent,
UserPromptPart,
VideoUrl,
)
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
from pydantic_ai.profiles import ModelProfileSpec
from pydantic_ai.settings import ModelSettings
from pydantic_ai.usage import RequestUsage
from .provider import DifyPluginDaemonLLMClient, DifyPluginDaemonProvider
_THINK_START = "<think>\n"
_THINK_END = "\n</think>"
_THINK_OPEN_TAG = "<think>"
_THINK_CLOSE_TAG = "</think>"
_THINK_TAG_PATTERN = re.compile(r"<think>(.*?)</think>", re.DOTALL)
_DETAIL_HIGH = "high"
@dataclass(slots=True)
class _DifyRequestInput:
credentials: dict[str, object]
prompt_messages: list[PromptMessage]
model_parameters: dict[str, object]
tools: list[PromptMessageTool] | None
stop_sequences: list[str] | None
@dataclass(slots=True)
class DifyLLMAdapterModel(Model[DifyPluginDaemonLLMClient]):
"""Use a Dify plugin-daemon LLM provider as a Pydantic AI model."""
model: str
daemon_provider: DifyPluginDaemonProvider
_: KW_ONLY
credentials: dict[str, object] = field(default_factory=dict, repr=False)
model_profile: InitVar[ModelProfileSpec | None] = None
model_settings: InitVar[ModelSettings | None] = None
def __post_init__(
self,
model_profile: ModelProfileSpec | None,
model_settings: ModelSettings | None,
) -> None:
Model.__init__(
self,
settings=model_settings,
profile=model_profile or self.daemon_provider.model_profile(self.model),
)
@property
@override
def provider(self) -> DifyPluginDaemonProvider:
return self.daemon_provider
@property
@override
def model_name(self) -> str:
return self.model
@property
@override
def system(self) -> str:
return self.daemon_provider.name
@override
async def request(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> ModelResponse:
prepared_settings, prepared_params = self.prepare_request(
model_settings, model_request_parameters
)
request_input = self._build_request_input(
messages, prepared_settings, prepared_params
)
response = DifyStreamedResponse(
model_request_parameters=prepared_params,
chunks=self.daemon_provider.client.iter_llm_result_chunks(
model=self.model_name,
credentials=request_input.credentials,
prompt_messages=request_input.prompt_messages,
model_parameters=request_input.model_parameters,
tools=request_input.tools,
stop=request_input.stop_sequences,
stream=False,
),
response_model_name=self.model_name,
provider_name_value=self.system,
)
async for _event in response:
pass
return response.get()
@asynccontextmanager
@override
async def request_stream(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
run_context: object | None = None,
) -> AsyncGenerator[StreamedResponse, None]:
del run_context
prepared_settings, prepared_params = self.prepare_request(
model_settings, model_request_parameters
)
request_input = self._build_request_input(
messages, prepared_settings, prepared_params
)
yield DifyStreamedResponse(
model_request_parameters=prepared_params,
chunks=self.daemon_provider.client.iter_llm_result_chunks(
model=self.model_name,
credentials=request_input.credentials,
prompt_messages=request_input.prompt_messages,
model_parameters=request_input.model_parameters,
tools=request_input.tools,
stop=request_input.stop_sequences,
stream=True,
),
response_model_name=self.model_name,
provider_name_value=self.system,
)
def _build_request_input(
self,
messages: Sequence[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> _DifyRequestInput:
return _DifyRequestInput(
credentials=dict(self.credentials),
prompt_messages=_map_messages_to_prompt_messages(
messages, model_request_parameters
),
model_parameters=_map_model_settings_to_parameters(model_settings),
tools=_map_tool_definitions_to_prompt_tools(model_request_parameters),
stop_sequences=_get_stop_sequences(model_settings),
)
@dataclass
class DifyStreamedResponse(StreamedResponse):
chunks: AsyncIterator[LLMResultChunk]
response_model_name: str
provider_name_value: str
_timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
_embedded_thinking_parser: "_EmbeddedThinkingParser" = field(
default_factory=lambda: _EmbeddedThinkingParser()
)
@override
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
async for chunk in self.chunks:
if chunk.delta.usage is not None:
self._usage: RequestUsage = _map_usage(chunk.delta.usage)
if chunk.delta.finish_reason is not None:
self.finish_reason: FinishReason | None = _normalize_finish_reason(
chunk.delta.finish_reason
)
for event in _chunk_to_stream_events(
self._parts_manager,
chunk,
self.provider_name_value,
self._embedded_thinking_parser,
):
yield event
for event in self._embedded_thinking_parser.flush(
self._parts_manager, self.provider_name_value
):
yield event
@property
@override
def model_name(self) -> str:
return self.response_model_name
@property
@override
def provider_name(self) -> str:
return self.provider_name_value
@property
@override
def provider_url(self) -> None:
return None
@property
@override
def timestamp(self) -> datetime:
return self._timestamp
def _map_messages_to_prompt_messages(
messages: Sequence[ModelMessage],
model_request_parameters: ModelRequestParameters,
) -> list[PromptMessage]:
prompt_messages: list[PromptMessage] = []
for message in messages:
if isinstance(message, ModelRequest):
prompt_messages.extend(_map_model_request_to_prompt_messages(message))
elif isinstance(message, ModelResponse):
assistant_message = _map_model_response_to_prompt_message(message)
if assistant_message is not None:
prompt_messages.append(assistant_message)
else:
assert_never(message)
instruction_messages = [
SystemPromptMessage(content=part.content)
for part in (
Model._get_instruction_parts(messages, model_request_parameters) or []
)
if part.content.strip()
]
if instruction_messages:
insert_at = next(
(
index
for index, message in enumerate(prompt_messages)
if not isinstance(message, SystemPromptMessage)
),
len(prompt_messages),
)
prompt_messages[insert_at:insert_at] = instruction_messages
return prompt_messages
def _map_model_request_to_prompt_messages(message: ModelRequest) -> list[PromptMessage]:
prompt_messages: list[PromptMessage] = []
for part in message.parts:
if isinstance(part, SystemPromptPart):
prompt_messages.append(SystemPromptMessage(content=part.content))
elif isinstance(part, UserPromptPart):
prompt_messages.append(
UserPromptMessage(content=_map_user_prompt_content(part.content))
)
elif isinstance(part, ToolReturnPart):
prompt_messages.append(_map_tool_return_part_to_prompt_message(part))
elif isinstance(part, RetryPromptPart):
if part.tool_name is None:
prompt_messages.append(UserPromptMessage(content=part.model_response()))
else:
prompt_messages.append(
ToolPromptMessage(
content=part.model_response(),
tool_call_id=part.tool_call_id,
name=part.tool_name,
)
)
else:
assert_never(part)
return prompt_messages
def _map_tool_return_part_to_prompt_message(part: ToolReturnPart) -> ToolPromptMessage:
items = part.content_items(mode="str")
if len(items) == 1 and isinstance(items[0], str):
content: str | list[PromptMessageContentUnionTypes] | None = items[0]
else:
content_items: list[PromptMessageContentUnionTypes] = []
for item in items:
if isinstance(item, str):
content_items.append(TextPromptMessageContent(data=item))
elif isinstance(item, CachePoint):
continue
elif _is_multi_modal_content(item):
content_items.append(_map_multi_modal_user_content(item))
else:
raise UnexpectedModelBehavior(
f"Unsupported daemon tool message content: {type(item).__name__}"
)
content = content_items or None
return ToolPromptMessage(
content=content, tool_call_id=part.tool_call_id, name=part.tool_name
)
def _map_model_response_to_prompt_message(
message: ModelResponse,
) -> AssistantPromptMessage | None:
content_parts: list[PromptMessageContentUnionTypes] = []
tool_calls: list[AssistantPromptMessage.ToolCall] = []
for part in message.parts:
if isinstance(part, TextPart):
if part.content:
content_parts.append(TextPromptMessageContent(data=part.content))
elif isinstance(part, ThinkingPart):
if part.content:
content_parts.append(
TextPromptMessageContent(
data=f"{_THINK_START}{part.content}{_THINK_END}"
)
)
elif isinstance(part, FilePart):
content_parts.append(_map_binary_content_to_prompt_content(part.content))
elif isinstance(part, ToolCallPart):
tool_calls.append(
AssistantPromptMessage.ToolCall(
id=part.tool_call_id or f"tool-call-{part.tool_name}",
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=part.tool_name,
arguments=part.args_as_json_str(),
),
)
)
elif isinstance(
part, BuiltinToolCallPart | BuiltinToolReturnPart | CompactionPart
):
raise UnexpectedModelBehavior(
f"Unsupported response part for daemon adapter: {type(part).__name__}"
)
else:
assert_never(part)
content = _normalize_prompt_content(content_parts)
if content is None and not tool_calls:
return None
return AssistantPromptMessage(content=content, tool_calls=tool_calls)
def _map_user_prompt_content(
content: str | Sequence[UserContent],
) -> str | list[PromptMessageContentUnionTypes] | None:
if isinstance(content, str):
return content
prompt_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, CachePoint):
continue
if isinstance(item, str):
prompt_content.append(TextPromptMessageContent(data=item))
elif isinstance(item, TextContent):
prompt_content.append(TextPromptMessageContent(data=item.content))
elif _is_multi_modal_content(item):
prompt_content.append(_map_multi_modal_user_content(item))
else:
raise UnexpectedModelBehavior(f"Unsupported user prompt content: {type(item).__name__}")
return _normalize_prompt_content(prompt_content)
def _is_multi_modal_content(item: object) -> bool:
return isinstance(
item,
ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent | UploadedFile,
)
def _map_multi_modal_user_content(
item: MultiModalContent,
) -> PromptMessageContentUnionTypes:
if isinstance(item, ImageUrl):
detail = (
ImagePromptMessageContent.DETAIL.HIGH
if _get_detail(item) == _DETAIL_HIGH
else ImagePromptMessageContent.DETAIL.LOW
)
return ImagePromptMessageContent(
url=item.url,
mime_type=item.media_type,
format=item.format,
filename=f"{item.identifier}.{item.format}",
detail=detail,
)
if isinstance(item, AudioUrl):
return AudioPromptMessageContent(
url=item.url,
mime_type=item.media_type,
format=item.format,
filename=f"{item.identifier}.{item.format}",
)
if isinstance(item, VideoUrl):
return VideoPromptMessageContent(
url=item.url,
mime_type=item.media_type,
format=item.format,
filename=f"{item.identifier}.{item.format}",
)
if isinstance(item, DocumentUrl):
return DocumentPromptMessageContent(
url=item.url,
mime_type=item.media_type,
format=item.format,
filename=f"{item.identifier}.{item.format}",
)
if isinstance(item, BinaryContent):
return _map_binary_content_to_prompt_content(item)
if isinstance(item, UploadedFile):
raise UnexpectedModelBehavior(
"UploadedFile content is not supported by the daemon adapter"
)
assert_never(item)
def _map_binary_content_to_prompt_content(
item: BinaryContent,
) -> PromptMessageContentUnionTypes:
filename = f"{item.identifier}.{item.format}"
if item.is_image:
detail = (
ImagePromptMessageContent.DETAIL.HIGH
if _get_detail(item) == _DETAIL_HIGH
else ImagePromptMessageContent.DETAIL.LOW
)
return ImagePromptMessageContent(
base64_data=item.base64,
mime_type=item.media_type,
format=item.format,
filename=filename,
detail=detail,
)
if item.is_audio:
return AudioPromptMessageContent(
base64_data=item.base64,
mime_type=item.media_type,
format=item.format,
filename=filename,
)
if item.is_video:
return VideoPromptMessageContent(
base64_data=item.base64,
mime_type=item.media_type,
format=item.format,
filename=filename,
)
if item.is_document:
return DocumentPromptMessageContent(
base64_data=item.base64,
mime_type=item.media_type,
format=item.format,
filename=filename,
)
raise UnexpectedModelBehavior(
f"Unsupported binary media type for daemon adapter: {item.media_type}"
)
def _normalize_prompt_content(
content: list[PromptMessageContentUnionTypes],
) -> str | list[PromptMessageContentUnionTypes] | None:
if not content:
return None
if len(content) == 1 and isinstance(content[0], TextPromptMessageContent):
return content[0].data
return content
def _map_tool_definitions_to_prompt_tools(
model_request_parameters: ModelRequestParameters,
) -> list[PromptMessageTool] | None:
tool_definitions = [
*model_request_parameters.function_tools,
*model_request_parameters.output_tools,
]
if not tool_definitions:
return None
return [
PromptMessageTool(
name=tool_definition.name,
description=tool_definition.description or "",
parameters=cast(dict[str, object], tool_definition.parameters_json_schema),
)
for tool_definition in tool_definitions
]
def _map_model_settings_to_parameters(model_settings: ModelSettings | None) -> dict[str, object]:
if not model_settings:
return {}
parameters: dict[str, object] = {
key: value
for key, value in model_settings.items()
if value is not None and key not in {"extra_body", "stop_sequences"}
}
extra_body = model_settings.get("extra_body")
if isinstance(extra_body, Mapping):
parameters.update(cast(Mapping[str, object], extra_body))
return parameters
def _get_stop_sequences(model_settings: ModelSettings | None) -> list[str] | None:
if not model_settings:
return None
return list(model_settings.get("stop_sequences") or []) or None
def _map_usage(usage: LLMUsage) -> RequestUsage:
return RequestUsage(
input_tokens=usage.prompt_tokens, output_tokens=usage.completion_tokens
)
def _normalize_finish_reason(finish_reason: str) -> FinishReason:
lowered = finish_reason.lower()
if lowered in {"stop", "length", "content_filter", "error", "tool_call"}:
return cast(FinishReason, lowered)
if lowered in {"tool_calls", "function_call", "function_calls"}:
return "tool_call"
return "error"
def _chunk_to_stream_events(
parts_manager: ModelResponsePartsManager,
chunk: LLMResultChunk,
provider_name: str,
embedded_thinking_parser: "_EmbeddedThinkingParser",
) -> list[ModelResponseStreamEvent]:
events: list[ModelResponseStreamEvent] = []
message = chunk.delta.message
if isinstance(message.content, str):
if message.content:
events.extend(
embedded_thinking_parser.parse(
parts_manager, message.content, provider_name
)
)
elif isinstance(message.content, list):
for part in _map_assistant_content_to_response_parts(message.content):
if isinstance(part, TextPart):
events.extend(
parts_manager.handle_text_delta(
vendor_part_id=None,
content=part.content,
provider_name=provider_name,
)
)
else:
events.append(parts_manager.handle_part(vendor_part_id=None, part=part))
for index, tool_call in enumerate(message.tool_calls):
vendor_id = tool_call.id or f"chunk-{chunk.delta.index}-tool-{index}"
events.append(
parts_manager.handle_tool_call_part(
vendor_part_id=vendor_id,
tool_name=tool_call.function.name,
args=tool_call.function.arguments,
tool_call_id=tool_call.id,
provider_name=provider_name,
)
)
return events
def _map_assistant_content_to_response_parts(
content: Sequence[PromptMessageContentUnionTypes],
) -> list[ModelResponsePart]:
response_parts: list[ModelResponsePart] = []
for item in content:
if isinstance(item, TextPromptMessageContent):
if item.data:
response_parts.extend(_parse_assistant_text_parts(item.data))
elif isinstance(
item,
ImagePromptMessageContent
| AudioPromptMessageContent
| VideoPromptMessageContent
| DocumentPromptMessageContent,
):
if item.url:
raise UnexpectedModelBehavior(
"URL-based assistant multimodal output is not supported by the daemon adapter"
)
if not item.base64_data:
continue
response_parts.append(
FilePart(
content=BinaryContent(
data=base64.b64decode(item.base64_data),
media_type=item.mime_type,
),
provider_name=None,
)
)
else:
assert_never(item)
return response_parts
def _get_detail(item: ImageUrl | BinaryContent) -> str | None:
metadata = item.vendor_metadata or {}
detail = metadata.get("detail")
return detail if isinstance(detail, str) else None
def _parse_assistant_text_parts(content: str) -> list[ModelResponsePart]:
response_parts: list[ModelResponsePart] = []
cursor = 0
for match in _THINK_TAG_PATTERN.finditer(content):
if match.start() > cursor:
response_parts.append(
TextPart(content=content[cursor : match.start()], provider_name=None)
)
thinking_content = match.group(1).strip("\n")
if thinking_content:
response_parts.append(
ThinkingPart(content=thinking_content, provider_name=None)
)
cursor = match.end()
if cursor < len(content):
response_parts.append(TextPart(content=content[cursor:], provider_name=None))
if response_parts:
return response_parts
return [TextPart(content=content, provider_name=None)]
@dataclass(slots=True)
class _EmbeddedThinkingParser:
_pending: str = ""
_inside_thinking: bool = False
def parse(
self,
parts_manager: ModelResponsePartsManager,
content: str,
provider_name: str,
) -> list[ModelResponseStreamEvent]:
events: list[ModelResponseStreamEvent] = []
buffer = self._pending + content
self._pending = ""
while buffer:
if self._inside_thinking:
end_index = buffer.find(_THINK_CLOSE_TAG)
if end_index >= 0:
if end_index > 0:
events.extend(
parts_manager.handle_thinking_delta(
vendor_part_id=None,
content=buffer[:end_index],
provider_name=provider_name,
)
)
buffer = buffer[end_index + len(_THINK_CLOSE_TAG) :]
self._inside_thinking = False
continue
safe_content, self._pending = _split_incomplete_tag_suffix(
buffer, _THINK_CLOSE_TAG
)
if safe_content:
events.extend(
parts_manager.handle_thinking_delta(
vendor_part_id=None,
content=safe_content,
provider_name=provider_name,
)
)
break
start_index = buffer.find(_THINK_OPEN_TAG)
if start_index >= 0:
if start_index > 0:
events.extend(
parts_manager.handle_text_delta(
vendor_part_id=None,
content=buffer[:start_index],
provider_name=provider_name,
)
)
buffer = buffer[start_index + len(_THINK_OPEN_TAG) :]
self._inside_thinking = True
continue
safe_content, self._pending = _split_incomplete_tag_suffix(
buffer, _THINK_OPEN_TAG
)
if safe_content:
events.extend(
parts_manager.handle_text_delta(
vendor_part_id=None,
content=safe_content,
provider_name=provider_name,
)
)
break
return events
def flush(
self,
parts_manager: ModelResponsePartsManager,
provider_name: str,
) -> list[ModelResponseStreamEvent]:
if not self._pending:
return []
pending = self._pending
self._pending = ""
if self._inside_thinking:
return list(
parts_manager.handle_thinking_delta(
vendor_part_id=None,
content=pending,
provider_name=provider_name,
)
)
return list(
parts_manager.handle_text_delta(
vendor_part_id=None,
content=pending,
provider_name=provider_name,
)
)
def _split_incomplete_tag_suffix(content: str, tag: str) -> tuple[str, str]:
for suffix_length in range(len(tag) - 1, 0, -1):
if content.endswith(tag[:suffix_length]):
return content[:-suffix_length], content[-suffix_length:]
return content, ""

View File

@ -0,0 +1,252 @@
"""Dify plugin-daemon provider for Pydantic AI LLM adapters."""
from __future__ import annotations
import json
from collections.abc import AsyncIterator, Callable, Mapping
from dataclasses import dataclass, field
from typing import NoReturn
import httpx
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from pydantic import BaseModel
from typing_extensions import override
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UnexpectedModelBehavior, UserError
from pydantic_ai.providers import Provider
_DEFAULT_DAEMON_TIMEOUT: float | httpx.Timeout | None = 600.0
class PluginDaemonBasicResponse(BaseModel):
code: int
message: str
data: object | None = None
@dataclass(slots=True)
class DifyPluginDaemonLLMClient:
plugin_daemon_url: str
plugin_daemon_api_key: str
tenant_id: str
plugin_id: str
provider: str
user_id: str | None
http_client: httpx.AsyncClient = field(repr=False)
def __post_init__(self) -> None:
self.plugin_daemon_url = self.plugin_daemon_url.rstrip("/")
async def iter_llm_result_chunks(
self,
*,
model: str,
credentials: dict[str, object],
prompt_messages: list[PromptMessage],
model_parameters: dict[str, object],
tools: list[PromptMessageTool] | None,
stop: list[str] | None,
stream: bool,
) -> AsyncIterator[LLMResultChunk]:
async for item in self._iter_stream_response(
model_name=model,
path=f"plugin/{self.tenant_id}/dispatch/llm/invoke",
request_data={
"provider": self.provider,
"model_type": "llm",
"model": model,
"credentials": credentials,
"prompt_messages": prompt_messages,
"model_parameters": model_parameters,
"tools": tools,
"stop": stop,
"stream": stream,
},
response_model=LLMResultChunk,
):
yield item
async def _iter_stream_response[T: BaseModel](
self,
*,
model_name: str,
path: str,
request_data: Mapping[str, object],
response_model: type[T],
) -> AsyncIterator[T]:
payload: dict[str, object] = {"data": _to_jsonable(request_data)}
if self.user_id is not None:
payload["user_id"] = self.user_id
headers = {
"X-Api-Key": self.plugin_daemon_api_key,
"X-Plugin-ID": self.plugin_id,
"Content-Type": "application/json",
}
url = f"{self.plugin_daemon_url}/{path}"
async with self.http_client.stream("POST", url, headers=headers, json=payload) as response:
if response.is_error:
body = (await response.aread()).decode("utf-8", errors="replace")
error = _decode_plugin_daemon_error_payload(body)
if error is not None:
_raise_plugin_daemon_error(
model_name=model_name,
error_type=error["error_type"],
message=error["message"],
status_code=response.status_code,
body=error,
)
raise ModelHTTPError(response.status_code, model_name, body or None)
async for raw_line in response.aiter_lines():
line = raw_line.strip()
if not line:
continue
if line.startswith("data:"):
line = line[5:].strip()
wrapped = PluginDaemonBasicResponse.model_validate_json(line)
if wrapped.code != 0:
error = _decode_plugin_daemon_error_payload(wrapped.message)
if error is not None:
_raise_plugin_daemon_error(
model_name=model_name,
error_type=error["error_type"],
message=error["message"],
body=error,
)
raise ModelAPIError(
model_name,
f"Plugin daemon returned error code {wrapped.code}: {wrapped.message}",
)
if wrapped.data is None:
raise UnexpectedModelBehavior("Plugin daemon returned an empty stream item")
yield response_model.model_validate(wrapped.data)
@dataclass(slots=True, kw_only=True)
class DifyPluginDaemonProvider(Provider[DifyPluginDaemonLLMClient]):
"""Pydantic AI provider for Dify plugin-daemon dispatch requests."""
tenant_id: str
plugin_id: str
plugin_provider: str
plugin_daemon_url: str
plugin_daemon_api_key: str = field(repr=False)
user_id: str | None = None
timeout: float | httpx.Timeout | None = _DEFAULT_DAEMON_TIMEOUT
_client: DifyPluginDaemonLLMClient = field(init=False, repr=False)
_own_http_client: httpx.AsyncClient | None = field(init=False, default=None, repr=False)
_http_client_factory: Callable[[], httpx.AsyncClient] | None = field(init=False, default=None, repr=False)
def __post_init__(self) -> None:
self.plugin_daemon_url = self.plugin_daemon_url.rstrip("/")
self._http_client_factory = self._make_http_client
http_client = self._make_http_client()
self._own_http_client = http_client
self._client = DifyPluginDaemonLLMClient(
plugin_daemon_url=self.plugin_daemon_url,
plugin_daemon_api_key=self.plugin_daemon_api_key,
tenant_id=self.tenant_id,
plugin_id=self.plugin_id,
provider=self.plugin_provider,
user_id=self.user_id,
http_client=http_client,
)
def _make_http_client(self) -> httpx.AsyncClient:
return httpx.AsyncClient(timeout=self.timeout, trust_env=False)
@override
def _set_http_client(self, http_client: httpx.AsyncClient) -> None:
self._client.http_client = http_client
@property
@override
def name(self) -> str:
return f"DifyPlugin/{self.plugin_provider}"
@property
@override
def base_url(self) -> str:
return self.plugin_daemon_url
@property
@override
def client(self) -> DifyPluginDaemonLLMClient:
return self._client
def _to_jsonable(value: object) -> object:
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
if isinstance(value, dict):
return {key: _to_jsonable(item) for key, item in value.items()}
if isinstance(value, list | tuple):
return [_to_jsonable(item) for item in value]
return value
def _decode_plugin_daemon_error_payload(raw_message: str) -> dict[str, str] | None:
try:
parsed = json.loads(raw_message)
except json.JSONDecodeError:
return None
if not isinstance(parsed, dict):
return None
error_type = parsed.get("error_type")
message = parsed.get("message")
if not isinstance(error_type, str) or not isinstance(message, str):
return None
return {"error_type": error_type, "message": message}
def _raise_plugin_daemon_error(
*,
model_name: str,
error_type: str,
message: str,
status_code: int | None = None,
body: object | None = None,
) -> NoReturn:
http_error_body = body or {"error_type": error_type, "message": message}
match error_type:
case "PluginInvokeError":
nested_error = _decode_plugin_daemon_error_payload(message)
if nested_error is not None:
_raise_plugin_daemon_error(
model_name=model_name,
error_type=nested_error["error_type"],
message=nested_error["message"],
status_code=status_code,
body=nested_error,
)
raise ModelAPIError(model_name, message)
case "PluginDaemonUnauthorizedError" | "InvokeAuthorizationError":
raise ModelHTTPError(status_code or 401, model_name, http_error_body)
case "PluginPermissionDeniedError":
raise ModelHTTPError(status_code or 403, model_name, http_error_body)
case (
"PluginDaemonBadRequestError"
| "InvokeBadRequestError"
| "CredentialsValidateFailedError"
| "PluginUniqueIdentifierError"
):
raise ModelHTTPError(status_code or 400, model_name, http_error_body)
case "EndpointSetupFailedError" | "TriggerProviderCredentialValidationError":
raise UserError(message)
case "PluginDaemonNotFoundError" | "PluginNotFoundError":
raise ModelHTTPError(status_code or 404, model_name, http_error_body)
case "InvokeRateLimitError":
raise ModelHTTPError(status_code or 429, model_name, http_error_body)
case "PluginDaemonInternalServerError" | "PluginDaemonInnerError":
raise ModelHTTPError(status_code or 500, model_name, http_error_body)
case "InvokeConnectionError" | "InvokeServerUnavailableError":
raise ModelHTTPError(status_code or 503, model_name, http_error_body)
case _:
raise ModelAPIError(model_name, f"{error_type}: {message}")

View File

@ -0,0 +1,91 @@
import json
from decimal import Decimal
import httpx
from graphon.model_runtime.entities.llm_entities import (
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
)
from graphon.model_runtime.entities.message_entities import (
AssistantPromptMessage,
)
from pydantic import BaseModel
def make_usage(prompt_tokens: int = 3, completion_tokens: int = 5) -> LLMUsage:
return LLMUsage(
prompt_tokens=prompt_tokens,
prompt_unit_price=Decimal("0"),
prompt_price_unit=Decimal("0"),
prompt_price=Decimal("0"),
completion_tokens=completion_tokens,
completion_unit_price=Decimal("0"),
completion_price_unit=Decimal("0"),
completion_price=Decimal("0"),
total_tokens=prompt_tokens + completion_tokens,
total_price=Decimal("0"),
currency="USD",
latency=0.1,
)
def single_text_chunk(
text: str,
*,
prompt_tokens: int = 3,
completion_tokens: int = 5,
) -> list[LLMResultChunk]:
return [
LLMResultChunk(
model="demo-model",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=text, tool_calls=[]),
usage=make_usage(
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
),
),
)
]
def wrap_plugin_daemon_stream_item(item: object) -> str:
if isinstance(item, BaseModel):
data = item.model_dump(mode="json")
else:
data = item
return f"data: {json.dumps({'code': 0, 'message': '', 'data': data})}\n\n"
def build_stream_response(*items: object, status_code: int = 200) -> httpx.Response:
body = "".join(wrap_plugin_daemon_stream_item(item) for item in items)
return httpx.Response(
status_code=status_code,
headers={"content-type": "text/event-stream"},
content=body.encode("utf-8"),
)
def build_error_response(
error_type: str, message: str, *, status_code: int
) -> httpx.Response:
return httpx.Response(
status_code=status_code,
headers={"content-type": "application/json"},
content=json.dumps({"error_type": error_type, "message": message}).encode(
"utf-8"
),
)
def build_stream_error(
error_type: str, message: str, *, code: int = -500
) -> httpx.Response:
return httpx.Response(
status_code=200,
headers={"content-type": "text/event-stream"},
content=(
f"data: {json.dumps({'code': code, 'message': json.dumps({'error_type': error_type, 'message': message}), 'data': None})}\n\n"
).encode("utf-8"),
)

View File

@ -0,0 +1,374 @@
import json
import unittest
from contextlib import asynccontextmanager
from typing import cast
from unittest.mock import patch
import httpx
from pydantic_ai.exceptions import ModelHTTPError, UserError
from pydantic_ai.messages import (
InstructionPart,
ModelRequest,
ModelResponse,
RetryPromptPart,
SystemPromptPart,
TextPart,
ThinkingPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.models import ModelRequestParameters
from pydantic_ai.tools import ToolDefinition
from dify_agent.adapters.llm import DifyLLMAdapterModel, DifyPluginDaemonProvider
from ._test_support import (
AssistantPromptMessage,
LLMResultChunk,
LLMResultChunkDelta,
build_error_response,
build_stream_error,
build_stream_response,
make_usage,
single_text_chunk,
)
class DifyLLMAdapterModelTests(unittest.IsolatedAsyncioTestCase):
def make_provider(
self,
*,
user_id: str | None = None,
) -> DifyPluginDaemonProvider:
return DifyPluginDaemonProvider(
tenant_id="tenant-1",
plugin_id="langgenius/openai",
plugin_provider="openai",
plugin_daemon_url="http://plugin-daemon",
plugin_daemon_api_key="daemon-secret",
user_id=user_id,
)
@asynccontextmanager
async def mock_daemon_stream(self, handler: httpx.MockTransport):
@asynccontextmanager
async def mock_stream(
client: httpx.AsyncClient,
method: str,
url: str,
**kwargs: object,
):
request = client.build_request(
method,
url,
headers=cast(dict[str, str] | None, kwargs.get("headers")),
json=kwargs.get("json"),
)
yield handler.handle_request(request)
with patch.object(httpx.AsyncClient, "stream", new=mock_stream):
yield
async def test_request_uses_plugin_daemon_dispatch_contract(self) -> None:
messages = [
ModelRequest(
parts=[
SystemPromptPart("request system"),
UserPromptPart("hello"),
ToolReturnPart(
tool_name="lookup",
content={"city": "Paris"},
tool_call_id="tool-1",
),
RetryPromptPart(
content="try again", tool_name="lookup", tool_call_id="tool-1"
),
]
),
ModelResponse(
parts=[
TextPart(content="previous answer"),
ToolCallPart(
tool_name="lookup",
args='{"city":"Paris"}',
tool_call_id="tool-1",
),
]
),
]
request_parameters = ModelRequestParameters(
function_tools=[
ToolDefinition(
name="weather",
description="Look up the weather",
parameters_json_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
},
)
],
instruction_parts=[InstructionPart(content="be concise")],
)
def handler(request: httpx.Request) -> httpx.Response:
self.assertEqual(request.method, "POST")
self.assertEqual(request.url.path, "/plugin/tenant-1/dispatch/llm/invoke")
self.assertEqual(request.headers["X-Api-Key"], "daemon-secret")
self.assertEqual(request.headers["X-Plugin-ID"], "langgenius/openai")
payload = json.loads(request.content.decode("utf-8"))
self.assertEqual(payload["user_id"], "user-123")
data = payload["data"]
self.assertEqual(data["provider"], "openai")
self.assertEqual(data["model_type"], "llm")
self.assertEqual(data["model"], "demo-model")
self.assertEqual(data["credentials"], {"api_key": "secret"})
self.assertEqual(
data["model_parameters"],
{"temperature": 0.2, "max_tokens": 128, "logit_bias": {"1": 2}},
)
self.assertEqual(data["stop"], ["END"])
self.assertFalse(data["stream"])
self.assertEqual(data["tools"][0]["name"], "weather")
self.assertEqual(data["prompt_messages"][0]["role"], "system")
self.assertEqual(data["prompt_messages"][0]["content"], "request system")
self.assertEqual(data["prompt_messages"][1]["content"], "be concise")
self.assertEqual(data["prompt_messages"][2]["content"], "hello")
self.assertEqual(data["prompt_messages"][3]["role"], "tool")
self.assertEqual(data["prompt_messages"][4]["role"], "tool")
self.assertEqual(data["prompt_messages"][5]["role"], "assistant")
return build_stream_response(
LLMResultChunk(
model="demo-model",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content="adapter response", tool_calls=[]
),
usage=make_usage(prompt_tokens=11, completion_tokens=7),
),
)
)
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
adapter = DifyLLMAdapterModel(
"demo-model",
self.make_provider(user_id="user-123"),
credentials={"api_key": "secret"},
model_settings={"temperature": 0.2, "stop_sequences": ["DEFAULT_STOP"]},
)
response = await adapter.request(
messages,
model_settings={"max_tokens": 128, "logit_bias": {"1": 2}, "stop_sequences": ["END"]},
model_request_parameters=request_parameters,
)
self.assertEqual(response.model_name, "demo-model")
self.assertEqual(response.provider_name, "DifyPlugin/openai")
self.assertEqual(response.usage.input_tokens, 11)
self.assertEqual(response.usage.output_tokens, 7)
self.assertEqual(response.parts[0].part_kind, "text")
self.assertEqual(cast(TextPart, response.parts[0]).content, "adapter response")
async def test_request_returns_a_response(self) -> None:
def handler(_request: httpx.Request) -> httpx.Response:
return build_stream_response(
*single_text_chunk(
"adapter response", prompt_tokens=11, completion_tokens=7
)
)
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
adapter = DifyLLMAdapterModel(
"demo-model",
self.make_provider(),
credentials={"api_key": "secret"},
)
response = await adapter.request(
[ModelRequest(parts=[UserPromptPart("hello")])],
model_settings=None,
model_request_parameters=ModelRequestParameters(),
)
self.assertEqual(response.model_name, "demo-model")
self.assertEqual(response.parts[0].part_kind, "text")
self.assertEqual(cast(TextPart, response.parts[0]).content, "adapter response")
self.assertEqual(response.usage.input_tokens, 11)
self.assertEqual(response.usage.output_tokens, 7)
async def test_request_stream_yields_response_parts_and_usage(self) -> None:
def handler(_request: httpx.Request) -> httpx.Response:
return build_stream_response(
LLMResultChunk(
model="demo-model",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content="hello ", tool_calls=[]),
),
),
LLMResultChunk(
model="demo-model",
delta=LLMResultChunkDelta(
index=1,
message=AssistantPromptMessage(
content="",
tool_calls=[
AssistantPromptMessage.ToolCall(
id="call-1",
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name="weather",
arguments='{"city":"Paris"}',
),
)
],
),
),
),
LLMResultChunk(
model="demo-model",
delta=LLMResultChunkDelta(
index=2,
message=AssistantPromptMessage(content="world", tool_calls=[]),
usage=make_usage(prompt_tokens=6, completion_tokens=4),
finish_reason="tool_calls",
),
),
)
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
adapter = DifyLLMAdapterModel(
"demo-model",
self.make_provider(),
credentials={"api_key": "secret"},
)
async with adapter.request_stream(
[ModelRequest(parts=[UserPromptPart("hello")])],
model_settings=None,
model_request_parameters=ModelRequestParameters(),
) as stream:
events = [event async for event in stream]
response = stream.get()
self.assertTrue(events)
self.assertEqual(response.usage.input_tokens, 6)
self.assertEqual(response.usage.output_tokens, 4)
self.assertEqual(response.finish_reason, "tool_call")
self.assertEqual(response.parts[0].part_kind, "text")
self.assertEqual(cast(TextPart, response.parts[0]).content, "hello ")
self.assertEqual(response.parts[1].part_kind, "tool-call")
self.assertEqual(cast(ToolCallPart, response.parts[1]).tool_name, "weather")
self.assertEqual(response.parts[2].part_kind, "text")
self.assertEqual(cast(TextPart, response.parts[2]).content, "world")
async def test_request_splits_embedded_thinking_tags_into_parts(self) -> None:
def handler(_request: httpx.Request) -> httpx.Response:
return build_stream_response(
*single_text_chunk("before<think>reasoning</think>after")
)
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
adapter = DifyLLMAdapterModel(
"demo-model",
self.make_provider(),
credentials={"api_key": "secret"},
)
response = await adapter.request(
[ModelRequest(parts=[UserPromptPart("hello")])],
model_settings=None,
model_request_parameters=ModelRequestParameters(),
)
self.assertEqual(
[part.part_kind for part in response.parts], ["text", "thinking", "text"]
)
self.assertEqual(cast(TextPart, response.parts[0]).content, "before")
self.assertEqual(cast(ThinkingPart, response.parts[1]).content, "reasoning")
self.assertEqual(cast(TextPart, response.parts[2]).content, "after")
async def test_request_maps_stream_envelope_rate_limit_error_to_http_error(
self,
) -> None:
def handler(_request: httpx.Request) -> httpx.Response:
return build_stream_error(
"PluginInvokeError",
json.dumps(
{"error_type": "InvokeRateLimitError", "message": "too many"}
),
)
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
adapter = DifyLLMAdapterModel(
"demo-model",
self.make_provider(),
credentials={"api_key": "secret"},
)
with self.assertRaises(ModelHTTPError) as context:
await adapter.request(
[ModelRequest(parts=[UserPromptPart("hello")])],
model_settings=None,
model_request_parameters=ModelRequestParameters(),
)
self.assertEqual(context.exception.status_code, 429)
self.assertEqual(
context.exception.body,
{"error_type": "InvokeRateLimitError", "message": "too many"},
)
async def test_request_maps_http_error_payload_to_http_error(self) -> None:
def handler(_request: httpx.Request) -> httpx.Response:
return build_error_response(
"PluginDaemonUnauthorizedError", "invalid api key", status_code=401
)
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
adapter = DifyLLMAdapterModel(
"demo-model",
self.make_provider(),
credentials={"api_key": "secret"},
)
with self.assertRaises(ModelHTTPError) as context:
await adapter.request(
[ModelRequest(parts=[UserPromptPart("hello")])],
model_settings=None,
model_request_parameters=ModelRequestParameters(),
)
self.assertEqual(context.exception.status_code, 401)
self.assertEqual(
context.exception.body,
{
"error_type": "PluginDaemonUnauthorizedError",
"message": "invalid api key",
},
)
async def test_request_maps_endpoint_setup_error_to_user_error(self) -> None:
def handler(_request: httpx.Request) -> httpx.Response:
return build_stream_error(
"EndpointSetupFailedError", "missing endpoint config"
)
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
adapter = DifyLLMAdapterModel(
"demo-model",
self.make_provider(),
credentials={"api_key": "secret"},
)
with self.assertRaises(UserError) as context:
await adapter.request(
[ModelRequest(parts=[UserPromptPart("hello")])],
model_settings=None,
model_request_parameters=ModelRequestParameters(),
)
self.assertEqual(str(context.exception), "missing endpoint config")