mirror of
https://github.com/langgenius/dify.git
synced 2026-06-17 23:21:12 +08:00
feat(agent): wire knowledge base retrieval into runtime (#37577)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
8782da42c8
commit
0ea0647dd0
@ -33,6 +33,7 @@ from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAge
|
||||
from clients.agent_backend.request_builder import (
|
||||
AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
DIFY_KNOWLEDGE_BASE_LAYER_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
@ -47,6 +48,7 @@ from clients.agent_backend.request_builder import (
|
||||
__all__ = [
|
||||
"AGENT_SOUL_PROMPT_LAYER_ID",
|
||||
"DIFY_EXECUTION_CONTEXT_LAYER_ID",
|
||||
"DIFY_KNOWLEDGE_BASE_LAYER_ID",
|
||||
"DIFY_PLUGIN_TOOLS_LAYER_ID",
|
||||
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
|
||||
"WORKFLOW_USER_PROMPT_LAYER_ID",
|
||||
|
||||
@ -32,6 +32,7 @@ from dify_agent.layers.execution_context import (
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
DifyExecutionContextLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.knowledge import DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID, DifyKnowledgeBaseLayerConfig
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
|
||||
from dify_agent.layers.shell import DIFY_SHELL_LAYER_TYPE_ID, DifyShellLayerConfig
|
||||
from dify_agent.protocol import (
|
||||
@ -55,6 +56,7 @@ AGENT_APP_USER_PROMPT_LAYER_ID = "agent_app_user_prompt"
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
|
||||
DIFY_DRIVE_LAYER_ID = "drive"
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
|
||||
DIFY_KNOWLEDGE_BASE_LAYER_ID = "knowledge"
|
||||
DIFY_ASK_HUMAN_LAYER_ID = "ask_human"
|
||||
DIFY_SHELL_LAYER_ID = "shell"
|
||||
|
||||
@ -139,6 +141,7 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
idempotency_key: str | None = None
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
tools: DifyPluginToolsLayerConfig | None = None
|
||||
knowledge: DifyKnowledgeBaseLayerConfig | None = None
|
||||
# Drive Skills & Files declaration (dify.drive) — an index the agent pulls
|
||||
# through the back proxy, never inline content; see AGENT_DRIVE_MANIFEST_ENABLED.
|
||||
drive_config: DifyDriveLayerConfig | None = None
|
||||
@ -185,6 +188,7 @@ class AgentBackendAgentAppRunInput(BaseModel):
|
||||
idempotency_key: str | None = None
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
tools: DifyPluginToolsLayerConfig | None = None
|
||||
knowledge: DifyKnowledgeBaseLayerConfig | None = None
|
||||
# Drive Skills & Files declaration (dify.drive) — an index the agent pulls
|
||||
# through the back proxy, never inline content; see AGENT_DRIVE_MANIFEST_ENABLED.
|
||||
drive_config: DifyDriveLayerConfig | None = None
|
||||
@ -221,7 +225,7 @@ class AgentBackendRunRequestBuilder:
|
||||
|
||||
Layer graph: optional Agent Soul system prompt → user prompt →
|
||||
execution context → optional history (multi-turn) → LLM → optional
|
||||
plugin tools → optional structured output. Mirrors the workflow-node
|
||||
plugin tools / knowledge search → optional structured output. Mirrors the workflow-node
|
||||
layer ordering minus the workflow-job / previous-node prompt.
|
||||
"""
|
||||
layers: list[RunLayerSpec] = []
|
||||
@ -300,6 +304,17 @@ class AgentBackendRunRequestBuilder:
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.knowledge is not None and run_input.knowledge.dataset_ids:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_KNOWLEDGE_BASE_LAYER_ID,
|
||||
type=DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.knowledge,
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.ask_human_config is not None:
|
||||
# Human-in-the-loop ask_human deferred tool (dify.ask_human). A call ends
|
||||
# the run with a deferred_tool_call; the caller pauses (workflow HITL) and
|
||||
@ -398,7 +413,12 @@ class AgentBackendRunRequestBuilder:
|
||||
)
|
||||
|
||||
def build_for_workflow_node(self, run_input: AgentBackendWorkflowNodeRunInput) -> CreateRunRequest:
|
||||
"""Build a workflow Agent Node run request without defining another wire schema."""
|
||||
"""Build a workflow Agent Node run request without defining another wire schema.
|
||||
|
||||
Layer graph mirrors the workflow surface: prompts → execution context →
|
||||
optional drive/history → LLM → optional plugin tools / knowledge search
|
||||
→ optional auxiliary layers such as ask_human, shell, and structured output.
|
||||
"""
|
||||
layers: list[RunLayerSpec] = []
|
||||
if run_input.agent_soul_prompt:
|
||||
layers.append(
|
||||
@ -483,6 +503,17 @@ class AgentBackendRunRequestBuilder:
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.knowledge is not None and run_input.knowledge.dataset_ids:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_KNOWLEDGE_BASE_LAYER_ID,
|
||||
type=DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.knowledge,
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.ask_human_config is not None:
|
||||
# Human-in-the-loop ask_human deferred tool (dify.ask_human). A call ends
|
||||
# the run with a deferred_tool_call; the caller pauses (workflow HITL) and
|
||||
|
||||
@ -9,7 +9,7 @@ api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="Inner API",
|
||||
description="Internal APIs for enterprise features, billing, and plugin communication",
|
||||
description="Internal APIs for enterprise features, billing, knowledge retrieval, and plugin communication",
|
||||
)
|
||||
|
||||
# Create namespace
|
||||
@ -17,6 +17,7 @@ inner_api_ns = Namespace("inner_api", description="Internal API operations", pat
|
||||
|
||||
from . import mail as _mail
|
||||
from .app import dsl as _app_dsl
|
||||
from .knowledge import retrieval as _knowledge_retrieval
|
||||
from .plugin import agent_drive as _agent_drive
|
||||
from .plugin import plugin as _plugin
|
||||
from .workspace import workspace as _workspace
|
||||
@ -26,6 +27,7 @@ api.add_namespace(inner_api_ns)
|
||||
__all__ = [
|
||||
"_agent_drive",
|
||||
"_app_dsl",
|
||||
"_knowledge_retrieval",
|
||||
"_mail",
|
||||
"_plugin",
|
||||
"_workspace",
|
||||
|
||||
1
api/controllers/inner_api/knowledge/__init__.py
Normal file
1
api/controllers/inner_api/knowledge/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Inner knowledge retrieval endpoints."""
|
||||
110
api/controllers/inner_api/knowledge/retrieval.py
Normal file
110
api/controllers/inner_api/knowledge/retrieval.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""Inner API endpoint for tenant-scoped knowledge retrieval.
|
||||
|
||||
This controller is a thin HTTP wrapper around
|
||||
``services.knowledge_retrieval_inner_service.InnerKnowledgeRetrievalService``.
|
||||
It intentionally keeps authorization simple: shared inner API key plus
|
||||
tenant-scoped app/dataset validation in the service layer.
|
||||
"""
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.wraps import inner_api_only
|
||||
from core.workflow.nodes.knowledge_retrieval import exc as retrieval_exc
|
||||
from libs.exception import BaseHTTPException
|
||||
from services.entities.knowledge_retrieval_inner import InnerKnowledgeRetrieveRequest, InnerKnowledgeRetrieveResponse
|
||||
from services.errors.knowledge_retrieval import ExternalKnowledgeRetrievalError, InnerKnowledgeRetrievalServiceError
|
||||
from services.knowledge_retrieval_inner_service import InnerKnowledgeRetrievalService
|
||||
|
||||
|
||||
class InnerKnowledgeRetrievalHttpError(BaseHTTPException):
|
||||
error_code = "knowledge_retrieve_failed"
|
||||
description = "Knowledge retrieval failed."
|
||||
code = 500
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
error_code: str | None = None,
|
||||
description: str | None = None,
|
||||
status_code: int | None = None,
|
||||
) -> None:
|
||||
if error_code is not None:
|
||||
self.error_code = error_code
|
||||
if description is not None:
|
||||
self.description = description
|
||||
if status_code is not None:
|
||||
self.code = status_code
|
||||
super().__init__(self.description)
|
||||
|
||||
|
||||
register_schema_models(inner_api_ns, InnerKnowledgeRetrieveRequest)
|
||||
register_response_schema_models(inner_api_ns, InnerKnowledgeRetrieveResponse)
|
||||
|
||||
|
||||
@inner_api_ns.route("/knowledge/retrieve")
|
||||
class InnerKnowledgeRetrieveApi(Resource):
|
||||
"""Retrieve knowledge from one or more datasets within the caller tenant."""
|
||||
|
||||
@inner_api_only
|
||||
@inner_api_ns.doc("inner_knowledge_retrieve")
|
||||
@inner_api_ns.doc(description="Retrieve knowledge for trusted internal callers")
|
||||
@inner_api_ns.expect(inner_api_ns.models[InnerKnowledgeRetrieveRequest.__name__])
|
||||
@inner_api_ns.response(
|
||||
200,
|
||||
"Knowledge retrieved successfully",
|
||||
inner_api_ns.models[InnerKnowledgeRetrieveResponse.__name__],
|
||||
)
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
400: "Invalid request body",
|
||||
401: "Unauthorized - invalid inner API key",
|
||||
403: "Caller tenant does not own the requested resource",
|
||||
404: "App or dataset not found",
|
||||
422: "Invalid retrieval configuration",
|
||||
429: "Knowledge retrieval rate limited",
|
||||
502: "External knowledge retrieval failed",
|
||||
500: "Unexpected knowledge retrieval failure",
|
||||
}
|
||||
)
|
||||
def post(self) -> dict[str, object]:
|
||||
"""Validate the payload, run retrieval, and return workflow-style sources."""
|
||||
try:
|
||||
payload = InnerKnowledgeRetrieveRequest.model_validate(inner_api_ns.payload or {})
|
||||
except ValidationError as exc:
|
||||
raise InnerKnowledgeRetrievalHttpError(
|
||||
error_code="invalid_request",
|
||||
description=str(exc),
|
||||
status_code=400,
|
||||
) from exc
|
||||
|
||||
try:
|
||||
response = InnerKnowledgeRetrievalService().retrieve(payload)
|
||||
except InnerKnowledgeRetrievalServiceError as exc:
|
||||
raise InnerKnowledgeRetrievalHttpError(
|
||||
error_code=exc.error_code,
|
||||
description=exc.description,
|
||||
status_code=exc.status_code,
|
||||
) from exc
|
||||
except retrieval_exc.RateLimitExceededError as exc:
|
||||
raise InnerKnowledgeRetrievalHttpError(
|
||||
error_code="knowledge_rate_limited",
|
||||
description=str(exc),
|
||||
status_code=429,
|
||||
) from exc
|
||||
except ExternalKnowledgeRetrievalError as exc:
|
||||
raise InnerKnowledgeRetrievalHttpError(
|
||||
error_code="external_knowledge_failed",
|
||||
description=str(exc),
|
||||
status_code=502,
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
raise InnerKnowledgeRetrievalHttpError(
|
||||
error_code="retrieval_config_invalid",
|
||||
description=str(exc),
|
||||
status_code=422,
|
||||
) from exc
|
||||
|
||||
return response.model_dump(mode="json", by_alias=True)
|
||||
@ -8,39 +8,39 @@ from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from libs.exception import BaseHTTPException
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
class InnerApiUnauthorizedError(BaseHTTPException):
|
||||
error_code = "inner_api_unauthorized"
|
||||
description = "Unauthorized."
|
||||
code = 401
|
||||
|
||||
|
||||
def inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Restrict access to callers authenticated with the shared inner API key."""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
|
||||
abort(401)
|
||||
raise InnerApiUnauthorizedError()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return inner_api_only(view)
|
||||
|
||||
|
||||
def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
|
||||
abort(401)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
return inner_api_only(view)
|
||||
|
||||
|
||||
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
@ -2,8 +2,10 @@
|
||||
|
||||
Mirrors the workflow ``WorkflowAgentRuntimeRequestBuilder`` but for the Agent
|
||||
App surface: the user prompt is the chat message (no workflow-node job / no
|
||||
previous-node context), and multi-turn continuity flows through the
|
||||
conversation-keyed ``session_snapshot`` plus the history layer.
|
||||
previous-node context), multi-turn continuity flows through the
|
||||
conversation-keyed ``session_snapshot`` plus the history layer, and Agent Soul
|
||||
knowledge config is mapped into the same fixed ``dify.knowledge_base`` layer
|
||||
used by workflow runs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -36,6 +38,7 @@ from core.workflow.nodes.agent_v2.runtime_request_builder import (
|
||||
append_runtime_warnings,
|
||||
build_ask_human_layer_config,
|
||||
build_drive_layer_config,
|
||||
build_knowledge_layer_config,
|
||||
build_shell_layer_config,
|
||||
)
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
@ -123,6 +126,7 @@ class AgentAppRuntimeRequestBuilder:
|
||||
if dify_config.AGENT_DRIVE_MANIFEST_ENABLED:
|
||||
drive_config, drive_warnings = build_drive_layer_config(agent_soul, agent_id=context.agent_id)
|
||||
append_runtime_warnings(metadata, drive_warnings)
|
||||
knowledge_config = build_knowledge_layer_config(agent_soul)
|
||||
|
||||
request = self._request_builder.build_for_agent_app(
|
||||
AgentBackendAgentAppRunInput(
|
||||
@ -156,6 +160,7 @@ class AgentAppRuntimeRequestBuilder:
|
||||
or None,
|
||||
user_prompt=context.user_query,
|
||||
tools=tools_layer,
|
||||
knowledge=knowledge_config,
|
||||
drive_config=drive_config,
|
||||
ask_human_config=build_ask_human_layer_config(agent_soul),
|
||||
include_shell=dify_config.AGENT_SHELL_ENABLED,
|
||||
|
||||
@ -123,7 +123,7 @@ class DatasetRetrieval:
|
||||
if not available_datasets_ids:
|
||||
return []
|
||||
|
||||
if not request.query:
|
||||
if not request.query and not request.attachment_ids:
|
||||
return []
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = None, None
|
||||
|
||||
@ -13,6 +13,7 @@ SUPPORTED_AGENT_BACKEND_FEATURES = frozenset(
|
||||
"structured_output",
|
||||
"tools.dify_tools",
|
||||
"tools.cli_tools",
|
||||
"knowledge",
|
||||
"env",
|
||||
"sandbox",
|
||||
# ENG-623: exposed at runtime as the dify.drive declaration layer
|
||||
@ -26,7 +27,6 @@ SUPPORTED_AGENT_BACKEND_FEATURES = frozenset(
|
||||
|
||||
RESERVED_AGENT_BACKEND_FEATURES = frozenset(
|
||||
{
|
||||
"knowledge",
|
||||
"memory",
|
||||
}
|
||||
)
|
||||
@ -80,6 +80,9 @@ def build_runtime_feature_manifest(
|
||||
)
|
||||
|
||||
reserved_status = dict.fromkeys(sorted(RESERVED_AGENT_BACKEND_FEATURES), "reserved_not_executed")
|
||||
reserved_status["knowledge"] = (
|
||||
"supported_by_knowledge_layer" if list_configured_knowledge_dataset_ids(agent_soul) else "not_configured"
|
||||
)
|
||||
reserved_status["skills_files"] = (
|
||||
"supported_by_drive_manifest" if drive_manifest_enabled else "drive_manifest_disabled"
|
||||
)
|
||||
@ -97,6 +100,17 @@ def build_runtime_feature_manifest(
|
||||
}
|
||||
|
||||
|
||||
def list_configured_knowledge_dataset_ids(agent_soul: AgentSoulConfig) -> list[str]:
|
||||
"""Return the normalized knowledge dataset ids that can produce a runtime layer.
|
||||
|
||||
``build_runtime_feature_manifest()`` and ``build_knowledge_layer_config()``
|
||||
must stay aligned: both decide knowledge support from this effective,
|
||||
non-blank dataset-id set rather than from raw
|
||||
``agent_soul.knowledge.datasets`` entries.
|
||||
"""
|
||||
return [dataset_id for dataset in agent_soul.knowledge.datasets if (dataset_id := (dataset.id or "").strip())]
|
||||
|
||||
|
||||
def _get_nested(value: dict[str, Any], path: str) -> Any:
|
||||
current: Any = value
|
||||
for part in path.split("."):
|
||||
|
||||
@ -16,6 +16,7 @@ from dify_agent.layers.execution_context import (
|
||||
DifyExecutionContextLayerConfig,
|
||||
DifyExecutionContextUserFrom,
|
||||
)
|
||||
from dify_agent.layers.knowledge import DifyKnowledgeBaseLayerConfig, DifyKnowledgeRetrievalConfig
|
||||
from dify_agent.layers.shell import (
|
||||
DifyShellCliToolConfig,
|
||||
DifyShellEnvVarConfig,
|
||||
@ -40,6 +41,7 @@ from graphon.file import FileTransferMethod
|
||||
from graphon.variables.segments import Segment
|
||||
from models.agent import Agent, AgentConfigSnapshot, WorkflowAgentNodeBinding
|
||||
from models.agent_config_entities import (
|
||||
AgentKnowledgeQueryConfig,
|
||||
AgentSoulConfig,
|
||||
DeclaredArrayItem,
|
||||
DeclaredOutputChildConfig,
|
||||
@ -60,6 +62,7 @@ from services.agent.prompt_mentions import (
|
||||
|
||||
from .output_failure_orchestrator import retry_idempotency_key
|
||||
from .plugin_tools_builder import WorkflowAgentPluginToolsBuilder, WorkflowAgentPluginToolsBuildError
|
||||
from .runtime_feature_manifest import build_runtime_feature_manifest, list_configured_knowledge_dataset_ids
|
||||
|
||||
_DENIED_PERMISSION_STATUSES = frozenset({"unauthorized", "denied", "forbidden", "invalid", "unavailable"})
|
||||
_DANGEROUS_FLAG_KEYS = ("dangerous", "dangerous_command", "requires_confirmation")
|
||||
@ -69,7 +72,6 @@ _DANGEROUS_ACK_KEYS = (
|
||||
"risk_accepted",
|
||||
"approved",
|
||||
)
|
||||
from .runtime_feature_manifest import build_runtime_feature_manifest
|
||||
|
||||
|
||||
class WorkflowAgentRuntimeRequestBuildError(ValueError):
|
||||
@ -183,6 +185,7 @@ class WorkflowAgentRuntimeRequestBuilder:
|
||||
if dify_config.AGENT_DRIVE_MANIFEST_ENABLED:
|
||||
drive_config, drive_warnings = build_drive_layer_config(agent_soul, agent_id=context.agent.id)
|
||||
append_runtime_warnings(metadata, drive_warnings)
|
||||
knowledge_config = build_knowledge_layer_config(agent_soul)
|
||||
|
||||
request = self._request_builder.build_for_workflow_node(
|
||||
AgentBackendWorkflowNodeRunInput(
|
||||
@ -197,10 +200,11 @@ class WorkflowAgentRuntimeRequestBuilder:
|
||||
model_settings=agent_soul.model.model_settings.model_dump(mode="json", exclude_none=True),
|
||||
),
|
||||
# The execution-context layer is now the only public protocol
|
||||
# carrier for Dify tenant/user/run identifiers. ``user_id`` must
|
||||
# be forwarded here because downstream plugin-daemon provider and
|
||||
# tool clients read it from this layer rather than from any
|
||||
# parallel top-level request field.
|
||||
# carrier for Dify tenant/user/run identifiers. ``user_id`` and
|
||||
# ``user_from`` must be forwarded here because downstream plugin-
|
||||
# daemon provider/tool clients and knowledge-base layers read
|
||||
# caller identity from this layer rather than from any parallel
|
||||
# top-level request field.
|
||||
execution_context=DifyExecutionContextLayerConfig(
|
||||
tenant_id=context.dify_context.tenant_id,
|
||||
user_id=context.dify_context.user_id,
|
||||
@ -221,6 +225,7 @@ class WorkflowAgentRuntimeRequestBuilder:
|
||||
user_prompt=user_prompt,
|
||||
output=self._build_output_config(node_job.declared_outputs),
|
||||
tools=tools_layer,
|
||||
knowledge=knowledge_config,
|
||||
drive_config=drive_config,
|
||||
ask_human_config=build_ask_human_layer_config(agent_soul),
|
||||
include_shell=dify_config.AGENT_SHELL_ENABLED,
|
||||
@ -534,6 +539,45 @@ def build_shell_layer_config(agent_soul: AgentSoulConfig) -> DifyShellLayerConfi
|
||||
)
|
||||
|
||||
|
||||
def build_knowledge_layer_config(agent_soul: AgentSoulConfig) -> DifyKnowledgeBaseLayerConfig | None:
|
||||
"""Map Agent Soul knowledge config into the fixed Dify knowledge-base layer.
|
||||
|
||||
Normalization intentionally matches the current dify-agent runtime contract:
|
||||
|
||||
- blank or missing dataset ids are ignored;
|
||||
- if no valid dataset ids remain, no knowledge layer is injected;
|
||||
- retrieval mode is always forced to ``multiple`` in this first wiring pass;
|
||||
- ``top_k`` falls back to a stable runtime default when the soul omits it;
|
||||
- ``score_threshold`` is only forwarded when the product config explicitly
|
||||
enables it, otherwise the layer keeps the disabled/default ``0.0`` value;
|
||||
- metadata filtering stays at the layer DTO default (disabled).
|
||||
"""
|
||||
dataset_ids = list_configured_knowledge_dataset_ids(agent_soul)
|
||||
if not dataset_ids:
|
||||
return None
|
||||
|
||||
query_config = agent_soul.knowledge.query_config
|
||||
return DifyKnowledgeBaseLayerConfig(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieval=DifyKnowledgeRetrievalConfig(
|
||||
mode="multiple",
|
||||
top_k=_knowledge_top_k(query_config),
|
||||
score_threshold=_knowledge_score_threshold(query_config),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _knowledge_top_k(query_config: AgentKnowledgeQueryConfig) -> int:
|
||||
top_k = query_config.top_k
|
||||
return top_k if isinstance(top_k, int) and top_k >= 1 else 4
|
||||
|
||||
|
||||
def _knowledge_score_threshold(query_config: AgentKnowledgeQueryConfig) -> float:
|
||||
if query_config.score_threshold_enabled and query_config.score_threshold is not None:
|
||||
return query_config.score_threshold
|
||||
return 0.0
|
||||
|
||||
|
||||
def build_ask_human_layer_config(agent_soul: AgentSoulConfig) -> DifyAskHumanLayerConfig | None:
|
||||
"""Enable the dify.ask_human deferred tool when the soul configures human involvement.
|
||||
|
||||
|
||||
210
api/services/entities/knowledge_retrieval_inner.py
Normal file
210
api/services/entities/knowledge_retrieval_inner.py
Normal file
@ -0,0 +1,210 @@
|
||||
"""DTOs for the inner knowledge retrieval API.
|
||||
|
||||
These models define the stable HTTP contract for trusted internal callers and
|
||||
the response shape returned by the workflow knowledge retrieval stack.
|
||||
|
||||
Key cross-field invariants live here because callers cannot infer them from
|
||||
scalar field types alone: ``dataset_ids`` must be non-empty, either ``query``
|
||||
or ``attachment_ids`` is required, ``single`` retrieval requires both ``query``
|
||||
and ``retrieval.model``, ``automatic`` metadata filtering requires
|
||||
``model_config``, and ``manual`` metadata filtering requires conditions. The
|
||||
response reuses workflow ``Source`` items plus serialized ``llm_usage``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import WeightsDict
|
||||
from core.rag.entities.metadata_entities import SupportedComparisonOperator
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import Source
|
||||
from fields.base import ResponseModel
|
||||
|
||||
type JsonScalar = str | int | float | bool | None
|
||||
type JsonValue = JsonScalar | list[JsonScalar] | dict[str, JsonScalar]
|
||||
type MetadataValue = str | list[str] | int | float | None
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveCaller(BaseModel):
|
||||
"""Execution context provided by the trusted internal caller."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
tenant_id: str = Field(min_length=1)
|
||||
user_id: str = Field(min_length=1)
|
||||
app_id: str = Field(min_length=1)
|
||||
user_from: Literal["account", "end-user"]
|
||||
invoke_from: str = Field(min_length=1)
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveModelConfig(BaseModel):
|
||||
"""Model configuration used by single-retrieval or metadata filtering."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
provider: str = Field(min_length=1)
|
||||
name: str = Field(min_length=1)
|
||||
mode: str = Field(min_length=1)
|
||||
completion_params: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveRerankingModelConfig(BaseModel):
|
||||
"""Reranking model configuration for multiple retrieval mode."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
provider: str = Field(min_length=1)
|
||||
model: str = Field(min_length=1)
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveRetrievalConfig(BaseModel):
|
||||
"""Retrieval strategy and its mode-specific configuration."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
mode: Literal["multiple", "single"]
|
||||
top_k: int | None = Field(default=None, ge=1)
|
||||
score_threshold: float = 0.0
|
||||
reranking_mode: str = "reranking_model"
|
||||
reranking_enable: bool = True
|
||||
reranking_model: InnerKnowledgeRetrieveRerankingModelConfig | None = None
|
||||
weights: WeightsDict | None = None
|
||||
model: InnerKnowledgeRetrieveModelConfig | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode_specific_fields(self) -> InnerKnowledgeRetrieveRetrievalConfig:
|
||||
if self.mode == "single" and self.model is None:
|
||||
raise ValueError("retrieval.model is required for single mode")
|
||||
if self.mode == "multiple" and self.top_k is None:
|
||||
raise ValueError("retrieval.top_k is required for multiple mode")
|
||||
return self
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveMetadataCondition(BaseModel):
|
||||
"""Single metadata filter condition."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str = Field(min_length=1)
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: MetadataValue = None
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveMetadataConditions(BaseModel):
|
||||
"""Boolean composition for metadata filter conditions."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[InnerKnowledgeRetrieveMetadataCondition] | None = None
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveMetadataFilteringConfig(BaseModel):
|
||||
"""Metadata filtering configuration forwarded to workflow retrieval.
|
||||
|
||||
``automatic`` mode requires ``model_config`` so downstream metadata model
|
||||
planning has the necessary LLM settings. ``manual`` mode requires
|
||||
non-empty conditions because workflow retrieval expects explicit filters
|
||||
instead of a bare mode switch.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
mode: Literal["disabled", "automatic", "manual"] = "disabled"
|
||||
metadata_model_config: InnerKnowledgeRetrieveModelConfig | None = Field(default=None, alias="model_config")
|
||||
conditions: InnerKnowledgeRetrieveMetadataConditions | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode_specific_fields(self) -> InnerKnowledgeRetrieveMetadataFilteringConfig:
|
||||
if self.mode == "automatic" and self.metadata_model_config is None:
|
||||
raise ValueError("metadata_filtering.model_config is required for automatic mode")
|
||||
if self.mode == "manual" and (self.conditions is None or not self.conditions.conditions):
|
||||
raise ValueError("metadata_filtering.conditions is required for manual mode")
|
||||
return self
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveRequest(BaseModel):
|
||||
"""Top-level request payload for the inner knowledge retrieval endpoint.
|
||||
|
||||
Request validation enforces the endpoint's behavioral contract: callers
|
||||
must provide at least one dataset ID, at least one of ``query`` or
|
||||
``attachment_ids``, and a text query for ``single`` retrieval mode.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
caller: InnerKnowledgeRetrieveCaller
|
||||
dataset_ids: list[str]
|
||||
query: str | None = None
|
||||
retrieval: InnerKnowledgeRetrieveRetrievalConfig
|
||||
metadata_filtering: InnerKnowledgeRetrieveMetadataFilteringConfig = Field(
|
||||
default_factory=InnerKnowledgeRetrieveMetadataFilteringConfig
|
||||
)
|
||||
attachment_ids: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("dataset_ids", "attachment_ids")
|
||||
@classmethod
|
||||
def validate_non_empty_items(cls, value: list[str]) -> list[str]:
|
||||
if any(not item.strip() for item in value):
|
||||
raise ValueError("list items must not be empty")
|
||||
return value
|
||||
|
||||
@field_validator("query")
|
||||
@classmethod
|
||||
def normalize_query(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
return normalized or None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_request(self) -> InnerKnowledgeRetrieveRequest:
|
||||
if not self.dataset_ids:
|
||||
raise ValueError("dataset_ids must contain at least one item")
|
||||
if not self.query and not self.attachment_ids:
|
||||
raise ValueError("query or attachment_ids is required")
|
||||
if self.retrieval.mode == "single" and not self.query:
|
||||
raise ValueError("query is required for single mode")
|
||||
return self
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveUsage(ResponseModel):
|
||||
"""Serialized LLM usage payload returned by dataset retrieval."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="forbid",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_unit_price: str
|
||||
completion_unit_price: str
|
||||
prompt_price_unit: str
|
||||
completion_price_unit: str
|
||||
prompt_price: str
|
||||
completion_price: str
|
||||
total_price: str
|
||||
currency: str | None = None
|
||||
latency: float | int
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveResponse(ResponseModel):
|
||||
"""Workflow-style retrieval results plus accumulated usage."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="forbid",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
results: list[Source]
|
||||
usage: InnerKnowledgeRetrieveUsage
|
||||
49
api/services/errors/knowledge_retrieval.py
Normal file
49
api/services/errors/knowledge_retrieval.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Service errors for the inner knowledge retrieval API."""
|
||||
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class InnerKnowledgeRetrievalServiceError(BaseServiceError):
|
||||
"""Base service error with a stable HTTP mapping contract."""
|
||||
|
||||
error_code = "knowledge_retrieve_failed"
|
||||
status_code = 500
|
||||
default_description = "Knowledge retrieval failed."
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
self.description = description or self.default_description
|
||||
ValueError.__init__(self, self.description)
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveAppNotFoundError(InnerKnowledgeRetrievalServiceError):
|
||||
error_code = "app_not_found"
|
||||
status_code = 404
|
||||
default_description = "App not found."
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveAppTenantMismatchError(InnerKnowledgeRetrievalServiceError):
|
||||
error_code = "app_tenant_mismatch"
|
||||
status_code = 403
|
||||
default_description = "App does not belong to caller tenant."
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveDatasetNotFoundError(InnerKnowledgeRetrievalServiceError):
|
||||
error_code = "dataset_not_found"
|
||||
status_code = 404
|
||||
default_description = "Dataset not found."
|
||||
|
||||
|
||||
class InnerKnowledgeRetrieveDatasetTenantMismatchError(InnerKnowledgeRetrievalServiceError):
|
||||
error_code = "dataset_tenant_mismatch"
|
||||
status_code = 403
|
||||
default_description = "Dataset does not belong to caller tenant."
|
||||
|
||||
|
||||
class ExternalKnowledgeRetrievalError(ValueError):
|
||||
"""Raised when an external dataset retrieval dependency fails.
|
||||
|
||||
This stays a ``ValueError`` subclass for compatibility with existing callers
|
||||
that already treat external retrieval failures as generic retrieval errors,
|
||||
while still giving inner API controllers a dedicated error type to map to
|
||||
``502 external_knowledge_failed``.
|
||||
"""
|
||||
@ -22,6 +22,7 @@ from services.entities.external_knowledge_entities.external_knowledge_entities i
|
||||
ExternalKnowledgeApiSetting,
|
||||
)
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
from services.errors.knowledge_retrieval import ExternalKnowledgeRetrievalError
|
||||
|
||||
|
||||
class ExternalDatasetService:
|
||||
@ -309,13 +310,22 @@ class ExternalDatasetService:
|
||||
external_retrieval_parameters: dict[str, Any],
|
||||
metadata_condition: MetadataFilteringCondition | None = None,
|
||||
):
|
||||
"""Fetch retrieval records from an external knowledge provider.
|
||||
|
||||
Success requires a tenant-scoped binding plus API template and a ``200``
|
||||
response body shaped like ``{"records": [...]}``. All dependency
|
||||
failures, non-200 responses, and malformed success payloads must be
|
||||
normalized to ``ExternalKnowledgeRetrievalError`` so callers—especially
|
||||
the inner knowledge retrieval API—can consistently expose
|
||||
``502 external_knowledge_failed``.
|
||||
"""
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings)
|
||||
.where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("external knowledge binding not found")
|
||||
raise ExternalKnowledgeRetrievalError("external knowledge binding not found")
|
||||
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
@ -326,7 +336,7 @@ class ExternalDatasetService:
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None or external_knowledge_api.settings is None:
|
||||
raise ValueError("external api template not found")
|
||||
raise ExternalKnowledgeRetrievalError("external api template not found")
|
||||
|
||||
settings = json.loads(external_knowledge_api.settings)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
@ -344,16 +354,34 @@ class ExternalDatasetService:
|
||||
"metadata_condition": metadata_condition.model_dump() if metadata_condition else None,
|
||||
}
|
||||
|
||||
response = ExternalDatasetService.process_external_api(
|
||||
ExternalKnowledgeApiSetting(
|
||||
url=f"{settings.get('endpoint')}/retrieval",
|
||||
request_method="post",
|
||||
headers=headers,
|
||||
params=request_params,
|
||||
),
|
||||
None,
|
||||
)
|
||||
try:
|
||||
response = ExternalDatasetService.process_external_api(
|
||||
ExternalKnowledgeApiSetting(
|
||||
url=f"{settings.get('endpoint')}/retrieval",
|
||||
request_method="post",
|
||||
headers=headers,
|
||||
params=request_params,
|
||||
),
|
||||
None,
|
||||
)
|
||||
except ExternalKnowledgeRetrievalError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise ExternalKnowledgeRetrievalError(str(exc)) from exc
|
||||
|
||||
if response.status_code == 200:
|
||||
return cast(list[Any], response.json().get("records", []))
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except Exception as exc:
|
||||
raise ExternalKnowledgeRetrievalError("invalid external knowledge response") from exc
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
raise ExternalKnowledgeRetrievalError("invalid external knowledge response")
|
||||
|
||||
records = response_payload.get("records", [])
|
||||
if not isinstance(records, list):
|
||||
raise ExternalKnowledgeRetrievalError("invalid external knowledge response")
|
||||
|
||||
return cast(list[Any], records)
|
||||
else:
|
||||
raise ValueError(response.text)
|
||||
raise ExternalKnowledgeRetrievalError(response.text)
|
||||
|
||||
145
api/services/knowledge_retrieval_inner_service.py
Normal file
145
api/services/knowledge_retrieval_inner_service.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Service wrapper for the inner knowledge retrieval API.
|
||||
|
||||
This service keeps the internal HTTP contract small while reusing the workflow
|
||||
retrieval stack in ``core.rag.retrieval.dataset_retrieval.DatasetRetrieval``.
|
||||
The only authorization enforced here is tenant ownership of the caller app and
|
||||
requested datasets.
|
||||
|
||||
It intentionally does not check ``dataset.enable_api`` or user-level dataset
|
||||
permissions. After the caller app and requested datasets pass tenant-scoped
|
||||
prechecks, dataset availability and "no usable document" cases are delegated to
|
||||
``DatasetRetrieval`` and may legitimately produce an empty result list instead
|
||||
of a separate validation error.
|
||||
"""
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.nodes.llm.entities import ModelConfig
|
||||
from models.dataset import Dataset
|
||||
from models.model import App
|
||||
from services.entities.knowledge_retrieval_inner import (
|
||||
InnerKnowledgeRetrieveRequest,
|
||||
InnerKnowledgeRetrieveResponse,
|
||||
InnerKnowledgeRetrieveUsage,
|
||||
)
|
||||
from services.errors.knowledge_retrieval import (
|
||||
InnerKnowledgeRetrieveAppNotFoundError,
|
||||
InnerKnowledgeRetrieveAppTenantMismatchError,
|
||||
InnerKnowledgeRetrieveDatasetNotFoundError,
|
||||
InnerKnowledgeRetrieveDatasetTenantMismatchError,
|
||||
)
|
||||
|
||||
|
||||
class InnerKnowledgeRetrievalService:
|
||||
"""Validate inner caller scope and delegate to workflow dataset retrieval."""
|
||||
|
||||
def retrieve(self, request: InnerKnowledgeRetrieveRequest) -> InnerKnowledgeRetrieveResponse:
|
||||
"""Run tenant-scoped retrieval for a trusted internal caller.
|
||||
|
||||
This method only rejects caller app existence/tenant mismatches and
|
||||
requested dataset existence/tenant mismatches. It deliberately leaves
|
||||
``dataset.enable_api``, user-level dataset permissions, and
|
||||
availability/no-usable-document handling to ``DatasetRetrieval`` so the
|
||||
inner API stays aligned with workflow retrieval semantics, including
|
||||
returning ``[]`` when datasets are present but yield no retrievable
|
||||
content.
|
||||
|
||||
Raises:
|
||||
InnerKnowledgeRetrieveAppNotFoundError: The caller app does not exist.
|
||||
InnerKnowledgeRetrieveAppTenantMismatchError: The caller app is outside the caller tenant.
|
||||
InnerKnowledgeRetrieveDatasetNotFoundError: At least one requested dataset does not exist.
|
||||
InnerKnowledgeRetrieveDatasetTenantMismatchError:
|
||||
At least one requested dataset is outside the caller tenant.
|
||||
"""
|
||||
self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id)
|
||||
self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids)
|
||||
|
||||
rag = DatasetRetrieval()
|
||||
results = rag.knowledge_retrieval(request=self._to_rag_request(request))
|
||||
return InnerKnowledgeRetrieveResponse(
|
||||
results=results,
|
||||
usage=InnerKnowledgeRetrieveUsage.model_validate(jsonable_encoder(rag.llm_usage)),
|
||||
)
|
||||
|
||||
def _validate_caller_app(self, *, tenant_id: str, app_id: str) -> None:
|
||||
app = db.session.scalar(select(App).where(App.id == app_id).limit(1))
|
||||
if app is None:
|
||||
raise InnerKnowledgeRetrieveAppNotFoundError(f"App '{app_id}' not found")
|
||||
if app.tenant_id != tenant_id:
|
||||
raise InnerKnowledgeRetrieveAppTenantMismatchError(
|
||||
f"App '{app_id}' does not belong to tenant '{tenant_id}'"
|
||||
)
|
||||
|
||||
def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str]) -> None:
|
||||
datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
|
||||
|
||||
found_ids = {dataset.id for dataset in datasets}
|
||||
missing_ids = sorted(set(dataset_ids) - found_ids)
|
||||
if missing_ids:
|
||||
raise InnerKnowledgeRetrieveDatasetNotFoundError(f"Datasets not found: {', '.join(missing_ids)}")
|
||||
|
||||
mismatched_ids = sorted(dataset.id for dataset in datasets if dataset.tenant_id != tenant_id)
|
||||
if mismatched_ids:
|
||||
raise InnerKnowledgeRetrieveDatasetTenantMismatchError(
|
||||
f"Datasets do not belong to tenant '{tenant_id}': {', '.join(mismatched_ids)}"
|
||||
)
|
||||
|
||||
def _to_rag_request(self, request: InnerKnowledgeRetrieveRequest) -> KnowledgeRetrievalRequest:
|
||||
metadata_model_config = request.metadata_filtering.metadata_model_config
|
||||
metadata_conditions = request.metadata_filtering.conditions
|
||||
|
||||
return KnowledgeRetrievalRequest(
|
||||
tenant_id=request.caller.tenant_id,
|
||||
user_id=request.caller.user_id,
|
||||
app_id=request.caller.app_id,
|
||||
user_from=request.caller.user_from,
|
||||
dataset_ids=request.dataset_ids,
|
||||
query=request.query,
|
||||
retrieval_mode=request.retrieval.mode,
|
||||
model_provider=request.retrieval.model.provider if request.retrieval.model else None,
|
||||
completion_params=request.retrieval.model.completion_params if request.retrieval.model else None,
|
||||
model_mode=request.retrieval.model.mode if request.retrieval.model else None,
|
||||
model_name=request.retrieval.model.name if request.retrieval.model else None,
|
||||
metadata_model_config=ModelConfig.model_validate(metadata_model_config.model_dump(mode="python"))
|
||||
if metadata_model_config
|
||||
else None,
|
||||
metadata_filtering_conditions=(
|
||||
MetadataFilteringCondition(
|
||||
logical_operator=metadata_conditions.logical_operator,
|
||||
conditions=(
|
||||
[
|
||||
Condition(
|
||||
name=condition.name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=condition.value,
|
||||
)
|
||||
for condition in metadata_conditions.conditions
|
||||
]
|
||||
if metadata_conditions.conditions is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
if metadata_conditions is not None
|
||||
else None
|
||||
),
|
||||
metadata_filtering_mode=request.metadata_filtering.mode,
|
||||
top_k=request.retrieval.top_k or 0,
|
||||
score_threshold=request.retrieval.score_threshold,
|
||||
reranking_mode=request.retrieval.reranking_mode,
|
||||
reranking_model=(
|
||||
{
|
||||
"reranking_provider_name": request.retrieval.reranking_model.provider,
|
||||
"reranking_model_name": request.retrieval.reranking_model.model,
|
||||
}
|
||||
if request.retrieval.reranking_model is not None
|
||||
else None
|
||||
),
|
||||
weights=request.retrieval.weights,
|
||||
reranking_enable=request.retrieval.reranking_enable,
|
||||
attachment_ids=request.attachment_ids or None,
|
||||
)
|
||||
@ -15,6 +15,7 @@ from dify_agent.layers.dify_plugin import (
|
||||
DifyPluginToolsLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig
|
||||
from dify_agent.layers.knowledge import DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID, DifyKnowledgeBaseLayerConfig
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID
|
||||
from dify_agent.layers.shell import DIFY_SHELL_LAYER_TYPE_ID, DifyShellEnvVarConfig, DifyShellLayerConfig
|
||||
from dify_agent.protocol import (
|
||||
@ -28,6 +29,7 @@ from pydantic import ValidationError
|
||||
from clients.agent_backend import (
|
||||
AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
DIFY_KNOWLEDGE_BASE_LAYER_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
@ -155,6 +157,25 @@ def test_request_builder_adds_dify_plugin_tools_layer_when_configured():
|
||||
assert tools_config.tools[0].tool_name == "current_time"
|
||||
|
||||
|
||||
def test_request_builder_adds_knowledge_layer_when_configured():
|
||||
run_input = _run_input()
|
||||
run_input.knowledge = DifyKnowledgeBaseLayerConfig.model_validate(
|
||||
{
|
||||
"dataset_ids": ["dataset-1"],
|
||||
"retrieval": {"mode": "multiple", "top_k": 4},
|
||||
}
|
||||
)
|
||||
|
||||
request = AgentBackendRunRequestBuilder().build_for_workflow_node(run_input)
|
||||
layers = {layer.name: layer for layer in request.composition.layers}
|
||||
|
||||
assert DIFY_KNOWLEDGE_BASE_LAYER_ID in layers
|
||||
assert layers[DIFY_KNOWLEDGE_BASE_LAYER_ID].type == DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID
|
||||
assert layers[DIFY_KNOWLEDGE_BASE_LAYER_ID].deps == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID}
|
||||
knowledge_config = cast(DifyKnowledgeBaseLayerConfig, layers[DIFY_KNOWLEDGE_BASE_LAYER_ID].config)
|
||||
assert knowledge_config.dataset_ids == ["dataset-1"]
|
||||
|
||||
|
||||
def test_request_builder_can_delete_on_exit_for_cleanup_paths():
|
||||
run_input = _run_input()
|
||||
run_input.suspend_on_exit = False
|
||||
@ -329,6 +350,25 @@ def test_agent_app_request_builder_adds_shell_layer_when_include_shell():
|
||||
assert shell_config.env[0].name == "APP_ENV"
|
||||
|
||||
|
||||
def test_agent_app_request_builder_adds_knowledge_layer_when_configured():
|
||||
run_input = _agent_app_input()
|
||||
run_input.knowledge = DifyKnowledgeBaseLayerConfig.model_validate(
|
||||
{
|
||||
"dataset_ids": ["dataset-1", "dataset-2"],
|
||||
"retrieval": {"mode": "multiple", "top_k": 2},
|
||||
}
|
||||
)
|
||||
|
||||
request = AgentBackendRunRequestBuilder().build_for_agent_app(run_input)
|
||||
layers = {layer.name: layer for layer in request.composition.layers}
|
||||
|
||||
assert DIFY_KNOWLEDGE_BASE_LAYER_ID in layers
|
||||
assert layers[DIFY_KNOWLEDGE_BASE_LAYER_ID].type == DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID
|
||||
assert layers[DIFY_KNOWLEDGE_BASE_LAYER_ID].deps == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID}
|
||||
knowledge_config = cast(DifyKnowledgeBaseLayerConfig, layers[DIFY_KNOWLEDGE_BASE_LAYER_ID].config)
|
||||
assert knowledge_config.dataset_ids == ["dataset-1", "dataset-2"]
|
||||
|
||||
|
||||
# ── ENG-635 / ENG-638: ask_human layer injection + deferred_tool_results ─────
|
||||
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from controllers.inner_api.wraps import (
|
||||
billing_inner_api_only,
|
||||
enterprise_inner_api_only,
|
||||
enterprise_inner_api_user_auth,
|
||||
inner_api_only,
|
||||
plugin_inner_api_only,
|
||||
)
|
||||
from models.model import EndUser
|
||||
@ -154,6 +155,57 @@ class TestEnterpriseInnerApiOnly:
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
|
||||
class TestInnerApiOnly:
|
||||
"""Test inner_api_only decorator."""
|
||||
|
||||
def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask):
|
||||
@inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
result = protected_view()
|
||||
|
||||
assert result == "success"
|
||||
|
||||
def test_should_return_404_when_inner_api_disabled(self, app: Flask):
|
||||
@inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
with app.test_request_context():
|
||||
with patch.object(dify_config, "INNER_API", False):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_should_return_401_when_api_key_missing(self, app: Flask):
|
||||
@inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
with app.test_request_context(headers={}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
def test_should_return_401_when_api_key_invalid(self, app: Flask):
|
||||
@inner_api_only
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}):
|
||||
with patch.object(dify_config, "INNER_API", True):
|
||||
with patch.object(dify_config, "INNER_API_KEY", "valid_key"):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
protected_view()
|
||||
assert exc_info.value.code == 401
|
||||
|
||||
|
||||
class TestEnterpriseInnerApiUserAuth:
|
||||
"""Test enterprise_inner_api_user_auth decorator for HMAC-based user authentication"""
|
||||
|
||||
|
||||
@ -0,0 +1,233 @@
|
||||
"""Unit tests for the inner knowledge retrieval controller."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.inner_api import bp as inner_api_bp
|
||||
from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import Source, SourceMetadata
|
||||
from services.entities.knowledge_retrieval_inner import InnerKnowledgeRetrieveResponse, InnerKnowledgeRetrieveUsage
|
||||
from services.errors.knowledge_retrieval import (
|
||||
ExternalKnowledgeRetrievalError,
|
||||
InnerKnowledgeRetrieveAppNotFoundError,
|
||||
InnerKnowledgeRetrieveDatasetTenantMismatchError,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def inner_api_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(inner_api_bp)
|
||||
return app
|
||||
|
||||
|
||||
def _headers(api_key: str | None = "inner-key") -> dict[str, str]:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key is not None:
|
||||
headers["X-Inner-Api-Key"] = api_key
|
||||
return headers
|
||||
|
||||
|
||||
def _payload() -> dict[str, object]:
|
||||
return {
|
||||
"caller": {
|
||||
"tenant_id": "tenant-1",
|
||||
"user_id": "user-1",
|
||||
"app_id": "app-1",
|
||||
"user_from": "account",
|
||||
"invoke_from": "workflow",
|
||||
},
|
||||
"dataset_ids": ["dataset-1"],
|
||||
"query": "reset password",
|
||||
"retrieval": {
|
||||
"mode": "multiple",
|
||||
"top_k": 4,
|
||||
},
|
||||
"metadata_filtering": {
|
||||
"mode": "disabled",
|
||||
},
|
||||
"attachment_ids": [],
|
||||
}
|
||||
|
||||
|
||||
class TestInnerKnowledgeRetrieveApi:
|
||||
def test_post_returns_401_when_api_key_missing(self, inner_api_app: Flask):
|
||||
with patch("configs.dify_config.INNER_API", True):
|
||||
response = inner_api_app.test_client().post(
|
||||
"/inner/api/knowledge/retrieve",
|
||||
json=_payload(),
|
||||
headers=_headers(api_key=None),
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.get_json()["code"] == "inner_api_unauthorized"
|
||||
|
||||
def test_post_returns_401_when_api_key_invalid(self, inner_api_app: Flask):
|
||||
with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"):
|
||||
response = inner_api_app.test_client().post(
|
||||
"/inner/api/knowledge/retrieve",
|
||||
json=_payload(),
|
||||
headers=_headers(api_key="wrong-key"),
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
assert response.get_json()["code"] == "inner_api_unauthorized"
|
||||
|
||||
def test_post_returns_400_for_invalid_body(self, inner_api_app: Flask):
|
||||
with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"):
|
||||
response = inner_api_app.test_client().post(
|
||||
"/inner/api/knowledge/retrieve",
|
||||
json={"caller": {"tenant_id": "tenant-1"}},
|
||||
headers=_headers(),
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["code"] == "invalid_request"
|
||||
|
||||
@patch("controllers.inner_api.knowledge.retrieval.InnerKnowledgeRetrievalService.retrieve")
|
||||
def test_post_returns_404_for_service_not_found_error(self, mock_retrieve, inner_api_app: Flask):
|
||||
mock_retrieve.side_effect = InnerKnowledgeRetrieveAppNotFoundError("app missing")
|
||||
|
||||
with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"):
|
||||
response = inner_api_app.test_client().post(
|
||||
"/inner/api/knowledge/retrieve",
|
||||
json=_payload(),
|
||||
headers=_headers(),
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.get_json()["code"] == "app_not_found"
|
||||
|
||||
@patch("controllers.inner_api.knowledge.retrieval.InnerKnowledgeRetrievalService.retrieve")
|
||||
def test_post_returns_403_for_service_forbidden_error(self, mock_retrieve, inner_api_app: Flask):
|
||||
mock_retrieve.side_effect = InnerKnowledgeRetrieveDatasetTenantMismatchError("wrong tenant")
|
||||
|
||||
with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"):
|
||||
response = inner_api_app.test_client().post(
|
||||
"/inner/api/knowledge/retrieve",
|
||||
json=_payload(),
|
||||
headers=_headers(),
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.get_json()["code"] == "dataset_tenant_mismatch"
|
||||
|
||||
@patch("controllers.inner_api.knowledge.retrieval.InnerKnowledgeRetrievalService.retrieve")
|
||||
def test_post_returns_422_for_retrieval_config_value_error(self, mock_retrieve, inner_api_app: Flask):
|
||||
mock_retrieve.side_effect = ValueError("invalid reranking config")
|
||||
|
||||
with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"):
|
||||
response = inner_api_app.test_client().post(
|
||||
"/inner/api/knowledge/retrieve",
|
||||
json=_payload(),
|
||||
headers=_headers(),
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
assert response.get_json()["code"] == "retrieval_config_invalid"
|
||||
|
||||
@patch("controllers.inner_api.knowledge.retrieval.InnerKnowledgeRetrievalService.retrieve")
|
||||
def test_post_returns_429_for_rate_limit_error(self, mock_retrieve, inner_api_app: Flask):
|
||||
mock_retrieve.side_effect = RateLimitExceededError("knowledge rate limited")
|
||||
|
||||
with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"):
|
||||
response = inner_api_app.test_client().post(
|
||||
"/inner/api/knowledge/retrieve",
|
||||
json=_payload(),
|
||||
headers=_headers(),
|
||||
)
|
||||
|
||||
assert response.status_code == 429
|
||||
assert response.get_json()["code"] == "knowledge_rate_limited"
|
||||
|
||||
def test_post_returns_400_for_manual_metadata_without_conditions(self, inner_api_app: Flask):
|
||||
payload = _payload()
|
||||
payload["metadata_filtering"] = {"mode": "manual"}
|
||||
|
||||
with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"):
|
||||
response = inner_api_app.test_client().post(
|
||||
"/inner/api/knowledge/retrieve",
|
||||
json=payload,
|
||||
headers=_headers(),
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["code"] == "invalid_request"
|
||||
|
||||
def test_post_returns_400_for_automatic_metadata_without_model_config(self, inner_api_app: Flask):
|
||||
payload = _payload()
|
||||
payload["metadata_filtering"] = {"mode": "automatic"}
|
||||
|
||||
with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"):
|
||||
response = inner_api_app.test_client().post(
|
||||
"/inner/api/knowledge/retrieve",
|
||||
json=payload,
|
||||
headers=_headers(),
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.get_json()["code"] == "invalid_request"
|
||||
|
||||
@patch("controllers.inner_api.knowledge.retrieval.InnerKnowledgeRetrievalService.retrieve")
|
||||
def test_post_returns_502_for_external_knowledge_failure(self, mock_retrieve, inner_api_app: Flask):
|
||||
mock_retrieve.side_effect = ExternalKnowledgeRetrievalError("upstream failed")
|
||||
|
||||
with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"):
|
||||
response = inner_api_app.test_client().post(
|
||||
"/inner/api/knowledge/retrieve",
|
||||
json=_payload(),
|
||||
headers=_headers(),
|
||||
)
|
||||
|
||||
assert response.status_code == 502
|
||||
assert response.get_json()["code"] == "external_knowledge_failed"
|
||||
|
||||
@patch("controllers.inner_api.knowledge.retrieval.InnerKnowledgeRetrievalService.retrieve")
|
||||
def test_post_returns_service_response(self, mock_retrieve, inner_api_app: Flask):
|
||||
mock_retrieve.return_value = InnerKnowledgeRetrieveResponse(
|
||||
results=[
|
||||
Source(
|
||||
metadata=SourceMetadata(
|
||||
dataset_id="dataset-1",
|
||||
dataset_name="Docs",
|
||||
document_id="document-1",
|
||||
document_name="FAQ.md",
|
||||
data_source_type="upload_file",
|
||||
),
|
||||
title="FAQ.md",
|
||||
files=[],
|
||||
content="Reset your password from settings.",
|
||||
summary=None,
|
||||
)
|
||||
],
|
||||
usage=InnerKnowledgeRetrieveUsage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
prompt_unit_price="0",
|
||||
completion_unit_price="0",
|
||||
prompt_price_unit="0.001",
|
||||
completion_price_unit="0.001",
|
||||
prompt_price="0",
|
||||
completion_price="0",
|
||||
total_price="0",
|
||||
currency="USD",
|
||||
latency=0,
|
||||
),
|
||||
)
|
||||
|
||||
with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"):
|
||||
response = inner_api_app.test_client().post(
|
||||
"/inner/api/knowledge/retrieve",
|
||||
json=_payload(),
|
||||
headers=_headers(),
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert data["results"][0]["metadata"]["_source"] == "knowledge"
|
||||
assert data["results"][0]["title"] == "FAQ.md"
|
||||
assert data["usage"]["total_tokens"] == 0
|
||||
@ -144,6 +144,40 @@ class TestAgentAppRuntimeRequestBuilder:
|
||||
assert result.redacted_request["composition"]["layers"][-1]["config"]["credentials"] == "[REDACTED]"
|
||||
assert result.metadata["conversation_id"] == "conv-1"
|
||||
|
||||
def test_build_maps_agent_soul_knowledge_to_knowledge_layer(self):
|
||||
soul = AgentSoulConfig.model_validate(
|
||||
{
|
||||
"model": {
|
||||
"plugin_id": "langgenius/openai",
|
||||
"model_provider": "langgenius/openai/openai",
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
"knowledge": {
|
||||
"datasets": [{"id": "dataset-1"}, {"id": "dataset-2"}],
|
||||
"query_config": {
|
||||
"top_k": 3,
|
||||
"score_threshold": 0.5,
|
||||
"score_threshold_enabled": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
builder = AgentAppRuntimeRequestBuilder(
|
||||
credentials_provider=_FakeCredentialsProvider(),
|
||||
plugin_tools_builder=_NoToolsBuilder(), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
result = builder.build(_ctx(soul))
|
||||
|
||||
knowledge = next(layer for layer in result.request.composition.layers if layer.name == "knowledge")
|
||||
assert knowledge.type == "dify.knowledge_base"
|
||||
assert knowledge.deps == {"execution_context": "execution_context"}
|
||||
dumped_config = knowledge.config.model_dump(mode="json", by_alias=True)
|
||||
assert dumped_config["dataset_ids"] == ["dataset-1", "dataset-2"]
|
||||
assert dumped_config["retrieval"]["mode"] == "multiple"
|
||||
assert dumped_config["retrieval"]["top_k"] == 3
|
||||
assert dumped_config["retrieval"]["score_threshold"] == 0.0
|
||||
|
||||
def test_build_raises_when_model_missing(self):
|
||||
builder = AgentAppRuntimeRequestBuilder(
|
||||
credentials_provider=_FakeCredentialsProvider(),
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
"""Focused tests for attachment-aware dataset retrieval entry behavior."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
|
||||
|
||||
|
||||
def test_knowledge_retrieval_allows_attachment_only_requests() -> None:
|
||||
retrieval = DatasetRetrieval()
|
||||
available_dataset = MagicMock(id="dataset-1")
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
user_from="account",
|
||||
dataset_ids=["dataset-1"],
|
||||
query=None,
|
||||
retrieval_mode="multiple",
|
||||
top_k=4,
|
||||
score_threshold=0.0,
|
||||
reranking_mode="reranking_model",
|
||||
reranking_enable=True,
|
||||
attachment_ids=["attachment-1"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(retrieval, "_check_knowledge_rate_limit"),
|
||||
patch.object(retrieval, "_get_available_datasets", return_value=[available_dataset]),
|
||||
patch.object(retrieval, "multiple_retrieve", return_value=[]) as mock_multiple,
|
||||
):
|
||||
result = retrieval.knowledge_retrieval(request)
|
||||
|
||||
assert result == []
|
||||
mock_multiple.assert_called_once()
|
||||
@ -496,6 +496,119 @@ def test_builds_workflow_run_request_with_dify_plugin_tools_layer():
|
||||
assert plugin_tools_builder.last_invoke_from == context.dify_context.invoke_from
|
||||
|
||||
|
||||
def test_build_maps_agent_soul_knowledge_to_knowledge_layer_config():
|
||||
context = _context()
|
||||
snapshot = AgentConfigSnapshot(
|
||||
id="snapshot-1",
|
||||
tenant_id="tenant-1",
|
||||
agent_id="agent-1",
|
||||
version=1,
|
||||
config_snapshot=AgentSoulConfig.model_validate(
|
||||
{
|
||||
"prompt": {"system_prompt": "You are careful."},
|
||||
"model": {
|
||||
"plugin_id": "langgenius/openai",
|
||||
"model_provider": "openai",
|
||||
"model": "gpt-test",
|
||||
},
|
||||
"knowledge": {
|
||||
"datasets": [{"id": "dataset-1"}, {"id": " "}, {"id": "dataset-2"}],
|
||||
"query_config": {
|
||||
"top_k": 6,
|
||||
"score_threshold": 0.4,
|
||||
"score_threshold_enabled": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
context = replace(context, snapshot=snapshot)
|
||||
|
||||
result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()).build(context)
|
||||
|
||||
dumped = result.request.model_dump(mode="json")
|
||||
layers = {layer["name"]: layer for layer in dumped["composition"]["layers"]}
|
||||
knowledge_layer = layers["knowledge"]
|
||||
assert knowledge_layer["type"] == "dify.knowledge_base"
|
||||
assert knowledge_layer["deps"] == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID}
|
||||
assert knowledge_layer["config"] == {
|
||||
"dataset_ids": ["dataset-1", "dataset-2"],
|
||||
"retrieval": {
|
||||
"mode": "multiple",
|
||||
"top_k": 6,
|
||||
"score_threshold": 0.4,
|
||||
"reranking_mode": "reranking_model",
|
||||
"reranking_enable": True,
|
||||
"reranking_model": None,
|
||||
"weights": None,
|
||||
"model": None,
|
||||
},
|
||||
"metadata_filtering": {"mode": "disabled", "metadata_model_config": None, "conditions": None},
|
||||
"max_result_content_chars": 2000,
|
||||
"max_observation_chars": 12000,
|
||||
}
|
||||
|
||||
|
||||
def test_build_knowledge_layer_uses_stable_default_top_k_when_query_config_omits_it():
|
||||
context = _context()
|
||||
snapshot = AgentConfigSnapshot(
|
||||
id="snapshot-1",
|
||||
tenant_id="tenant-1",
|
||||
agent_id="agent-1",
|
||||
version=1,
|
||||
config_snapshot=AgentSoulConfig.model_validate(
|
||||
{
|
||||
"prompt": {"system_prompt": "You are careful."},
|
||||
"model": {
|
||||
"plugin_id": "langgenius/openai",
|
||||
"model_provider": "openai",
|
||||
"model": "gpt-test",
|
||||
},
|
||||
"knowledge": {
|
||||
"datasets": [{"id": "dataset-1"}],
|
||||
"query_config": {},
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
context = replace(context, snapshot=snapshot)
|
||||
|
||||
result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()).build(context)
|
||||
|
||||
dumped = result.request.model_dump(mode="json")
|
||||
knowledge_layer = next(layer for layer in dumped["composition"]["layers"] if layer["name"] == "knowledge")
|
||||
assert knowledge_layer["config"]["retrieval"]["top_k"] == 4
|
||||
|
||||
|
||||
def test_build_skips_knowledge_layer_when_agent_soul_has_no_valid_dataset_ids():
|
||||
context = _context()
|
||||
snapshot = AgentConfigSnapshot(
|
||||
id="snapshot-1",
|
||||
tenant_id="tenant-1",
|
||||
agent_id="agent-1",
|
||||
version=1,
|
||||
config_snapshot=AgentSoulConfig.model_validate(
|
||||
{
|
||||
"prompt": {"system_prompt": "You are careful."},
|
||||
"model": {
|
||||
"plugin_id": "langgenius/openai",
|
||||
"model_provider": "openai",
|
||||
"model": "gpt-test",
|
||||
},
|
||||
"knowledge": {
|
||||
"datasets": [{"id": " "}, {}],
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
context = replace(context, snapshot=snapshot)
|
||||
|
||||
result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()).build(context)
|
||||
|
||||
dumped = result.request.model_dump(mode="json")
|
||||
assert all(layer["name"] != "knowledge" for layer in dumped["composition"]["layers"])
|
||||
|
||||
|
||||
def test_build_passes_saved_session_snapshot_to_agent_backend_request():
|
||||
session_snapshot = CompositorSessionSnapshot(layers=[])
|
||||
context = replace(_context(), session_snapshot=session_snapshot)
|
||||
@ -868,3 +981,36 @@ def test_feature_manifest_marks_human_supported_when_configured():
|
||||
assert manifest["reserved_status"]["human"] == "supported_by_ask_human_hitl"
|
||||
# configured human no longer produces a "not executed" warning
|
||||
assert all("human" not in w["section"] for w in manifest["unsupported_runtime_warnings"])
|
||||
|
||||
|
||||
def test_feature_manifest_marks_knowledge_supported_without_warning_when_configured():
|
||||
from core.workflow.nodes.agent_v2.runtime_feature_manifest import build_runtime_feature_manifest
|
||||
|
||||
soul = AgentSoulConfig.model_validate(
|
||||
{
|
||||
"knowledge": {
|
||||
"datasets": [{"id": "dataset-1", "name": "Product Docs"}],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
manifest = build_runtime_feature_manifest(soul)
|
||||
assert "knowledge" in manifest["supported"]
|
||||
assert "knowledge" not in manifest["reserved"]
|
||||
assert manifest["reserved_status"]["knowledge"] == "supported_by_knowledge_layer"
|
||||
assert all("knowledge" not in w["section"] for w in manifest["unsupported_runtime_warnings"])
|
||||
|
||||
|
||||
def test_feature_manifest_treats_blank_knowledge_dataset_ids_as_not_configured():
|
||||
from core.workflow.nodes.agent_v2.runtime_feature_manifest import build_runtime_feature_manifest
|
||||
|
||||
soul = AgentSoulConfig.model_validate(
|
||||
{
|
||||
"knowledge": {
|
||||
"datasets": [{"id": " "}, {}],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
manifest = build_runtime_feature_manifest(soul)
|
||||
assert manifest["reserved_status"]["knowledge"] == "not_configured"
|
||||
|
||||
@ -21,6 +21,7 @@ from services.entities.external_knowledge_entities.external_knowledge_entities i
|
||||
ExternalKnowledgeApiSetting,
|
||||
)
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
from services.errors.knowledge_retrieval import ExternalKnowledgeRetrievalError
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
|
||||
@ -1558,7 +1559,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
with pytest.raises(ExternalKnowledgeRetrievalError, match="external knowledge binding not found"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval("tenant-123", "dataset-123", "query", {})
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
@ -1569,7 +1570,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
mock_db.session.scalar.side_effect = [binding, None]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="external api template not found"):
|
||||
with pytest.raises(ExternalKnowledgeRetrievalError, match="external api template not found"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval("tenant-123", "dataset-123", "query", {})
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@ -1643,7 +1644,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
mock_process.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Internal Server Error: Database connection failed"):
|
||||
with pytest.raises(ExternalKnowledgeRetrievalError, match="Internal Server Error: Database connection failed"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
"tenant-123", "dataset-123", "query", {"top_k": 5}
|
||||
)
|
||||
@ -1684,7 +1685,7 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
mock_process.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match=re.escape(error_message)):
|
||||
with pytest.raises(ExternalKnowledgeRetrievalError, match=re.escape(error_message)):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(tenant_id, dataset_id, "query", {"top_k": 5})
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@ -1703,7 +1704,79 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
mock_process.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ExternalKnowledgeRetrievalError):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
"tenant-123", "dataset-123", "query", {"top_k": 5}
|
||||
)
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_invalid_json_response(self, mock_db, mock_process, factory):
|
||||
"""Test malformed JSON success responses are normalized to external retrieval errors."""
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
|
||||
mock_process.return_value = mock_response
|
||||
|
||||
with pytest.raises(ExternalKnowledgeRetrievalError, match="invalid external knowledge response"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
"tenant-123", "dataset-123", "query", {"top_k": 5}
|
||||
)
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_invalid_success_payload_shape(self, mock_db, mock_process, factory):
|
||||
"""Test malformed success payload shapes are normalized to external retrieval errors."""
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = ["not-a-dict"]
|
||||
mock_process.return_value = mock_response
|
||||
|
||||
with pytest.raises(ExternalKnowledgeRetrievalError, match="invalid external knowledge response"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
"tenant-123", "dataset-123", "query", {"top_k": 5}
|
||||
)
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_invalid_records_shape(self, mock_db, mock_process, factory):
|
||||
"""Test non-list records payloads are normalized to external retrieval errors."""
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"records": {"unexpected": "shape"}}
|
||||
mock_process.return_value = mock_response
|
||||
|
||||
with pytest.raises(ExternalKnowledgeRetrievalError, match="invalid external knowledge response"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
"tenant-123", "dataset-123", "query", {"top_k": 5}
|
||||
)
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_wraps_transport_errors(self, mock_db, mock_process, factory):
|
||||
"""Test transport/runtime failures are normalized to external retrieval errors."""
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
api = factory.create_external_knowledge_api_mock()
|
||||
|
||||
mock_db.session.scalar.side_effect = [binding, api]
|
||||
mock_process.side_effect = RuntimeError("connection reset by peer")
|
||||
|
||||
with pytest.raises(ExternalKnowledgeRetrievalError, match="connection reset by peer"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
"tenant-123", "dataset-123", "query", {"top_k": 5}
|
||||
)
|
||||
|
||||
@ -0,0 +1,218 @@
|
||||
"""Unit tests for the inner knowledge retrieval service."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import Source, SourceMetadata
|
||||
from services.entities.knowledge_retrieval_inner import InnerKnowledgeRetrieveRequest
|
||||
from services.errors.knowledge_retrieval import (
|
||||
InnerKnowledgeRetrieveAppNotFoundError,
|
||||
InnerKnowledgeRetrieveAppTenantMismatchError,
|
||||
InnerKnowledgeRetrieveDatasetNotFoundError,
|
||||
InnerKnowledgeRetrieveDatasetTenantMismatchError,
|
||||
)
|
||||
from services.knowledge_retrieval_inner_service import InnerKnowledgeRetrievalService
|
||||
|
||||
|
||||
def _build_request(**overrides):
|
||||
payload = {
|
||||
"caller": {
|
||||
"tenant_id": "tenant-1",
|
||||
"user_id": "user-1",
|
||||
"app_id": "app-1",
|
||||
"user_from": "account",
|
||||
"invoke_from": "workflow",
|
||||
},
|
||||
"dataset_ids": ["dataset-1", "dataset-2"],
|
||||
"query": "how to reset password",
|
||||
"retrieval": {
|
||||
"mode": "multiple",
|
||||
"top_k": 4,
|
||||
"score_threshold": 0.25,
|
||||
"reranking_mode": "reranking_model",
|
||||
"reranking_enable": True,
|
||||
"reranking_model": {
|
||||
"provider": "cohere",
|
||||
"model": "rerank-english-v3.0",
|
||||
},
|
||||
},
|
||||
"metadata_filtering": {
|
||||
"mode": "manual",
|
||||
"conditions": {
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"name": "category",
|
||||
"comparison_operator": "contains",
|
||||
"value": "pricing",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
"attachment_ids": ["attachment-1"],
|
||||
}
|
||||
payload.update(overrides)
|
||||
return InnerKnowledgeRetrieveRequest.model_validate(payload)
|
||||
|
||||
|
||||
def _build_source() -> Source:
|
||||
return Source(
|
||||
metadata=SourceMetadata(
|
||||
dataset_id="dataset-1",
|
||||
dataset_name="Docs",
|
||||
document_id="document-1",
|
||||
document_name="FAQ.md",
|
||||
data_source_type="upload_file",
|
||||
),
|
||||
title="FAQ.md",
|
||||
files=[],
|
||||
content="Reset your password from settings.",
|
||||
summary=None,
|
||||
)
|
||||
|
||||
|
||||
class TestInnerKnowledgeRetrievalService:
|
||||
@patch("services.knowledge_retrieval_inner_service.DatasetRetrieval")
|
||||
@patch("services.knowledge_retrieval_inner_service.db")
|
||||
def test_retrieve_maps_multiple_request_and_skips_enable_api_check(self, mock_db, mock_rag_cls):
|
||||
request = _build_request()
|
||||
mock_app = MagicMock(id="app-1", tenant_id="tenant-1")
|
||||
dataset_1 = MagicMock(id="dataset-1", tenant_id="tenant-1", enable_api=False)
|
||||
dataset_2 = MagicMock(id="dataset-2", tenant_id="tenant-1", enable_api=True)
|
||||
mock_db.session.scalar.return_value = mock_app
|
||||
mock_db.session.scalars.return_value.all.return_value = [dataset_1, dataset_2]
|
||||
|
||||
rag = MagicMock()
|
||||
rag.knowledge_retrieval.return_value = [_build_source()]
|
||||
rag.llm_usage = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"prompt_unit_price": "0",
|
||||
"completion_unit_price": "0",
|
||||
"prompt_price_unit": "0.001",
|
||||
"completion_price_unit": "0.001",
|
||||
"prompt_price": "0",
|
||||
"completion_price": "0",
|
||||
"total_price": "0",
|
||||
"currency": "USD",
|
||||
"latency": 0,
|
||||
}
|
||||
mock_rag_cls.return_value = rag
|
||||
|
||||
response = InnerKnowledgeRetrievalService().retrieve(request)
|
||||
|
||||
rag_request = rag.knowledge_retrieval.call_args.kwargs["request"]
|
||||
assert rag_request.tenant_id == "tenant-1"
|
||||
assert rag_request.app_id == "app-1"
|
||||
assert rag_request.user_id == "user-1"
|
||||
assert rag_request.dataset_ids == ["dataset-1", "dataset-2"]
|
||||
assert rag_request.query == "how to reset password"
|
||||
assert rag_request.retrieval_mode == "multiple"
|
||||
assert rag_request.top_k == 4
|
||||
assert rag_request.score_threshold == 0.25
|
||||
assert rag_request.reranking_model == {
|
||||
"reranking_provider_name": "cohere",
|
||||
"reranking_model_name": "rerank-english-v3.0",
|
||||
}
|
||||
assert rag_request.metadata_filtering_mode == "manual"
|
||||
assert rag_request.metadata_filtering_conditions is not None
|
||||
metadata_conditions = rag_request.metadata_filtering_conditions.model_dump(mode="python")
|
||||
assert metadata_conditions["logical_operator"] == "and"
|
||||
assert metadata_conditions["conditions"] is not None
|
||||
assert metadata_conditions["conditions"][0]["name"] == "category"
|
||||
assert rag_request.attachment_ids == ["attachment-1"]
|
||||
assert response.results[0].title == "FAQ.md"
|
||||
assert response.usage.currency == "USD"
|
||||
|
||||
@patch("services.knowledge_retrieval_inner_service.DatasetRetrieval")
|
||||
@patch("services.knowledge_retrieval_inner_service.db")
|
||||
def test_retrieve_maps_single_request(self, mock_db, mock_rag_cls):
|
||||
request = _build_request(
|
||||
dataset_ids=["dataset-1"],
|
||||
retrieval={
|
||||
"mode": "single",
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o-mini",
|
||||
"mode": "chat",
|
||||
"completion_params": {"temperature": 0},
|
||||
},
|
||||
},
|
||||
metadata_filtering={
|
||||
"mode": "automatic",
|
||||
"model_config": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o-mini",
|
||||
"mode": "chat",
|
||||
"completion_params": {"temperature": 0},
|
||||
},
|
||||
},
|
||||
attachment_ids=[],
|
||||
)
|
||||
mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
|
||||
mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")]
|
||||
|
||||
rag = MagicMock()
|
||||
rag.knowledge_retrieval.return_value = []
|
||||
rag.llm_usage = {
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
"prompt_unit_price": "0",
|
||||
"completion_unit_price": "0",
|
||||
"prompt_price_unit": "0.001",
|
||||
"completion_price_unit": "0.001",
|
||||
"prompt_price": "0",
|
||||
"completion_price": "0",
|
||||
"total_price": "0",
|
||||
"currency": "USD",
|
||||
"latency": 1,
|
||||
}
|
||||
mock_rag_cls.return_value = rag
|
||||
|
||||
InnerKnowledgeRetrievalService().retrieve(request)
|
||||
|
||||
rag_request = rag.knowledge_retrieval.call_args.kwargs["request"]
|
||||
assert rag_request.retrieval_mode == "single"
|
||||
assert rag_request.model_provider == "openai"
|
||||
assert rag_request.model_name == "gpt-4o-mini"
|
||||
assert rag_request.model_mode == "chat"
|
||||
assert rag_request.completion_params == {"temperature": 0}
|
||||
assert rag_request.metadata_filtering_mode == "automatic"
|
||||
assert rag_request.metadata_model_config is not None
|
||||
assert rag_request.metadata_model_config.provider == "openai"
|
||||
|
||||
@patch("services.knowledge_retrieval_inner_service.db")
|
||||
def test_retrieve_raises_when_app_missing(self, mock_db):
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(InnerKnowledgeRetrieveAppNotFoundError):
|
||||
InnerKnowledgeRetrievalService().retrieve(_build_request())
|
||||
|
||||
@patch("services.knowledge_retrieval_inner_service.db")
|
||||
def test_retrieve_raises_when_app_belongs_to_other_tenant(self, mock_db):
|
||||
mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-2")
|
||||
|
||||
with pytest.raises(InnerKnowledgeRetrieveAppTenantMismatchError):
|
||||
InnerKnowledgeRetrievalService().retrieve(_build_request())
|
||||
|
||||
@patch("services.knowledge_retrieval_inner_service.db")
|
||||
def test_retrieve_raises_when_dataset_missing(self, mock_db):
|
||||
mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
|
||||
mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")]
|
||||
|
||||
with pytest.raises(InnerKnowledgeRetrieveDatasetNotFoundError):
|
||||
InnerKnowledgeRetrievalService().retrieve(_build_request())
|
||||
|
||||
@patch("services.knowledge_retrieval_inner_service.db")
|
||||
def test_retrieve_raises_when_dataset_belongs_to_other_tenant(self, mock_db):
|
||||
mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1")
|
||||
mock_db.session.scalars.return_value.all.return_value = [
|
||||
MagicMock(id="dataset-1", tenant_id="tenant-1"),
|
||||
MagicMock(id="dataset-2", tenant_id="tenant-2"),
|
||||
]
|
||||
|
||||
with pytest.raises(InnerKnowledgeRetrieveDatasetTenantMismatchError):
|
||||
InnerKnowledgeRetrievalService().retrieve(_build_request())
|
||||
@ -2,7 +2,8 @@
|
||||
|
||||
Implementation layers live in sibling modules and require server-side runtime
|
||||
dependencies. Keep this package root import-safe for client code that only
|
||||
needs to build run requests.
|
||||
needs to build run requests. Knowledge layers read ``user_from`` from the same
|
||||
DTO, but that runtime implementation still lives in sibling modules.
|
||||
"""
|
||||
|
||||
from dify_agent.layers.execution_context.configs import (
|
||||
|
||||
@ -4,9 +4,11 @@ This layer carries both Dify product execution context (tenant, user, workflow,
|
||||
invoke source) and Agent backend runtime mode. The product-facing fields are
|
||||
used by trusted server-side boundaries such as the Agent Stub when they
|
||||
need to reconstruct Dify API file-access scope without granting the sandbox any
|
||||
direct inner-API credentials. Server-only plugin-daemon settings are injected
|
||||
by the runtime provider factory and therefore do not appear in this public
|
||||
schema.
|
||||
direct inner-API credentials. Knowledge-base layers also read ``user_from`` from
|
||||
this shared config so the inner Dify API can distinguish platform-user and
|
||||
end-user searches without making that caller identity model-controlled.
|
||||
Server-only plugin-daemon settings are injected by the runtime provider factory
|
||||
and therefore do not appear in this public schema.
|
||||
"""
|
||||
|
||||
from typing import ClassVar, Final, Literal, TypeAlias
|
||||
@ -42,7 +44,7 @@ class DifyExecutionContextLayerConfig(LayerConfig):
|
||||
|
||||
tenant_id: str
|
||||
user_id: str | None = None
|
||||
user_from: DifyExecutionContextUserFrom
|
||||
user_from: DifyExecutionContextUserFrom | None = None
|
||||
app_id: str | None = None
|
||||
workflow_id: str | None = None
|
||||
workflow_run_id: str | None = None
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
"""Runtime Dify execution-context layer.
|
||||
|
||||
The public config carries Dify-owned execution identifiers plus the tenant/user
|
||||
daemon context needed by plugin-backed business layers. Server-only daemon URL
|
||||
daemon context needed by plugin-backed business layers and the caller identity
|
||||
needed by knowledge-base layers. Server-only daemon URL
|
||||
and API key are injected by the provider factory. The layer is intentionally
|
||||
config/settings-only under Agenton's state-only core: it does not open, cache,
|
||||
close, or snapshot HTTP clients, and its lifecycle hooks remain the inherited
|
||||
@ -29,7 +30,7 @@ from dify_agent.layers.execution_context.configs import (
|
||||
class DifyExecutionContextLayer(PlainLayer[NoLayerDeps, DifyExecutionContextLayerConfig, EmptyRuntimeState]):
|
||||
"""Layer that carries Dify execution context without owning live resources."""
|
||||
|
||||
type_id: ClassVar[str] = DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID
|
||||
type_id: ClassVar[str | None] = DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID
|
||||
|
||||
config: DifyExecutionContextLayerConfig
|
||||
daemon_url: str
|
||||
|
||||
27
dify-agent/src/dify_agent/layers/knowledge/__init__.py
Normal file
27
dify-agent/src/dify_agent/layers/knowledge/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Client-safe exports for Dify knowledge-base layer DTOs and type ids.
|
||||
|
||||
Implementation layers and HTTP clients live in sibling modules so this package
|
||||
root stays import-safe for callers that only need to construct run requests.
|
||||
"""
|
||||
|
||||
from dify_agent.layers.knowledge.configs import (
|
||||
DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID,
|
||||
DifyKnowledgeBaseLayerConfig,
|
||||
DifyKnowledgeMetadataCondition,
|
||||
DifyKnowledgeMetadataConditions,
|
||||
DifyKnowledgeMetadataFilteringConfig,
|
||||
DifyKnowledgeModelConfig,
|
||||
DifyKnowledgeRerankingModelConfig,
|
||||
DifyKnowledgeRetrievalConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID",
|
||||
"DifyKnowledgeBaseLayerConfig",
|
||||
"DifyKnowledgeMetadataCondition",
|
||||
"DifyKnowledgeMetadataConditions",
|
||||
"DifyKnowledgeMetadataFilteringConfig",
|
||||
"DifyKnowledgeModelConfig",
|
||||
"DifyKnowledgeRerankingModelConfig",
|
||||
"DifyKnowledgeRetrievalConfig",
|
||||
]
|
||||
214
dify-agent/src/dify_agent/layers/knowledge/client.py
Normal file
214
dify-agent/src/dify_agent/layers/knowledge/client.py
Normal file
@ -0,0 +1,214 @@
|
||||
"""Async client for the Dify API inner knowledge retrieval endpoint.
|
||||
|
||||
This wrapper owns only request/response mapping and error normalization for
|
||||
``POST /inner/api/knowledge/retrieve``. The shared ``httpx.AsyncClient`` is
|
||||
supplied by the FastAPI lifespan/runtime and must stay open for the caller.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue, ValidationError
|
||||
|
||||
from dify_agent.layers.knowledge.configs import (
|
||||
DifyKnowledgeMetadataFilteringConfig,
|
||||
DifyKnowledgeRetrievalConfig,
|
||||
)
|
||||
|
||||
|
||||
class DifyKnowledgeBaseClientError(RuntimeError):
|
||||
"""Raised when the inner knowledge retrieval HTTP boundary fails."""
|
||||
|
||||
status_code: int | None
|
||||
error_code: str | None
|
||||
retryable: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
status_code: int | None = None,
|
||||
error_code: str | None = None,
|
||||
retryable: bool,
|
||||
) -> None:
|
||||
self.status_code = status_code
|
||||
self.error_code = error_code
|
||||
self.retryable = retryable
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class _DifyKnowledgeCaller(BaseModel):
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
app_id: str
|
||||
user_from: str
|
||||
invoke_from: str
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class _DifyKnowledgeRetrieveRequest(BaseModel):
|
||||
caller: _DifyKnowledgeCaller
|
||||
dataset_ids: list[str]
|
||||
query: str
|
||||
retrieval: dict[str, JsonValue]
|
||||
metadata_filtering: dict[str, JsonValue]
|
||||
attachment_ids: list[str] = Field(default_factory=list)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class DifyKnowledgeResultMetadata(BaseModel):
|
||||
source: str | None = Field(default=None, alias="_source")
|
||||
dataset_id: str | None = None
|
||||
dataset_name: str | None = None
|
||||
document_id: str | None = None
|
||||
document_name: str | None = None
|
||||
score: float | int | str | None = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow", populate_by_name=True)
|
||||
|
||||
|
||||
class DifyKnowledgeResult(BaseModel):
|
||||
metadata: DifyKnowledgeResultMetadata
|
||||
title: str | None = None
|
||||
files: list[JsonValue] | None = None
|
||||
content: str | None = None
|
||||
summary: str | None = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class DifyKnowledgeRetrieveResponse(BaseModel):
|
||||
results: list[DifyKnowledgeResult]
|
||||
usage: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DifyKnowledgeBaseClient:
|
||||
"""Boundary client for the Dify API inner knowledge retrieval endpoint."""
|
||||
|
||||
base_url: str
|
||||
api_key: str = field(repr=False)
|
||||
http_client: httpx.AsyncClient = field(repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.base_url = self.base_url.rstrip("/")
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
app_id: str,
|
||||
user_from: str,
|
||||
invoke_from: str,
|
||||
dataset_ids: list[str],
|
||||
query: str,
|
||||
retrieval: DifyKnowledgeRetrievalConfig,
|
||||
metadata_filtering: DifyKnowledgeMetadataFilteringConfig,
|
||||
) -> DifyKnowledgeRetrieveResponse:
|
||||
"""Call the inner API and return parsed retrieval results.
|
||||
|
||||
Raises:
|
||||
DifyKnowledgeBaseClientError: For HTTP, transport, or response-shape
|
||||
failures. Only ``429``, ``502``, and transport/timeout failures
|
||||
are marked retryable because the model may continue gracefully in
|
||||
those temporary-unavailable cases.
|
||||
"""
|
||||
request_payload = _DifyKnowledgeRetrieveRequest(
|
||||
caller=_DifyKnowledgeCaller(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
dataset_ids=dataset_ids,
|
||||
query=query,
|
||||
retrieval=retrieval.to_request_payload(),
|
||||
metadata_filtering=metadata_filtering.to_request_payload(),
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.http_client.post(
|
||||
f"{self.base_url}/inner/api/knowledge/retrieve",
|
||||
headers={
|
||||
"X-Inner-Api-Key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_payload.model_dump(mode="json", by_alias=True),
|
||||
)
|
||||
except (httpx.InvalidURL, httpx.UnsupportedProtocol) as exc:
|
||||
raise DifyKnowledgeBaseClientError(
|
||||
f"Knowledge base search is misconfigured: {exc}",
|
||||
retryable=False,
|
||||
) from exc
|
||||
except httpx.TimeoutException as exc:
|
||||
raise DifyKnowledgeBaseClientError(
|
||||
"Knowledge base search timed out.",
|
||||
retryable=True,
|
||||
) from exc
|
||||
except httpx.RequestError as exc:
|
||||
raise DifyKnowledgeBaseClientError(
|
||||
f"Knowledge base search request failed: {exc}",
|
||||
retryable=True,
|
||||
) from exc
|
||||
|
||||
if response.status_code >= 400:
|
||||
raise _build_http_error(response)
|
||||
|
||||
try:
|
||||
return DifyKnowledgeRetrieveResponse.model_validate_json(response.text)
|
||||
except ValidationError as exc:
|
||||
raise DifyKnowledgeBaseClientError(
|
||||
"Invalid knowledge retrieval response from Dify API.",
|
||||
status_code=response.status_code,
|
||||
error_code="invalid_response",
|
||||
retryable=False,
|
||||
) from exc
|
||||
|
||||
|
||||
def _build_http_error(response: httpx.Response) -> DifyKnowledgeBaseClientError:
|
||||
detail = _decode_error_detail(response)
|
||||
retryable = response.status_code in {429, 502}
|
||||
message = detail["message"] or f"HTTP {response.status_code}"
|
||||
return DifyKnowledgeBaseClientError(
|
||||
message,
|
||||
status_code=response.status_code,
|
||||
error_code=detail["error_code"],
|
||||
retryable=retryable,
|
||||
)
|
||||
|
||||
|
||||
def _decode_error_detail(response: httpx.Response) -> dict[str, str | None]:
|
||||
raw_body = response.text
|
||||
try:
|
||||
payload = response.json()
|
||||
except json.JSONDecodeError:
|
||||
payload = None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
error_code = payload.get("code")
|
||||
message = payload.get("message")
|
||||
return {
|
||||
"error_code": error_code if isinstance(error_code, str) else None,
|
||||
"message": message if isinstance(message, str) and message else raw_body or f"HTTP {response.status_code}",
|
||||
}
|
||||
|
||||
return {"error_code": None, "message": raw_body or f"HTTP {response.status_code}"}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DifyKnowledgeBaseClient",
|
||||
"DifyKnowledgeBaseClientError",
|
||||
"DifyKnowledgeResult",
|
||||
"DifyKnowledgeResultMetadata",
|
||||
"DifyKnowledgeRetrieveResponse",
|
||||
]
|
||||
200
dify-agent/src/dify_agent/layers/knowledge/configs.py
Normal file
200
dify-agent/src/dify_agent/layers/knowledge/configs.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""Client-safe DTOs for the Dify knowledge-base Agenton layer.
|
||||
|
||||
The public layer config exposes only static retrieval controls: dataset ids,
|
||||
retrieval strategy, metadata filtering, and observation-size limits. The agent
|
||||
model itself should only ever see a single ``query`` tool argument; tenant/
|
||||
app/user context comes from the execution-context layer and the actual
|
||||
retrieval is delegated to the Dify API inner endpoint. Tool naming is not
|
||||
caller-configurable: the runtime always exposes the same stable knowledge-base
|
||||
search tool.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, Final, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator, model_validator
|
||||
|
||||
from agenton.layers import LayerConfig
|
||||
|
||||
type DifyKnowledgeMetadataComparisonOperator = Literal[
|
||||
"contains",
|
||||
"not contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
"=",
|
||||
"≠",
|
||||
">",
|
||||
"<",
|
||||
"≥",
|
||||
"≤",
|
||||
"before",
|
||||
"after",
|
||||
]
|
||||
|
||||
DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID: Final[str] = "dify.knowledge_base"
|
||||
|
||||
|
||||
class DifyKnowledgeModelConfig(BaseModel):
|
||||
"""Static model configuration forwarded to the inner retrieval API."""
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class DifyKnowledgeRerankingModelConfig(BaseModel):
|
||||
"""Reranking model settings for multiple-mode retrieval."""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class DifyKnowledgeRetrievalConfig(BaseModel):
|
||||
"""Static retrieval controls mirrored into the inner API request."""
|
||||
|
||||
mode: Literal["multiple", "single"]
|
||||
top_k: int | None = Field(default=None, ge=1)
|
||||
score_threshold: float = 0.0
|
||||
reranking_mode: str = "reranking_model"
|
||||
reranking_enable: bool = True
|
||||
reranking_model: DifyKnowledgeRerankingModelConfig | None = None
|
||||
weights: dict[str, JsonValue] | None = None
|
||||
model: DifyKnowledgeModelConfig | None = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode_specific_fields(self) -> DifyKnowledgeRetrievalConfig:
|
||||
if self.mode == "multiple" and self.top_k is None:
|
||||
raise ValueError("retrieval.top_k is required for multiple mode")
|
||||
if self.mode == "single" and self.model is None:
|
||||
raise ValueError("retrieval.model is required for single mode")
|
||||
return self
|
||||
|
||||
def to_request_payload(self) -> dict[str, JsonValue]:
|
||||
"""Serialize the retrieval config into the inner API request shape."""
|
||||
payload: dict[str, JsonValue] = {
|
||||
"mode": self.mode,
|
||||
"score_threshold": self.score_threshold,
|
||||
"reranking_mode": self.reranking_mode,
|
||||
"reranking_enable": self.reranking_enable,
|
||||
}
|
||||
if self.mode == "multiple":
|
||||
payload["top_k"] = self.top_k
|
||||
payload["reranking_model"] = (
|
||||
self.reranking_model.model_dump(mode="json") if self.reranking_model is not None else None
|
||||
)
|
||||
payload["weights"] = self.weights
|
||||
else:
|
||||
payload["model"] = self.model.model_dump(mode="json") if self.model is not None else None
|
||||
return payload
|
||||
|
||||
|
||||
class DifyKnowledgeMetadataCondition(BaseModel):
|
||||
"""One manual metadata filter clause."""
|
||||
|
||||
name: str
|
||||
comparison_operator: DifyKnowledgeMetadataComparisonOperator
|
||||
value: str | list[str] | int | float | None = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class DifyKnowledgeMetadataConditions(BaseModel):
|
||||
"""Boolean composition for manual metadata filtering."""
|
||||
|
||||
logical_operator: Literal["and", "or"] = "and"
|
||||
conditions: list[DifyKnowledgeMetadataCondition]
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class DifyKnowledgeMetadataFilteringConfig(BaseModel):
|
||||
"""Static metadata filtering controls for the inner API request."""
|
||||
|
||||
mode: Literal["disabled", "automatic", "manual"] = "disabled"
|
||||
metadata_model_config: DifyKnowledgeModelConfig | None = Field(default=None, alias="model_config")
|
||||
conditions: DifyKnowledgeMetadataConditions | None = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mode_specific_fields(self) -> DifyKnowledgeMetadataFilteringConfig:
|
||||
if self.mode == "automatic" and self.metadata_model_config is None:
|
||||
raise ValueError("metadata_filtering.model_config is required for automatic mode")
|
||||
if self.mode == "manual" and (self.conditions is None or not self.conditions.conditions):
|
||||
raise ValueError("metadata_filtering.conditions is required for manual mode")
|
||||
return self
|
||||
|
||||
def to_request_payload(self) -> dict[str, JsonValue]:
|
||||
"""Serialize metadata filtering using the inner API request field names."""
|
||||
if self.mode == "disabled":
|
||||
return {"mode": self.mode}
|
||||
|
||||
payload: dict[str, JsonValue] = {"mode": self.mode}
|
||||
if self.metadata_model_config is not None:
|
||||
payload["model_config"] = self.metadata_model_config.model_dump(mode="json")
|
||||
if self.conditions is not None:
|
||||
payload["conditions"] = self.conditions.model_dump(mode="json")
|
||||
return payload
|
||||
|
||||
|
||||
class DifyKnowledgeBaseLayerConfig(LayerConfig):
|
||||
"""Public config for one model-visible knowledge search tool.
|
||||
|
||||
The model only gets to choose whether to call the tool and what ``query``
|
||||
to send. Dataset ids, retrieval settings, metadata filtering, and caller
|
||||
context remain config/runtime concerns outside the model-visible tool
|
||||
schema. The tool name and description are fixed by the layer runtime and do
|
||||
not appear in the public config DTO.
|
||||
"""
|
||||
|
||||
dataset_ids: list[str]
|
||||
retrieval: DifyKnowledgeRetrievalConfig
|
||||
metadata_filtering: DifyKnowledgeMetadataFilteringConfig = Field(
|
||||
default_factory=DifyKnowledgeMetadataFilteringConfig
|
||||
)
|
||||
max_result_content_chars: int = Field(default=2000, ge=1)
|
||||
max_observation_chars: int = Field(default=12000, ge=1)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("dataset_ids")
|
||||
@classmethod
|
||||
def validate_dataset_ids(cls, value: list[str]) -> list[str]:
|
||||
if not value:
|
||||
raise ValueError("dataset_ids must contain at least one item")
|
||||
normalized_ids = [item.strip() for item in value]
|
||||
if any(not item for item in normalized_ids):
|
||||
raise ValueError("dataset_ids must not contain blank items")
|
||||
return normalized_ids
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_observation_limits(self) -> DifyKnowledgeBaseLayerConfig:
|
||||
if self.max_observation_chars < self.max_result_content_chars:
|
||||
raise ValueError("max_observation_chars must be greater than or equal to max_result_content_chars")
|
||||
return self
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID",
|
||||
"DifyKnowledgeBaseLayerConfig",
|
||||
"DifyKnowledgeMetadataCondition",
|
||||
"DifyKnowledgeMetadataConditions",
|
||||
"DifyKnowledgeMetadataFilteringConfig",
|
||||
"DifyKnowledgeModelConfig",
|
||||
"DifyKnowledgeRerankingModelConfig",
|
||||
"DifyKnowledgeRetrievalConfig",
|
||||
]
|
||||
285
dify-agent/src/dify_agent/layers/knowledge/layer.py
Normal file
285
dify-agent/src/dify_agent/layers/knowledge/layer.py
Normal file
@ -0,0 +1,285 @@
|
||||
"""Dify knowledge-base layer exposing one model-visible search tool.
|
||||
|
||||
The layer depends on ``DifyExecutionContextLayer`` for tenant/app/user/invoke
|
||||
identity, keeps retrieval controls in config only, and borrows a lifespan-owned
|
||||
HTTP client for each tool invocation. It never owns live clients or stores
|
||||
retrieved source content in layer state. Tool identity is intentionally fixed at
|
||||
runtime: callers cannot rename the knowledge tool or override its description
|
||||
through public layer config because the model-visible surface must stay stable
|
||||
across API-side Agent Soul mappings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import ClassVar, cast
|
||||
|
||||
import httpx
|
||||
from pydantic_ai import RunContext, Tool
|
||||
from pydantic_ai.tools import ToolDefinition
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from agenton.layers import LayerDeps, PlainLayer
|
||||
from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer
|
||||
from dify_agent.layers.knowledge.client import (
|
||||
DifyKnowledgeBaseClient,
|
||||
DifyKnowledgeBaseClientError,
|
||||
DifyKnowledgeRetrieveResponse,
|
||||
)
|
||||
from dify_agent.layers.knowledge.configs import DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID, DifyKnowledgeBaseLayerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Fixed model-visible tool identity. These stay module-private on purpose so the
|
||||
# public DTO cannot grow a parallel naming contract that diverges from the
|
||||
# runtime knowledge-search surface.
|
||||
_KNOWLEDGE_BASE_TOOL_NAME = "knowledge_base_search"
|
||||
_KNOWLEDGE_BASE_TOOL_DESCRIPTION = "Search configured knowledge bases for information relevant to the query."
|
||||
BLANK_QUERY_OBSERVATION = "knowledge base search requires a non-empty query"
|
||||
NO_RESULTS_OBSERVATION = "No relevant knowledge base results were found."
|
||||
TEMPORARY_UNAVAILABLE_OBSERVATION = (
|
||||
"Knowledge base search is temporarily unavailable. Please continue without it if possible."
|
||||
)
|
||||
QUERY_TOOL_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query for the configured knowledge bases.",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
|
||||
class DifyKnowledgeBaseDeps(LayerDeps):
|
||||
"""Dependencies required by ``DifyKnowledgeBaseLayer``."""
|
||||
|
||||
execution_context: DifyExecutionContextLayer # pyright: ignore[reportUninitializedInstanceVariable]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DifyKnowledgeBaseLayer(PlainLayer[DifyKnowledgeBaseDeps, DifyKnowledgeBaseLayerConfig]):
|
||||
"""Layer that resolves one config-scoped knowledge search tool."""
|
||||
|
||||
type_id: ClassVar[str | None] = DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID
|
||||
|
||||
config: DifyKnowledgeBaseLayerConfig
|
||||
dify_api_inner_url: str
|
||||
dify_api_inner_api_key: str
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_config(cls, config: DifyKnowledgeBaseLayerConfig) -> Self:
|
||||
"""Reject construction without server-injected Dify API settings."""
|
||||
del config
|
||||
raise TypeError(
|
||||
"DifyKnowledgeBaseLayer requires server-side Dify API settings and must use a provider factory."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config_with_settings(
|
||||
cls,
|
||||
config: DifyKnowledgeBaseLayerConfig,
|
||||
*,
|
||||
dify_api_inner_url: str,
|
||||
dify_api_inner_api_key: str,
|
||||
) -> Self:
|
||||
"""Create the layer from public config plus server-only API settings."""
|
||||
return cls(
|
||||
config=DifyKnowledgeBaseLayerConfig.model_validate(config),
|
||||
dify_api_inner_url=dify_api_inner_url,
|
||||
dify_api_inner_api_key=dify_api_inner_api_key,
|
||||
)
|
||||
|
||||
async def get_tools(self, *, http_client: httpx.AsyncClient) -> list[Tool[object]]:
|
||||
"""Build one Pydantic AI tool that exposes only ``query`` to the model.
|
||||
|
||||
Knowledge tools depend on execution-context identity that is optional for
|
||||
other run types but mandatory here: ``tenant_id``, ``user_id``,
|
||||
``user_from``, ``app_id``, and ``invoke_from`` must all be present before
|
||||
any HTTP request is attempted. Tool execution then follows a strict
|
||||
observation policy:
|
||||
|
||||
- blank ``query`` returns a local validation observation;
|
||||
- retryable client failures (timeouts, connection failures, HTTP
|
||||
``429``/``502``) become a temporary-unavailable observation;
|
||||
- non-retryable client failures are raised so the run fails fast.
|
||||
"""
|
||||
if http_client.is_closed:
|
||||
raise RuntimeError("DifyKnowledgeBaseLayer.get_tools() requires an open shared HTTP client.")
|
||||
|
||||
execution_context = self.deps.execution_context.config
|
||||
caller = _build_caller_context(execution_context)
|
||||
client = DifyKnowledgeBaseClient(
|
||||
base_url=self.dify_api_inner_url,
|
||||
api_key=self.dify_api_inner_api_key,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
async def knowledge_base_search(_ctx: RunContext[object], query: str) -> str:
|
||||
normalized_query = query.strip()
|
||||
if not normalized_query:
|
||||
return BLANK_QUERY_OBSERVATION
|
||||
try:
|
||||
response = await client.retrieve(
|
||||
tenant_id=caller["tenant_id"],
|
||||
user_id=caller["user_id"],
|
||||
app_id=caller["app_id"],
|
||||
user_from=caller["user_from"],
|
||||
invoke_from=caller["invoke_from"],
|
||||
dataset_ids=list(self.config.dataset_ids),
|
||||
query=normalized_query,
|
||||
retrieval=self.config.retrieval,
|
||||
metadata_filtering=self.config.metadata_filtering,
|
||||
)
|
||||
except DifyKnowledgeBaseClientError as exc:
|
||||
if exc.retryable:
|
||||
logger.warning(
|
||||
"knowledge base search temporarily unavailable",
|
||||
extra={
|
||||
"tenant_id": caller["tenant_id"],
|
||||
"app_id": caller["app_id"],
|
||||
"invoke_from": caller["invoke_from"],
|
||||
"error_code": exc.error_code,
|
||||
"status_code": exc.status_code,
|
||||
},
|
||||
)
|
||||
return TEMPORARY_UNAVAILABLE_OBSERVATION
|
||||
logger.error(
|
||||
"knowledge base search failed",
|
||||
extra={
|
||||
"tenant_id": caller["tenant_id"],
|
||||
"app_id": caller["app_id"],
|
||||
"invoke_from": caller["invoke_from"],
|
||||
"error_code": exc.error_code,
|
||||
"status_code": exc.status_code,
|
||||
},
|
||||
)
|
||||
raise
|
||||
return _format_observation(response, self.config)
|
||||
|
||||
async def prepare_tool_definition(_ctx: RunContext[object], tool_def: ToolDefinition) -> ToolDefinition:
|
||||
return ToolDefinition(
|
||||
name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters_json_schema=QUERY_TOOL_SCHEMA,
|
||||
strict=tool_def.strict,
|
||||
sequential=tool_def.sequential,
|
||||
metadata=tool_def.metadata,
|
||||
timeout=tool_def.timeout,
|
||||
defer_loading=tool_def.defer_loading,
|
||||
kind=tool_def.kind,
|
||||
return_schema=tool_def.return_schema,
|
||||
include_return_schema=tool_def.include_return_schema,
|
||||
)
|
||||
|
||||
return [
|
||||
Tool(
|
||||
knowledge_base_search,
|
||||
takes_ctx=True,
|
||||
name=_KNOWLEDGE_BASE_TOOL_NAME,
|
||||
description=_KNOWLEDGE_BASE_TOOL_DESCRIPTION,
|
||||
prepare=prepare_tool_definition,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _build_caller_context(execution_context: object) -> dict[str, str]:
|
||||
"""Extract the inner-API caller identity from execution-context config.
|
||||
|
||||
The public execution-context DTO keeps several fields optional for general
|
||||
runs, but knowledge retrieval requires all of ``tenant_id``, ``user_id``,
|
||||
``user_from``, ``app_id``, and ``invoke_from``. Missing or blank values are
|
||||
rejected here so misconfigured runs fail before transport rather than being
|
||||
softened into tool observations.
|
||||
"""
|
||||
tenant_id = getattr(execution_context, "tenant_id", None)
|
||||
user_id = getattr(execution_context, "user_id", None)
|
||||
user_from = getattr(execution_context, "user_from", None)
|
||||
app_id = getattr(execution_context, "app_id", None)
|
||||
invoke_from = getattr(execution_context, "invoke_from", None)
|
||||
|
||||
missing_fields = [
|
||||
field_name
|
||||
for field_name, value in (
|
||||
("tenant_id", tenant_id),
|
||||
("user_id", user_id),
|
||||
("user_from", user_from),
|
||||
("app_id", app_id),
|
||||
("invoke_from", invoke_from),
|
||||
)
|
||||
if not isinstance(value, str) or not value.strip()
|
||||
]
|
||||
if missing_fields:
|
||||
joined_fields = ", ".join(missing_fields)
|
||||
raise ValueError(f"Dify knowledge base layer requires execution context fields: {joined_fields}")
|
||||
|
||||
normalized_tenant_id = cast(str, tenant_id).strip()
|
||||
normalized_user_id = cast(str, user_id).strip()
|
||||
normalized_user_from = cast(str, user_from).strip()
|
||||
normalized_app_id = cast(str, app_id).strip()
|
||||
normalized_invoke_from = cast(str, invoke_from).strip()
|
||||
|
||||
return {
|
||||
"tenant_id": normalized_tenant_id,
|
||||
"user_id": normalized_user_id,
|
||||
"user_from": normalized_user_from,
|
||||
"app_id": normalized_app_id,
|
||||
"invoke_from": normalized_invoke_from,
|
||||
}
|
||||
|
||||
|
||||
def _format_observation(response: DifyKnowledgeRetrieveResponse, config: DifyKnowledgeBaseLayerConfig) -> str:
|
||||
"""Render inner-API retrieval results into the model-visible tool response.
|
||||
|
||||
The formatting contract is intentionally simple and stable for the model:
|
||||
|
||||
- empty ``results`` returns ``NO_RESULTS_OBSERVATION``;
|
||||
- non-empty results become a numbered list headed by
|
||||
``"Knowledge base search results:"``;
|
||||
- each item includes title plus dataset/document/score metadata when those
|
||||
fields are present;
|
||||
- each content snippet is truncated by ``max_result_content_chars``;
|
||||
- the final observation is truncated by ``max_observation_chars``.
|
||||
"""
|
||||
if not response.results:
|
||||
return NO_RESULTS_OBSERVATION
|
||||
|
||||
lines = ["Knowledge base search results:"]
|
||||
for index, result in enumerate(response.results, start=1):
|
||||
metadata = result.metadata
|
||||
title = result.title or metadata.document_name or "Untitled"
|
||||
lines.append(f"{index}. Title: {title}")
|
||||
if metadata.dataset_name:
|
||||
lines.append(f" Dataset: {metadata.dataset_name}")
|
||||
if metadata.document_name:
|
||||
lines.append(f" Document: {metadata.document_name}")
|
||||
if metadata.score is not None:
|
||||
lines.append(f" Score: {metadata.score}")
|
||||
content = _truncate_text(result.content or result.summary or "", config.max_result_content_chars)
|
||||
if content:
|
||||
lines.append(f" Content: {content}")
|
||||
lines.append("")
|
||||
|
||||
return _truncate_text("\n".join(lines).rstrip(), config.max_observation_chars)
|
||||
|
||||
|
||||
def _truncate_text(text: str, max_chars: int) -> str:
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
if max_chars <= 3:
|
||||
return text[:max_chars]
|
||||
return f"{text[: max_chars - 3]}..."
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BLANK_QUERY_OBSERVATION",
|
||||
"DifyKnowledgeBaseDeps",
|
||||
"DifyKnowledgeBaseLayer",
|
||||
"NO_RESULTS_OBSERVATION",
|
||||
"QUERY_TOOL_SCHEMA",
|
||||
"TEMPORARY_UNAVAILABLE_OBSERVATION",
|
||||
]
|
||||
@ -4,23 +4,23 @@ Only explicitly allowed provider type ids are constructible here. The default
|
||||
provider set contains prompt layers, the optional pydantic-ai history layer, the
|
||||
state-free Dify structured output layer, the optional Dify ask-human layer, the
|
||||
Dify execution-context layer, the stateful Dify shell layer, and the Dify
|
||||
plugin business-layer family:
|
||||
plugin/knowledge business-layer family:
|
||||
|
||||
- ``dify.drive`` for the inert Skills & Files drive declaration,
|
||||
- ``dify.execution_context`` for shared tenant/user/run daemon context,
|
||||
- ``dify.shell`` for shellctl-backed shell job control,
|
||||
- ``dify.plugin.llm`` for plugin-backed model selection, and
|
||||
- ``dify.plugin.tools`` for prepared plugin tool exposure.
|
||||
- ``dify.plugin.llm`` for plugin-backed model selection,
|
||||
- ``dify.plugin.tools`` for prepared plugin tool exposure, and
|
||||
- ``dify.knowledge_base`` for inner-API-backed knowledge search tools.
|
||||
|
||||
Public DTOs provide Dify context plus plugin/model/tool data, while server-only
|
||||
plugin daemon settings are injected through the provider factory for
|
||||
``DifyExecutionContextLayer`` and the optional shellctl entrypoint/auth token plus
|
||||
client factory plus optional Agent Stub URL/token issuer are injected for
|
||||
``DifyShellLayer``. The resulting ``Compositor``
|
||||
remains Agenton state-only at the snapshot boundary: live resources such as
|
||||
HTTP clients are injected by runtime-owned providers, may be held on active
|
||||
layer instances inside ``resource_context()``, and never enter session
|
||||
snapshots.
|
||||
plugin daemon settings and Dify API inner settings are injected through provider
|
||||
factories. Optional shellctl entrypoint/auth token, client factory, and Agent
|
||||
Stub URL/token issuer are injected for ``DifyShellLayer``. The resulting
|
||||
``Compositor`` remains Agenton state-only at the snapshot boundary: live
|
||||
resources such as HTTP clients are injected by runtime-owned providers, may be
|
||||
held on active layer instances inside ``resource_context()``, and never enter
|
||||
session snapshots.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
@ -41,6 +41,8 @@ from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer
|
||||
from dify_agent.layers.drive.layer import DifyDriveLayer
|
||||
from dify_agent.layers.execution_context.configs import DifyExecutionContextLayerConfig
|
||||
from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer
|
||||
from dify_agent.layers.knowledge.configs import DifyKnowledgeBaseLayerConfig
|
||||
from dify_agent.layers.knowledge.layer import DifyKnowledgeBaseLayer
|
||||
from dify_agent.layers.output.output_layer import DifyOutputLayer
|
||||
from dify_agent.layers.shell.configs import DifyShellLayerConfig
|
||||
from dify_agent.layers.shell.layer import DifyShellLayer, create_shellctl_client_factory
|
||||
@ -52,6 +54,8 @@ def create_default_layer_providers(
|
||||
*,
|
||||
plugin_daemon_url: str = "http://localhost:5002",
|
||||
plugin_daemon_api_key: str = "",
|
||||
dify_api_inner_url: str = "http://localhost:5001",
|
||||
dify_api_inner_api_key: str = "",
|
||||
shellctl_entrypoint: str | None = None,
|
||||
shellctl_auth_token: str | None = None,
|
||||
agent_stub_url: str | None = None,
|
||||
@ -109,6 +113,14 @@ def create_default_layer_providers(
|
||||
),
|
||||
LayerProvider.from_layer_type(DifyPluginLLMLayer),
|
||||
LayerProvider.from_layer_type(DifyPluginToolsLayer),
|
||||
LayerProvider.from_factory(
|
||||
layer_type=DifyKnowledgeBaseLayer,
|
||||
create=lambda config: DifyKnowledgeBaseLayer.from_config_with_settings(
|
||||
DifyKnowledgeBaseLayerConfig.model_validate(config),
|
||||
dify_api_inner_url=dify_api_inner_url,
|
||||
dify_api_inner_api_key=dify_api_inner_api_key,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -69,6 +69,7 @@ class RunScheduler:
|
||||
runner_factory: RunRunnerFactory
|
||||
layer_providers: tuple[LayerProviderInput, ...]
|
||||
plugin_daemon_http_client: httpx.AsyncClient
|
||||
dify_api_http_client: httpx.AsyncClient
|
||||
_lifecycle_lock: asyncio.Lock
|
||||
|
||||
def __init__(
|
||||
@ -76,6 +77,7 @@ class RunScheduler:
|
||||
*,
|
||||
store: RunStore,
|
||||
plugin_daemon_http_client: httpx.AsyncClient,
|
||||
dify_api_http_client: httpx.AsyncClient,
|
||||
shutdown_grace_seconds: float = 30,
|
||||
layer_providers: tuple[LayerProviderInput, ...] | None = None,
|
||||
runner_factory: RunRunnerFactory | None = None,
|
||||
@ -85,6 +87,7 @@ class RunScheduler:
|
||||
self.active_tasks = {}
|
||||
self.stopping = False
|
||||
self.plugin_daemon_http_client = plugin_daemon_http_client
|
||||
self.dify_api_http_client = dify_api_http_client
|
||||
self.layer_providers = layer_providers if layer_providers is not None else create_default_layer_providers()
|
||||
self.runner_factory = runner_factory or self._default_runner_factory
|
||||
self._lifecycle_lock = asyncio.Lock()
|
||||
@ -141,6 +144,7 @@ class RunScheduler:
|
||||
request=request,
|
||||
run_id=record.run_id,
|
||||
plugin_daemon_http_client=self.plugin_daemon_http_client,
|
||||
dify_api_http_client=self.dify_api_http_client,
|
||||
layer_providers=self.layer_providers,
|
||||
)
|
||||
|
||||
|
||||
@ -19,7 +19,9 @@ publishes that deferred request through the normal ``run_succeeded`` event as
|
||||
``deferred_tool_call`` instead of a final ``output``. Invalid structured outputs
|
||||
or invalid deferred-tool behavior still trigger normal retries/failures before
|
||||
Dify Agent emits success. Layers still never own the FastAPI lifespan-owned
|
||||
plugin daemon HTTP client.
|
||||
plugin daemon or Dify API inner HTTP clients. Successful terminal events contain
|
||||
both the JSON-safe final output or deferred tool call and the session snapshot;
|
||||
there are no separate output or snapshot events to correlate.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterable
|
||||
@ -38,6 +40,7 @@ from agenton.layers.types import PydanticAITool
|
||||
from dify_agent.layers.ask_human.layer import get_ask_human_layer, validate_ask_human_layer_composition
|
||||
from dify_agent.layers.dify_plugin.llm_layer import DifyPluginLLMLayer
|
||||
from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer
|
||||
from dify_agent.layers.knowledge.layer import DifyKnowledgeBaseLayer
|
||||
from dify_agent.protocol.schemas import (
|
||||
CreateRunRequest,
|
||||
DIFY_AGENT_MODEL_LAYER_ID,
|
||||
@ -91,6 +94,7 @@ class AgentRunRunner:
|
||||
run_id: str
|
||||
layer_providers: tuple[LayerProviderInput, ...]
|
||||
plugin_daemon_http_client: httpx.AsyncClient
|
||||
dify_api_http_client: httpx.AsyncClient
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -99,12 +103,14 @@ class AgentRunRunner:
|
||||
request: CreateRunRequest,
|
||||
run_id: str,
|
||||
plugin_daemon_http_client: httpx.AsyncClient,
|
||||
dify_api_http_client: httpx.AsyncClient,
|
||||
layer_providers: tuple[LayerProviderInput, ...] | None = None,
|
||||
) -> None:
|
||||
self.sink = sink
|
||||
self.request = request
|
||||
self.run_id = run_id
|
||||
self.plugin_daemon_http_client = plugin_daemon_http_client
|
||||
self.dify_api_http_client = dify_api_http_client
|
||||
self.layer_providers = layer_providers if layer_providers is not None else create_default_layer_providers()
|
||||
|
||||
async def run(self) -> None:
|
||||
@ -187,7 +193,11 @@ class AgentRunRunner:
|
||||
ask_human_layer = get_ask_human_layer(run)
|
||||
llm_layer = run.get_layer(DIFY_AGENT_MODEL_LAYER_ID, DifyPluginLLMLayer)
|
||||
model = llm_layer.get_model(http_client=self.plugin_daemon_http_client)
|
||||
tools = await _resolve_run_tools(run, http_client=self.plugin_daemon_http_client)
|
||||
tools = await _resolve_run_tools(
|
||||
run,
|
||||
plugin_daemon_http_client=self.plugin_daemon_http_client,
|
||||
dify_api_http_client=self.dify_api_http_client,
|
||||
)
|
||||
except (KeyError, TypeError, RuntimeError, ValueError) as exc:
|
||||
raise AgentRunValidationError(str(exc)) from exc
|
||||
|
||||
@ -266,14 +276,17 @@ def _resolve_deferred_tool_results(request: CreateRunRequest) -> DeferredToolRes
|
||||
async def _resolve_run_tools(
|
||||
run: Any,
|
||||
*,
|
||||
http_client: httpx.AsyncClient,
|
||||
plugin_daemon_http_client: httpx.AsyncClient,
|
||||
dify_api_http_client: httpx.AsyncClient,
|
||||
) -> list[PydanticAITool[object]]:
|
||||
"""Return the static compositor tools plus any Dify plugin runtime tools."""
|
||||
"""Return the static compositor tools plus any Dify runtime tools."""
|
||||
resolved_tools = list(cast(list[PydanticAITool[object]], run.tools))
|
||||
for slot in run.slots.values():
|
||||
layer = slot.layer
|
||||
if isinstance(layer, DifyPluginToolsLayer):
|
||||
resolved_tools.extend(await layer.get_tools(http_client=http_client))
|
||||
resolved_tools.extend(await layer.get_tools(http_client=plugin_daemon_http_client))
|
||||
if isinstance(layer, DifyKnowledgeBaseLayer):
|
||||
resolved_tools.extend(await layer.get_tools(http_client=dify_api_http_client))
|
||||
_validate_unique_tool_names(resolved_tools)
|
||||
return resolved_tools
|
||||
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
"""FastAPI application factory for the Dify Agent run server.
|
||||
|
||||
The HTTP process owns Redis clients, one shared plugin daemon ``httpx.AsyncClient``,
|
||||
route wiring, and a process-local scheduler. Run execution happens in background
|
||||
``asyncio`` tasks rather than request handlers, so client disconnects do not
|
||||
cancel the agent runtime. Redis persists run records and per-run event streams
|
||||
with configured retention only; it is not used as a job queue. Agenton layers and
|
||||
providers stay state-only: they borrow the lifespan-owned plugin daemon client
|
||||
through the runner and receive shell-layer server settings through provider
|
||||
construction rather than reading environment variables themselves. The standard
|
||||
server always mounts the HTTP Agent Stub router and additionally starts the
|
||||
optional grpclib Agent Stub server when ``DIFY_AGENT_STUB_URL`` uses ``grpc://``.
|
||||
The HTTP process owns Redis clients plus separate shared ``httpx.AsyncClient``
|
||||
instances for plugin-daemon and Dify API inner calls, route wiring, and a
|
||||
process-local scheduler. Run execution happens in background ``asyncio`` tasks
|
||||
rather than request handlers, so client disconnects do not cancel the agent
|
||||
runtime. Redis persists run records and per-run event streams with configured
|
||||
retention only; it is not used as a job queue. Agenton layers and providers
|
||||
stay state-only: they borrow the lifespan-owned clients through the runner and
|
||||
receive shell-layer server settings through provider construction rather than
|
||||
reading environment variables themselves. The standard server always mounts the
|
||||
HTTP Agent Stub router and additionally starts the optional grpclib Agent Stub
|
||||
server when ``DIFY_AGENT_STUB_URL`` uses ``grpc://``.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
@ -39,6 +40,8 @@ def create_app(settings: ServerSettings | None = None) -> FastAPI:
|
||||
layer_providers = create_default_layer_providers(
|
||||
plugin_daemon_url=resolved_settings.plugin_daemon_url,
|
||||
plugin_daemon_api_key=resolved_settings.plugin_daemon_api_key,
|
||||
dify_api_inner_url=resolved_settings.dify_api_inner_url,
|
||||
dify_api_inner_api_key=resolved_settings.dify_api_inner_api_key or "",
|
||||
shellctl_entrypoint=resolved_settings.shellctl_entrypoint,
|
||||
shellctl_auth_token=resolved_settings.shellctl_auth_token,
|
||||
agent_stub_url=resolved_settings.agent_stub_url,
|
||||
@ -53,6 +56,7 @@ def create_app(settings: ServerSettings | None = None) -> FastAPI:
|
||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
redis = Redis.from_url(resolved_settings.redis_url)
|
||||
plugin_daemon_http_client = create_plugin_daemon_http_client(resolved_settings)
|
||||
dify_api_inner_http_client = create_dify_api_inner_http_client(resolved_settings)
|
||||
store = RedisRunStore(
|
||||
redis,
|
||||
prefix=resolved_settings.redis_prefix,
|
||||
@ -61,6 +65,7 @@ def create_app(settings: ServerSettings | None = None) -> FastAPI:
|
||||
scheduler = RunScheduler(
|
||||
store=store,
|
||||
plugin_daemon_http_client=plugin_daemon_http_client,
|
||||
dify_api_http_client=dify_api_inner_http_client,
|
||||
shutdown_grace_seconds=resolved_settings.shutdown_grace_seconds,
|
||||
layer_providers=layer_providers,
|
||||
)
|
||||
@ -83,6 +88,7 @@ def create_app(settings: ServerSettings | None = None) -> FastAPI:
|
||||
if grpc_server is not None:
|
||||
await grpc_server.aclose()
|
||||
await scheduler.shutdown()
|
||||
await dify_api_inner_http_client.aclose()
|
||||
await plugin_daemon_http_client.aclose()
|
||||
await redis.aclose()
|
||||
|
||||
@ -112,17 +118,33 @@ def create_plugin_daemon_http_client(settings: ServerSettings) -> httpx.AsyncCli
|
||||
process and must be closed by the app lifespan after the scheduler has stopped
|
||||
using it.
|
||||
"""
|
||||
return _create_shared_http_client(settings)
|
||||
|
||||
|
||||
def create_dify_api_inner_http_client(settings: ServerSettings) -> httpx.AsyncClient:
|
||||
"""Create the lifespan-owned Dify API inner HTTP client.
|
||||
|
||||
The Dify API inner client intentionally shares the generic outbound HTTP
|
||||
timeout and connection-pool settings with the plugin daemon client so
|
||||
operational tuning stays in one place while endpoint URL/API keys remain
|
||||
distinct server settings.
|
||||
"""
|
||||
return _create_shared_http_client(settings)
|
||||
|
||||
|
||||
def _create_shared_http_client(settings: ServerSettings) -> httpx.AsyncClient:
|
||||
"""Build one shared HTTP client using generic outbound timeout/pool settings."""
|
||||
return httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
connect=settings.plugin_daemon_connect_timeout,
|
||||
read=settings.plugin_daemon_read_timeout,
|
||||
write=settings.plugin_daemon_write_timeout,
|
||||
pool=settings.plugin_daemon_pool_timeout,
|
||||
connect=settings.outbound_http_connect_timeout,
|
||||
read=settings.outbound_http_read_timeout,
|
||||
write=settings.outbound_http_write_timeout,
|
||||
pool=settings.outbound_http_pool_timeout,
|
||||
),
|
||||
limits=httpx.Limits(
|
||||
max_connections=settings.plugin_daemon_max_connections,
|
||||
max_keepalive_connections=settings.plugin_daemon_max_keepalive_connections,
|
||||
keepalive_expiry=settings.plugin_daemon_keepalive_expiry,
|
||||
max_connections=settings.outbound_http_max_connections,
|
||||
max_keepalive_connections=settings.outbound_http_max_keepalive_connections,
|
||||
keepalive_expiry=settings.outbound_http_keepalive_expiry,
|
||||
),
|
||||
trust_env=False,
|
||||
)
|
||||
@ -131,4 +153,4 @@ def create_plugin_daemon_http_client(settings: ServerSettings) -> httpx.AsyncCli
|
||||
app = create_app()
|
||||
|
||||
|
||||
__all__ = ["app", "create_app", "create_plugin_daemon_http_client"]
|
||||
__all__ = ["app", "create_app", "create_dify_api_inner_http_client", "create_plugin_daemon_http_client"]
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
"""Configuration for the FastAPI run server.
|
||||
|
||||
Plugin daemon HTTP client settings describe the single FastAPI lifespan-owned
|
||||
``httpx.AsyncClient`` shared by local run tasks. Layers and Agenton providers do
|
||||
not own that client, so these settings are process resource limits rather than
|
||||
per-run lifecycle knobs. The Agent Stub now also uses this main server settings
|
||||
model directly: the public Agent Stub URL, server secret, optional gRPC bind
|
||||
override, and optional Dify inner API file-request settings all live here under
|
||||
the longstanding ``DIFY_AGENT_...`` environment-variable namespace.
|
||||
Outbound HTTP client settings describe the FastAPI lifespan-owned
|
||||
``httpx.AsyncClient`` instances shared by local run tasks for plugin-daemon and
|
||||
Dify API inner calls. Layers and Agenton providers do not own those clients, so
|
||||
these settings are process resource limits rather than per-run lifecycle knobs.
|
||||
Endpoint URLs and API keys stay service-specific. The Agent Stub also uses this
|
||||
settings model directly: the public Agent Stub URL, server secret, optional gRPC
|
||||
bind override, and optional Dify inner API file-request settings all live here
|
||||
under the longstanding ``DIFY_AGENT_...`` environment-variable namespace.
|
||||
"""
|
||||
|
||||
from typing import ClassVar
|
||||
@ -23,7 +24,7 @@ DEFAULT_RUN_RETENTION_SECONDS = 3 * 24 * 60 * 60
|
||||
|
||||
|
||||
class ServerSettings(BaseSettings):
|
||||
"""Environment-backed settings for Redis, scheduling, plugin, and shell access."""
|
||||
"""Environment-backed settings for Redis, scheduling, outbound HTTP, and shell access."""
|
||||
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
redis_prefix: str = "dify-agent"
|
||||
@ -31,6 +32,7 @@ class ServerSettings(BaseSettings):
|
||||
run_retention_seconds: int = Field(default=DEFAULT_RUN_RETENTION_SECONDS, ge=1)
|
||||
plugin_daemon_url: str = "http://localhost:5002"
|
||||
plugin_daemon_api_key: str = ""
|
||||
dify_api_inner_url: str = "http://localhost:5001"
|
||||
dify_api_base_url: str | None = None
|
||||
dify_api_inner_api_key: str | None = None
|
||||
shellctl_entrypoint: str | None = None
|
||||
@ -38,13 +40,13 @@ class ServerSettings(BaseSettings):
|
||||
agent_stub_url: str | None = Field(default=None, validation_alias="DIFY_AGENT_STUB_URL")
|
||||
agent_stub_grpc_bind_address: str | None = Field(default=None, validation_alias="DIFY_AGENT_STUB_GRPC_BIND_ADDRESS")
|
||||
server_secret_key: str | None = None
|
||||
plugin_daemon_connect_timeout: float = Field(default=10.0, ge=0)
|
||||
plugin_daemon_read_timeout: float = Field(default=600.0, ge=0)
|
||||
plugin_daemon_write_timeout: float = Field(default=30.0, ge=0)
|
||||
plugin_daemon_pool_timeout: float = Field(default=10.0, ge=0)
|
||||
plugin_daemon_max_connections: int = Field(default=100, ge=1)
|
||||
plugin_daemon_max_keepalive_connections: int = Field(default=20, ge=0)
|
||||
plugin_daemon_keepalive_expiry: float = Field(default=30.0, ge=0)
|
||||
outbound_http_connect_timeout: float = Field(default=10.0, ge=0)
|
||||
outbound_http_read_timeout: float = Field(default=600.0, ge=0)
|
||||
outbound_http_write_timeout: float = Field(default=30.0, ge=0)
|
||||
outbound_http_pool_timeout: float = Field(default=10.0, ge=0)
|
||||
outbound_http_max_connections: int = Field(default=100, ge=1)
|
||||
outbound_http_max_keepalive_connections: int = Field(default=20, ge=0)
|
||||
outbound_http_keepalive_expiry: float = Field(default=30.0, ge=0)
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
env_prefix="DIFY_AGENT_",
|
||||
@ -116,7 +118,7 @@ class ServerSettings(BaseSettings):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_agent_stub_requirements(self) -> "ServerSettings":
|
||||
"""Require the server secret and Dify API file settings in valid pairs."""
|
||||
"""Require Agent Stub settings while allowing knowledge-only inner API keys."""
|
||||
if self.agent_stub_url is not None and self.server_secret_key is None:
|
||||
raise ValueError("DIFY_AGENT_SERVER_SECRET_KEY is required when DIFY_AGENT_STUB_URL is set.")
|
||||
if self.agent_stub_grpc_bind_address is not None:
|
||||
@ -124,8 +126,8 @@ class ServerSettings(BaseSettings):
|
||||
raise ValueError("DIFY_AGENT_STUB_URL is required when DIFY_AGENT_STUB_GRPC_BIND_ADDRESS is set.")
|
||||
if not parse_agent_stub_endpoint(self.agent_stub_url).is_grpc:
|
||||
raise ValueError("DIFY_AGENT_STUB_GRPC_BIND_ADDRESS requires a grpc:// DIFY_AGENT_STUB_URL.")
|
||||
if (self.dify_api_base_url is None) != (self.dify_api_inner_api_key is None):
|
||||
raise ValueError("DIFY_AGENT_DIFY_API_BASE_URL and DIFY_AGENT_DIFY_API_INNER_API_KEY must be set together.")
|
||||
if self.dify_api_base_url is not None and self.dify_api_inner_api_key is None:
|
||||
raise ValueError("DIFY_AGENT_DIFY_API_INNER_API_KEY is required when DIFY_AGENT_DIFY_API_BASE_URL is set.")
|
||||
return self
|
||||
|
||||
def create_agent_stub_token_codec(self) -> AgentStubTokenCodec | None:
|
||||
|
||||
@ -0,0 +1,248 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from dify_agent.layers.knowledge.client import DifyKnowledgeBaseClient, DifyKnowledgeBaseClientError
|
||||
from dify_agent.layers.knowledge.configs import (
|
||||
DifyKnowledgeMetadataFilteringConfig,
|
||||
DifyKnowledgeRetrievalConfig,
|
||||
)
|
||||
|
||||
|
||||
def _retrieval_config() -> DifyKnowledgeRetrievalConfig:
|
||||
return DifyKnowledgeRetrievalConfig(mode="multiple", top_k=4, score_threshold=0.2)
|
||||
|
||||
|
||||
def _metadata_filtering() -> DifyKnowledgeMetadataFilteringConfig:
|
||||
return DifyKnowledgeMetadataFilteringConfig(mode="disabled")
|
||||
|
||||
|
||||
def test_knowledge_client_posts_inner_api_request_with_static_controls() -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
assert str(request.url) == "http://dify-api/inner/api/knowledge/retrieve"
|
||||
assert request.headers["X-Inner-Api-Key"] == "inner-secret"
|
||||
payload = json.loads(request.content.decode("utf-8"))
|
||||
assert payload == {
|
||||
"caller": {
|
||||
"tenant_id": "tenant-1",
|
||||
"user_id": "user-1",
|
||||
"app_id": "app-1",
|
||||
"user_from": "account",
|
||||
"invoke_from": "agent_app",
|
||||
},
|
||||
"dataset_ids": ["dataset-1"],
|
||||
"query": "reset password",
|
||||
"retrieval": {
|
||||
"mode": "multiple",
|
||||
"top_k": 4,
|
||||
"score_threshold": 0.2,
|
||||
"reranking_mode": "reranking_model",
|
||||
"reranking_enable": True,
|
||||
"reranking_model": None,
|
||||
"weights": None,
|
||||
},
|
||||
"metadata_filtering": {"mode": "disabled"},
|
||||
"attachment_ids": [],
|
||||
}
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"results": [
|
||||
{
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"dataset_name": "Docs",
|
||||
"document_name": "FAQ.md",
|
||||
"score": 0.9,
|
||||
},
|
||||
"title": "FAQ",
|
||||
"files": [],
|
||||
"content": "Use the reset link.",
|
||||
"summary": None,
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
},
|
||||
)
|
||||
|
||||
async def scenario() -> None:
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http_client:
|
||||
client = DifyKnowledgeBaseClient(
|
||||
base_url="http://dify-api",
|
||||
api_key="inner-secret",
|
||||
http_client=http_client,
|
||||
)
|
||||
response = await client.retrieve(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
user_from="account",
|
||||
invoke_from="agent_app",
|
||||
dataset_ids=["dataset-1"],
|
||||
query="reset password",
|
||||
retrieval=_retrieval_config(),
|
||||
metadata_filtering=_metadata_filtering(),
|
||||
)
|
||||
assert response.results[0].metadata.dataset_name == "Docs"
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_knowledge_client_marks_retryable_http_failures() -> None:
|
||||
async def scenario() -> None:
|
||||
async with httpx.AsyncClient(
|
||||
transport=httpx.MockTransport(
|
||||
lambda _request: httpx.Response(
|
||||
502, json={"code": "external_knowledge_failed", "message": "bad gateway"}
|
||||
)
|
||||
)
|
||||
) as http_client:
|
||||
client = DifyKnowledgeBaseClient(
|
||||
base_url="http://dify-api", api_key="inner-secret", http_client=http_client
|
||||
)
|
||||
with pytest.raises(DifyKnowledgeBaseClientError) as exc_info:
|
||||
_ = await client.retrieve(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
user_from="account",
|
||||
invoke_from="agent_app",
|
||||
dataset_ids=["dataset-1"],
|
||||
query="reset password",
|
||||
retrieval=_retrieval_config(),
|
||||
metadata_filtering=_metadata_filtering(),
|
||||
)
|
||||
assert exc_info.value.status_code == 502
|
||||
assert exc_info.value.error_code == "external_knowledge_failed"
|
||||
assert exc_info.value.retryable is True
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_knowledge_client_marks_non_retryable_http_failures() -> None:
|
||||
async def scenario() -> None:
|
||||
async with httpx.AsyncClient(
|
||||
transport=httpx.MockTransport(
|
||||
lambda _request: httpx.Response(
|
||||
403,
|
||||
json={"code": "dataset_tenant_mismatch", "message": "forbidden"},
|
||||
)
|
||||
)
|
||||
) as http_client:
|
||||
client = DifyKnowledgeBaseClient(
|
||||
base_url="http://dify-api", api_key="inner-secret", http_client=http_client
|
||||
)
|
||||
with pytest.raises(DifyKnowledgeBaseClientError) as exc_info:
|
||||
_ = await client.retrieve(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
user_from="account",
|
||||
invoke_from="agent_app",
|
||||
dataset_ids=["dataset-1"],
|
||||
query="reset password",
|
||||
retrieval=_retrieval_config(),
|
||||
metadata_filtering=_metadata_filtering(),
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.error_code == "dataset_tenant_mismatch"
|
||||
assert exc_info.value.retryable is False
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_knowledge_client_rejects_malformed_success_response() -> None:
|
||||
async def scenario() -> None:
|
||||
async with httpx.AsyncClient(
|
||||
transport=httpx.MockTransport(lambda _request: httpx.Response(200, json={"bad": []}))
|
||||
) as http_client:
|
||||
client = DifyKnowledgeBaseClient(
|
||||
base_url="http://dify-api", api_key="inner-secret", http_client=http_client
|
||||
)
|
||||
with pytest.raises(DifyKnowledgeBaseClientError) as exc_info:
|
||||
_ = await client.retrieve(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
user_from="account",
|
||||
invoke_from="agent_app",
|
||||
dataset_ids=["dataset-1"],
|
||||
query="reset password",
|
||||
retrieval=_retrieval_config(),
|
||||
metadata_filtering=_metadata_filtering(),
|
||||
)
|
||||
assert exc_info.value.error_code == "invalid_response"
|
||||
assert exc_info.value.retryable is False
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_factory",
|
||||
[
|
||||
lambda request: httpx.ReadTimeout("timed out", request=request),
|
||||
lambda request: httpx.ConnectError("connection failed", request=request),
|
||||
],
|
||||
)
|
||||
def test_knowledge_client_marks_transport_failures_retryable(error_factory) -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
raise error_factory(request)
|
||||
|
||||
async def scenario() -> None:
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http_client:
|
||||
client = DifyKnowledgeBaseClient(
|
||||
base_url="http://dify-api", api_key="inner-secret", http_client=http_client
|
||||
)
|
||||
with pytest.raises(DifyKnowledgeBaseClientError) as exc_info:
|
||||
_ = await client.retrieve(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
user_from="account",
|
||||
invoke_from="agent_app",
|
||||
dataset_ids=["dataset-1"],
|
||||
query="reset password",
|
||||
retrieval=_retrieval_config(),
|
||||
metadata_filtering=_metadata_filtering(),
|
||||
)
|
||||
assert exc_info.value.retryable is True
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_knowledge_client_treats_invalid_url_errors_as_non_retryable_configuration_error() -> None:
|
||||
async def scenario() -> None:
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
http_client.post = AsyncMock(side_effect=httpx.UnsupportedProtocol("unsupported protocol"))
|
||||
client = DifyKnowledgeBaseClient(
|
||||
base_url="http://dify-api", api_key="inner-secret", http_client=http_client
|
||||
)
|
||||
with pytest.raises(DifyKnowledgeBaseClientError) as exc_info:
|
||||
_ = await client.retrieve(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
user_from="account",
|
||||
invoke_from="agent_app",
|
||||
dataset_ids=["dataset-1"],
|
||||
query="reset password",
|
||||
retrieval=_retrieval_config(),
|
||||
metadata_filtering=_metadata_filtering(),
|
||||
)
|
||||
assert exc_info.value.retryable is False
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from dify_agent.layers.knowledge import DifyKnowledgeBaseLayerConfig
|
||||
|
||||
|
||||
def _valid_config() -> dict[str, object]:
|
||||
return {
|
||||
"dataset_ids": ["dataset-1"],
|
||||
"retrieval": {
|
||||
"mode": "multiple",
|
||||
"top_k": 4,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_knowledge_base_config_accepts_valid_multiple_mode() -> None:
|
||||
config = DifyKnowledgeBaseLayerConfig.model_validate(_valid_config())
|
||||
|
||||
assert config.dataset_ids == ["dataset-1"]
|
||||
assert config.retrieval.top_k == 4
|
||||
assert config.metadata_filtering.mode == "disabled"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload, expected_message",
|
||||
[
|
||||
({"dataset_ids": [], "retrieval": {"mode": "multiple", "top_k": 4}}, "dataset_ids"),
|
||||
({"tool_name": "knowledge_base_search", **_valid_config()}, "Extra inputs are not permitted"),
|
||||
({"tool_description": "Search knowledge", **_valid_config()}, "Extra inputs are not permitted"),
|
||||
({"dataset_ids": ["dataset-1"], "retrieval": {"mode": "multiple"}}, "top_k"),
|
||||
({"dataset_ids": ["dataset-1"], "retrieval": {"mode": "single"}}, "retrieval.model"),
|
||||
(
|
||||
{
|
||||
"dataset_ids": ["dataset-1"],
|
||||
"retrieval": {"mode": "multiple", "top_k": 4},
|
||||
"metadata_filtering": {"mode": "automatic"},
|
||||
},
|
||||
"metadata_filtering.model_config",
|
||||
),
|
||||
(
|
||||
{
|
||||
"dataset_ids": ["dataset-1"],
|
||||
"retrieval": {"mode": "multiple", "top_k": 4},
|
||||
"metadata_filtering": {"mode": "manual"},
|
||||
},
|
||||
"metadata_filtering.conditions",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_knowledge_base_config_rejects_invalid_inputs(payload: dict[str, object], expected_message: str) -> None:
|
||||
with pytest.raises(ValidationError, match=expected_message):
|
||||
_ = DifyKnowledgeBaseLayerConfig.model_validate(payload)
|
||||
|
||||
|
||||
def test_knowledge_base_config_rejects_observation_limit_smaller_than_result_limit() -> None:
|
||||
with pytest.raises(ValidationError, match="max_observation_chars"):
|
||||
_ = DifyKnowledgeBaseLayerConfig.model_validate(
|
||||
{
|
||||
"dataset_ids": ["dataset-1"],
|
||||
"retrieval": {"mode": "multiple", "top_k": 4},
|
||||
"max_result_content_chars": 50,
|
||||
"max_observation_chars": 20,
|
||||
}
|
||||
)
|
||||
417
dify-agent/tests/local/dify_agent/layers/knowledge/test_layer.py
Normal file
417
dify-agent/tests/local/dify_agent/layers/knowledge/test_layer.py
Normal file
@ -0,0 +1,417 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic_ai import Tool
|
||||
|
||||
from agenton.compositor import Compositor, LayerNode, LayerProvider
|
||||
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
|
||||
from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer
|
||||
from dify_agent.layers.knowledge.client import DifyKnowledgeBaseClientError
|
||||
from dify_agent.layers.knowledge.configs import DifyKnowledgeBaseLayerConfig
|
||||
from dify_agent.layers.knowledge.layer import (
|
||||
BLANK_QUERY_OBSERVATION,
|
||||
DifyKnowledgeBaseLayer,
|
||||
NO_RESULTS_OBSERVATION,
|
||||
TEMPORARY_UNAVAILABLE_OBSERVATION,
|
||||
)
|
||||
|
||||
|
||||
def _execution_context_config(**overrides: object) -> DifyExecutionContextLayerConfig:
|
||||
payload: dict[str, object] = {
|
||||
"tenant_id": "tenant-1",
|
||||
"user_id": "user-1",
|
||||
"user_from": "account",
|
||||
"app_id": "app-1",
|
||||
"agent_mode": "agent_app",
|
||||
"invoke_from": "web-app",
|
||||
}
|
||||
payload.update(overrides)
|
||||
return DifyExecutionContextLayerConfig.model_validate(payload)
|
||||
|
||||
|
||||
def _knowledge_config(**overrides: object) -> DifyKnowledgeBaseLayerConfig:
|
||||
payload: dict[str, object] = {
|
||||
"dataset_ids": ["dataset-1"],
|
||||
"retrieval": {"mode": "multiple", "top_k": 4},
|
||||
}
|
||||
payload.update(overrides)
|
||||
return DifyKnowledgeBaseLayerConfig.model_validate(payload)
|
||||
|
||||
|
||||
def _execution_context_provider() -> LayerProvider[DifyExecutionContextLayer]:
|
||||
return LayerProvider.from_factory(
|
||||
layer_type=DifyExecutionContextLayer,
|
||||
create=lambda config: DifyExecutionContextLayer.from_config_with_settings(
|
||||
DifyExecutionContextLayerConfig.model_validate(config),
|
||||
daemon_url="http://plugin-daemon",
|
||||
daemon_api_key="daemon-secret",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _knowledge_provider() -> LayerProvider[DifyKnowledgeBaseLayer]:
|
||||
return LayerProvider.from_factory(
|
||||
layer_type=DifyKnowledgeBaseLayer,
|
||||
create=lambda config: DifyKnowledgeBaseLayer.from_config_with_settings(
|
||||
DifyKnowledgeBaseLayerConfig.model_validate(config),
|
||||
dify_api_inner_url="http://dify-api",
|
||||
dify_api_inner_api_key="inner-secret",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_knowledge_layer_exposes_one_query_only_tool_definition() -> None:
|
||||
async def scenario() -> None:
|
||||
compositor = Compositor(
|
||||
[
|
||||
LayerNode("execution_context", _execution_context_provider()),
|
||||
LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}),
|
||||
]
|
||||
)
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
async with compositor.enter(
|
||||
configs={
|
||||
"execution_context": _execution_context_config(),
|
||||
"knowledge": _knowledge_config(),
|
||||
}
|
||||
) as run:
|
||||
knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer)
|
||||
tool = (await knowledge_layer.get_tools(http_client=http_client))[0]
|
||||
tool_def = await tool.prepare_tool_def(None) # pyright: ignore[reportArgumentType]
|
||||
assert isinstance(tool, Tool)
|
||||
assert tool.name == "knowledge_base_search"
|
||||
assert tool.description == "Search configured knowledge bases for information relevant to the query."
|
||||
assert tool_def is not None
|
||||
assert (
|
||||
tool_def.description == "Search configured knowledge bases for information relevant to the query."
|
||||
)
|
||||
assert tool_def.parameters_json_schema == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query for the configured knowledge bases.",
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_knowledge_layer_rejects_blank_query_locally() -> None:
|
||||
async def scenario() -> None:
|
||||
compositor = Compositor(
|
||||
[
|
||||
LayerNode("execution_context", _execution_context_provider()),
|
||||
LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}),
|
||||
]
|
||||
)
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
async with compositor.enter(
|
||||
configs={
|
||||
"execution_context": _execution_context_config(),
|
||||
"knowledge": _knowledge_config(),
|
||||
}
|
||||
) as run:
|
||||
knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer)
|
||||
tool = (await knowledge_layer.get_tools(http_client=http_client))[0]
|
||||
result = await tool.function_schema.call({"query": " "}, None) # pyright: ignore[reportArgumentType]
|
||||
assert result == BLANK_QUERY_OBSERVATION
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field_name", "field_value"),
|
||||
[
|
||||
("user_id", None),
|
||||
("user_from", None),
|
||||
("app_id", None),
|
||||
],
|
||||
)
|
||||
def test_knowledge_layer_fails_fast_when_execution_context_is_missing_required_fields(
|
||||
field_name: str,
|
||||
field_value: object,
|
||||
) -> None:
|
||||
async def scenario() -> None:
|
||||
compositor = Compositor(
|
||||
[
|
||||
LayerNode("execution_context", _execution_context_provider()),
|
||||
LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}),
|
||||
]
|
||||
)
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
async with compositor.enter(
|
||||
configs={
|
||||
"execution_context": _execution_context_config(),
|
||||
"knowledge": _knowledge_config(),
|
||||
}
|
||||
) as run:
|
||||
execution_context_layer = run.get_layer("execution_context", DifyExecutionContextLayer)
|
||||
setattr(execution_context_layer.config, field_name, field_value)
|
||||
knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer)
|
||||
with pytest.raises(ValueError, match=field_name):
|
||||
_ = await knowledge_layer.get_tools(http_client=http_client)
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_knowledge_layer_formats_results_and_truncates_observation() -> None:
|
||||
def handler(_request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"results": [
|
||||
{
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"dataset_name": "Docs",
|
||||
"document_name": "Guide.md",
|
||||
"score": 0.9,
|
||||
},
|
||||
"title": "Guide",
|
||||
"files": [],
|
||||
"content": "ABCDEFGHIJKL",
|
||||
"summary": None,
|
||||
}
|
||||
],
|
||||
"usage": {},
|
||||
},
|
||||
)
|
||||
|
||||
async def scenario() -> None:
|
||||
compositor = Compositor(
|
||||
[
|
||||
LayerNode("execution_context", _execution_context_provider()),
|
||||
LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}),
|
||||
]
|
||||
)
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http_client:
|
||||
async with compositor.enter(
|
||||
configs={
|
||||
"execution_context": _execution_context_config(),
|
||||
"knowledge": _knowledge_config(max_result_content_chars=8, max_observation_chars=160),
|
||||
}
|
||||
) as run:
|
||||
knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer)
|
||||
tool = (await knowledge_layer.get_tools(http_client=http_client))[0]
|
||||
result = await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType]
|
||||
assert result.startswith("Knowledge base search results:\n1. Title: Guide")
|
||||
assert "Dataset: Docs" in result
|
||||
assert "Document: Guide.md" in result
|
||||
assert "Score: 0.9" in result
|
||||
assert "Content: ABCDE..." in result
|
||||
assert len(result) <= 160
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_knowledge_layer_returns_no_results_observation() -> None:
|
||||
async def scenario() -> None:
|
||||
compositor = Compositor(
|
||||
[
|
||||
LayerNode("execution_context", _execution_context_provider()),
|
||||
LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}),
|
||||
]
|
||||
)
|
||||
async with httpx.AsyncClient(
|
||||
transport=httpx.MockTransport(lambda _request: httpx.Response(200, json={"results": [], "usage": {}}))
|
||||
) as http_client:
|
||||
async with compositor.enter(
|
||||
configs={
|
||||
"execution_context": _execution_context_config(),
|
||||
"knowledge": _knowledge_config(),
|
||||
}
|
||||
) as run:
|
||||
knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer)
|
||||
tool = (await knowledge_layer.get_tools(http_client=http_client))[0]
|
||||
result = await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType]
|
||||
assert result == NO_RESULTS_OBSERVATION
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_knowledge_layer_converts_retryable_failures_into_observation() -> None:
|
||||
async def scenario() -> None:
|
||||
compositor = Compositor(
|
||||
[
|
||||
LayerNode("execution_context", _execution_context_provider()),
|
||||
LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}),
|
||||
]
|
||||
)
|
||||
async with httpx.AsyncClient(
|
||||
transport=httpx.MockTransport(
|
||||
lambda _request: httpx.Response(429, json={"code": "knowledge_rate_limited", "message": "slow down"})
|
||||
)
|
||||
) as http_client:
|
||||
async with compositor.enter(
|
||||
configs={
|
||||
"execution_context": _execution_context_config(),
|
||||
"knowledge": _knowledge_config(),
|
||||
}
|
||||
) as run:
|
||||
knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer)
|
||||
tool = (await knowledge_layer.get_tools(http_client=http_client))[0]
|
||||
result = await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType]
|
||||
assert result == TEMPORARY_UNAVAILABLE_OBSERVATION
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"transport_error",
|
||||
[
|
||||
lambda request: httpx.ReadTimeout("timed out", request=request),
|
||||
lambda request: httpx.ConnectError("connection failed", request=request),
|
||||
],
|
||||
)
|
||||
def test_knowledge_layer_converts_retryable_transport_failures_into_observation(transport_error) -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
raise transport_error(request)
|
||||
|
||||
async def scenario() -> None:
|
||||
compositor = Compositor(
|
||||
[
|
||||
LayerNode("execution_context", _execution_context_provider()),
|
||||
LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}),
|
||||
]
|
||||
)
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http_client:
|
||||
async with compositor.enter(
|
||||
configs={
|
||||
"execution_context": _execution_context_config(),
|
||||
"knowledge": _knowledge_config(),
|
||||
}
|
||||
) as run:
|
||||
knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer)
|
||||
tool = (await knowledge_layer.get_tools(http_client=http_client))[0]
|
||||
result = await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType]
|
||||
assert result == TEMPORARY_UNAVAILABLE_OBSERVATION
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_knowledge_layer_raises_non_retryable_client_errors() -> None:
|
||||
async def scenario() -> None:
|
||||
compositor = Compositor(
|
||||
[
|
||||
LayerNode("execution_context", _execution_context_provider()),
|
||||
LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}),
|
||||
]
|
||||
)
|
||||
async with httpx.AsyncClient(
|
||||
transport=httpx.MockTransport(
|
||||
lambda _request: httpx.Response(403, json={"code": "dataset_tenant_mismatch", "message": "forbidden"})
|
||||
)
|
||||
) as http_client:
|
||||
async with compositor.enter(
|
||||
configs={
|
||||
"execution_context": _execution_context_config(),
|
||||
"knowledge": _knowledge_config(),
|
||||
}
|
||||
) as run:
|
||||
knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer)
|
||||
tool = (await knowledge_layer.get_tools(http_client=http_client))[0]
|
||||
with pytest.raises(DifyKnowledgeBaseClientError) as exc_info:
|
||||
await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType]
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_knowledge_layer_raises_for_malformed_success_responses() -> None:
|
||||
async def scenario() -> None:
|
||||
compositor = Compositor(
|
||||
[
|
||||
LayerNode("execution_context", _execution_context_provider()),
|
||||
LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}),
|
||||
]
|
||||
)
|
||||
async with httpx.AsyncClient(
|
||||
transport=httpx.MockTransport(lambda _request: httpx.Response(200, json={"bad": []}))
|
||||
) as http_client:
|
||||
async with compositor.enter(
|
||||
configs={
|
||||
"execution_context": _execution_context_config(),
|
||||
"knowledge": _knowledge_config(),
|
||||
}
|
||||
) as run:
|
||||
knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer)
|
||||
tool = (await knowledge_layer.get_tools(http_client=http_client))[0]
|
||||
with pytest.raises(DifyKnowledgeBaseClientError) as exc_info:
|
||||
await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType]
|
||||
assert exc_info.value.error_code == "invalid_response"
|
||||
assert exc_info.value.retryable is False
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_knowledge_layer_sends_execution_context_and_static_config_to_inner_api() -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
payload = json.loads(request.content.decode("utf-8"))
|
||||
assert request.headers["X-Inner-Api-Key"] == "inner-secret"
|
||||
assert payload["caller"] == {
|
||||
"tenant_id": "tenant-1",
|
||||
"user_id": "user-1",
|
||||
"app_id": "app-1",
|
||||
"user_from": "account",
|
||||
"invoke_from": "web-app",
|
||||
}
|
||||
assert payload["dataset_ids"] == ["dataset-1", "dataset-2"]
|
||||
assert payload["query"] == "reset"
|
||||
assert payload["retrieval"]["top_k"] == 2
|
||||
assert payload["metadata_filtering"] == {
|
||||
"mode": "manual",
|
||||
"conditions": {
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"name": "category",
|
||||
"comparison_operator": "contains",
|
||||
"value": "auth",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
return httpx.Response(200, json={"results": [], "usage": {}})
|
||||
|
||||
async def scenario() -> None:
|
||||
compositor = Compositor(
|
||||
[
|
||||
LayerNode("execution_context", _execution_context_provider()),
|
||||
LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}),
|
||||
]
|
||||
)
|
||||
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http_client:
|
||||
async with compositor.enter(
|
||||
configs={
|
||||
"execution_context": _execution_context_config(),
|
||||
"knowledge": _knowledge_config(
|
||||
dataset_ids=["dataset-1", "dataset-2"],
|
||||
retrieval={"mode": "multiple", "top_k": 2},
|
||||
metadata_filtering={
|
||||
"mode": "manual",
|
||||
"conditions": {
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"name": "category",
|
||||
"comparison_operator": "contains",
|
||||
"value": "auth",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
),
|
||||
}
|
||||
) as run:
|
||||
knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer)
|
||||
tool = (await knowledge_layer.get_tools(http_client=http_client))[0]
|
||||
result = await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType]
|
||||
assert result == NO_RESULTS_OBSERVATION
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -120,6 +120,7 @@ def test_create_run_starts_background_task_and_returns_running() -> None:
|
||||
scheduler = RunScheduler(
|
||||
store=store,
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
runner_factory=lambda _record, _request: ControlledRunner(started=started, release=release),
|
||||
)
|
||||
|
||||
@ -144,6 +145,7 @@ def test_shutdown_marks_unfinished_runs_failed_and_appends_event() -> None:
|
||||
scheduler = RunScheduler(
|
||||
store=store,
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
shutdown_grace_seconds=0,
|
||||
runner_factory=lambda _record, _request: ControlledRunner(started=started, release=asyncio.Event()),
|
||||
)
|
||||
@ -165,7 +167,7 @@ def test_create_run_accepts_blank_prompt_and_runner_fails_asynchronously() -> No
|
||||
async def scenario() -> None:
|
||||
store = FakeStore()
|
||||
async with httpx.AsyncClient() as client:
|
||||
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
|
||||
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client, dify_api_http_client=client)
|
||||
|
||||
record = await scheduler.create_run(_request(["", " "]))
|
||||
await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1)
|
||||
@ -182,7 +184,7 @@ def test_create_run_accepts_invalid_output_schema_and_runner_fails_asynchronousl
|
||||
async def scenario() -> None:
|
||||
store = FakeStore()
|
||||
async with httpx.AsyncClient() as client:
|
||||
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
|
||||
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client, dify_api_http_client=client)
|
||||
|
||||
record = await scheduler.create_run(
|
||||
_request(
|
||||
@ -205,7 +207,12 @@ def test_create_run_honors_explicit_empty_layer_providers_by_failing_after_persi
|
||||
async def scenario() -> None:
|
||||
store = FakeStore()
|
||||
async with httpx.AsyncClient() as client:
|
||||
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client, layer_providers=())
|
||||
scheduler = RunScheduler(
|
||||
store=store,
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
layer_providers=(),
|
||||
)
|
||||
|
||||
record = await scheduler.create_run(_request())
|
||||
await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1)
|
||||
@ -222,7 +229,7 @@ def test_create_run_accepts_closed_session_snapshot_and_runner_fails_asynchronou
|
||||
async def scenario() -> None:
|
||||
store = FakeStore()
|
||||
async with httpx.AsyncClient() as client:
|
||||
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
|
||||
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client, dify_api_http_client=client)
|
||||
request = _request()
|
||||
request.session_snapshot = CompositorSessionSnapshot(
|
||||
layers=[
|
||||
@ -248,7 +255,7 @@ def test_create_run_accepts_closed_session_snapshot_and_runner_fails_asynchronou
|
||||
def test_create_run_rejects_after_shutdown_starts() -> None:
|
||||
async def scenario() -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
scheduler = RunScheduler(store=FakeStore(), plugin_daemon_http_client=client)
|
||||
scheduler = RunScheduler(store=FakeStore(), plugin_daemon_http_client=client, dify_api_http_client=client)
|
||||
await scheduler.shutdown()
|
||||
|
||||
with pytest.raises(SchedulerStoppingError):
|
||||
@ -261,7 +268,7 @@ def test_create_run_rejects_invalid_request_after_shutdown_without_persisting()
|
||||
async def scenario() -> None:
|
||||
store = FakeStore()
|
||||
async with httpx.AsyncClient() as client:
|
||||
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client)
|
||||
scheduler = RunScheduler(store=store, plugin_daemon_http_client=client, dify_api_http_client=client)
|
||||
await scheduler.shutdown()
|
||||
|
||||
with pytest.raises(SchedulerStoppingError):
|
||||
@ -282,6 +289,7 @@ def test_shutdown_waits_for_in_flight_create_to_register_before_cancelling() ->
|
||||
scheduler = RunScheduler(
|
||||
store=store,
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
shutdown_grace_seconds=0,
|
||||
runner_factory=lambda _record, _request: ControlledRunner(
|
||||
started=runner_started, release=asyncio.Event()
|
||||
|
||||
@ -41,6 +41,8 @@ from dify_agent.layers.dify_plugin.configs import (
|
||||
)
|
||||
from dify_agent.layers.dify_plugin.llm_layer import DifyPluginLLMLayer
|
||||
from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer
|
||||
from dify_agent.layers.knowledge.configs import DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID, DifyKnowledgeBaseLayerConfig
|
||||
from dify_agent.layers.knowledge.layer import DifyKnowledgeBaseLayer
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
|
||||
from dify_agent.protocol import DIFY_AGENT_HISTORY_LAYER_ID, DIFY_AGENT_MODEL_LAYER_ID, DIFY_AGENT_OUTPUT_LAYER_ID
|
||||
from dify_agent.protocol.schemas import (
|
||||
@ -357,6 +359,7 @@ def test_runner_emits_terminal_success_and_snapshot(monkeypatch: pytest.MonkeyPa
|
||||
request=request,
|
||||
run_id="run-1",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
assert seen_clients == [client]
|
||||
assert client.is_closed is False
|
||||
@ -406,6 +409,7 @@ def test_runner_preserves_explicit_json_null_output(monkeypatch: pytest.MonkeyPa
|
||||
request=request,
|
||||
run_id="run-null-output",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -462,6 +466,7 @@ def test_runner_emits_deferred_tool_call_and_persists_pending_history(monkeypatc
|
||||
request=request,
|
||||
run_id="run-ask-human",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -558,6 +563,7 @@ def test_runner_resumes_with_deferred_tool_results_and_no_user_prompt(monkeypatc
|
||||
request=request,
|
||||
run_id="run-ask-human-initial",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
initial_terminal = sink.events["run-ask-human-initial"][-1]
|
||||
@ -582,6 +588,7 @@ def test_runner_resumes_with_deferred_tool_results_and_no_user_prompt(monkeypatc
|
||||
request=resumed_request,
|
||||
run_id="run-ask-human-resume",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -657,6 +664,7 @@ def test_runner_can_emit_second_deferred_tool_call_after_resume(monkeypatch: pyt
|
||||
request=request,
|
||||
run_id="run-ask-human-turn-1",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
first_terminal = sink.events["run-ask-human-turn-1"][-1]
|
||||
@ -681,6 +689,7 @@ def test_runner_can_emit_second_deferred_tool_call_after_resume(monkeypatch: pyt
|
||||
request=resumed_request,
|
||||
run_id="run-ask-human-turn-2",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -736,6 +745,7 @@ def test_runner_rejects_deferred_tool_call_without_history_layer(monkeypatch: py
|
||||
request=request,
|
||||
run_id="run-ask-human-no-history",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -785,6 +795,7 @@ def test_runner_rejects_resume_with_deferred_tool_results_without_history_layer(
|
||||
request=request,
|
||||
run_id="run-ask-human-resume-no-history",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -823,6 +834,7 @@ def test_runner_rejects_multiple_deferred_tool_calls(monkeypatch: pytest.MonkeyP
|
||||
request=request,
|
||||
run_id="run-ask-human-multi",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -861,6 +873,7 @@ def test_runner_rejects_deferred_approval_requests(monkeypatch: pytest.MonkeyPat
|
||||
request=request,
|
||||
run_id="run-ask-human-approval",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -960,6 +973,7 @@ def test_runner_passes_dynamic_dify_plugin_tools_to_agent(monkeypatch: pytest.Mo
|
||||
request=request,
|
||||
run_id="run-tools",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -970,6 +984,105 @@ def test_runner_passes_dynamic_dify_plugin_tools_to_agent(monkeypatch: pytest.Mo
|
||||
assert terminal.data.output == "done"
|
||||
|
||||
|
||||
def test_runner_passes_dynamic_dify_knowledge_tools_to_agent(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
seen_tools: list[Tool[object]] = []
|
||||
|
||||
async def knowledge_tool() -> str:
|
||||
return "knowledge"
|
||||
|
||||
def fake_get_model(_self: DifyPluginLLMLayer, *, http_client: httpx.AsyncClient):
|
||||
assert http_client.is_closed is False
|
||||
return TestModel(custom_output_text="done") # pyright: ignore[reportReturnType]
|
||||
|
||||
async def fake_get_tools(self: DifyKnowledgeBaseLayer, *, http_client: httpx.AsyncClient) -> list[Tool[object]]:
|
||||
assert self.config.dataset_ids == ["dataset-1"]
|
||||
assert http_client.headers.get("X-Test-Client") == "dify-api"
|
||||
return [Tool(knowledge_tool, name="knowledge_base_search")]
|
||||
|
||||
class FakeResult:
|
||||
output: str = "done"
|
||||
|
||||
def new_messages(self) -> list[ModelMessage]:
|
||||
return []
|
||||
|
||||
class FakeAgent:
|
||||
async def run(self, *_args: object, **_kwargs: object) -> FakeResult:
|
||||
return FakeResult()
|
||||
|
||||
def fake_create_agent(model: object, *, tools: list[Tool[object]], output_type: object) -> FakeAgent:
|
||||
del model, output_type
|
||||
seen_tools.extend(tools)
|
||||
return FakeAgent()
|
||||
|
||||
monkeypatch.setattr(DifyPluginLLMLayer, "get_model", fake_get_model)
|
||||
monkeypatch.setattr(DifyKnowledgeBaseLayer, "get_tools", fake_get_tools)
|
||||
monkeypatch.setattr("dify_agent.runtime.runner.create_agent", fake_create_agent)
|
||||
|
||||
request = CreateRunRequest(
|
||||
composition=RunComposition(
|
||||
layers=[
|
||||
RunLayerSpec(
|
||||
name="prompt",
|
||||
type="plain.prompt",
|
||||
config=PromptLayerConfig(prefix="system", user="hello"),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name="execution_context",
|
||||
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
config=DifyExecutionContextLayerConfig(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
user_from="account",
|
||||
app_id="app-1",
|
||||
agent_mode="workflow_run",
|
||||
invoke_from="service-api",
|
||||
),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_MODEL_LAYER_ID,
|
||||
type="dify.plugin.llm",
|
||||
deps={"execution_context": "execution_context"},
|
||||
config=DifyPluginLLMLayerConfig(
|
||||
plugin_id="langgenius/openai",
|
||||
model_provider="openai",
|
||||
model="demo-model",
|
||||
credentials={"api_key": "secret"},
|
||||
),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name="knowledge",
|
||||
type=DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID,
|
||||
deps={"execution_context": "execution_context"},
|
||||
config=DifyKnowledgeBaseLayerConfig.model_validate(
|
||||
{
|
||||
"dataset_ids": ["dataset-1"],
|
||||
"retrieval": {"mode": "multiple", "top_k": 4},
|
||||
}
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
sink = InMemoryRunEventSink()
|
||||
|
||||
async def scenario() -> None:
|
||||
async with (
|
||||
httpx.AsyncClient() as plugin_client,
|
||||
httpx.AsyncClient(headers={"X-Test-Client": "dify-api"}) as dify_api_client,
|
||||
):
|
||||
await AgentRunRunner(
|
||||
sink=sink,
|
||||
request=request,
|
||||
run_id="run-knowledge-tools",
|
||||
plugin_daemon_http_client=plugin_client,
|
||||
dify_api_http_client=dify_api_client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
assert [tool.name for tool in seen_tools] == ["knowledge_base_search"]
|
||||
|
||||
|
||||
def test_runner_rejects_duplicate_tool_names_across_dynamic_tool_layers(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
@ -1075,6 +1188,7 @@ def test_runner_rejects_duplicate_tool_names_across_dynamic_tool_layers(
|
||||
request=request,
|
||||
run_id="run-duplicate-tools",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -1182,6 +1296,7 @@ def test_runner_rejects_duplicate_tool_names_between_static_and_dynamic_tools(
|
||||
request=request,
|
||||
run_id="run-static-dynamic-duplicate-tools",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
layer_providers=layer_providers,
|
||||
).run()
|
||||
|
||||
@ -1297,6 +1412,7 @@ def test_runner_rejects_duplicate_tool_names_between_shell_and_other_layers(
|
||||
request=request,
|
||||
run_id="run-shell-duplicate-tools",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
layer_providers=layer_providers,
|
||||
).run()
|
||||
|
||||
@ -1325,6 +1441,7 @@ def test_runner_passes_temporary_system_prompt_prefix_without_history_layer(monk
|
||||
request=_request("current user"),
|
||||
run_id="run-no-history",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -1368,6 +1485,7 @@ def test_runner_prepends_current_system_prompt_to_stored_history_and_appends_onl
|
||||
request=request,
|
||||
run_id="run-history",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -1418,6 +1536,7 @@ def test_runner_with_empty_history_layer_still_sends_system_prompt_and_saves_onl
|
||||
request=request,
|
||||
run_id="run-empty-history",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -1468,6 +1587,7 @@ def test_runner_failure_with_history_layer_emits_failed_terminal_event_without_s
|
||||
request=request,
|
||||
run_id="run-history-failure",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -1499,6 +1619,7 @@ def test_runner_applies_on_exit_overrides_to_success_snapshot(monkeypatch: pytes
|
||||
request=request,
|
||||
run_id="run-exit",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -1559,6 +1680,7 @@ def test_runner_passes_output_layer_spec_to_agent_and_serializes_structured_resu
|
||||
request=request,
|
||||
run_id="run-structured-output",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
first_terminal = sink.events["run-structured-output"][-1]
|
||||
@ -1572,6 +1694,7 @@ def test_runner_passes_output_layer_spec_to_agent_and_serializes_structured_resu
|
||||
request=resumed_request,
|
||||
run_id="run-structured-output-resume",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -1645,6 +1768,7 @@ def test_runner_retries_invalid_structured_output_and_eventually_succeeds(monkey
|
||||
request=request,
|
||||
run_id="run-output-retry-success",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -1698,6 +1822,7 @@ def test_runner_fails_when_invalid_structured_output_exhausts_retries(monkeypatc
|
||||
request=request,
|
||||
run_id="run-output-retry-failed",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -1734,6 +1859,7 @@ def test_runner_rejects_invalid_output_layer_before_model_resolution(monkeypatch
|
||||
request=request,
|
||||
run_id="run-invalid-output",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -1808,6 +1934,7 @@ def test_runner_rejects_misnamed_output_layer_before_model_resolution(monkeypatc
|
||||
request=request,
|
||||
run_id="run-misnamed-output",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -1894,6 +2021,7 @@ def test_runner_rejects_multiple_output_layers_before_model_resolution(monkeypat
|
||||
request=request,
|
||||
run_id="run-duplicate-output",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -1965,6 +2093,7 @@ def test_runner_rejects_reserved_output_name_with_wrong_layer_type_before_model_
|
||||
request=request,
|
||||
run_id="run-wrong-output-type",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -2009,6 +2138,7 @@ def test_runner_rejects_misnamed_output_layer_before_provider_checks() -> None:
|
||||
request=request,
|
||||
run_id="run-misnamed-output-before-providers",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
layer_providers=(),
|
||||
).run()
|
||||
|
||||
@ -2033,6 +2163,7 @@ def test_runner_rejects_unknown_on_exit_layer_id() -> None:
|
||||
request=request,
|
||||
run_id="run-unknown-signal",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -2053,6 +2184,7 @@ def test_runner_honors_explicit_empty_layer_providers() -> None:
|
||||
request=request,
|
||||
run_id="run-empty-providers",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
layer_providers=(),
|
||||
).run()
|
||||
|
||||
@ -2074,6 +2206,7 @@ def test_runner_fails_empty_user_prompts() -> None:
|
||||
request=request,
|
||||
run_id="run-2",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -2094,6 +2227,7 @@ def test_runner_fails_blank_string_user_prompt_list() -> None:
|
||||
request=request,
|
||||
run_id="run-3",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -2114,6 +2248,7 @@ def test_runner_requires_llm_layer_id() -> None:
|
||||
request=request,
|
||||
run_id="run-4",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -2153,6 +2288,7 @@ def test_runner_rejects_closed_session_snapshot_as_validation_error() -> None:
|
||||
request=request,
|
||||
run_id="run-closed-snapshot",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -2205,6 +2341,7 @@ def test_runner_treats_missing_shell_entrypoint_as_validation_error() -> None:
|
||||
request=request,
|
||||
run_id="run-missing-shell-entrypoint",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
).run()
|
||||
|
||||
asyncio.run(scenario())
|
||||
@ -2282,6 +2419,7 @@ def test_runner_treats_invalid_shell_snapshot_offsets_as_validation_error() -> N
|
||||
request=request,
|
||||
run_id="run-invalid-shell-offset",
|
||||
plugin_daemon_http_client=client,
|
||||
dify_api_http_client=client,
|
||||
layer_providers=create_default_layer_providers(shellctl_entrypoint="http://shellctl"),
|
||||
).run()
|
||||
|
||||
|
||||
@ -13,10 +13,12 @@ from shell_session_manager.shellctl.client import ShellctlClient
|
||||
import dify_agent.server.app as app_module
|
||||
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
|
||||
from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer
|
||||
from dify_agent.layers.knowledge.configs import DifyKnowledgeBaseLayerConfig
|
||||
from dify_agent.layers.knowledge.layer import DifyKnowledgeBaseLayer
|
||||
from dify_agent.layers.shell import DifyShellLayerConfig
|
||||
from dify_agent.layers.shell.layer import DifyShellLayer
|
||||
from dify_agent.runtime.compositor_factory import DifyAgentLayerProvider
|
||||
from dify_agent.server.app import create_app, create_plugin_daemon_http_client
|
||||
from dify_agent.server.app import create_app, create_dify_api_inner_http_client, create_plugin_daemon_http_client
|
||||
from dify_agent.server.settings import ServerSettings
|
||||
from dify_agent.storage.redis_run_store import RedisRunStore
|
||||
|
||||
@ -67,6 +69,7 @@ class FakeRunScheduler:
|
||||
shutdown_grace_seconds: float
|
||||
layer_providers: tuple[DifyAgentLayerProvider, ...]
|
||||
plugin_daemon_http_client: FakePluginDaemonHttpClient
|
||||
dify_api_http_client: FakePluginDaemonHttpClient
|
||||
shutdown_called: bool
|
||||
|
||||
def __init__(
|
||||
@ -74,6 +77,7 @@ class FakeRunScheduler:
|
||||
*,
|
||||
store: object,
|
||||
plugin_daemon_http_client: FakePluginDaemonHttpClient,
|
||||
dify_api_http_client: FakePluginDaemonHttpClient,
|
||||
shutdown_grace_seconds: float,
|
||||
layer_providers: tuple[DifyAgentLayerProvider, ...],
|
||||
) -> None:
|
||||
@ -81,6 +85,7 @@ class FakeRunScheduler:
|
||||
self.shutdown_grace_seconds = shutdown_grace_seconds
|
||||
self.layer_providers = layer_providers
|
||||
self.plugin_daemon_http_client = plugin_daemon_http_client
|
||||
self.dify_api_http_client = dify_api_http_client
|
||||
self.shutdown_called = False
|
||||
self.created.append(self)
|
||||
|
||||
@ -160,7 +165,22 @@ class FakeHttpxModule:
|
||||
|
||||
|
||||
def test_create_app_creates_scheduler_and_closes_after_shutdown(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fake_redis, fake_http_client = _patch_app_lifecycle(monkeypatch)
|
||||
fake_redis = FakeRedis()
|
||||
fake_http_client = FakePluginDaemonHttpClient()
|
||||
fake_dify_api_http_client = FakePluginDaemonHttpClient()
|
||||
FakeRunScheduler.created.clear()
|
||||
FakeRedisModule.fake_redis = fake_redis
|
||||
monkeypatch.setattr(app_module, "Redis", FakeRedisModule)
|
||||
monkeypatch.setattr(app_module, "RunScheduler", FakeRunScheduler)
|
||||
|
||||
def fake_create_plugin_daemon_http_client(_settings: ServerSettings) -> FakePluginDaemonHttpClient:
|
||||
return fake_http_client
|
||||
|
||||
def fake_create_dify_api_inner_http_client(_settings: ServerSettings) -> FakePluginDaemonHttpClient:
|
||||
return fake_dify_api_http_client
|
||||
|
||||
monkeypatch.setattr(app_module, "create_plugin_daemon_http_client", fake_create_plugin_daemon_http_client)
|
||||
monkeypatch.setattr(app_module, "create_dify_api_inner_http_client", fake_create_dify_api_inner_http_client)
|
||||
|
||||
settings = ServerSettings(
|
||||
redis_url="redis://example.invalid/0",
|
||||
@ -169,19 +189,20 @@ def test_create_app_creates_scheduler_and_closes_after_shutdown(monkeypatch: pyt
|
||||
run_retention_seconds=7,
|
||||
plugin_daemon_url="http://plugin-daemon",
|
||||
plugin_daemon_api_key="daemon-secret",
|
||||
dify_api_inner_url="http://dify-api",
|
||||
shellctl_entrypoint="http://shellctl",
|
||||
shellctl_auth_token="shell-secret",
|
||||
agent_stub_url="https://agent.example.com/agent-stub",
|
||||
server_secret_key=_base64url_secret(b"1" * 32),
|
||||
dify_api_base_url="https://api.example.com",
|
||||
dify_api_inner_api_key="inner-secret",
|
||||
plugin_daemon_connect_timeout=1,
|
||||
plugin_daemon_read_timeout=2,
|
||||
plugin_daemon_write_timeout=3,
|
||||
plugin_daemon_pool_timeout=4,
|
||||
plugin_daemon_max_connections=5,
|
||||
plugin_daemon_max_keepalive_connections=3,
|
||||
plugin_daemon_keepalive_expiry=6,
|
||||
outbound_http_connect_timeout=1,
|
||||
outbound_http_read_timeout=2,
|
||||
outbound_http_write_timeout=3,
|
||||
outbound_http_pool_timeout=4,
|
||||
outbound_http_max_connections=5,
|
||||
outbound_http_max_keepalive_connections=3,
|
||||
outbound_http_keepalive_expiry=6,
|
||||
)
|
||||
|
||||
with TestClient(create_app(settings)):
|
||||
@ -207,6 +228,18 @@ def test_create_app_creates_scheduler_and_closes_after_shutdown(monkeypatch: pyt
|
||||
assert isinstance(shell_layer, DifyShellLayer)
|
||||
assert execution_context_layer.daemon_url == "http://plugin-daemon"
|
||||
assert execution_context_layer.daemon_api_key == "daemon-secret"
|
||||
knowledge_provider = next(provider for provider in layer_providers if provider.type_id == "dify.knowledge_base")
|
||||
knowledge_layer = knowledge_provider.create_layer(
|
||||
DifyKnowledgeBaseLayerConfig.model_validate(
|
||||
{
|
||||
"dataset_ids": ["dataset-1"],
|
||||
"retrieval": {"mode": "multiple", "top_k": 2},
|
||||
}
|
||||
)
|
||||
)
|
||||
assert isinstance(knowledge_layer, DifyKnowledgeBaseLayer)
|
||||
assert knowledge_layer.dify_api_inner_url == "http://dify-api"
|
||||
assert knowledge_layer.dify_api_inner_api_key == "inner-secret"
|
||||
assert shell_layer.shellctl_entrypoint == "http://shellctl"
|
||||
assert shell_layer.agent_stub_url == "https://agent.example.com/agent-stub"
|
||||
shellctl_client = shell_layer.shellctl_client_factory("http://shellctl")
|
||||
@ -216,6 +249,8 @@ def test_create_app_creates_scheduler_and_closes_after_shutdown(monkeypatch: pyt
|
||||
http_client = scheduler.plugin_daemon_http_client
|
||||
assert http_client is fake_http_client
|
||||
assert http_client.is_closed is False
|
||||
assert scheduler.dify_api_http_client is fake_dify_api_http_client
|
||||
assert scheduler.dify_api_http_client.is_closed is False
|
||||
store = scheduler.store
|
||||
assert isinstance(store, RedisRunStore)
|
||||
assert store.run_retention_seconds == 7
|
||||
@ -229,6 +264,7 @@ def test_create_app_creates_scheduler_and_closes_after_shutdown(monkeypatch: pyt
|
||||
)
|
||||
|
||||
assert FakeRunScheduler.created[0].shutdown_called is True
|
||||
assert FakeRunScheduler.created[0].dify_api_http_client.is_closed is True
|
||||
assert FakeRunScheduler.created[0].plugin_daemon_http_client.is_closed is True
|
||||
assert fake_redis.closed is True
|
||||
|
||||
@ -326,21 +362,75 @@ def test_create_app_starts_and_stops_agent_stub_grpc_server_for_grpc_url(monkeyp
|
||||
assert fake_redis.closed is True
|
||||
|
||||
|
||||
def test_create_plugin_daemon_http_client_uses_configured_httpx_construction_args(
|
||||
def test_create_plugin_daemon_http_client_uses_generic_outbound_httpx_construction_args(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(app_module, "httpx", FakeHttpxModule)
|
||||
|
||||
client = create_plugin_daemon_http_client(ServerSettings())
|
||||
client = create_plugin_daemon_http_client(
|
||||
ServerSettings(
|
||||
outbound_http_connect_timeout=1,
|
||||
outbound_http_read_timeout=2,
|
||||
outbound_http_write_timeout=3,
|
||||
outbound_http_pool_timeout=4,
|
||||
outbound_http_max_connections=5,
|
||||
outbound_http_max_keepalive_connections=3,
|
||||
outbound_http_keepalive_expiry=6,
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(client, FakePluginDaemonHttpClient)
|
||||
assert isinstance(client.timeout, FakeTimeout)
|
||||
assert client.timeout.connect == 10
|
||||
assert client.timeout.read == 600
|
||||
assert client.timeout.write == 30
|
||||
assert client.timeout.pool == 10
|
||||
assert client.timeout.connect == 1
|
||||
assert client.timeout.read == 2
|
||||
assert client.timeout.write == 3
|
||||
assert client.timeout.pool == 4
|
||||
assert isinstance(client.limits, FakeLimits)
|
||||
assert client.limits.max_connections == 100
|
||||
assert client.limits.max_keepalive_connections == 20
|
||||
assert client.limits.keepalive_expiry == 30
|
||||
assert client.limits.max_connections == 5
|
||||
assert client.limits.max_keepalive_connections == 3
|
||||
assert client.limits.keepalive_expiry == 6
|
||||
assert client.trust_env is False
|
||||
|
||||
|
||||
def test_create_dify_api_inner_http_client_uses_generic_outbound_httpx_construction_args(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(app_module, "httpx", FakeHttpxModule)
|
||||
|
||||
client = create_dify_api_inner_http_client(
|
||||
ServerSettings(
|
||||
outbound_http_connect_timeout=1,
|
||||
outbound_http_read_timeout=2,
|
||||
outbound_http_write_timeout=3,
|
||||
outbound_http_pool_timeout=4,
|
||||
outbound_http_max_connections=5,
|
||||
outbound_http_max_keepalive_connections=3,
|
||||
outbound_http_keepalive_expiry=6,
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(client, FakePluginDaemonHttpClient)
|
||||
assert isinstance(client.timeout, FakeTimeout)
|
||||
assert client.timeout.connect == 1
|
||||
assert client.timeout.read == 2
|
||||
assert client.timeout.write == 3
|
||||
assert client.timeout.pool == 4
|
||||
assert isinstance(client.limits, FakeLimits)
|
||||
assert client.limits.max_connections == 5
|
||||
assert client.limits.max_keepalive_connections == 3
|
||||
assert client.limits.keepalive_expiry == 6
|
||||
assert client.trust_env is False
|
||||
|
||||
|
||||
def test_server_settings_use_generic_outbound_http_args_for_shared_clients() -> None:
|
||||
model_fields = ServerSettings.model_fields
|
||||
|
||||
assert "dify_api_inner_url" in model_fields
|
||||
assert "dify_api_inner_api_key" in model_fields
|
||||
assert "outbound_http_connect_timeout" in model_fields
|
||||
assert "outbound_http_read_timeout" in model_fields
|
||||
assert "outbound_http_write_timeout" in model_fields
|
||||
assert "outbound_http_pool_timeout" in model_fields
|
||||
assert "outbound_http_max_connections" in model_fields
|
||||
assert "outbound_http_max_keepalive_connections" in model_fields
|
||||
assert "outbound_http_keepalive_expiry" in model_fields
|
||||
|
||||
@ -129,12 +129,13 @@ def test_server_settings_normalizes_dify_api_base_url_from_env(monkeypatch: pyte
|
||||
assert settings.dify_api_inner_api_key == "inner-secret"
|
||||
|
||||
|
||||
def test_server_settings_requires_dify_api_base_url_and_key_together() -> None:
|
||||
with pytest.raises(ValidationError, match="DIFY_AGENT_DIFY_API_BASE_URL"):
|
||||
def test_server_settings_requires_inner_api_key_when_dify_api_base_url_is_set() -> None:
|
||||
with pytest.raises(ValidationError, match="DIFY_AGENT_DIFY_API_INNER_API_KEY"):
|
||||
_ = ServerSettings(dify_api_base_url="https://api.example.com")
|
||||
|
||||
with pytest.raises(ValidationError, match="DIFY_AGENT_DIFY_API_BASE_URL"):
|
||||
_ = ServerSettings(dify_api_inner_api_key="inner-secret")
|
||||
settings = ServerSettings(dify_api_inner_api_key="inner-secret")
|
||||
assert settings.dify_api_inner_api_key == "inner-secret"
|
||||
assert settings.dify_api_base_url is None
|
||||
|
||||
|
||||
def test_server_settings_rejects_dify_api_base_url_with_query_or_fragment() -> None:
|
||||
|
||||
@ -84,6 +84,8 @@ def test_protocol_and_dify_plugin_exports_do_not_import_server_only_modules() ->
|
||||
"dify_agent.layers.ask_human.layer",
|
||||
"dify_agent.layers.dify_plugin.llm_layer",
|
||||
"dify_agent.layers.dify_plugin.tools_layer",
|
||||
"dify_agent.layers.knowledge.client",
|
||||
"dify_agent.layers.knowledge.layer",
|
||||
"dify_agent.layers.output.output_layer",
|
||||
"dify_agent.layers.shell.layer",
|
||||
"dify_agent.runtime",
|
||||
@ -103,6 +105,7 @@ def test_protocol_and_dify_plugin_exports_do_not_import_server_only_modules() ->
|
||||
"dify_agent.layers.execution_context",
|
||||
"dify_agent.layers.ask_human",
|
||||
"dify_agent.layers.dify_plugin",
|
||||
"dify_agent.layers.knowledge",
|
||||
"dify_agent.layers.output",
|
||||
"dify_agent.layers.shell",
|
||||
],
|
||||
@ -112,6 +115,7 @@ def test_protocol_and_dify_plugin_exports_do_not_import_server_only_modules() ->
|
||||
"assert dify_agent_layers_execution_context.__all__ == ['DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID', 'DifyExecutionContextAgentMode', 'DifyExecutionContextInvokeFrom', 'DifyExecutionContextLayerConfig', 'DifyExecutionContextUserFrom']",
|
||||
"assert dify_agent_layers_ask_human.__all__ == ['AskHumanAction', 'AskHumanActionStyle', 'AskHumanField', 'AskHumanFieldType', 'AskHumanFileField', 'AskHumanFileListField', 'AskHumanParagraphField', 'AskHumanResultStatus', 'AskHumanSelectField', 'AskHumanSelectOption', 'AskHumanSelectedAction', 'AskHumanToolArgs', 'AskHumanToolResult', 'AskHumanUrgency', 'DEFAULT_ASK_HUMAN_TOOL_DESCRIPTION', 'DIFY_ASK_HUMAN_LAYER_TYPE_ID', 'DifyAskHumanLayerConfig']",
|
||||
"assert dify_agent_layers_dify_plugin.__all__ == ['DIFY_PLUGIN_LLM_LAYER_TYPE_ID', 'DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID', 'DifyPluginCredentialValue', 'DifyPluginLLMLayerConfig', 'DifyPluginToolCredentialType', 'DifyPluginToolConfig', 'DifyPluginToolOption', 'DifyPluginToolParameter', 'DifyPluginToolParameterForm', 'DifyPluginToolParameterType', 'DifyPluginToolsLayerConfig', 'DifyPluginToolValue']",
|
||||
"assert dify_agent_layers_knowledge.__all__ == ['DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID', 'DifyKnowledgeBaseLayerConfig', 'DifyKnowledgeMetadataCondition', 'DifyKnowledgeMetadataConditions', 'DifyKnowledgeMetadataFilteringConfig', 'DifyKnowledgeModelConfig', 'DifyKnowledgeRerankingModelConfig', 'DifyKnowledgeRetrievalConfig']",
|
||||
"assert dify_agent_layers_output.__all__ == ['DIFY_OUTPUT_LAYER_TYPE_ID', 'DifyOutputLayerConfig']",
|
||||
"assert dify_agent_layers_shell.__all__ == ['DIFY_SHELL_LAYER_TYPE_ID', 'DifyShellCliToolConfig', 'DifyShellEnvVarConfig', 'DifyShellLayerConfig', 'DifyShellSandboxConfig', 'DifyShellSecretRefConfig']",
|
||||
],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user