dify/api/core/plugin/impl/model_runtime.py
QuantumGhost f0b34bdeb4
feat(api): LLM polling support (#37462)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
2026-06-17 23:34:33 +00:00

713 lines
23 KiB
Python

from __future__ import annotations
import hashlib
import logging
from collections.abc import Generator, Iterable, Sequence
from typing import IO, Any, Literal, cast, overload, override
from pydantic import ValidationError
from pydantic.json_schema import JsonValue
from redis import RedisError
from configs import dify_config
from core.llm_generator.output_parser.structured_output import (
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
)
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
from core.plugin.plugin_service import PluginService
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.llm_entities import (
LLMPollingResult,
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
)
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
from graphon.model_runtime.model_providers.base.large_language_model import normalize_non_stream_runtime_result
from graphon.model_runtime.protocols.runtime import ModelRuntime
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
# `TS` means tenant scope
TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__"
# TODO(-LAN-): Move native structured-output invocation into Graphon's LLM node.
# TODO(-LAN-): Remove this Dify-side adapter once Graphon owns structured output end-to-end.
class _PluginStructuredOutputModelInstance:
"""Bind plugin model identity to the shared structured-output helper.
The structured-output parser is shared with legacy ``ModelInstance`` flows
and only needs an object exposing ``invoke_llm(...)``. ``PluginModelRuntime``
intentionally exposes a lower-level API where provider, model, and
credentials are passed per call. This adapter supplies the small bound
``invoke_llm`` surface the helper needs without constructing a full
``ModelInstance`` or reintroducing model-manager dependencies into the
plugin runtime path.
"""
def __init__(
self,
*,
runtime: PluginModelRuntime,
provider: str,
model: str,
credentials: dict[str, Any],
) -> None:
self._runtime = runtime
self._provider = provider
self._model = model
self._credentials = credentials
def invoke_llm(
self,
*,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict[str, Any] | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
callbacks: object | None = None,
) -> LLMResult | Generator[LLMResultChunk, None, None]:
del callbacks
if stream:
return self._runtime.invoke_llm(
provider=self._provider,
model=self._model,
credentials=self._credentials,
model_parameters=model_parameters or {},
prompt_messages=prompt_messages,
tools=list(tools) if tools else None,
stop=stop,
stream=True,
)
return self._runtime.invoke_llm(
provider=self._provider,
model=self._model,
credentials=self._credentials,
model_parameters=model_parameters or {},
prompt_messages=prompt_messages,
tools=list(tools) if tools else None,
stop=stop,
stream=False,
)
class PluginModelRuntime(ModelRuntime):
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope.
Provider discovery goes through ``PluginService`` so the plugin lifecycle
methods and provider reads share one tenant-scoped cache owner.
"""
tenant_id: str
user_id: str | None
client: PluginModelClient
_plugin_service: type[PluginService]
def __init__(
self,
tenant_id: str,
user_id: str | None,
client: PluginModelClient,
plugin_service: type[PluginService],
) -> None:
if client is None:
raise ValueError("client is required.")
if plugin_service is None:
raise ValueError("plugin_service is required.")
self.tenant_id = tenant_id
self.user_id = user_id
self.client = client
self._plugin_service = plugin_service
@override
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
return self._plugin_service.fetch_plugin_model_providers(tenant_id=self.tenant_id, client=self.client)
@override
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
provider_schema = self._get_provider_schema(provider)
if icon_type.lower() == "icon_small":
if not provider_schema.icon_small:
raise ValueError(f"Provider {provider} does not have small icon.")
file_name = (
provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us
)
elif icon_type.lower() == "icon_small_dark":
if not provider_schema.icon_small_dark:
raise ValueError(f"Provider {provider} does not have small dark icon.")
file_name = (
provider_schema.icon_small_dark.zh_hans
if lang.lower() == "zh_hans"
else provider_schema.icon_small_dark.en_us
)
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")
image_mime_types = {
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"bmp": "image/bmp",
"tiff": "image/tiff",
"tif": "image/tiff",
"webp": "image/webp",
"svg": "image/svg+xml",
"ico": "image/vnd.microsoft.icon",
"heif": "image/heif",
"heic": "image/heic",
}
extension = file_name.split(".")[-1]
mime_type = image_mime_types.get(extension, "image/png")
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
@override
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
plugin_id, provider_name = self._split_provider(provider)
self.client.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
credentials=credentials,
)
@override
def validate_model_credentials(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> None:
plugin_id, provider_name = self._split_provider(provider)
self.client.validate_model_credentials(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
@override
def get_model_schema(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> AIModelEntity | None:
cache_key = self._get_schema_cache_key(
provider=provider,
model_type=model_type,
model=model,
credentials=credentials,
)
cached_schema_json = None
try:
cached_schema_json = redis_client.get(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to read plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
if cached_schema_json:
try:
return AIModelEntity.model_validate_json(cached_schema_json)
except ValidationError:
logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True)
try:
redis_client.delete(cache_key)
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to delete invalid plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
plugin_id, provider_name = self._split_provider(provider)
schema = self.client.get_model_schema(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
if schema:
try:
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json())
except (RedisError, RuntimeError) as exc:
logger.warning(
"Failed to write plugin model schema cache for model %s: %s",
model,
str(exc),
exc_info=True,
)
return schema
@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResult: ...
@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
@override
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> LLMResult | Generator[LLMResultChunk, None, None]:
plugin_id, provider_name = self._split_provider(provider)
result = self.client.invoke_llm(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
model_parameters=model_parameters,
prompt_messages=list(prompt_messages),
tools=tools,
stop=list(stop) if stop else None,
stream=stream,
)
if stream:
return result
return normalize_non_stream_runtime_result(
model=model,
prompt_messages=prompt_messages,
result=result,
)
@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@override
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: bool,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
model_schema = self.get_model_schema(
provider=provider,
model_type=ModelType.LLM,
model=model,
credentials=credentials,
)
if model_schema is None:
raise ValueError(f"Model schema not found for {model}")
adapter = _PluginStructuredOutputModelInstance(
runtime=self,
provider=provider,
model=model,
credentials=credentials,
)
return invoke_llm_with_structured_output_helper(
provider=provider,
model_schema=model_schema,
model_instance=cast(Any, adapter),
prompt_messages=prompt_messages,
json_schema=json_schema,
model_parameters=model_parameters,
tools=None,
stop=list(stop) if stop else None,
stream=stream,
)
@override
def get_llm_num_tokens(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: Sequence[PromptMessageTool] | None,
) -> int:
if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
return 0
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_llm_num_tokens(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
prompt_messages=list(prompt_messages),
tools=list(tools) if tools else None,
)
def start_llm_polling(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: Sequence[PromptMessageTool] | None,
stop: Sequence[str] | None,
json_schema: dict[str, Any] | None,
) -> LLMPollingResult:
"""Start a plugin-side polling job for long-running LLM invocations."""
plugin_id, provider_name = self._split_provider(provider)
return self.client.start_llm_polling(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
prompt_messages=list(prompt_messages),
model_parameters=model_parameters,
tools=list(tools) if tools else None,
stop=list(stop) if stop else None,
json_schema=json_schema,
)
def check_llm_polling(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
plugin_state: dict[str, JsonValue],
) -> LLMPollingResult:
"""Check the latest plugin-side polling state for an LLM invocation."""
plugin_id, provider_name = self._split_provider(provider)
return self.client.check_llm_polling(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
plugin_state=plugin_state,
)
@override
def invoke_text_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
input_type: EmbeddingInputType,
) -> EmbeddingResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
texts=texts,
input_type=input_type,
)
@override
def invoke_multimodal_embedding(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
documents: list[dict[str, Any]],
input_type: EmbeddingInputType,
) -> EmbeddingResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_multimodal_embedding(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
documents=documents,
input_type=input_type,
)
@override
def get_text_embedding_num_tokens(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
texts: list[str],
) -> list[int]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_text_embedding_num_tokens(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
texts=texts,
)
@override
def invoke_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: str,
docs: list[str],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_rerank(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
@override
def invoke_multimodal_rerank(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None,
top_n: int | None,
) -> RerankResult:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_multimodal_rerank(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
query=query,
docs=docs,
score_threshold=score_threshold,
top_n=top_n,
)
@override
def invoke_tts(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
content_text: str,
voice: str,
) -> Iterable[bytes]:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_tts(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
content_text=content_text,
voice=voice,
)
@override
def get_tts_model_voices(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
language: str | None,
) -> Any:
plugin_id, provider_name = self._split_provider(provider)
return self.client.get_tts_model_voices(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
language=language,
)
@override
def invoke_speech_to_text(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
file: IO[bytes],
) -> str:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_speech_to_text(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
file=file,
)
@override
def invoke_moderation(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
text: str,
) -> bool:
plugin_id, provider_name = self._split_provider(provider)
return self.client.invoke_moderation(
tenant_id=self.tenant_id,
user_id=self.user_id,
plugin_id=plugin_id,
provider=provider_name,
model=model,
credentials=credentials,
text=text,
)
def _get_provider_schema(self, provider: str) -> ProviderEntity:
providers = self.fetch_model_providers()
provider_entity = next((item for item in providers if item.provider == provider), None)
if provider_entity is None:
provider_entity = next((item for item in providers if provider == item.provider_name), None)
if provider_entity is None:
raise ValueError(f"Invalid provider: {provider}")
return provider_entity
def _get_schema_cache_key(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> str:
# The plugin daemon distinguishes ``None`` from an explicit empty-string
# caller id, so the cache must only collapse ``None`` into tenant scope.
cache_user_id = TENANT_SCOPE_SCHEMA_CACHE_USER_ID if self.user_id is None else self.user_id
cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}:{cache_user_id}"
sorted_credentials = sorted(credentials.items()) if credentials else []
if not sorted_credentials:
return cache_key
hashed_credentials = ":".join(
[hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials]
)
return f"{cache_key}:{hashed_credentials}"
def _split_provider(self, provider: str) -> tuple[str, str]:
provider_id = ModelProviderID(provider)
return provider_id.plugin_id, provider_id.provider_name