mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 04:36:31 +08:00
feat(agent): add dify llm adapter
This commit is contained in:
parent
70bd5439d0
commit
6ca07c4a4e
78
dify-agent/examples/run_pydantic_ai_agent.py
Normal file
78
dify-agent/examples/run_pydantic_ai_agent.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""Run a Pydantic AI agent through the Dify plugin-daemon adapter.
|
||||
|
||||
Prerequisites:
|
||||
- Start the plugin daemon from `dify-aio/dify/docker/docker-compose.middleware.yaml`.
|
||||
- Run the Dify API with `dify-aio/dify/api/.env` so the daemon can resolve tenants/plugins.
|
||||
- Fill `dify-agent/.env` with a real tenant, plugin, provider, model, and provider credentials.
|
||||
|
||||
Example:
|
||||
uv run --project dify-agent python examples/run_pydantic_ai_agent.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic_ai import Agent
|
||||
|
||||
from dify_agent import DifyLLMAdapterModel, DifyPluginDaemonProvider
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def load_env_file(path: Path) -> None:
|
||||
"""Load simple KEY=VALUE lines without adding a dotenv dependency."""
|
||||
if not path.exists():
|
||||
return
|
||||
|
||||
for raw_line in path.read_text().splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'"))
|
||||
|
||||
|
||||
def required_env(name: str) -> str:
|
||||
value = os.environ.get(name)
|
||||
if value:
|
||||
return value
|
||||
raise RuntimeError(f"Missing required environment variable: {name}")
|
||||
|
||||
|
||||
def load_credentials() -> dict[str, Any]:
|
||||
raw_credentials = required_env("DIFY_AGENT_MODEL_CREDENTIALS_JSON")
|
||||
credentials = json.loads(raw_credentials)
|
||||
if not isinstance(credentials, dict):
|
||||
raise RuntimeError("DIFY_AGENT_MODEL_CREDENTIALS_JSON must be a JSON object")
|
||||
return credentials
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
load_env_file(PROJECT_ROOT / ".env")
|
||||
|
||||
model = DifyLLMAdapterModel(
|
||||
required_env("DIFY_AGENT_MODEL_NAME"),
|
||||
DifyPluginDaemonProvider(
|
||||
tenant_id=required_env("DIFY_AGENT_TENANT_ID"),
|
||||
plugin_id=required_env("DIFY_AGENT_PLUGIN_ID"),
|
||||
plugin_provider=required_env("DIFY_AGENT_PROVIDER"),
|
||||
plugin_daemon_url=required_env("PLUGIN_DAEMON_URL"),
|
||||
plugin_daemon_api_key=required_env("PLUGIN_DAEMON_KEY"),
|
||||
),
|
||||
credentials=load_credentials(),
|
||||
)
|
||||
agent = Agent(model=model)
|
||||
async with agent.run_stream("Explain the theory of relativity") as run:
|
||||
async for piece in run.stream_output():
|
||||
print(piece, end="", flush=True)
|
||||
print(run.usage())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
5
dify-agent/src/dify_agent/__init__.py
Normal file
5
dify-agent/src/dify_agent/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""Adapters for using Dify components inside the local agent package."""
|
||||
|
||||
from .adapters.llm import DifyLLMAdapterModel, DifyPluginDaemonProvider
|
||||
|
||||
__all__ = ["DifyLLMAdapterModel", "DifyPluginDaemonProvider"]
|
||||
1
dify-agent/src/dify_agent/adapters/__init__.py
Normal file
1
dify-agent/src/dify_agent/adapters/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Adapter integrations for Dify agent components."""
|
||||
6
dify-agent/src/dify_agent/adapters/llm/__init__.py
Normal file
6
dify-agent/src/dify_agent/adapters/llm/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""LLM adapters for Dify plugin-daemon integrations."""
|
||||
|
||||
from .model import DifyLLMAdapterModel
|
||||
from .provider import DifyPluginDaemonProvider
|
||||
|
||||
__all__ = ["DifyLLMAdapterModel", "DifyPluginDaemonProvider"]
|
||||
798
dify-agent/src/dify_agent/adapters/llm/model.py
Normal file
798
dify-agent/src/dify_agent/adapters/llm/model.py
Normal file
@ -0,0 +1,798 @@
|
||||
"""Bridge Dify plugin-daemon LLM invocations into Pydantic AI's model interface.
|
||||
|
||||
The API and agent layers are clients of the plugin daemon, not direct hosts of provider SDK
|
||||
implementations. This adapter therefore targets the plugin-daemon dispatch protocol and maps
|
||||
Pydantic AI messages into the daemon's Graphon-compatible request and stream response schema.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import re
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Mapping, Sequence
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import KW_ONLY, InitVar, dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import cast
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from typing_extensions import assert_never, override
|
||||
|
||||
from pydantic_ai._parts_manager import ModelResponsePartsManager
|
||||
from pydantic_ai.exceptions import UnexpectedModelBehavior
|
||||
from pydantic_ai.messages import (
|
||||
AudioUrl,
|
||||
BinaryContent,
|
||||
BuiltinToolCallPart,
|
||||
BuiltinToolReturnPart,
|
||||
CachePoint,
|
||||
CompactionPart,
|
||||
DocumentUrl,
|
||||
FilePart,
|
||||
FinishReason,
|
||||
ImageUrl,
|
||||
ModelMessage,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ModelResponsePart,
|
||||
ModelResponseStreamEvent,
|
||||
MultiModalContent,
|
||||
RetryPromptPart,
|
||||
SystemPromptPart,
|
||||
TextContent,
|
||||
TextPart,
|
||||
ThinkingPart,
|
||||
ToolCallPart,
|
||||
ToolReturnPart,
|
||||
UploadedFile,
|
||||
UserContent,
|
||||
UserPromptPart,
|
||||
VideoUrl,
|
||||
)
|
||||
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
|
||||
from pydantic_ai.profiles import ModelProfileSpec
|
||||
from pydantic_ai.settings import ModelSettings
|
||||
from pydantic_ai.usage import RequestUsage
|
||||
|
||||
from .provider import DifyPluginDaemonLLMClient, DifyPluginDaemonProvider
|
||||
|
||||
_THINK_START = "<think>\n"
|
||||
_THINK_END = "\n</think>"
|
||||
_THINK_OPEN_TAG = "<think>"
|
||||
_THINK_CLOSE_TAG = "</think>"
|
||||
_THINK_TAG_PATTERN = re.compile(r"<think>(.*?)</think>", re.DOTALL)
|
||||
_DETAIL_HIGH = "high"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _DifyRequestInput:
|
||||
credentials: dict[str, object]
|
||||
prompt_messages: list[PromptMessage]
|
||||
model_parameters: dict[str, object]
|
||||
tools: list[PromptMessageTool] | None
|
||||
stop_sequences: list[str] | None
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DifyLLMAdapterModel(Model[DifyPluginDaemonLLMClient]):
|
||||
"""Use a Dify plugin-daemon LLM provider as a Pydantic AI model."""
|
||||
|
||||
model: str
|
||||
daemon_provider: DifyPluginDaemonProvider
|
||||
_: KW_ONLY
|
||||
credentials: dict[str, object] = field(default_factory=dict, repr=False)
|
||||
model_profile: InitVar[ModelProfileSpec | None] = None
|
||||
model_settings: InitVar[ModelSettings | None] = None
|
||||
|
||||
def __post_init__(
|
||||
self,
|
||||
model_profile: ModelProfileSpec | None,
|
||||
model_settings: ModelSettings | None,
|
||||
) -> None:
|
||||
Model.__init__(
|
||||
self,
|
||||
settings=model_settings,
|
||||
profile=model_profile or self.daemon_provider.model_profile(self.model),
|
||||
)
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider(self) -> DifyPluginDaemonProvider:
|
||||
return self.daemon_provider
|
||||
|
||||
@property
|
||||
@override
|
||||
def model_name(self) -> str:
|
||||
return self.model
|
||||
|
||||
@property
|
||||
@override
|
||||
def system(self) -> str:
|
||||
return self.daemon_provider.name
|
||||
|
||||
@override
|
||||
async def request(
|
||||
self,
|
||||
messages: list[ModelMessage],
|
||||
model_settings: ModelSettings | None,
|
||||
model_request_parameters: ModelRequestParameters,
|
||||
) -> ModelResponse:
|
||||
prepared_settings, prepared_params = self.prepare_request(
|
||||
model_settings, model_request_parameters
|
||||
)
|
||||
request_input = self._build_request_input(
|
||||
messages, prepared_settings, prepared_params
|
||||
)
|
||||
|
||||
response = DifyStreamedResponse(
|
||||
model_request_parameters=prepared_params,
|
||||
chunks=self.daemon_provider.client.iter_llm_result_chunks(
|
||||
model=self.model_name,
|
||||
credentials=request_input.credentials,
|
||||
prompt_messages=request_input.prompt_messages,
|
||||
model_parameters=request_input.model_parameters,
|
||||
tools=request_input.tools,
|
||||
stop=request_input.stop_sequences,
|
||||
stream=False,
|
||||
),
|
||||
response_model_name=self.model_name,
|
||||
provider_name_value=self.system,
|
||||
)
|
||||
async for _event in response:
|
||||
pass
|
||||
return response.get()
|
||||
|
||||
@asynccontextmanager
|
||||
@override
|
||||
async def request_stream(
|
||||
self,
|
||||
messages: list[ModelMessage],
|
||||
model_settings: ModelSettings | None,
|
||||
model_request_parameters: ModelRequestParameters,
|
||||
run_context: object | None = None,
|
||||
) -> AsyncGenerator[StreamedResponse, None]:
|
||||
del run_context
|
||||
prepared_settings, prepared_params = self.prepare_request(
|
||||
model_settings, model_request_parameters
|
||||
)
|
||||
request_input = self._build_request_input(
|
||||
messages, prepared_settings, prepared_params
|
||||
)
|
||||
|
||||
yield DifyStreamedResponse(
|
||||
model_request_parameters=prepared_params,
|
||||
chunks=self.daemon_provider.client.iter_llm_result_chunks(
|
||||
model=self.model_name,
|
||||
credentials=request_input.credentials,
|
||||
prompt_messages=request_input.prompt_messages,
|
||||
model_parameters=request_input.model_parameters,
|
||||
tools=request_input.tools,
|
||||
stop=request_input.stop_sequences,
|
||||
stream=True,
|
||||
),
|
||||
response_model_name=self.model_name,
|
||||
provider_name_value=self.system,
|
||||
)
|
||||
|
||||
def _build_request_input(
|
||||
self,
|
||||
messages: Sequence[ModelMessage],
|
||||
model_settings: ModelSettings | None,
|
||||
model_request_parameters: ModelRequestParameters,
|
||||
) -> _DifyRequestInput:
|
||||
return _DifyRequestInput(
|
||||
credentials=dict(self.credentials),
|
||||
prompt_messages=_map_messages_to_prompt_messages(
|
||||
messages, model_request_parameters
|
||||
),
|
||||
model_parameters=_map_model_settings_to_parameters(model_settings),
|
||||
tools=_map_tool_definitions_to_prompt_tools(model_request_parameters),
|
||||
stop_sequences=_get_stop_sequences(model_settings),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DifyStreamedResponse(StreamedResponse):
|
||||
chunks: AsyncIterator[LLMResultChunk]
|
||||
response_model_name: str
|
||||
provider_name_value: str
|
||||
_timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
_embedded_thinking_parser: "_EmbeddedThinkingParser" = field(
|
||||
default_factory=lambda: _EmbeddedThinkingParser()
|
||||
)
|
||||
|
||||
@override
|
||||
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
||||
async for chunk in self.chunks:
|
||||
if chunk.delta.usage is not None:
|
||||
self._usage: RequestUsage = _map_usage(chunk.delta.usage)
|
||||
if chunk.delta.finish_reason is not None:
|
||||
self.finish_reason: FinishReason | None = _normalize_finish_reason(
|
||||
chunk.delta.finish_reason
|
||||
)
|
||||
|
||||
for event in _chunk_to_stream_events(
|
||||
self._parts_manager,
|
||||
chunk,
|
||||
self.provider_name_value,
|
||||
self._embedded_thinking_parser,
|
||||
):
|
||||
yield event
|
||||
|
||||
for event in self._embedded_thinking_parser.flush(
|
||||
self._parts_manager, self.provider_name_value
|
||||
):
|
||||
yield event
|
||||
|
||||
@property
|
||||
@override
|
||||
def model_name(self) -> str:
|
||||
return self.response_model_name
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_name(self) -> str:
|
||||
return self.provider_name_value
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_url(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
@override
|
||||
def timestamp(self) -> datetime:
|
||||
return self._timestamp
|
||||
|
||||
|
||||
def _map_messages_to_prompt_messages(
|
||||
messages: Sequence[ModelMessage],
|
||||
model_request_parameters: ModelRequestParameters,
|
||||
) -> list[PromptMessage]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, ModelRequest):
|
||||
prompt_messages.extend(_map_model_request_to_prompt_messages(message))
|
||||
elif isinstance(message, ModelResponse):
|
||||
assistant_message = _map_model_response_to_prompt_message(message)
|
||||
if assistant_message is not None:
|
||||
prompt_messages.append(assistant_message)
|
||||
else:
|
||||
assert_never(message)
|
||||
|
||||
instruction_messages = [
|
||||
SystemPromptMessage(content=part.content)
|
||||
for part in (
|
||||
Model._get_instruction_parts(messages, model_request_parameters) or []
|
||||
)
|
||||
if part.content.strip()
|
||||
]
|
||||
if instruction_messages:
|
||||
insert_at = next(
|
||||
(
|
||||
index
|
||||
for index, message in enumerate(prompt_messages)
|
||||
if not isinstance(message, SystemPromptMessage)
|
||||
),
|
||||
len(prompt_messages),
|
||||
)
|
||||
prompt_messages[insert_at:insert_at] = instruction_messages
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _map_model_request_to_prompt_messages(message: ModelRequest) -> list[PromptMessage]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
for part in message.parts:
|
||||
if isinstance(part, SystemPromptPart):
|
||||
prompt_messages.append(SystemPromptMessage(content=part.content))
|
||||
elif isinstance(part, UserPromptPart):
|
||||
prompt_messages.append(
|
||||
UserPromptMessage(content=_map_user_prompt_content(part.content))
|
||||
)
|
||||
elif isinstance(part, ToolReturnPart):
|
||||
prompt_messages.append(_map_tool_return_part_to_prompt_message(part))
|
||||
elif isinstance(part, RetryPromptPart):
|
||||
if part.tool_name is None:
|
||||
prompt_messages.append(UserPromptMessage(content=part.model_response()))
|
||||
else:
|
||||
prompt_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=part.model_response(),
|
||||
tool_call_id=part.tool_call_id,
|
||||
name=part.tool_name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert_never(part)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _map_tool_return_part_to_prompt_message(part: ToolReturnPart) -> ToolPromptMessage:
|
||||
items = part.content_items(mode="str")
|
||||
if len(items) == 1 and isinstance(items[0], str):
|
||||
content: str | list[PromptMessageContentUnionTypes] | None = items[0]
|
||||
else:
|
||||
content_items: list[PromptMessageContentUnionTypes] = []
|
||||
for item in items:
|
||||
if isinstance(item, str):
|
||||
content_items.append(TextPromptMessageContent(data=item))
|
||||
elif isinstance(item, CachePoint):
|
||||
continue
|
||||
elif _is_multi_modal_content(item):
|
||||
content_items.append(_map_multi_modal_user_content(item))
|
||||
else:
|
||||
raise UnexpectedModelBehavior(
|
||||
f"Unsupported daemon tool message content: {type(item).__name__}"
|
||||
)
|
||||
content = content_items or None
|
||||
|
||||
return ToolPromptMessage(
|
||||
content=content, tool_call_id=part.tool_call_id, name=part.tool_name
|
||||
)
|
||||
|
||||
|
||||
def _map_model_response_to_prompt_message(
|
||||
message: ModelResponse,
|
||||
) -> AssistantPromptMessage | None:
|
||||
content_parts: list[PromptMessageContentUnionTypes] = []
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
for part in message.parts:
|
||||
if isinstance(part, TextPart):
|
||||
if part.content:
|
||||
content_parts.append(TextPromptMessageContent(data=part.content))
|
||||
elif isinstance(part, ThinkingPart):
|
||||
if part.content:
|
||||
content_parts.append(
|
||||
TextPromptMessageContent(
|
||||
data=f"{_THINK_START}{part.content}{_THINK_END}"
|
||||
)
|
||||
)
|
||||
elif isinstance(part, FilePart):
|
||||
content_parts.append(_map_binary_content_to_prompt_content(part.content))
|
||||
elif isinstance(part, ToolCallPart):
|
||||
tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=part.tool_call_id or f"tool-call-{part.tool_name}",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=part.tool_name,
|
||||
arguments=part.args_as_json_str(),
|
||||
),
|
||||
)
|
||||
)
|
||||
elif isinstance(
|
||||
part, BuiltinToolCallPart | BuiltinToolReturnPart | CompactionPart
|
||||
):
|
||||
raise UnexpectedModelBehavior(
|
||||
f"Unsupported response part for daemon adapter: {type(part).__name__}"
|
||||
)
|
||||
else:
|
||||
assert_never(part)
|
||||
|
||||
content = _normalize_prompt_content(content_parts)
|
||||
if content is None and not tool_calls:
|
||||
return None
|
||||
|
||||
return AssistantPromptMessage(content=content, tool_calls=tool_calls)
|
||||
|
||||
|
||||
def _map_user_prompt_content(
|
||||
content: str | Sequence[UserContent],
|
||||
) -> str | list[PromptMessageContentUnionTypes] | None:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
prompt_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, CachePoint):
|
||||
continue
|
||||
if isinstance(item, str):
|
||||
prompt_content.append(TextPromptMessageContent(data=item))
|
||||
elif isinstance(item, TextContent):
|
||||
prompt_content.append(TextPromptMessageContent(data=item.content))
|
||||
elif _is_multi_modal_content(item):
|
||||
prompt_content.append(_map_multi_modal_user_content(item))
|
||||
else:
|
||||
raise UnexpectedModelBehavior(f"Unsupported user prompt content: {type(item).__name__}")
|
||||
return _normalize_prompt_content(prompt_content)
|
||||
|
||||
|
||||
def _is_multi_modal_content(item: object) -> bool:
|
||||
return isinstance(
|
||||
item,
|
||||
ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent | UploadedFile,
|
||||
)
|
||||
|
||||
|
||||
def _map_multi_modal_user_content(
|
||||
item: MultiModalContent,
|
||||
) -> PromptMessageContentUnionTypes:
|
||||
if isinstance(item, ImageUrl):
|
||||
detail = (
|
||||
ImagePromptMessageContent.DETAIL.HIGH
|
||||
if _get_detail(item) == _DETAIL_HIGH
|
||||
else ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
return ImagePromptMessageContent(
|
||||
url=item.url,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=f"{item.identifier}.{item.format}",
|
||||
detail=detail,
|
||||
)
|
||||
if isinstance(item, AudioUrl):
|
||||
return AudioPromptMessageContent(
|
||||
url=item.url,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=f"{item.identifier}.{item.format}",
|
||||
)
|
||||
if isinstance(item, VideoUrl):
|
||||
return VideoPromptMessageContent(
|
||||
url=item.url,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=f"{item.identifier}.{item.format}",
|
||||
)
|
||||
if isinstance(item, DocumentUrl):
|
||||
return DocumentPromptMessageContent(
|
||||
url=item.url,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=f"{item.identifier}.{item.format}",
|
||||
)
|
||||
if isinstance(item, BinaryContent):
|
||||
return _map_binary_content_to_prompt_content(item)
|
||||
if isinstance(item, UploadedFile):
|
||||
raise UnexpectedModelBehavior(
|
||||
"UploadedFile content is not supported by the daemon adapter"
|
||||
)
|
||||
assert_never(item)
|
||||
|
||||
|
||||
def _map_binary_content_to_prompt_content(
|
||||
item: BinaryContent,
|
||||
) -> PromptMessageContentUnionTypes:
|
||||
filename = f"{item.identifier}.{item.format}"
|
||||
if item.is_image:
|
||||
detail = (
|
||||
ImagePromptMessageContent.DETAIL.HIGH
|
||||
if _get_detail(item) == _DETAIL_HIGH
|
||||
else ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
return ImagePromptMessageContent(
|
||||
base64_data=item.base64,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=filename,
|
||||
detail=detail,
|
||||
)
|
||||
if item.is_audio:
|
||||
return AudioPromptMessageContent(
|
||||
base64_data=item.base64,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=filename,
|
||||
)
|
||||
if item.is_video:
|
||||
return VideoPromptMessageContent(
|
||||
base64_data=item.base64,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=filename,
|
||||
)
|
||||
if item.is_document:
|
||||
return DocumentPromptMessageContent(
|
||||
base64_data=item.base64,
|
||||
mime_type=item.media_type,
|
||||
format=item.format,
|
||||
filename=filename,
|
||||
)
|
||||
raise UnexpectedModelBehavior(
|
||||
f"Unsupported binary media type for daemon adapter: {item.media_type}"
|
||||
)
|
||||
|
||||
|
||||
def _normalize_prompt_content(
|
||||
content: list[PromptMessageContentUnionTypes],
|
||||
) -> str | list[PromptMessageContentUnionTypes] | None:
|
||||
if not content:
|
||||
return None
|
||||
if len(content) == 1 and isinstance(content[0], TextPromptMessageContent):
|
||||
return content[0].data
|
||||
return content
|
||||
|
||||
|
||||
def _map_tool_definitions_to_prompt_tools(
|
||||
model_request_parameters: ModelRequestParameters,
|
||||
) -> list[PromptMessageTool] | None:
|
||||
tool_definitions = [
|
||||
*model_request_parameters.function_tools,
|
||||
*model_request_parameters.output_tools,
|
||||
]
|
||||
if not tool_definitions:
|
||||
return None
|
||||
|
||||
return [
|
||||
PromptMessageTool(
|
||||
name=tool_definition.name,
|
||||
description=tool_definition.description or "",
|
||||
parameters=cast(dict[str, object], tool_definition.parameters_json_schema),
|
||||
)
|
||||
for tool_definition in tool_definitions
|
||||
]
|
||||
|
||||
|
||||
def _map_model_settings_to_parameters(model_settings: ModelSettings | None) -> dict[str, object]:
|
||||
if not model_settings:
|
||||
return {}
|
||||
|
||||
parameters: dict[str, object] = {
|
||||
key: value
|
||||
for key, value in model_settings.items()
|
||||
if value is not None and key not in {"extra_body", "stop_sequences"}
|
||||
}
|
||||
|
||||
extra_body = model_settings.get("extra_body")
|
||||
if isinstance(extra_body, Mapping):
|
||||
parameters.update(cast(Mapping[str, object], extra_body))
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def _get_stop_sequences(model_settings: ModelSettings | None) -> list[str] | None:
|
||||
if not model_settings:
|
||||
return None
|
||||
return list(model_settings.get("stop_sequences") or []) or None
|
||||
|
||||
|
||||
def _map_usage(usage: LLMUsage) -> RequestUsage:
|
||||
return RequestUsage(
|
||||
input_tokens=usage.prompt_tokens, output_tokens=usage.completion_tokens
|
||||
)
|
||||
|
||||
|
||||
def _normalize_finish_reason(finish_reason: str) -> FinishReason:
|
||||
lowered = finish_reason.lower()
|
||||
if lowered in {"stop", "length", "content_filter", "error", "tool_call"}:
|
||||
return cast(FinishReason, lowered)
|
||||
if lowered in {"tool_calls", "function_call", "function_calls"}:
|
||||
return "tool_call"
|
||||
return "error"
|
||||
|
||||
|
||||
def _chunk_to_stream_events(
|
||||
parts_manager: ModelResponsePartsManager,
|
||||
chunk: LLMResultChunk,
|
||||
provider_name: str,
|
||||
embedded_thinking_parser: "_EmbeddedThinkingParser",
|
||||
) -> list[ModelResponseStreamEvent]:
|
||||
events: list[ModelResponseStreamEvent] = []
|
||||
message = chunk.delta.message
|
||||
|
||||
if isinstance(message.content, str):
|
||||
if message.content:
|
||||
events.extend(
|
||||
embedded_thinking_parser.parse(
|
||||
parts_manager, message.content, provider_name
|
||||
)
|
||||
)
|
||||
elif isinstance(message.content, list):
|
||||
for part in _map_assistant_content_to_response_parts(message.content):
|
||||
if isinstance(part, TextPart):
|
||||
events.extend(
|
||||
parts_manager.handle_text_delta(
|
||||
vendor_part_id=None,
|
||||
content=part.content,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
events.append(parts_manager.handle_part(vendor_part_id=None, part=part))
|
||||
|
||||
for index, tool_call in enumerate(message.tool_calls):
|
||||
vendor_id = tool_call.id or f"chunk-{chunk.delta.index}-tool-{index}"
|
||||
events.append(
|
||||
parts_manager.handle_tool_call_part(
|
||||
vendor_part_id=vendor_id,
|
||||
tool_name=tool_call.function.name,
|
||||
args=tool_call.function.arguments,
|
||||
tool_call_id=tool_call.id,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def _map_assistant_content_to_response_parts(
|
||||
content: Sequence[PromptMessageContentUnionTypes],
|
||||
) -> list[ModelResponsePart]:
|
||||
response_parts: list[ModelResponsePart] = []
|
||||
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
if item.data:
|
||||
response_parts.extend(_parse_assistant_text_parts(item.data))
|
||||
elif isinstance(
|
||||
item,
|
||||
ImagePromptMessageContent
|
||||
| AudioPromptMessageContent
|
||||
| VideoPromptMessageContent
|
||||
| DocumentPromptMessageContent,
|
||||
):
|
||||
if item.url:
|
||||
raise UnexpectedModelBehavior(
|
||||
"URL-based assistant multimodal output is not supported by the daemon adapter"
|
||||
)
|
||||
if not item.base64_data:
|
||||
continue
|
||||
response_parts.append(
|
||||
FilePart(
|
||||
content=BinaryContent(
|
||||
data=base64.b64decode(item.base64_data),
|
||||
media_type=item.mime_type,
|
||||
),
|
||||
provider_name=None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert_never(item)
|
||||
|
||||
return response_parts
|
||||
|
||||
|
||||
def _get_detail(item: ImageUrl | BinaryContent) -> str | None:
|
||||
metadata = item.vendor_metadata or {}
|
||||
detail = metadata.get("detail")
|
||||
return detail if isinstance(detail, str) else None
|
||||
|
||||
|
||||
def _parse_assistant_text_parts(content: str) -> list[ModelResponsePart]:
|
||||
response_parts: list[ModelResponsePart] = []
|
||||
cursor = 0
|
||||
|
||||
for match in _THINK_TAG_PATTERN.finditer(content):
|
||||
if match.start() > cursor:
|
||||
response_parts.append(
|
||||
TextPart(content=content[cursor : match.start()], provider_name=None)
|
||||
)
|
||||
|
||||
thinking_content = match.group(1).strip("\n")
|
||||
if thinking_content:
|
||||
response_parts.append(
|
||||
ThinkingPart(content=thinking_content, provider_name=None)
|
||||
)
|
||||
cursor = match.end()
|
||||
|
||||
if cursor < len(content):
|
||||
response_parts.append(TextPart(content=content[cursor:], provider_name=None))
|
||||
|
||||
if response_parts:
|
||||
return response_parts
|
||||
return [TextPart(content=content, provider_name=None)]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _EmbeddedThinkingParser:
|
||||
_pending: str = ""
|
||||
_inside_thinking: bool = False
|
||||
|
||||
def parse(
|
||||
self,
|
||||
parts_manager: ModelResponsePartsManager,
|
||||
content: str,
|
||||
provider_name: str,
|
||||
) -> list[ModelResponseStreamEvent]:
|
||||
events: list[ModelResponseStreamEvent] = []
|
||||
buffer = self._pending + content
|
||||
self._pending = ""
|
||||
|
||||
while buffer:
|
||||
if self._inside_thinking:
|
||||
end_index = buffer.find(_THINK_CLOSE_TAG)
|
||||
if end_index >= 0:
|
||||
if end_index > 0:
|
||||
events.extend(
|
||||
parts_manager.handle_thinking_delta(
|
||||
vendor_part_id=None,
|
||||
content=buffer[:end_index],
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
buffer = buffer[end_index + len(_THINK_CLOSE_TAG) :]
|
||||
self._inside_thinking = False
|
||||
continue
|
||||
|
||||
safe_content, self._pending = _split_incomplete_tag_suffix(
|
||||
buffer, _THINK_CLOSE_TAG
|
||||
)
|
||||
if safe_content:
|
||||
events.extend(
|
||||
parts_manager.handle_thinking_delta(
|
||||
vendor_part_id=None,
|
||||
content=safe_content,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
start_index = buffer.find(_THINK_OPEN_TAG)
|
||||
if start_index >= 0:
|
||||
if start_index > 0:
|
||||
events.extend(
|
||||
parts_manager.handle_text_delta(
|
||||
vendor_part_id=None,
|
||||
content=buffer[:start_index],
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
buffer = buffer[start_index + len(_THINK_OPEN_TAG) :]
|
||||
self._inside_thinking = True
|
||||
continue
|
||||
|
||||
safe_content, self._pending = _split_incomplete_tag_suffix(
|
||||
buffer, _THINK_OPEN_TAG
|
||||
)
|
||||
if safe_content:
|
||||
events.extend(
|
||||
parts_manager.handle_text_delta(
|
||||
vendor_part_id=None,
|
||||
content=safe_content,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
return events
|
||||
|
||||
def flush(
|
||||
self,
|
||||
parts_manager: ModelResponsePartsManager,
|
||||
provider_name: str,
|
||||
) -> list[ModelResponseStreamEvent]:
|
||||
if not self._pending:
|
||||
return []
|
||||
|
||||
pending = self._pending
|
||||
self._pending = ""
|
||||
if self._inside_thinking:
|
||||
return list(
|
||||
parts_manager.handle_thinking_delta(
|
||||
vendor_part_id=None,
|
||||
content=pending,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
return list(
|
||||
parts_manager.handle_text_delta(
|
||||
vendor_part_id=None,
|
||||
content=pending,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _split_incomplete_tag_suffix(content: str, tag: str) -> tuple[str, str]:
|
||||
for suffix_length in range(len(tag) - 1, 0, -1):
|
||||
if content.endswith(tag[:suffix_length]):
|
||||
return content[:-suffix_length], content[-suffix_length:]
|
||||
return content, ""
|
||||
252
dify-agent/src/dify_agent/adapters/llm/provider.py
Normal file
252
dify-agent/src/dify_agent/adapters/llm/provider.py
Normal file
@ -0,0 +1,252 @@
|
||||
"""Dify plugin-daemon provider for Pydantic AI LLM adapters."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Callable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import NoReturn
|
||||
|
||||
import httpx
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UnexpectedModelBehavior, UserError
|
||||
from pydantic_ai.providers import Provider
|
||||
|
||||
_DEFAULT_DAEMON_TIMEOUT: float | httpx.Timeout | None = 600.0
|
||||
|
||||
|
||||
class PluginDaemonBasicResponse(BaseModel):
|
||||
code: int
|
||||
message: str
|
||||
data: object | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DifyPluginDaemonLLMClient:
|
||||
plugin_daemon_url: str
|
||||
plugin_daemon_api_key: str
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
provider: str
|
||||
user_id: str | None
|
||||
http_client: httpx.AsyncClient = field(repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.plugin_daemon_url = self.plugin_daemon_url.rstrip("/")
|
||||
|
||||
async def iter_llm_result_chunks(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
credentials: dict[str, object],
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, object],
|
||||
tools: list[PromptMessageTool] | None,
|
||||
stop: list[str] | None,
|
||||
stream: bool,
|
||||
) -> AsyncIterator[LLMResultChunk]:
|
||||
async for item in self._iter_stream_response(
|
||||
model_name=model,
|
||||
path=f"plugin/{self.tenant_id}/dispatch/llm/invoke",
|
||||
request_data={
|
||||
"provider": self.provider,
|
||||
"model_type": "llm",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"prompt_messages": prompt_messages,
|
||||
"model_parameters": model_parameters,
|
||||
"tools": tools,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
},
|
||||
response_model=LLMResultChunk,
|
||||
):
|
||||
yield item
|
||||
|
||||
async def _iter_stream_response[T: BaseModel](
|
||||
self,
|
||||
*,
|
||||
model_name: str,
|
||||
path: str,
|
||||
request_data: Mapping[str, object],
|
||||
response_model: type[T],
|
||||
) -> AsyncIterator[T]:
|
||||
payload: dict[str, object] = {"data": _to_jsonable(request_data)}
|
||||
if self.user_id is not None:
|
||||
payload["user_id"] = self.user_id
|
||||
|
||||
headers = {
|
||||
"X-Api-Key": self.plugin_daemon_api_key,
|
||||
"X-Plugin-ID": self.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{self.plugin_daemon_url}/{path}"
|
||||
|
||||
async with self.http_client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
if response.is_error:
|
||||
body = (await response.aread()).decode("utf-8", errors="replace")
|
||||
error = _decode_plugin_daemon_error_payload(body)
|
||||
if error is not None:
|
||||
_raise_plugin_daemon_error(
|
||||
model_name=model_name,
|
||||
error_type=error["error_type"],
|
||||
message=error["message"],
|
||||
status_code=response.status_code,
|
||||
body=error,
|
||||
)
|
||||
raise ModelHTTPError(response.status_code, model_name, body or None)
|
||||
|
||||
async for raw_line in response.aiter_lines():
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
|
||||
wrapped = PluginDaemonBasicResponse.model_validate_json(line)
|
||||
if wrapped.code != 0:
|
||||
error = _decode_plugin_daemon_error_payload(wrapped.message)
|
||||
if error is not None:
|
||||
_raise_plugin_daemon_error(
|
||||
model_name=model_name,
|
||||
error_type=error["error_type"],
|
||||
message=error["message"],
|
||||
body=error,
|
||||
)
|
||||
raise ModelAPIError(
|
||||
model_name,
|
||||
f"Plugin daemon returned error code {wrapped.code}: {wrapped.message}",
|
||||
)
|
||||
if wrapped.data is None:
|
||||
raise UnexpectedModelBehavior("Plugin daemon returned an empty stream item")
|
||||
yield response_model.model_validate(wrapped.data)
|
||||
|
||||
|
||||
@dataclass(slots=True, kw_only=True)
|
||||
class DifyPluginDaemonProvider(Provider[DifyPluginDaemonLLMClient]):
|
||||
"""Pydantic AI provider for Dify plugin-daemon dispatch requests."""
|
||||
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
plugin_provider: str
|
||||
plugin_daemon_url: str
|
||||
plugin_daemon_api_key: str = field(repr=False)
|
||||
user_id: str | None = None
|
||||
timeout: float | httpx.Timeout | None = _DEFAULT_DAEMON_TIMEOUT
|
||||
_client: DifyPluginDaemonLLMClient = field(init=False, repr=False)
|
||||
_own_http_client: httpx.AsyncClient | None = field(init=False, default=None, repr=False)
|
||||
_http_client_factory: Callable[[], httpx.AsyncClient] | None = field(init=False, default=None, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.plugin_daemon_url = self.plugin_daemon_url.rstrip("/")
|
||||
self._http_client_factory = self._make_http_client
|
||||
http_client = self._make_http_client()
|
||||
self._own_http_client = http_client
|
||||
self._client = DifyPluginDaemonLLMClient(
|
||||
plugin_daemon_url=self.plugin_daemon_url,
|
||||
plugin_daemon_api_key=self.plugin_daemon_api_key,
|
||||
tenant_id=self.tenant_id,
|
||||
plugin_id=self.plugin_id,
|
||||
provider=self.plugin_provider,
|
||||
user_id=self.user_id,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
def _make_http_client(self) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(timeout=self.timeout, trust_env=False)
|
||||
|
||||
@override
|
||||
def _set_http_client(self, http_client: httpx.AsyncClient) -> None:
|
||||
self._client.http_client = http_client
|
||||
|
||||
@property
|
||||
@override
|
||||
def name(self) -> str:
|
||||
return f"DifyPlugin/{self.plugin_provider}"
|
||||
|
||||
@property
|
||||
@override
|
||||
def base_url(self) -> str:
|
||||
return self.plugin_daemon_url
|
||||
|
||||
@property
|
||||
@override
|
||||
def client(self) -> DifyPluginDaemonLLMClient:
|
||||
return self._client
|
||||
|
||||
|
||||
def _to_jsonable(value: object) -> object:
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
if isinstance(value, dict):
|
||||
return {key: _to_jsonable(item) for key, item in value.items()}
|
||||
if isinstance(value, list | tuple):
|
||||
return [_to_jsonable(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _decode_plugin_daemon_error_payload(raw_message: str) -> dict[str, str] | None:
|
||||
try:
|
||||
parsed = json.loads(raw_message)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
if not isinstance(parsed, dict):
|
||||
return None
|
||||
|
||||
error_type = parsed.get("error_type")
|
||||
message = parsed.get("message")
|
||||
if not isinstance(error_type, str) or not isinstance(message, str):
|
||||
return None
|
||||
return {"error_type": error_type, "message": message}
|
||||
|
||||
|
||||
def _raise_plugin_daemon_error(
|
||||
*,
|
||||
model_name: str,
|
||||
error_type: str,
|
||||
message: str,
|
||||
status_code: int | None = None,
|
||||
body: object | None = None,
|
||||
) -> NoReturn:
|
||||
http_error_body = body or {"error_type": error_type, "message": message}
|
||||
|
||||
match error_type:
|
||||
case "PluginInvokeError":
|
||||
nested_error = _decode_plugin_daemon_error_payload(message)
|
||||
if nested_error is not None:
|
||||
_raise_plugin_daemon_error(
|
||||
model_name=model_name,
|
||||
error_type=nested_error["error_type"],
|
||||
message=nested_error["message"],
|
||||
status_code=status_code,
|
||||
body=nested_error,
|
||||
)
|
||||
raise ModelAPIError(model_name, message)
|
||||
case "PluginDaemonUnauthorizedError" | "InvokeAuthorizationError":
|
||||
raise ModelHTTPError(status_code or 401, model_name, http_error_body)
|
||||
case "PluginPermissionDeniedError":
|
||||
raise ModelHTTPError(status_code or 403, model_name, http_error_body)
|
||||
case (
|
||||
"PluginDaemonBadRequestError"
|
||||
| "InvokeBadRequestError"
|
||||
| "CredentialsValidateFailedError"
|
||||
| "PluginUniqueIdentifierError"
|
||||
):
|
||||
raise ModelHTTPError(status_code or 400, model_name, http_error_body)
|
||||
case "EndpointSetupFailedError" | "TriggerProviderCredentialValidationError":
|
||||
raise UserError(message)
|
||||
case "PluginDaemonNotFoundError" | "PluginNotFoundError":
|
||||
raise ModelHTTPError(status_code or 404, model_name, http_error_body)
|
||||
case "InvokeRateLimitError":
|
||||
raise ModelHTTPError(status_code or 429, model_name, http_error_body)
|
||||
case "PluginDaemonInternalServerError" | "PluginDaemonInnerError":
|
||||
raise ModelHTTPError(status_code or 500, model_name, http_error_body)
|
||||
case "InvokeConnectionError" | "InvokeServerUnavailableError":
|
||||
raise ModelHTTPError(status_code or 503, model_name, http_error_body)
|
||||
case _:
|
||||
raise ModelAPIError(model_name, f"{error_type}: {message}")
|
||||
@ -0,0 +1,91 @@
|
||||
import json
|
||||
from decimal import Decimal
|
||||
|
||||
import httpx
|
||||
from graphon.model_runtime.entities.llm_entities import (
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def make_usage(prompt_tokens: int = 3, completion_tokens: int = 5) -> LLMUsage:
|
||||
return LLMUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_unit_price=Decimal("0"),
|
||||
prompt_price_unit=Decimal("0"),
|
||||
prompt_price=Decimal("0"),
|
||||
completion_tokens=completion_tokens,
|
||||
completion_unit_price=Decimal("0"),
|
||||
completion_price_unit=Decimal("0"),
|
||||
completion_price=Decimal("0"),
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_price=Decimal("0"),
|
||||
currency="USD",
|
||||
latency=0.1,
|
||||
)
|
||||
|
||||
|
||||
def single_text_chunk(
|
||||
text: str,
|
||||
*,
|
||||
prompt_tokens: int = 3,
|
||||
completion_tokens: int = 5,
|
||||
) -> list[LLMResultChunk]:
|
||||
return [
|
||||
LLMResultChunk(
|
||||
model="demo-model",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text, tool_calls=[]),
|
||||
usage=make_usage(
|
||||
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
||||
),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def wrap_plugin_daemon_stream_item(item: object) -> str:
|
||||
if isinstance(item, BaseModel):
|
||||
data = item.model_dump(mode="json")
|
||||
else:
|
||||
data = item
|
||||
return f"data: {json.dumps({'code': 0, 'message': '', 'data': data})}\n\n"
|
||||
|
||||
|
||||
def build_stream_response(*items: object, status_code: int = 200) -> httpx.Response:
|
||||
body = "".join(wrap_plugin_daemon_stream_item(item) for item in items)
|
||||
return httpx.Response(
|
||||
status_code=status_code,
|
||||
headers={"content-type": "text/event-stream"},
|
||||
content=body.encode("utf-8"),
|
||||
)
|
||||
|
||||
|
||||
def build_error_response(
|
||||
error_type: str, message: str, *, status_code: int
|
||||
) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
status_code=status_code,
|
||||
headers={"content-type": "application/json"},
|
||||
content=json.dumps({"error_type": error_type, "message": message}).encode(
|
||||
"utf-8"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def build_stream_error(
|
||||
error_type: str, message: str, *, code: int = -500
|
||||
) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
status_code=200,
|
||||
headers={"content-type": "text/event-stream"},
|
||||
content=(
|
||||
f"data: {json.dumps({'code': code, 'message': json.dumps({'error_type': error_type, 'message': message}), 'data': None})}\n\n"
|
||||
).encode("utf-8"),
|
||||
)
|
||||
374
dify-agent/tests/unit/dify_agent/adapters/llm/test_model.py
Normal file
374
dify-agent/tests/unit/dify_agent/adapters/llm/test_model.py
Normal file
@ -0,0 +1,374 @@
|
||||
import json
|
||||
import unittest
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
from pydantic_ai.exceptions import ModelHTTPError, UserError
|
||||
from pydantic_ai.messages import (
|
||||
InstructionPart,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
RetryPromptPart,
|
||||
SystemPromptPart,
|
||||
TextPart,
|
||||
ThinkingPart,
|
||||
ToolCallPart,
|
||||
ToolReturnPart,
|
||||
UserPromptPart,
|
||||
)
|
||||
from pydantic_ai.models import ModelRequestParameters
|
||||
from pydantic_ai.tools import ToolDefinition
|
||||
|
||||
from dify_agent.adapters.llm import DifyLLMAdapterModel, DifyPluginDaemonProvider
|
||||
|
||||
from ._test_support import (
|
||||
AssistantPromptMessage,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
build_error_response,
|
||||
build_stream_error,
|
||||
build_stream_response,
|
||||
make_usage,
|
||||
single_text_chunk,
|
||||
)
|
||||
|
||||
|
||||
class DifyLLMAdapterModelTests(unittest.IsolatedAsyncioTestCase):
|
||||
def make_provider(
|
||||
self,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
) -> DifyPluginDaemonProvider:
|
||||
return DifyPluginDaemonProvider(
|
||||
tenant_id="tenant-1",
|
||||
plugin_id="langgenius/openai",
|
||||
plugin_provider="openai",
|
||||
plugin_daemon_url="http://plugin-daemon",
|
||||
plugin_daemon_api_key="daemon-secret",
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_daemon_stream(self, handler: httpx.MockTransport):
|
||||
@asynccontextmanager
|
||||
async def mock_stream(
|
||||
client: httpx.AsyncClient,
|
||||
method: str,
|
||||
url: str,
|
||||
**kwargs: object,
|
||||
):
|
||||
request = client.build_request(
|
||||
method,
|
||||
url,
|
||||
headers=cast(dict[str, str] | None, kwargs.get("headers")),
|
||||
json=kwargs.get("json"),
|
||||
)
|
||||
yield handler.handle_request(request)
|
||||
|
||||
with patch.object(httpx.AsyncClient, "stream", new=mock_stream):
|
||||
yield
|
||||
|
||||
async def test_request_uses_plugin_daemon_dispatch_contract(self) -> None:
|
||||
messages = [
|
||||
ModelRequest(
|
||||
parts=[
|
||||
SystemPromptPart("request system"),
|
||||
UserPromptPart("hello"),
|
||||
ToolReturnPart(
|
||||
tool_name="lookup",
|
||||
content={"city": "Paris"},
|
||||
tool_call_id="tool-1",
|
||||
),
|
||||
RetryPromptPart(
|
||||
content="try again", tool_name="lookup", tool_call_id="tool-1"
|
||||
),
|
||||
]
|
||||
),
|
||||
ModelResponse(
|
||||
parts=[
|
||||
TextPart(content="previous answer"),
|
||||
ToolCallPart(
|
||||
tool_name="lookup",
|
||||
args='{"city":"Paris"}',
|
||||
tool_call_id="tool-1",
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
request_parameters = ModelRequestParameters(
|
||||
function_tools=[
|
||||
ToolDefinition(
|
||||
name="weather",
|
||||
description="Look up the weather",
|
||||
parameters_json_schema={
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
)
|
||||
],
|
||||
instruction_parts=[InstructionPart(content="be concise")],
|
||||
)
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
self.assertEqual(request.method, "POST")
|
||||
self.assertEqual(request.url.path, "/plugin/tenant-1/dispatch/llm/invoke")
|
||||
self.assertEqual(request.headers["X-Api-Key"], "daemon-secret")
|
||||
self.assertEqual(request.headers["X-Plugin-ID"], "langgenius/openai")
|
||||
|
||||
payload = json.loads(request.content.decode("utf-8"))
|
||||
self.assertEqual(payload["user_id"], "user-123")
|
||||
data = payload["data"]
|
||||
self.assertEqual(data["provider"], "openai")
|
||||
self.assertEqual(data["model_type"], "llm")
|
||||
self.assertEqual(data["model"], "demo-model")
|
||||
self.assertEqual(data["credentials"], {"api_key": "secret"})
|
||||
self.assertEqual(
|
||||
data["model_parameters"],
|
||||
{"temperature": 0.2, "max_tokens": 128, "logit_bias": {"1": 2}},
|
||||
)
|
||||
self.assertEqual(data["stop"], ["END"])
|
||||
self.assertFalse(data["stream"])
|
||||
self.assertEqual(data["tools"][0]["name"], "weather")
|
||||
self.assertEqual(data["prompt_messages"][0]["role"], "system")
|
||||
self.assertEqual(data["prompt_messages"][0]["content"], "request system")
|
||||
self.assertEqual(data["prompt_messages"][1]["content"], "be concise")
|
||||
self.assertEqual(data["prompt_messages"][2]["content"], "hello")
|
||||
self.assertEqual(data["prompt_messages"][3]["role"], "tool")
|
||||
self.assertEqual(data["prompt_messages"][4]["role"], "tool")
|
||||
self.assertEqual(data["prompt_messages"][5]["role"], "assistant")
|
||||
return build_stream_response(
|
||||
LLMResultChunk(
|
||||
model="demo-model",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content="adapter response", tool_calls=[]
|
||||
),
|
||||
usage=make_usage(prompt_tokens=11, completion_tokens=7),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
|
||||
adapter = DifyLLMAdapterModel(
|
||||
"demo-model",
|
||||
self.make_provider(user_id="user-123"),
|
||||
credentials={"api_key": "secret"},
|
||||
model_settings={"temperature": 0.2, "stop_sequences": ["DEFAULT_STOP"]},
|
||||
)
|
||||
|
||||
response = await adapter.request(
|
||||
messages,
|
||||
model_settings={"max_tokens": 128, "logit_bias": {"1": 2}, "stop_sequences": ["END"]},
|
||||
model_request_parameters=request_parameters,
|
||||
)
|
||||
|
||||
self.assertEqual(response.model_name, "demo-model")
|
||||
self.assertEqual(response.provider_name, "DifyPlugin/openai")
|
||||
self.assertEqual(response.usage.input_tokens, 11)
|
||||
self.assertEqual(response.usage.output_tokens, 7)
|
||||
self.assertEqual(response.parts[0].part_kind, "text")
|
||||
self.assertEqual(cast(TextPart, response.parts[0]).content, "adapter response")
|
||||
|
||||
async def test_request_returns_a_response(self) -> None:
|
||||
def handler(_request: httpx.Request) -> httpx.Response:
|
||||
return build_stream_response(
|
||||
*single_text_chunk(
|
||||
"adapter response", prompt_tokens=11, completion_tokens=7
|
||||
)
|
||||
)
|
||||
|
||||
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
|
||||
adapter = DifyLLMAdapterModel(
|
||||
"demo-model",
|
||||
self.make_provider(),
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
|
||||
response = await adapter.request(
|
||||
[ModelRequest(parts=[UserPromptPart("hello")])],
|
||||
model_settings=None,
|
||||
model_request_parameters=ModelRequestParameters(),
|
||||
)
|
||||
|
||||
self.assertEqual(response.model_name, "demo-model")
|
||||
self.assertEqual(response.parts[0].part_kind, "text")
|
||||
self.assertEqual(cast(TextPart, response.parts[0]).content, "adapter response")
|
||||
self.assertEqual(response.usage.input_tokens, 11)
|
||||
self.assertEqual(response.usage.output_tokens, 7)
|
||||
|
||||
async def test_request_stream_yields_response_parts_and_usage(self) -> None:
|
||||
def handler(_request: httpx.Request) -> httpx.Response:
|
||||
return build_stream_response(
|
||||
LLMResultChunk(
|
||||
model="demo-model",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content="hello ", tool_calls=[]),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="demo-model",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=1,
|
||||
message=AssistantPromptMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id="call-1",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name="weather",
|
||||
arguments='{"city":"Paris"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
),
|
||||
LLMResultChunk(
|
||||
model="demo-model",
|
||||
delta=LLMResultChunkDelta(
|
||||
index=2,
|
||||
message=AssistantPromptMessage(content="world", tool_calls=[]),
|
||||
usage=make_usage(prompt_tokens=6, completion_tokens=4),
|
||||
finish_reason="tool_calls",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
|
||||
adapter = DifyLLMAdapterModel(
|
||||
"demo-model",
|
||||
self.make_provider(),
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
|
||||
async with adapter.request_stream(
|
||||
[ModelRequest(parts=[UserPromptPart("hello")])],
|
||||
model_settings=None,
|
||||
model_request_parameters=ModelRequestParameters(),
|
||||
) as stream:
|
||||
events = [event async for event in stream]
|
||||
response = stream.get()
|
||||
|
||||
self.assertTrue(events)
|
||||
self.assertEqual(response.usage.input_tokens, 6)
|
||||
self.assertEqual(response.usage.output_tokens, 4)
|
||||
self.assertEqual(response.finish_reason, "tool_call")
|
||||
self.assertEqual(response.parts[0].part_kind, "text")
|
||||
self.assertEqual(cast(TextPart, response.parts[0]).content, "hello ")
|
||||
self.assertEqual(response.parts[1].part_kind, "tool-call")
|
||||
self.assertEqual(cast(ToolCallPart, response.parts[1]).tool_name, "weather")
|
||||
self.assertEqual(response.parts[2].part_kind, "text")
|
||||
self.assertEqual(cast(TextPart, response.parts[2]).content, "world")
|
||||
|
||||
async def test_request_splits_embedded_thinking_tags_into_parts(self) -> None:
|
||||
def handler(_request: httpx.Request) -> httpx.Response:
|
||||
return build_stream_response(
|
||||
*single_text_chunk("before<think>reasoning</think>after")
|
||||
)
|
||||
|
||||
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
|
||||
adapter = DifyLLMAdapterModel(
|
||||
"demo-model",
|
||||
self.make_provider(),
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
|
||||
response = await adapter.request(
|
||||
[ModelRequest(parts=[UserPromptPart("hello")])],
|
||||
model_settings=None,
|
||||
model_request_parameters=ModelRequestParameters(),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[part.part_kind for part in response.parts], ["text", "thinking", "text"]
|
||||
)
|
||||
self.assertEqual(cast(TextPart, response.parts[0]).content, "before")
|
||||
self.assertEqual(cast(ThinkingPart, response.parts[1]).content, "reasoning")
|
||||
self.assertEqual(cast(TextPart, response.parts[2]).content, "after")
|
||||
|
||||
async def test_request_maps_stream_envelope_rate_limit_error_to_http_error(
|
||||
self,
|
||||
) -> None:
|
||||
def handler(_request: httpx.Request) -> httpx.Response:
|
||||
return build_stream_error(
|
||||
"PluginInvokeError",
|
||||
json.dumps(
|
||||
{"error_type": "InvokeRateLimitError", "message": "too many"}
|
||||
),
|
||||
)
|
||||
|
||||
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
|
||||
adapter = DifyLLMAdapterModel(
|
||||
"demo-model",
|
||||
self.make_provider(),
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
|
||||
with self.assertRaises(ModelHTTPError) as context:
|
||||
await adapter.request(
|
||||
[ModelRequest(parts=[UserPromptPart("hello")])],
|
||||
model_settings=None,
|
||||
model_request_parameters=ModelRequestParameters(),
|
||||
)
|
||||
|
||||
self.assertEqual(context.exception.status_code, 429)
|
||||
self.assertEqual(
|
||||
context.exception.body,
|
||||
{"error_type": "InvokeRateLimitError", "message": "too many"},
|
||||
)
|
||||
|
||||
async def test_request_maps_http_error_payload_to_http_error(self) -> None:
|
||||
def handler(_request: httpx.Request) -> httpx.Response:
|
||||
return build_error_response(
|
||||
"PluginDaemonUnauthorizedError", "invalid api key", status_code=401
|
||||
)
|
||||
|
||||
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
|
||||
adapter = DifyLLMAdapterModel(
|
||||
"demo-model",
|
||||
self.make_provider(),
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
|
||||
with self.assertRaises(ModelHTTPError) as context:
|
||||
await adapter.request(
|
||||
[ModelRequest(parts=[UserPromptPart("hello")])],
|
||||
model_settings=None,
|
||||
model_request_parameters=ModelRequestParameters(),
|
||||
)
|
||||
|
||||
self.assertEqual(context.exception.status_code, 401)
|
||||
self.assertEqual(
|
||||
context.exception.body,
|
||||
{
|
||||
"error_type": "PluginDaemonUnauthorizedError",
|
||||
"message": "invalid api key",
|
||||
},
|
||||
)
|
||||
|
||||
async def test_request_maps_endpoint_setup_error_to_user_error(self) -> None:
|
||||
def handler(_request: httpx.Request) -> httpx.Response:
|
||||
return build_stream_error(
|
||||
"EndpointSetupFailedError", "missing endpoint config"
|
||||
)
|
||||
|
||||
async with self.mock_daemon_stream(httpx.MockTransport(handler)):
|
||||
adapter = DifyLLMAdapterModel(
|
||||
"demo-model",
|
||||
self.make_provider(),
|
||||
credentials={"api_key": "secret"},
|
||||
)
|
||||
|
||||
with self.assertRaises(UserError) as context:
|
||||
await adapter.request(
|
||||
[ModelRequest(parts=[UserPromptPart("hello")])],
|
||||
model_settings=None,
|
||||
model_request_parameters=ModelRequestParameters(),
|
||||
)
|
||||
|
||||
self.assertEqual(str(context.exception), "missing endpoint config")
|
||||
Loading…
Reference in New Issue
Block a user