From 0ea0647dd0e04b2a3a9e92e5f23dd0898ad6a487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Wed, 17 Jun 2026 18:27:38 +0900 Subject: [PATCH] feat(agent): wire knowledge base retrieval into runtime (#37577) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/clients/agent_backend/__init__.py | 2 + api/clients/agent_backend/request_builder.py | 35 +- api/controllers/inner_api/__init__.py | 4 +- .../inner_api/knowledge/__init__.py | 1 + .../inner_api/knowledge/retrieval.py | 110 +++++ api/controllers/inner_api/wraps.py | 32 +- .../apps/agent_app/runtime_request_builder.py | 9 +- api/core/rag/retrieval/dataset_retrieval.py | 2 +- .../agent_v2/runtime_feature_manifest.py | 16 +- .../nodes/agent_v2/runtime_request_builder.py | 54 ++- .../entities/knowledge_retrieval_inner.py | 210 +++++++++ api/services/errors/knowledge_retrieval.py | 49 ++ api/services/external_knowledge_service.py | 54 ++- .../knowledge_retrieval_inner_service.py | 145 ++++++ .../agent_backend/test_request_builder.py | 40 ++ .../controllers/inner_api/test_auth_wraps.py | 52 +++ .../inner_api/test_knowledge_retrieval.py | 233 ++++++++++ .../agent_app/test_runtime_request_builder.py | 34 ++ ...test_dataset_retrieval_attachment_entry.py | 36 ++ .../agent_v2/test_runtime_request_builder.py | 146 ++++++ .../services/test_external_dataset_service.py | 83 +++- .../test_knowledge_retrieval_inner_service.py | 218 +++++++++ .../layers/execution_context/__init__.py | 3 +- .../layers/execution_context/configs.py | 10 +- .../layers/execution_context/layer.py | 5 +- .../dify_agent/layers/knowledge/__init__.py | 27 ++ .../src/dify_agent/layers/knowledge/client.py | 214 +++++++++ .../dify_agent/layers/knowledge/configs.py | 200 +++++++++ .../src/dify_agent/layers/knowledge/layer.py | 285 ++++++++++++ .../dify_agent/runtime/compositor_factory.py | 34 +- .../src/dify_agent/runtime/run_scheduler.py | 4 + dify-agent/src/dify_agent/runtime/runner.py | 23 +- dify-agent/src/dify_agent/server/app.py | 58 ++- dify-agent/src/dify_agent/server/settings.py | 38 +- .../dify_agent/layers/knowledge/__init__.py | 0 .../layers/knowledge/test_client.py | 248 +++++++++++ .../layers/knowledge/test_configs.py | 65 +++ .../dify_agent/layers/knowledge/test_layer.py | 417 ++++++++++++++++++ .../dify_agent/runtime/test_run_scheduler.py | 20 +- .../local/dify_agent/runtime/test_runner.py | 138 ++++++ .../tests/local/dify_agent/server/test_app.py | 126 +++++- .../local/dify_agent/server/test_settings.py | 9 +- .../dify_agent/test_import_boundaries.py | 4 + 43 files changed, 3360 insertions(+), 133 deletions(-) create mode 100644 api/controllers/inner_api/knowledge/__init__.py create mode 100644 api/controllers/inner_api/knowledge/retrieval.py create mode 100644 api/services/entities/knowledge_retrieval_inner.py create mode 100644 api/services/errors/knowledge_retrieval.py create mode 100644 api/services/knowledge_retrieval_inner_service.py create mode 100644 api/tests/unit_tests/controllers/inner_api/test_knowledge_retrieval.py create mode 100644 api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_attachment_entry.py create mode 100644 api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py create mode 100644 dify-agent/src/dify_agent/layers/knowledge/__init__.py create mode 100644 dify-agent/src/dify_agent/layers/knowledge/client.py create mode 100644 dify-agent/src/dify_agent/layers/knowledge/configs.py create mode 100644 dify-agent/src/dify_agent/layers/knowledge/layer.py create mode 100644 dify-agent/tests/local/dify_agent/layers/knowledge/__init__.py create mode 100644 dify-agent/tests/local/dify_agent/layers/knowledge/test_client.py create mode 100644 dify-agent/tests/local/dify_agent/layers/knowledge/test_configs.py create mode 100644 dify-agent/tests/local/dify_agent/layers/knowledge/test_layer.py diff --git a/api/clients/agent_backend/__init__.py b/api/clients/agent_backend/__init__.py index 238c48a9de..b9032c521e 100644 --- a/api/clients/agent_backend/__init__.py +++ b/api/clients/agent_backend/__init__.py @@ -33,6 +33,7 @@ from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAge from clients.agent_backend.request_builder import ( AGENT_SOUL_PROMPT_LAYER_ID, DIFY_EXECUTION_CONTEXT_LAYER_ID, + DIFY_KNOWLEDGE_BASE_LAYER_ID, DIFY_PLUGIN_TOOLS_LAYER_ID, WORKFLOW_NODE_JOB_PROMPT_LAYER_ID, WORKFLOW_USER_PROMPT_LAYER_ID, @@ -47,6 +48,7 @@ from clients.agent_backend.request_builder import ( __all__ = [ "AGENT_SOUL_PROMPT_LAYER_ID", "DIFY_EXECUTION_CONTEXT_LAYER_ID", + "DIFY_KNOWLEDGE_BASE_LAYER_ID", "DIFY_PLUGIN_TOOLS_LAYER_ID", "WORKFLOW_NODE_JOB_PROMPT_LAYER_ID", "WORKFLOW_USER_PROMPT_LAYER_ID", diff --git a/api/clients/agent_backend/request_builder.py b/api/clients/agent_backend/request_builder.py index 55944929dd..c245a09e97 100644 --- a/api/clients/agent_backend/request_builder.py +++ b/api/clients/agent_backend/request_builder.py @@ -32,6 +32,7 @@ from dify_agent.layers.execution_context import ( DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig, ) +from dify_agent.layers.knowledge import DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID, DifyKnowledgeBaseLayerConfig from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig from dify_agent.layers.shell import DIFY_SHELL_LAYER_TYPE_ID, DifyShellLayerConfig from dify_agent.protocol import ( @@ -55,6 +56,7 @@ AGENT_APP_USER_PROMPT_LAYER_ID = "agent_app_user_prompt" DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context" DIFY_DRIVE_LAYER_ID = "drive" DIFY_PLUGIN_TOOLS_LAYER_ID = "tools" +DIFY_KNOWLEDGE_BASE_LAYER_ID = "knowledge" DIFY_ASK_HUMAN_LAYER_ID = "ask_human" DIFY_SHELL_LAYER_ID = "shell" @@ -139,6 +141,7 @@ class AgentBackendWorkflowNodeRunInput(BaseModel): idempotency_key: str | None = None output: AgentBackendOutputConfig | None = None tools: DifyPluginToolsLayerConfig | None = None + knowledge: DifyKnowledgeBaseLayerConfig | None = None # Drive Skills & Files declaration (dify.drive) — an index the agent pulls # through the back proxy, never inline content; see AGENT_DRIVE_MANIFEST_ENABLED. drive_config: DifyDriveLayerConfig | None = None @@ -185,6 +188,7 @@ class AgentBackendAgentAppRunInput(BaseModel): idempotency_key: str | None = None output: AgentBackendOutputConfig | None = None tools: DifyPluginToolsLayerConfig | None = None + knowledge: DifyKnowledgeBaseLayerConfig | None = None # Drive Skills & Files declaration (dify.drive) — an index the agent pulls # through the back proxy, never inline content; see AGENT_DRIVE_MANIFEST_ENABLED. drive_config: DifyDriveLayerConfig | None = None @@ -221,7 +225,7 @@ class AgentBackendRunRequestBuilder: Layer graph: optional Agent Soul system prompt → user prompt → execution context → optional history (multi-turn) → LLM → optional - plugin tools → optional structured output. Mirrors the workflow-node + plugin tools / knowledge search → optional structured output. Mirrors the workflow-node layer ordering minus the workflow-job / previous-node prompt. """ layers: list[RunLayerSpec] = [] @@ -300,6 +304,17 @@ class AgentBackendRunRequestBuilder: ) ) + if run_input.knowledge is not None and run_input.knowledge.dataset_ids: + layers.append( + RunLayerSpec( + name=DIFY_KNOWLEDGE_BASE_LAYER_ID, + type=DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID, + deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID}, + metadata=run_input.metadata, + config=run_input.knowledge, + ) + ) + if run_input.ask_human_config is not None: # Human-in-the-loop ask_human deferred tool (dify.ask_human). A call ends # the run with a deferred_tool_call; the caller pauses (workflow HITL) and @@ -398,7 +413,12 @@ class AgentBackendRunRequestBuilder: ) def build_for_workflow_node(self, run_input: AgentBackendWorkflowNodeRunInput) -> CreateRunRequest: - """Build a workflow Agent Node run request without defining another wire schema.""" + """Build a workflow Agent Node run request without defining another wire schema. + + Layer graph mirrors the workflow surface: prompts → execution context → + optional drive/history → LLM → optional plugin tools / knowledge search + → optional auxiliary layers such as ask_human, shell, and structured output. + """ layers: list[RunLayerSpec] = [] if run_input.agent_soul_prompt: layers.append( @@ -483,6 +503,17 @@ class AgentBackendRunRequestBuilder: ) ) + if run_input.knowledge is not None and run_input.knowledge.dataset_ids: + layers.append( + RunLayerSpec( + name=DIFY_KNOWLEDGE_BASE_LAYER_ID, + type=DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID, + deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID}, + metadata=run_input.metadata, + config=run_input.knowledge, + ) + ) + if run_input.ask_human_config is not None: # Human-in-the-loop ask_human deferred tool (dify.ask_human). A call ends # the run with a deferred_tool_call; the caller pauses (workflow HITL) and diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index c782c93ffd..c0e079eeb2 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -9,7 +9,7 @@ api = ExternalApi( bp, version="1.0", title="Inner API", - description="Internal APIs for enterprise features, billing, and plugin communication", + description="Internal APIs for enterprise features, billing, knowledge retrieval, and plugin communication", ) # Create namespace @@ -17,6 +17,7 @@ inner_api_ns = Namespace("inner_api", description="Internal API operations", pat from . import mail as _mail from .app import dsl as _app_dsl +from .knowledge import retrieval as _knowledge_retrieval from .plugin import agent_drive as _agent_drive from .plugin import plugin as _plugin from .workspace import workspace as _workspace @@ -26,6 +27,7 @@ api.add_namespace(inner_api_ns) __all__ = [ "_agent_drive", "_app_dsl", + "_knowledge_retrieval", "_mail", "_plugin", "_workspace", diff --git a/api/controllers/inner_api/knowledge/__init__.py b/api/controllers/inner_api/knowledge/__init__.py new file mode 100644 index 0000000000..20c447fa77 --- /dev/null +++ b/api/controllers/inner_api/knowledge/__init__.py @@ -0,0 +1 @@ +"""Inner knowledge retrieval endpoints.""" diff --git a/api/controllers/inner_api/knowledge/retrieval.py b/api/controllers/inner_api/knowledge/retrieval.py new file mode 100644 index 0000000000..ef33fbda51 --- /dev/null +++ b/api/controllers/inner_api/knowledge/retrieval.py @@ -0,0 +1,110 @@ +"""Inner API endpoint for tenant-scoped knowledge retrieval. + +This controller is a thin HTTP wrapper around +``services.knowledge_retrieval_inner_service.InnerKnowledgeRetrievalService``. +It intentionally keeps authorization simple: shared inner API key plus +tenant-scoped app/dataset validation in the service layer. +""" + +from flask_restx import Resource +from pydantic import ValidationError + +from controllers.common.schema import register_response_schema_models, register_schema_models +from controllers.inner_api import inner_api_ns +from controllers.inner_api.wraps import inner_api_only +from core.workflow.nodes.knowledge_retrieval import exc as retrieval_exc +from libs.exception import BaseHTTPException +from services.entities.knowledge_retrieval_inner import InnerKnowledgeRetrieveRequest, InnerKnowledgeRetrieveResponse +from services.errors.knowledge_retrieval import ExternalKnowledgeRetrievalError, InnerKnowledgeRetrievalServiceError +from services.knowledge_retrieval_inner_service import InnerKnowledgeRetrievalService + + +class InnerKnowledgeRetrievalHttpError(BaseHTTPException): + error_code = "knowledge_retrieve_failed" + description = "Knowledge retrieval failed." + code = 500 + + def __init__( + self, + *, + error_code: str | None = None, + description: str | None = None, + status_code: int | None = None, + ) -> None: + if error_code is not None: + self.error_code = error_code + if description is not None: + self.description = description + if status_code is not None: + self.code = status_code + super().__init__(self.description) + + +register_schema_models(inner_api_ns, InnerKnowledgeRetrieveRequest) +register_response_schema_models(inner_api_ns, InnerKnowledgeRetrieveResponse) + + +@inner_api_ns.route("/knowledge/retrieve") +class InnerKnowledgeRetrieveApi(Resource): + """Retrieve knowledge from one or more datasets within the caller tenant.""" + + @inner_api_only + @inner_api_ns.doc("inner_knowledge_retrieve") + @inner_api_ns.doc(description="Retrieve knowledge for trusted internal callers") + @inner_api_ns.expect(inner_api_ns.models[InnerKnowledgeRetrieveRequest.__name__]) + @inner_api_ns.response( + 200, + "Knowledge retrieved successfully", + inner_api_ns.models[InnerKnowledgeRetrieveResponse.__name__], + ) + @inner_api_ns.doc( + responses={ + 400: "Invalid request body", + 401: "Unauthorized - invalid inner API key", + 403: "Caller tenant does not own the requested resource", + 404: "App or dataset not found", + 422: "Invalid retrieval configuration", + 429: "Knowledge retrieval rate limited", + 502: "External knowledge retrieval failed", + 500: "Unexpected knowledge retrieval failure", + } + ) + def post(self) -> dict[str, object]: + """Validate the payload, run retrieval, and return workflow-style sources.""" + try: + payload = InnerKnowledgeRetrieveRequest.model_validate(inner_api_ns.payload or {}) + except ValidationError as exc: + raise InnerKnowledgeRetrievalHttpError( + error_code="invalid_request", + description=str(exc), + status_code=400, + ) from exc + + try: + response = InnerKnowledgeRetrievalService().retrieve(payload) + except InnerKnowledgeRetrievalServiceError as exc: + raise InnerKnowledgeRetrievalHttpError( + error_code=exc.error_code, + description=exc.description, + status_code=exc.status_code, + ) from exc + except retrieval_exc.RateLimitExceededError as exc: + raise InnerKnowledgeRetrievalHttpError( + error_code="knowledge_rate_limited", + description=str(exc), + status_code=429, + ) from exc + except ExternalKnowledgeRetrievalError as exc: + raise InnerKnowledgeRetrievalHttpError( + error_code="external_knowledge_failed", + description=str(exc), + status_code=502, + ) from exc + except ValueError as exc: + raise InnerKnowledgeRetrievalHttpError( + error_code="retrieval_config_invalid", + description=str(exc), + status_code=422, + ) from exc + + return response.model_dump(mode="json", by_alias=True) diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 95181b93cf..999932c98e 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -8,39 +8,39 @@ from flask import abort, request from configs import dify_config from core.db.session_factory import session_factory +from libs.exception import BaseHTTPException from models.model import EndUser -def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]: +class InnerApiUnauthorizedError(BaseHTTPException): + error_code = "inner_api_unauthorized" + description = "Unauthorized." + code = 401 + + +def inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]: + """Restrict access to callers authenticated with the shared inner API key.""" + @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs) -> R: if not dify_config.INNER_API: abort(404) - # get header 'X-Inner-Api-Key' inner_api_key = request.headers.get("X-Inner-Api-Key") if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: - abort(401) + raise InnerApiUnauthorizedError() return view(*args, **kwargs) return decorated +def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]: + return inner_api_only(view) + + def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]: - @wraps(view) - def decorated(*args: P.args, **kwargs: P.kwargs) -> R: - if not dify_config.INNER_API: - abort(404) - - # get header 'X-Inner-Api-Key' - inner_api_key = request.headers.get("X-Inner-Api-Key") - if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: - abort(401) - - return view(*args, **kwargs) - - return decorated + return inner_api_only(view) def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]: diff --git a/api/core/app/apps/agent_app/runtime_request_builder.py b/api/core/app/apps/agent_app/runtime_request_builder.py index 71cc0385f9..01206b12db 100644 --- a/api/core/app/apps/agent_app/runtime_request_builder.py +++ b/api/core/app/apps/agent_app/runtime_request_builder.py @@ -2,8 +2,10 @@ Mirrors the workflow ``WorkflowAgentRuntimeRequestBuilder`` but for the Agent App surface: the user prompt is the chat message (no workflow-node job / no -previous-node context), and multi-turn continuity flows through the -conversation-keyed ``session_snapshot`` plus the history layer. +previous-node context), multi-turn continuity flows through the +conversation-keyed ``session_snapshot`` plus the history layer, and Agent Soul +knowledge config is mapped into the same fixed ``dify.knowledge_base`` layer +used by workflow runs. """ from __future__ import annotations @@ -36,6 +38,7 @@ from core.workflow.nodes.agent_v2.runtime_request_builder import ( append_runtime_warnings, build_ask_human_layer_config, build_drive_layer_config, + build_knowledge_layer_config, build_shell_layer_config, ) from models.agent_config_entities import AgentSoulConfig @@ -123,6 +126,7 @@ class AgentAppRuntimeRequestBuilder: if dify_config.AGENT_DRIVE_MANIFEST_ENABLED: drive_config, drive_warnings = build_drive_layer_config(agent_soul, agent_id=context.agent_id) append_runtime_warnings(metadata, drive_warnings) + knowledge_config = build_knowledge_layer_config(agent_soul) request = self._request_builder.build_for_agent_app( AgentBackendAgentAppRunInput( @@ -156,6 +160,7 @@ class AgentAppRuntimeRequestBuilder: or None, user_prompt=context.user_query, tools=tools_layer, + knowledge=knowledge_config, drive_config=drive_config, ask_human_config=build_ask_human_layer_config(agent_soul), include_shell=dify_config.AGENT_SHELL_ENABLED, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index f4e850d34e..474c9f90c7 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -123,7 +123,7 @@ class DatasetRetrieval: if not available_datasets_ids: return [] - if not request.query: + if not request.query and not request.attachment_ids: return [] metadata_filter_document_ids, metadata_condition = None, None diff --git a/api/core/workflow/nodes/agent_v2/runtime_feature_manifest.py b/api/core/workflow/nodes/agent_v2/runtime_feature_manifest.py index 8e0578d1a1..65c5d42e91 100644 --- a/api/core/workflow/nodes/agent_v2/runtime_feature_manifest.py +++ b/api/core/workflow/nodes/agent_v2/runtime_feature_manifest.py @@ -13,6 +13,7 @@ SUPPORTED_AGENT_BACKEND_FEATURES = frozenset( "structured_output", "tools.dify_tools", "tools.cli_tools", + "knowledge", "env", "sandbox", # ENG-623: exposed at runtime as the dify.drive declaration layer @@ -26,7 +27,6 @@ SUPPORTED_AGENT_BACKEND_FEATURES = frozenset( RESERVED_AGENT_BACKEND_FEATURES = frozenset( { - "knowledge", "memory", } ) @@ -80,6 +80,9 @@ def build_runtime_feature_manifest( ) reserved_status = dict.fromkeys(sorted(RESERVED_AGENT_BACKEND_FEATURES), "reserved_not_executed") + reserved_status["knowledge"] = ( + "supported_by_knowledge_layer" if list_configured_knowledge_dataset_ids(agent_soul) else "not_configured" + ) reserved_status["skills_files"] = ( "supported_by_drive_manifest" if drive_manifest_enabled else "drive_manifest_disabled" ) @@ -97,6 +100,17 @@ def build_runtime_feature_manifest( } +def list_configured_knowledge_dataset_ids(agent_soul: AgentSoulConfig) -> list[str]: + """Return the normalized knowledge dataset ids that can produce a runtime layer. + + ``build_runtime_feature_manifest()`` and ``build_knowledge_layer_config()`` + must stay aligned: both decide knowledge support from this effective, + non-blank dataset-id set rather than from raw + ``agent_soul.knowledge.datasets`` entries. + """ + return [dataset_id for dataset in agent_soul.knowledge.datasets if (dataset_id := (dataset.id or "").strip())] + + def _get_nested(value: dict[str, Any], path: str) -> Any: current: Any = value for part in path.split("."): diff --git a/api/core/workflow/nodes/agent_v2/runtime_request_builder.py b/api/core/workflow/nodes/agent_v2/runtime_request_builder.py index 53c657e8ef..8aaa4fcc1d 100644 --- a/api/core/workflow/nodes/agent_v2/runtime_request_builder.py +++ b/api/core/workflow/nodes/agent_v2/runtime_request_builder.py @@ -16,6 +16,7 @@ from dify_agent.layers.execution_context import ( DifyExecutionContextLayerConfig, DifyExecutionContextUserFrom, ) +from dify_agent.layers.knowledge import DifyKnowledgeBaseLayerConfig, DifyKnowledgeRetrievalConfig from dify_agent.layers.shell import ( DifyShellCliToolConfig, DifyShellEnvVarConfig, @@ -40,6 +41,7 @@ from graphon.file import FileTransferMethod from graphon.variables.segments import Segment from models.agent import Agent, AgentConfigSnapshot, WorkflowAgentNodeBinding from models.agent_config_entities import ( + AgentKnowledgeQueryConfig, AgentSoulConfig, DeclaredArrayItem, DeclaredOutputChildConfig, @@ -60,6 +62,7 @@ from services.agent.prompt_mentions import ( from .output_failure_orchestrator import retry_idempotency_key from .plugin_tools_builder import WorkflowAgentPluginToolsBuilder, WorkflowAgentPluginToolsBuildError +from .runtime_feature_manifest import build_runtime_feature_manifest, list_configured_knowledge_dataset_ids _DENIED_PERMISSION_STATUSES = frozenset({"unauthorized", "denied", "forbidden", "invalid", "unavailable"}) _DANGEROUS_FLAG_KEYS = ("dangerous", "dangerous_command", "requires_confirmation") @@ -69,7 +72,6 @@ _DANGEROUS_ACK_KEYS = ( "risk_accepted", "approved", ) -from .runtime_feature_manifest import build_runtime_feature_manifest class WorkflowAgentRuntimeRequestBuildError(ValueError): @@ -183,6 +185,7 @@ class WorkflowAgentRuntimeRequestBuilder: if dify_config.AGENT_DRIVE_MANIFEST_ENABLED: drive_config, drive_warnings = build_drive_layer_config(agent_soul, agent_id=context.agent.id) append_runtime_warnings(metadata, drive_warnings) + knowledge_config = build_knowledge_layer_config(agent_soul) request = self._request_builder.build_for_workflow_node( AgentBackendWorkflowNodeRunInput( @@ -197,10 +200,11 @@ class WorkflowAgentRuntimeRequestBuilder: model_settings=agent_soul.model.model_settings.model_dump(mode="json", exclude_none=True), ), # The execution-context layer is now the only public protocol - # carrier for Dify tenant/user/run identifiers. ``user_id`` must - # be forwarded here because downstream plugin-daemon provider and - # tool clients read it from this layer rather than from any - # parallel top-level request field. + # carrier for Dify tenant/user/run identifiers. ``user_id`` and + # ``user_from`` must be forwarded here because downstream plugin- + # daemon provider/tool clients and knowledge-base layers read + # caller identity from this layer rather than from any parallel + # top-level request field. execution_context=DifyExecutionContextLayerConfig( tenant_id=context.dify_context.tenant_id, user_id=context.dify_context.user_id, @@ -221,6 +225,7 @@ class WorkflowAgentRuntimeRequestBuilder: user_prompt=user_prompt, output=self._build_output_config(node_job.declared_outputs), tools=tools_layer, + knowledge=knowledge_config, drive_config=drive_config, ask_human_config=build_ask_human_layer_config(agent_soul), include_shell=dify_config.AGENT_SHELL_ENABLED, @@ -534,6 +539,45 @@ def build_shell_layer_config(agent_soul: AgentSoulConfig) -> DifyShellLayerConfi ) +def build_knowledge_layer_config(agent_soul: AgentSoulConfig) -> DifyKnowledgeBaseLayerConfig | None: + """Map Agent Soul knowledge config into the fixed Dify knowledge-base layer. + + Normalization intentionally matches the current dify-agent runtime contract: + + - blank or missing dataset ids are ignored; + - if no valid dataset ids remain, no knowledge layer is injected; + - retrieval mode is always forced to ``multiple`` in this first wiring pass; + - ``top_k`` falls back to a stable runtime default when the soul omits it; + - ``score_threshold`` is only forwarded when the product config explicitly + enables it, otherwise the layer keeps the disabled/default ``0.0`` value; + - metadata filtering stays at the layer DTO default (disabled). + """ + dataset_ids = list_configured_knowledge_dataset_ids(agent_soul) + if not dataset_ids: + return None + + query_config = agent_soul.knowledge.query_config + return DifyKnowledgeBaseLayerConfig( + dataset_ids=dataset_ids, + retrieval=DifyKnowledgeRetrievalConfig( + mode="multiple", + top_k=_knowledge_top_k(query_config), + score_threshold=_knowledge_score_threshold(query_config), + ), + ) + + +def _knowledge_top_k(query_config: AgentKnowledgeQueryConfig) -> int: + top_k = query_config.top_k + return top_k if isinstance(top_k, int) and top_k >= 1 else 4 + + +def _knowledge_score_threshold(query_config: AgentKnowledgeQueryConfig) -> float: + if query_config.score_threshold_enabled and query_config.score_threshold is not None: + return query_config.score_threshold + return 0.0 + + def build_ask_human_layer_config(agent_soul: AgentSoulConfig) -> DifyAskHumanLayerConfig | None: """Enable the dify.ask_human deferred tool when the soul configures human involvement. diff --git a/api/services/entities/knowledge_retrieval_inner.py b/api/services/entities/knowledge_retrieval_inner.py new file mode 100644 index 0000000000..86276b8017 --- /dev/null +++ b/api/services/entities/knowledge_retrieval_inner.py @@ -0,0 +1,210 @@ +"""DTOs for the inner knowledge retrieval API. + +These models define the stable HTTP contract for trusted internal callers and +the response shape returned by the workflow knowledge retrieval stack. + +Key cross-field invariants live here because callers cannot infer them from +scalar field types alone: ``dataset_ids`` must be non-empty, either ``query`` +or ``attachment_ids`` is required, ``single`` retrieval requires both ``query`` +and ``retrieval.model``, ``automatic`` metadata filtering requires +``model_config``, and ``manual`` metadata filtering requires conditions. The +response reuses workflow ``Source`` items plus serialized ``llm_usage``. +""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + +from core.rag.data_post_processor.data_post_processor import WeightsDict +from core.rag.entities.metadata_entities import SupportedComparisonOperator +from core.workflow.nodes.knowledge_retrieval.retrieval import Source +from fields.base import ResponseModel + +type JsonScalar = str | int | float | bool | None +type JsonValue = JsonScalar | list[JsonScalar] | dict[str, JsonScalar] +type MetadataValue = str | list[str] | int | float | None + + +class InnerKnowledgeRetrieveCaller(BaseModel): + """Execution context provided by the trusted internal caller.""" + + model_config = ConfigDict(extra="forbid") + + tenant_id: str = Field(min_length=1) + user_id: str = Field(min_length=1) + app_id: str = Field(min_length=1) + user_from: Literal["account", "end-user"] + invoke_from: str = Field(min_length=1) + + +class InnerKnowledgeRetrieveModelConfig(BaseModel): + """Model configuration used by single-retrieval or metadata filtering.""" + + model_config = ConfigDict(extra="forbid") + + provider: str = Field(min_length=1) + name: str = Field(min_length=1) + mode: str = Field(min_length=1) + completion_params: dict[str, JsonValue] = Field(default_factory=dict) + + +class InnerKnowledgeRetrieveRerankingModelConfig(BaseModel): + """Reranking model configuration for multiple retrieval mode.""" + + model_config = ConfigDict(extra="forbid") + + provider: str = Field(min_length=1) + model: str = Field(min_length=1) + + +class InnerKnowledgeRetrieveRetrievalConfig(BaseModel): + """Retrieval strategy and its mode-specific configuration.""" + + model_config = ConfigDict(extra="forbid") + + mode: Literal["multiple", "single"] + top_k: int | None = Field(default=None, ge=1) + score_threshold: float = 0.0 + reranking_mode: str = "reranking_model" + reranking_enable: bool = True + reranking_model: InnerKnowledgeRetrieveRerankingModelConfig | None = None + weights: WeightsDict | None = None + model: InnerKnowledgeRetrieveModelConfig | None = None + + @model_validator(mode="after") + def validate_mode_specific_fields(self) -> InnerKnowledgeRetrieveRetrievalConfig: + if self.mode == "single" and self.model is None: + raise ValueError("retrieval.model is required for single mode") + if self.mode == "multiple" and self.top_k is None: + raise ValueError("retrieval.top_k is required for multiple mode") + return self + + +class InnerKnowledgeRetrieveMetadataCondition(BaseModel): + """Single metadata filter condition.""" + + model_config = ConfigDict(extra="forbid") + + name: str = Field(min_length=1) + comparison_operator: SupportedComparisonOperator + value: MetadataValue = None + + +class InnerKnowledgeRetrieveMetadataConditions(BaseModel): + """Boolean composition for metadata filter conditions.""" + + model_config = ConfigDict(extra="forbid") + + logical_operator: Literal["and", "or"] | None = "and" + conditions: list[InnerKnowledgeRetrieveMetadataCondition] | None = None + + +class InnerKnowledgeRetrieveMetadataFilteringConfig(BaseModel): + """Metadata filtering configuration forwarded to workflow retrieval. + + ``automatic`` mode requires ``model_config`` so downstream metadata model + planning has the necessary LLM settings. ``manual`` mode requires + non-empty conditions because workflow retrieval expects explicit filters + instead of a bare mode switch. + """ + + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + mode: Literal["disabled", "automatic", "manual"] = "disabled" + metadata_model_config: InnerKnowledgeRetrieveModelConfig | None = Field(default=None, alias="model_config") + conditions: InnerKnowledgeRetrieveMetadataConditions | None = None + + @model_validator(mode="after") + def validate_mode_specific_fields(self) -> InnerKnowledgeRetrieveMetadataFilteringConfig: + if self.mode == "automatic" and self.metadata_model_config is None: + raise ValueError("metadata_filtering.model_config is required for automatic mode") + if self.mode == "manual" and (self.conditions is None or not self.conditions.conditions): + raise ValueError("metadata_filtering.conditions is required for manual mode") + return self + + +class InnerKnowledgeRetrieveRequest(BaseModel): + """Top-level request payload for the inner knowledge retrieval endpoint. + + Request validation enforces the endpoint's behavioral contract: callers + must provide at least one dataset ID, at least one of ``query`` or + ``attachment_ids``, and a text query for ``single`` retrieval mode. + """ + + model_config = ConfigDict(extra="forbid") + + caller: InnerKnowledgeRetrieveCaller + dataset_ids: list[str] + query: str | None = None + retrieval: InnerKnowledgeRetrieveRetrievalConfig + metadata_filtering: InnerKnowledgeRetrieveMetadataFilteringConfig = Field( + default_factory=InnerKnowledgeRetrieveMetadataFilteringConfig + ) + attachment_ids: list[str] = Field(default_factory=list) + + @field_validator("dataset_ids", "attachment_ids") + @classmethod + def validate_non_empty_items(cls, value: list[str]) -> list[str]: + if any(not item.strip() for item in value): + raise ValueError("list items must not be empty") + return value + + @field_validator("query") + @classmethod + def normalize_query(cls, value: str | None) -> str | None: + if value is None: + return None + normalized = value.strip() + return normalized or None + + @model_validator(mode="after") + def validate_request(self) -> InnerKnowledgeRetrieveRequest: + if not self.dataset_ids: + raise ValueError("dataset_ids must contain at least one item") + if not self.query and not self.attachment_ids: + raise ValueError("query or attachment_ids is required") + if self.retrieval.mode == "single" and not self.query: + raise ValueError("query is required for single mode") + return self + + +class InnerKnowledgeRetrieveUsage(ResponseModel): + """Serialized LLM usage payload returned by dataset retrieval.""" + + model_config = ConfigDict( + from_attributes=True, + extra="forbid", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + prompt_tokens: int + completion_tokens: int + total_tokens: int + prompt_unit_price: str + completion_unit_price: str + prompt_price_unit: str + completion_price_unit: str + prompt_price: str + completion_price: str + total_price: str + currency: str | None = None + latency: float | int + + +class InnerKnowledgeRetrieveResponse(ResponseModel): + """Workflow-style retrieval results plus accumulated usage.""" + + model_config = ConfigDict( + from_attributes=True, + extra="forbid", + populate_by_name=True, + serialize_by_alias=True, + protected_namespaces=(), + ) + + results: list[Source] + usage: InnerKnowledgeRetrieveUsage diff --git a/api/services/errors/knowledge_retrieval.py b/api/services/errors/knowledge_retrieval.py new file mode 100644 index 0000000000..4e00641e34 --- /dev/null +++ b/api/services/errors/knowledge_retrieval.py @@ -0,0 +1,49 @@ +"""Service errors for the inner knowledge retrieval API.""" + +from services.errors.base import BaseServiceError + + +class InnerKnowledgeRetrievalServiceError(BaseServiceError): + """Base service error with a stable HTTP mapping contract.""" + + error_code = "knowledge_retrieve_failed" + status_code = 500 + default_description = "Knowledge retrieval failed." + + def __init__(self, description: str | None = None): + self.description = description or self.default_description + ValueError.__init__(self, self.description) + + +class InnerKnowledgeRetrieveAppNotFoundError(InnerKnowledgeRetrievalServiceError): + error_code = "app_not_found" + status_code = 404 + default_description = "App not found." + + +class InnerKnowledgeRetrieveAppTenantMismatchError(InnerKnowledgeRetrievalServiceError): + error_code = "app_tenant_mismatch" + status_code = 403 + default_description = "App does not belong to caller tenant." + + +class InnerKnowledgeRetrieveDatasetNotFoundError(InnerKnowledgeRetrievalServiceError): + error_code = "dataset_not_found" + status_code = 404 + default_description = "Dataset not found." + + +class InnerKnowledgeRetrieveDatasetTenantMismatchError(InnerKnowledgeRetrievalServiceError): + error_code = "dataset_tenant_mismatch" + status_code = 403 + default_description = "Dataset does not belong to caller tenant." + + +class ExternalKnowledgeRetrievalError(ValueError): + """Raised when an external dataset retrieval dependency fails. + + This stays a ``ValueError`` subclass for compatibility with existing callers + that already treat external retrieval failures as generic retrieval errors, + while still giving inner API controllers a dedicated error type to map to + ``502 external_knowledge_failed``. + """ diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 60b457ecd0..8f89bca8e2 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -22,6 +22,7 @@ from services.entities.external_knowledge_entities.external_knowledge_entities i ExternalKnowledgeApiSetting, ) from services.errors.dataset import DatasetNameDuplicateError +from services.errors.knowledge_retrieval import ExternalKnowledgeRetrievalError class ExternalDatasetService: @@ -309,13 +310,22 @@ class ExternalDatasetService: external_retrieval_parameters: dict[str, Any], metadata_condition: MetadataFilteringCondition | None = None, ): + """Fetch retrieval records from an external knowledge provider. + + Success requires a tenant-scoped binding plus API template and a ``200`` + response body shaped like ``{"records": [...]}``. All dependency + failures, non-200 responses, and malformed success payloads must be + normalized to ``ExternalKnowledgeRetrievalError`` so callers—especially + the inner knowledge retrieval API—can consistently expose + ``502 external_knowledge_failed``. + """ external_knowledge_binding = db.session.scalar( select(ExternalKnowledgeBindings) .where(ExternalKnowledgeBindings.dataset_id == dataset_id, ExternalKnowledgeBindings.tenant_id == tenant_id) .limit(1) ) if not external_knowledge_binding: - raise ValueError("external knowledge binding not found") + raise ExternalKnowledgeRetrievalError("external knowledge binding not found") external_knowledge_api = db.session.scalar( select(ExternalKnowledgeApis) @@ -326,7 +336,7 @@ class ExternalDatasetService: .limit(1) ) if external_knowledge_api is None or external_knowledge_api.settings is None: - raise ValueError("external api template not found") + raise ExternalKnowledgeRetrievalError("external api template not found") settings = json.loads(external_knowledge_api.settings) headers = {"Content-Type": "application/json"} @@ -344,16 +354,34 @@ class ExternalDatasetService: "metadata_condition": metadata_condition.model_dump() if metadata_condition else None, } - response = ExternalDatasetService.process_external_api( - ExternalKnowledgeApiSetting( - url=f"{settings.get('endpoint')}/retrieval", - request_method="post", - headers=headers, - params=request_params, - ), - None, - ) + try: + response = ExternalDatasetService.process_external_api( + ExternalKnowledgeApiSetting( + url=f"{settings.get('endpoint')}/retrieval", + request_method="post", + headers=headers, + params=request_params, + ), + None, + ) + except ExternalKnowledgeRetrievalError: + raise + except Exception as exc: + raise ExternalKnowledgeRetrievalError(str(exc)) from exc + if response.status_code == 200: - return cast(list[Any], response.json().get("records", [])) + try: + response_payload = response.json() + except Exception as exc: + raise ExternalKnowledgeRetrievalError("invalid external knowledge response") from exc + + if not isinstance(response_payload, dict): + raise ExternalKnowledgeRetrievalError("invalid external knowledge response") + + records = response_payload.get("records", []) + if not isinstance(records, list): + raise ExternalKnowledgeRetrievalError("invalid external knowledge response") + + return cast(list[Any], records) else: - raise ValueError(response.text) + raise ExternalKnowledgeRetrievalError(response.text) diff --git a/api/services/knowledge_retrieval_inner_service.py b/api/services/knowledge_retrieval_inner_service.py new file mode 100644 index 0000000000..fccc81c4a2 --- /dev/null +++ b/api/services/knowledge_retrieval_inner_service.py @@ -0,0 +1,145 @@ +"""Service wrapper for the inner knowledge retrieval API. + +This service keeps the internal HTTP contract small while reusing the workflow +retrieval stack in ``core.rag.retrieval.dataset_retrieval.DatasetRetrieval``. +The only authorization enforced here is tenant ownership of the caller app and +requested datasets. + +It intentionally does not check ``dataset.enable_api`` or user-level dataset +permissions. After the caller app and requested datasets pass tenant-scoped +prechecks, dataset availability and "no usable document" cases are delegated to +``DatasetRetrieval`` and may legitimately produce an empty result list instead +of a separate validation error. +""" + +from sqlalchemy import select + +from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest +from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import ModelConfig +from models.dataset import Dataset +from models.model import App +from services.entities.knowledge_retrieval_inner import ( + InnerKnowledgeRetrieveRequest, + InnerKnowledgeRetrieveResponse, + InnerKnowledgeRetrieveUsage, +) +from services.errors.knowledge_retrieval import ( + InnerKnowledgeRetrieveAppNotFoundError, + InnerKnowledgeRetrieveAppTenantMismatchError, + InnerKnowledgeRetrieveDatasetNotFoundError, + InnerKnowledgeRetrieveDatasetTenantMismatchError, +) + + +class InnerKnowledgeRetrievalService: + """Validate inner caller scope and delegate to workflow dataset retrieval.""" + + def retrieve(self, request: InnerKnowledgeRetrieveRequest) -> InnerKnowledgeRetrieveResponse: + """Run tenant-scoped retrieval for a trusted internal caller. + + This method only rejects caller app existence/tenant mismatches and + requested dataset existence/tenant mismatches. It deliberately leaves + ``dataset.enable_api``, user-level dataset permissions, and + availability/no-usable-document handling to ``DatasetRetrieval`` so the + inner API stays aligned with workflow retrieval semantics, including + returning ``[]`` when datasets are present but yield no retrievable + content. + + Raises: + InnerKnowledgeRetrieveAppNotFoundError: The caller app does not exist. + InnerKnowledgeRetrieveAppTenantMismatchError: The caller app is outside the caller tenant. + InnerKnowledgeRetrieveDatasetNotFoundError: At least one requested dataset does not exist. + InnerKnowledgeRetrieveDatasetTenantMismatchError: + At least one requested dataset is outside the caller tenant. + """ + self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id) + self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids) + + rag = DatasetRetrieval() + results = rag.knowledge_retrieval(request=self._to_rag_request(request)) + return InnerKnowledgeRetrieveResponse( + results=results, + usage=InnerKnowledgeRetrieveUsage.model_validate(jsonable_encoder(rag.llm_usage)), + ) + + def _validate_caller_app(self, *, tenant_id: str, app_id: str) -> None: + app = db.session.scalar(select(App).where(App.id == app_id).limit(1)) + if app is None: + raise InnerKnowledgeRetrieveAppNotFoundError(f"App '{app_id}' not found") + if app.tenant_id != tenant_id: + raise InnerKnowledgeRetrieveAppTenantMismatchError( + f"App '{app_id}' does not belong to tenant '{tenant_id}'" + ) + + def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str]) -> None: + datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() + + found_ids = {dataset.id for dataset in datasets} + missing_ids = sorted(set(dataset_ids) - found_ids) + if missing_ids: + raise InnerKnowledgeRetrieveDatasetNotFoundError(f"Datasets not found: {', '.join(missing_ids)}") + + mismatched_ids = sorted(dataset.id for dataset in datasets if dataset.tenant_id != tenant_id) + if mismatched_ids: + raise InnerKnowledgeRetrieveDatasetTenantMismatchError( + f"Datasets do not belong to tenant '{tenant_id}': {', '.join(mismatched_ids)}" + ) + + def _to_rag_request(self, request: InnerKnowledgeRetrieveRequest) -> KnowledgeRetrievalRequest: + metadata_model_config = request.metadata_filtering.metadata_model_config + metadata_conditions = request.metadata_filtering.conditions + + return KnowledgeRetrievalRequest( + tenant_id=request.caller.tenant_id, + user_id=request.caller.user_id, + app_id=request.caller.app_id, + user_from=request.caller.user_from, + dataset_ids=request.dataset_ids, + query=request.query, + retrieval_mode=request.retrieval.mode, + model_provider=request.retrieval.model.provider if request.retrieval.model else None, + completion_params=request.retrieval.model.completion_params if request.retrieval.model else None, + model_mode=request.retrieval.model.mode if request.retrieval.model else None, + model_name=request.retrieval.model.name if request.retrieval.model else None, + metadata_model_config=ModelConfig.model_validate(metadata_model_config.model_dump(mode="python")) + if metadata_model_config + else None, + metadata_filtering_conditions=( + MetadataFilteringCondition( + logical_operator=metadata_conditions.logical_operator, + conditions=( + [ + Condition( + name=condition.name, + comparison_operator=condition.comparison_operator, + value=condition.value, + ) + for condition in metadata_conditions.conditions + ] + if metadata_conditions.conditions is not None + else None + ), + ) + if metadata_conditions is not None + else None + ), + metadata_filtering_mode=request.metadata_filtering.mode, + top_k=request.retrieval.top_k or 0, + score_threshold=request.retrieval.score_threshold, + reranking_mode=request.retrieval.reranking_mode, + reranking_model=( + { + "reranking_provider_name": request.retrieval.reranking_model.provider, + "reranking_model_name": request.retrieval.reranking_model.model, + } + if request.retrieval.reranking_model is not None + else None + ), + weights=request.retrieval.weights, + reranking_enable=request.retrieval.reranking_enable, + attachment_ids=request.attachment_ids or None, + ) diff --git a/api/tests/unit_tests/clients/agent_backend/test_request_builder.py b/api/tests/unit_tests/clients/agent_backend/test_request_builder.py index 0fa4d3261b..c0b308a5cb 100644 --- a/api/tests/unit_tests/clients/agent_backend/test_request_builder.py +++ b/api/tests/unit_tests/clients/agent_backend/test_request_builder.py @@ -15,6 +15,7 @@ from dify_agent.layers.dify_plugin import ( DifyPluginToolsLayerConfig, ) from dify_agent.layers.execution_context import DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, DifyExecutionContextLayerConfig +from dify_agent.layers.knowledge import DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID, DifyKnowledgeBaseLayerConfig from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID from dify_agent.layers.shell import DIFY_SHELL_LAYER_TYPE_ID, DifyShellEnvVarConfig, DifyShellLayerConfig from dify_agent.protocol import ( @@ -28,6 +29,7 @@ from pydantic import ValidationError from clients.agent_backend import ( AGENT_SOUL_PROMPT_LAYER_ID, DIFY_EXECUTION_CONTEXT_LAYER_ID, + DIFY_KNOWLEDGE_BASE_LAYER_ID, DIFY_PLUGIN_TOOLS_LAYER_ID, WORKFLOW_NODE_JOB_PROMPT_LAYER_ID, WORKFLOW_USER_PROMPT_LAYER_ID, @@ -155,6 +157,25 @@ def test_request_builder_adds_dify_plugin_tools_layer_when_configured(): assert tools_config.tools[0].tool_name == "current_time" +def test_request_builder_adds_knowledge_layer_when_configured(): + run_input = _run_input() + run_input.knowledge = DifyKnowledgeBaseLayerConfig.model_validate( + { + "dataset_ids": ["dataset-1"], + "retrieval": {"mode": "multiple", "top_k": 4}, + } + ) + + request = AgentBackendRunRequestBuilder().build_for_workflow_node(run_input) + layers = {layer.name: layer for layer in request.composition.layers} + + assert DIFY_KNOWLEDGE_BASE_LAYER_ID in layers + assert layers[DIFY_KNOWLEDGE_BASE_LAYER_ID].type == DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID + assert layers[DIFY_KNOWLEDGE_BASE_LAYER_ID].deps == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID} + knowledge_config = cast(DifyKnowledgeBaseLayerConfig, layers[DIFY_KNOWLEDGE_BASE_LAYER_ID].config) + assert knowledge_config.dataset_ids == ["dataset-1"] + + def test_request_builder_can_delete_on_exit_for_cleanup_paths(): run_input = _run_input() run_input.suspend_on_exit = False @@ -329,6 +350,25 @@ def test_agent_app_request_builder_adds_shell_layer_when_include_shell(): assert shell_config.env[0].name == "APP_ENV" +def test_agent_app_request_builder_adds_knowledge_layer_when_configured(): + run_input = _agent_app_input() + run_input.knowledge = DifyKnowledgeBaseLayerConfig.model_validate( + { + "dataset_ids": ["dataset-1", "dataset-2"], + "retrieval": {"mode": "multiple", "top_k": 2}, + } + ) + + request = AgentBackendRunRequestBuilder().build_for_agent_app(run_input) + layers = {layer.name: layer for layer in request.composition.layers} + + assert DIFY_KNOWLEDGE_BASE_LAYER_ID in layers + assert layers[DIFY_KNOWLEDGE_BASE_LAYER_ID].type == DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID + assert layers[DIFY_KNOWLEDGE_BASE_LAYER_ID].deps == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID} + knowledge_config = cast(DifyKnowledgeBaseLayerConfig, layers[DIFY_KNOWLEDGE_BASE_LAYER_ID].config) + assert knowledge_config.dataset_ids == ["dataset-1", "dataset-2"] + + # ── ENG-635 / ENG-638: ask_human layer injection + deferred_tool_results ───── diff --git a/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py index ffe0c4e6b3..96f1dcaed5 100644 --- a/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py @@ -13,6 +13,7 @@ from controllers.inner_api.wraps import ( billing_inner_api_only, enterprise_inner_api_only, enterprise_inner_api_user_auth, + inner_api_only, plugin_inner_api_only, ) from models.model import EndUser @@ -154,6 +155,57 @@ class TestEnterpriseInnerApiOnly: assert exc_info.value.code == 401 +class TestInnerApiOnly: + """Test inner_api_only decorator.""" + + def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask): + @inner_api_only + def protected_view(): + return "success" + + with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + result = protected_view() + + assert result == "success" + + def test_should_return_404_when_inner_api_disabled(self, app: Flask): + @inner_api_only + def protected_view(): + return "success" + + with app.test_request_context(): + with patch.object(dify_config, "INNER_API", False): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 404 + + def test_should_return_401_when_api_key_missing(self, app: Flask): + @inner_api_only + def protected_view(): + return "success" + + with app.test_request_context(headers={}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + def test_should_return_401_when_api_key_invalid(self, app: Flask): + @inner_api_only + def protected_view(): + return "success" + + with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + class TestEnterpriseInnerApiUserAuth: """Test enterprise_inner_api_user_auth decorator for HMAC-based user authentication""" diff --git a/api/tests/unit_tests/controllers/inner_api/test_knowledge_retrieval.py b/api/tests/unit_tests/controllers/inner_api/test_knowledge_retrieval.py new file mode 100644 index 0000000000..fa648e0335 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/test_knowledge_retrieval.py @@ -0,0 +1,233 @@ +"""Unit tests for the inner knowledge retrieval controller.""" + +from unittest.mock import patch + +import pytest +from flask import Flask + +from controllers.inner_api import bp as inner_api_bp +from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError +from core.workflow.nodes.knowledge_retrieval.retrieval import Source, SourceMetadata +from services.entities.knowledge_retrieval_inner import InnerKnowledgeRetrieveResponse, InnerKnowledgeRetrieveUsage +from services.errors.knowledge_retrieval import ( + ExternalKnowledgeRetrievalError, + InnerKnowledgeRetrieveAppNotFoundError, + InnerKnowledgeRetrieveDatasetTenantMismatchError, +) + + +@pytest.fixture +def inner_api_app() -> Flask: + app = Flask(__name__) + app.config["TESTING"] = True + app.register_blueprint(inner_api_bp) + return app + + +def _headers(api_key: str | None = "inner-key") -> dict[str, str]: + headers = {"Content-Type": "application/json"} + if api_key is not None: + headers["X-Inner-Api-Key"] = api_key + return headers + + +def _payload() -> dict[str, object]: + return { + "caller": { + "tenant_id": "tenant-1", + "user_id": "user-1", + "app_id": "app-1", + "user_from": "account", + "invoke_from": "workflow", + }, + "dataset_ids": ["dataset-1"], + "query": "reset password", + "retrieval": { + "mode": "multiple", + "top_k": 4, + }, + "metadata_filtering": { + "mode": "disabled", + }, + "attachment_ids": [], + } + + +class TestInnerKnowledgeRetrieveApi: + def test_post_returns_401_when_api_key_missing(self, inner_api_app: Flask): + with patch("configs.dify_config.INNER_API", True): + response = inner_api_app.test_client().post( + "/inner/api/knowledge/retrieve", + json=_payload(), + headers=_headers(api_key=None), + ) + + assert response.status_code == 401 + assert response.get_json()["code"] == "inner_api_unauthorized" + + def test_post_returns_401_when_api_key_invalid(self, inner_api_app: Flask): + with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"): + response = inner_api_app.test_client().post( + "/inner/api/knowledge/retrieve", + json=_payload(), + headers=_headers(api_key="wrong-key"), + ) + + assert response.status_code == 401 + assert response.get_json()["code"] == "inner_api_unauthorized" + + def test_post_returns_400_for_invalid_body(self, inner_api_app: Flask): + with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"): + response = inner_api_app.test_client().post( + "/inner/api/knowledge/retrieve", + json={"caller": {"tenant_id": "tenant-1"}}, + headers=_headers(), + ) + + assert response.status_code == 400 + assert response.get_json()["code"] == "invalid_request" + + @patch("controllers.inner_api.knowledge.retrieval.InnerKnowledgeRetrievalService.retrieve") + def test_post_returns_404_for_service_not_found_error(self, mock_retrieve, inner_api_app: Flask): + mock_retrieve.side_effect = InnerKnowledgeRetrieveAppNotFoundError("app missing") + + with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"): + response = inner_api_app.test_client().post( + "/inner/api/knowledge/retrieve", + json=_payload(), + headers=_headers(), + ) + + assert response.status_code == 404 + assert response.get_json()["code"] == "app_not_found" + + @patch("controllers.inner_api.knowledge.retrieval.InnerKnowledgeRetrievalService.retrieve") + def test_post_returns_403_for_service_forbidden_error(self, mock_retrieve, inner_api_app: Flask): + mock_retrieve.side_effect = InnerKnowledgeRetrieveDatasetTenantMismatchError("wrong tenant") + + with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"): + response = inner_api_app.test_client().post( + "/inner/api/knowledge/retrieve", + json=_payload(), + headers=_headers(), + ) + + assert response.status_code == 403 + assert response.get_json()["code"] == "dataset_tenant_mismatch" + + @patch("controllers.inner_api.knowledge.retrieval.InnerKnowledgeRetrievalService.retrieve") + def test_post_returns_422_for_retrieval_config_value_error(self, mock_retrieve, inner_api_app: Flask): + mock_retrieve.side_effect = ValueError("invalid reranking config") + + with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"): + response = inner_api_app.test_client().post( + "/inner/api/knowledge/retrieve", + json=_payload(), + headers=_headers(), + ) + + assert response.status_code == 422 + assert response.get_json()["code"] == "retrieval_config_invalid" + + @patch("controllers.inner_api.knowledge.retrieval.InnerKnowledgeRetrievalService.retrieve") + def test_post_returns_429_for_rate_limit_error(self, mock_retrieve, inner_api_app: Flask): + mock_retrieve.side_effect = RateLimitExceededError("knowledge rate limited") + + with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"): + response = inner_api_app.test_client().post( + "/inner/api/knowledge/retrieve", + json=_payload(), + headers=_headers(), + ) + + assert response.status_code == 429 + assert response.get_json()["code"] == "knowledge_rate_limited" + + def test_post_returns_400_for_manual_metadata_without_conditions(self, inner_api_app: Flask): + payload = _payload() + payload["metadata_filtering"] = {"mode": "manual"} + + with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"): + response = inner_api_app.test_client().post( + "/inner/api/knowledge/retrieve", + json=payload, + headers=_headers(), + ) + + assert response.status_code == 400 + assert response.get_json()["code"] == "invalid_request" + + def test_post_returns_400_for_automatic_metadata_without_model_config(self, inner_api_app: Flask): + payload = _payload() + payload["metadata_filtering"] = {"mode": "automatic"} + + with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"): + response = inner_api_app.test_client().post( + "/inner/api/knowledge/retrieve", + json=payload, + headers=_headers(), + ) + + assert response.status_code == 400 + assert response.get_json()["code"] == "invalid_request" + + @patch("controllers.inner_api.knowledge.retrieval.InnerKnowledgeRetrievalService.retrieve") + def test_post_returns_502_for_external_knowledge_failure(self, mock_retrieve, inner_api_app: Flask): + mock_retrieve.side_effect = ExternalKnowledgeRetrievalError("upstream failed") + + with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"): + response = inner_api_app.test_client().post( + "/inner/api/knowledge/retrieve", + json=_payload(), + headers=_headers(), + ) + + assert response.status_code == 502 + assert response.get_json()["code"] == "external_knowledge_failed" + + @patch("controllers.inner_api.knowledge.retrieval.InnerKnowledgeRetrievalService.retrieve") + def test_post_returns_service_response(self, mock_retrieve, inner_api_app: Flask): + mock_retrieve.return_value = InnerKnowledgeRetrieveResponse( + results=[ + Source( + metadata=SourceMetadata( + dataset_id="dataset-1", + dataset_name="Docs", + document_id="document-1", + document_name="FAQ.md", + data_source_type="upload_file", + ), + title="FAQ.md", + files=[], + content="Reset your password from settings.", + summary=None, + ) + ], + usage=InnerKnowledgeRetrieveUsage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + prompt_unit_price="0", + completion_unit_price="0", + prompt_price_unit="0.001", + completion_price_unit="0.001", + prompt_price="0", + completion_price="0", + total_price="0", + currency="USD", + latency=0, + ), + ) + + with patch("configs.dify_config.INNER_API", True), patch("configs.dify_config.INNER_API_KEY", "inner-key"): + response = inner_api_app.test_client().post( + "/inner/api/knowledge/retrieve", + json=_payload(), + headers=_headers(), + ) + + assert response.status_code == 200 + data = response.get_json() + assert data["results"][0]["metadata"]["_source"] == "knowledge" + assert data["results"][0]["title"] == "FAQ.md" + assert data["usage"]["total_tokens"] == 0 diff --git a/api/tests/unit_tests/core/app/apps/agent_app/test_runtime_request_builder.py b/api/tests/unit_tests/core/app/apps/agent_app/test_runtime_request_builder.py index 83f9b697b7..85a2423f6b 100644 --- a/api/tests/unit_tests/core/app/apps/agent_app/test_runtime_request_builder.py +++ b/api/tests/unit_tests/core/app/apps/agent_app/test_runtime_request_builder.py @@ -144,6 +144,40 @@ class TestAgentAppRuntimeRequestBuilder: assert result.redacted_request["composition"]["layers"][-1]["config"]["credentials"] == "[REDACTED]" assert result.metadata["conversation_id"] == "conv-1" + def test_build_maps_agent_soul_knowledge_to_knowledge_layer(self): + soul = AgentSoulConfig.model_validate( + { + "model": { + "plugin_id": "langgenius/openai", + "model_provider": "langgenius/openai/openai", + "model": "gpt-4o-mini", + }, + "knowledge": { + "datasets": [{"id": "dataset-1"}, {"id": "dataset-2"}], + "query_config": { + "top_k": 3, + "score_threshold": 0.5, + "score_threshold_enabled": False, + }, + }, + } + ) + builder = AgentAppRuntimeRequestBuilder( + credentials_provider=_FakeCredentialsProvider(), + plugin_tools_builder=_NoToolsBuilder(), # type: ignore[arg-type] + ) + + result = builder.build(_ctx(soul)) + + knowledge = next(layer for layer in result.request.composition.layers if layer.name == "knowledge") + assert knowledge.type == "dify.knowledge_base" + assert knowledge.deps == {"execution_context": "execution_context"} + dumped_config = knowledge.config.model_dump(mode="json", by_alias=True) + assert dumped_config["dataset_ids"] == ["dataset-1", "dataset-2"] + assert dumped_config["retrieval"]["mode"] == "multiple" + assert dumped_config["retrieval"]["top_k"] == 3 + assert dumped_config["retrieval"]["score_threshold"] == 0.0 + def test_build_raises_when_model_missing(self): builder = AgentAppRuntimeRequestBuilder( credentials_provider=_FakeCredentialsProvider(), diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_attachment_entry.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_attachment_entry.py new file mode 100644 index 0000000000..adcf5585d3 --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_attachment_entry.py @@ -0,0 +1,36 @@ +"""Focused tests for attachment-aware dataset retrieval entry behavior.""" + +from unittest.mock import MagicMock, patch + +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest + + +def test_knowledge_retrieval_allows_attachment_only_requests() -> None: + retrieval = DatasetRetrieval() + available_dataset = MagicMock(id="dataset-1") + + request = KnowledgeRetrievalRequest( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="account", + dataset_ids=["dataset-1"], + query=None, + retrieval_mode="multiple", + top_k=4, + score_threshold=0.0, + reranking_mode="reranking_model", + reranking_enable=True, + attachment_ids=["attachment-1"], + ) + + with ( + patch.object(retrieval, "_check_knowledge_rate_limit"), + patch.object(retrieval, "_get_available_datasets", return_value=[available_dataset]), + patch.object(retrieval, "multiple_retrieve", return_value=[]) as mock_multiple, + ): + result = retrieval.knowledge_retrieval(request) + + assert result == [] + mock_multiple.assert_called_once() diff --git a/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py index f402f851b8..9313aea51e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py @@ -496,6 +496,119 @@ def test_builds_workflow_run_request_with_dify_plugin_tools_layer(): assert plugin_tools_builder.last_invoke_from == context.dify_context.invoke_from +def test_build_maps_agent_soul_knowledge_to_knowledge_layer_config(): + context = _context() + snapshot = AgentConfigSnapshot( + id="snapshot-1", + tenant_id="tenant-1", + agent_id="agent-1", + version=1, + config_snapshot=AgentSoulConfig.model_validate( + { + "prompt": {"system_prompt": "You are careful."}, + "model": { + "plugin_id": "langgenius/openai", + "model_provider": "openai", + "model": "gpt-test", + }, + "knowledge": { + "datasets": [{"id": "dataset-1"}, {"id": " "}, {"id": "dataset-2"}], + "query_config": { + "top_k": 6, + "score_threshold": 0.4, + "score_threshold_enabled": True, + }, + }, + } + ), + ) + context = replace(context, snapshot=snapshot) + + result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()).build(context) + + dumped = result.request.model_dump(mode="json") + layers = {layer["name"]: layer for layer in dumped["composition"]["layers"]} + knowledge_layer = layers["knowledge"] + assert knowledge_layer["type"] == "dify.knowledge_base" + assert knowledge_layer["deps"] == {"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID} + assert knowledge_layer["config"] == { + "dataset_ids": ["dataset-1", "dataset-2"], + "retrieval": { + "mode": "multiple", + "top_k": 6, + "score_threshold": 0.4, + "reranking_mode": "reranking_model", + "reranking_enable": True, + "reranking_model": None, + "weights": None, + "model": None, + }, + "metadata_filtering": {"mode": "disabled", "metadata_model_config": None, "conditions": None}, + "max_result_content_chars": 2000, + "max_observation_chars": 12000, + } + + +def test_build_knowledge_layer_uses_stable_default_top_k_when_query_config_omits_it(): + context = _context() + snapshot = AgentConfigSnapshot( + id="snapshot-1", + tenant_id="tenant-1", + agent_id="agent-1", + version=1, + config_snapshot=AgentSoulConfig.model_validate( + { + "prompt": {"system_prompt": "You are careful."}, + "model": { + "plugin_id": "langgenius/openai", + "model_provider": "openai", + "model": "gpt-test", + }, + "knowledge": { + "datasets": [{"id": "dataset-1"}], + "query_config": {}, + }, + } + ), + ) + context = replace(context, snapshot=snapshot) + + result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()).build(context) + + dumped = result.request.model_dump(mode="json") + knowledge_layer = next(layer for layer in dumped["composition"]["layers"] if layer["name"] == "knowledge") + assert knowledge_layer["config"]["retrieval"]["top_k"] == 4 + + +def test_build_skips_knowledge_layer_when_agent_soul_has_no_valid_dataset_ids(): + context = _context() + snapshot = AgentConfigSnapshot( + id="snapshot-1", + tenant_id="tenant-1", + agent_id="agent-1", + version=1, + config_snapshot=AgentSoulConfig.model_validate( + { + "prompt": {"system_prompt": "You are careful."}, + "model": { + "plugin_id": "langgenius/openai", + "model_provider": "openai", + "model": "gpt-test", + }, + "knowledge": { + "datasets": [{"id": " "}, {}], + }, + } + ), + ) + context = replace(context, snapshot=snapshot) + + result = WorkflowAgentRuntimeRequestBuilder(credentials_provider=FakeCredentialsProvider()).build(context) + + dumped = result.request.model_dump(mode="json") + assert all(layer["name"] != "knowledge" for layer in dumped["composition"]["layers"]) + + def test_build_passes_saved_session_snapshot_to_agent_backend_request(): session_snapshot = CompositorSessionSnapshot(layers=[]) context = replace(_context(), session_snapshot=session_snapshot) @@ -868,3 +981,36 @@ def test_feature_manifest_marks_human_supported_when_configured(): assert manifest["reserved_status"]["human"] == "supported_by_ask_human_hitl" # configured human no longer produces a "not executed" warning assert all("human" not in w["section"] for w in manifest["unsupported_runtime_warnings"]) + + +def test_feature_manifest_marks_knowledge_supported_without_warning_when_configured(): + from core.workflow.nodes.agent_v2.runtime_feature_manifest import build_runtime_feature_manifest + + soul = AgentSoulConfig.model_validate( + { + "knowledge": { + "datasets": [{"id": "dataset-1", "name": "Product Docs"}], + } + } + ) + + manifest = build_runtime_feature_manifest(soul) + assert "knowledge" in manifest["supported"] + assert "knowledge" not in manifest["reserved"] + assert manifest["reserved_status"]["knowledge"] == "supported_by_knowledge_layer" + assert all("knowledge" not in w["section"] for w in manifest["unsupported_runtime_warnings"]) + + +def test_feature_manifest_treats_blank_knowledge_dataset_ids_as_not_configured(): + from core.workflow.nodes.agent_v2.runtime_feature_manifest import build_runtime_feature_manifest + + soul = AgentSoulConfig.model_validate( + { + "knowledge": { + "datasets": [{"id": " "}, {}], + } + } + ) + + manifest = build_runtime_feature_manifest(soul) + assert manifest["reserved_status"]["knowledge"] == "not_configured" diff --git a/api/tests/unit_tests/services/test_external_dataset_service.py b/api/tests/unit_tests/services/test_external_dataset_service.py index fdea0ba869..143283c0ae 100644 --- a/api/tests/unit_tests/services/test_external_dataset_service.py +++ b/api/tests/unit_tests/services/test_external_dataset_service.py @@ -21,6 +21,7 @@ from services.entities.external_knowledge_entities.external_knowledge_entities i ExternalKnowledgeApiSetting, ) from services.errors.dataset import DatasetNameDuplicateError +from services.errors.knowledge_retrieval import ExternalKnowledgeRetrievalError from services.external_knowledge_service import ExternalDatasetService @@ -1558,7 +1559,7 @@ class TestExternalDatasetServiceFetchRetrieval: mock_db.session.scalar.return_value = None # Act & Assert - with pytest.raises(ValueError, match="external knowledge binding not found"): + with pytest.raises(ExternalKnowledgeRetrievalError, match="external knowledge binding not found"): ExternalDatasetService.fetch_external_knowledge_retrieval("tenant-123", "dataset-123", "query", {}) @patch("services.external_knowledge_service.db") @@ -1569,7 +1570,7 @@ class TestExternalDatasetServiceFetchRetrieval: mock_db.session.scalar.side_effect = [binding, None] # Act & Assert - with pytest.raises(ValueError, match="external api template not found"): + with pytest.raises(ExternalKnowledgeRetrievalError, match="external api template not found"): ExternalDatasetService.fetch_external_knowledge_retrieval("tenant-123", "dataset-123", "query", {}) @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") @@ -1643,7 +1644,7 @@ class TestExternalDatasetServiceFetchRetrieval: mock_process.return_value = mock_response # Act & Assert - with pytest.raises(Exception, match="Internal Server Error: Database connection failed"): + with pytest.raises(ExternalKnowledgeRetrievalError, match="Internal Server Error: Database connection failed"): ExternalDatasetService.fetch_external_knowledge_retrieval( "tenant-123", "dataset-123", "query", {"top_k": 5} ) @@ -1684,7 +1685,7 @@ class TestExternalDatasetServiceFetchRetrieval: mock_process.return_value = mock_response # Act & Assert - with pytest.raises(ValueError, match=re.escape(error_message)): + with pytest.raises(ExternalKnowledgeRetrievalError, match=re.escape(error_message)): ExternalDatasetService.fetch_external_knowledge_retrieval(tenant_id, dataset_id, "query", {"top_k": 5}) @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") @@ -1703,7 +1704,79 @@ class TestExternalDatasetServiceFetchRetrieval: mock_process.return_value = mock_response # Act & Assert - with pytest.raises(ValueError): + with pytest.raises(ExternalKnowledgeRetrievalError): + ExternalDatasetService.fetch_external_knowledge_retrieval( + "tenant-123", "dataset-123", "query", {"top_k": 5} + ) + + @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") + @patch("services.external_knowledge_service.db") + def test_fetch_external_knowledge_retrieval_invalid_json_response(self, mock_db, mock_process, factory): + """Test malformed JSON success responses are normalized to external retrieval errors.""" + binding = factory.create_external_knowledge_binding_mock() + api = factory.create_external_knowledge_api_mock() + + mock_db.session.scalar.side_effect = [binding, api] + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_process.return_value = mock_response + + with pytest.raises(ExternalKnowledgeRetrievalError, match="invalid external knowledge response"): + ExternalDatasetService.fetch_external_knowledge_retrieval( + "tenant-123", "dataset-123", "query", {"top_k": 5} + ) + + @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") + @patch("services.external_knowledge_service.db") + def test_fetch_external_knowledge_retrieval_invalid_success_payload_shape(self, mock_db, mock_process, factory): + """Test malformed success payload shapes are normalized to external retrieval errors.""" + binding = factory.create_external_knowledge_binding_mock() + api = factory.create_external_knowledge_api_mock() + + mock_db.session.scalar.side_effect = [binding, api] + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = ["not-a-dict"] + mock_process.return_value = mock_response + + with pytest.raises(ExternalKnowledgeRetrievalError, match="invalid external knowledge response"): + ExternalDatasetService.fetch_external_knowledge_retrieval( + "tenant-123", "dataset-123", "query", {"top_k": 5} + ) + + @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") + @patch("services.external_knowledge_service.db") + def test_fetch_external_knowledge_retrieval_invalid_records_shape(self, mock_db, mock_process, factory): + """Test non-list records payloads are normalized to external retrieval errors.""" + binding = factory.create_external_knowledge_binding_mock() + api = factory.create_external_knowledge_api_mock() + + mock_db.session.scalar.side_effect = [binding, api] + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"records": {"unexpected": "shape"}} + mock_process.return_value = mock_response + + with pytest.raises(ExternalKnowledgeRetrievalError, match="invalid external knowledge response"): + ExternalDatasetService.fetch_external_knowledge_retrieval( + "tenant-123", "dataset-123", "query", {"top_k": 5} + ) + + @patch("services.external_knowledge_service.ExternalDatasetService.process_external_api") + @patch("services.external_knowledge_service.db") + def test_fetch_external_knowledge_retrieval_wraps_transport_errors(self, mock_db, mock_process, factory): + """Test transport/runtime failures are normalized to external retrieval errors.""" + binding = factory.create_external_knowledge_binding_mock() + api = factory.create_external_knowledge_api_mock() + + mock_db.session.scalar.side_effect = [binding, api] + mock_process.side_effect = RuntimeError("connection reset by peer") + + with pytest.raises(ExternalKnowledgeRetrievalError, match="connection reset by peer"): ExternalDatasetService.fetch_external_knowledge_retrieval( "tenant-123", "dataset-123", "query", {"top_k": 5} ) diff --git a/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py b/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py new file mode 100644 index 0000000000..287d787ad7 --- /dev/null +++ b/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py @@ -0,0 +1,218 @@ +"""Unit tests for the inner knowledge retrieval service.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from core.workflow.nodes.knowledge_retrieval.retrieval import Source, SourceMetadata +from services.entities.knowledge_retrieval_inner import InnerKnowledgeRetrieveRequest +from services.errors.knowledge_retrieval import ( + InnerKnowledgeRetrieveAppNotFoundError, + InnerKnowledgeRetrieveAppTenantMismatchError, + InnerKnowledgeRetrieveDatasetNotFoundError, + InnerKnowledgeRetrieveDatasetTenantMismatchError, +) +from services.knowledge_retrieval_inner_service import InnerKnowledgeRetrievalService + + +def _build_request(**overrides): + payload = { + "caller": { + "tenant_id": "tenant-1", + "user_id": "user-1", + "app_id": "app-1", + "user_from": "account", + "invoke_from": "workflow", + }, + "dataset_ids": ["dataset-1", "dataset-2"], + "query": "how to reset password", + "retrieval": { + "mode": "multiple", + "top_k": 4, + "score_threshold": 0.25, + "reranking_mode": "reranking_model", + "reranking_enable": True, + "reranking_model": { + "provider": "cohere", + "model": "rerank-english-v3.0", + }, + }, + "metadata_filtering": { + "mode": "manual", + "conditions": { + "logical_operator": "and", + "conditions": [ + { + "name": "category", + "comparison_operator": "contains", + "value": "pricing", + } + ], + }, + }, + "attachment_ids": ["attachment-1"], + } + payload.update(overrides) + return InnerKnowledgeRetrieveRequest.model_validate(payload) + + +def _build_source() -> Source: + return Source( + metadata=SourceMetadata( + dataset_id="dataset-1", + dataset_name="Docs", + document_id="document-1", + document_name="FAQ.md", + data_source_type="upload_file", + ), + title="FAQ.md", + files=[], + content="Reset your password from settings.", + summary=None, + ) + + +class TestInnerKnowledgeRetrievalService: + @patch("services.knowledge_retrieval_inner_service.DatasetRetrieval") + @patch("services.knowledge_retrieval_inner_service.db") + def test_retrieve_maps_multiple_request_and_skips_enable_api_check(self, mock_db, mock_rag_cls): + request = _build_request() + mock_app = MagicMock(id="app-1", tenant_id="tenant-1") + dataset_1 = MagicMock(id="dataset-1", tenant_id="tenant-1", enable_api=False) + dataset_2 = MagicMock(id="dataset-2", tenant_id="tenant-1", enable_api=True) + mock_db.session.scalar.return_value = mock_app + mock_db.session.scalars.return_value.all.return_value = [dataset_1, dataset_2] + + rag = MagicMock() + rag.knowledge_retrieval.return_value = [_build_source()] + rag.llm_usage = { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "prompt_unit_price": "0", + "completion_unit_price": "0", + "prompt_price_unit": "0.001", + "completion_price_unit": "0.001", + "prompt_price": "0", + "completion_price": "0", + "total_price": "0", + "currency": "USD", + "latency": 0, + } + mock_rag_cls.return_value = rag + + response = InnerKnowledgeRetrievalService().retrieve(request) + + rag_request = rag.knowledge_retrieval.call_args.kwargs["request"] + assert rag_request.tenant_id == "tenant-1" + assert rag_request.app_id == "app-1" + assert rag_request.user_id == "user-1" + assert rag_request.dataset_ids == ["dataset-1", "dataset-2"] + assert rag_request.query == "how to reset password" + assert rag_request.retrieval_mode == "multiple" + assert rag_request.top_k == 4 + assert rag_request.score_threshold == 0.25 + assert rag_request.reranking_model == { + "reranking_provider_name": "cohere", + "reranking_model_name": "rerank-english-v3.0", + } + assert rag_request.metadata_filtering_mode == "manual" + assert rag_request.metadata_filtering_conditions is not None + metadata_conditions = rag_request.metadata_filtering_conditions.model_dump(mode="python") + assert metadata_conditions["logical_operator"] == "and" + assert metadata_conditions["conditions"] is not None + assert metadata_conditions["conditions"][0]["name"] == "category" + assert rag_request.attachment_ids == ["attachment-1"] + assert response.results[0].title == "FAQ.md" + assert response.usage.currency == "USD" + + @patch("services.knowledge_retrieval_inner_service.DatasetRetrieval") + @patch("services.knowledge_retrieval_inner_service.db") + def test_retrieve_maps_single_request(self, mock_db, mock_rag_cls): + request = _build_request( + dataset_ids=["dataset-1"], + retrieval={ + "mode": "single", + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "mode": "chat", + "completion_params": {"temperature": 0}, + }, + }, + metadata_filtering={ + "mode": "automatic", + "model_config": { + "provider": "openai", + "name": "gpt-4o-mini", + "mode": "chat", + "completion_params": {"temperature": 0}, + }, + }, + attachment_ids=[], + ) + mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") + mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] + + rag = MagicMock() + rag.knowledge_retrieval.return_value = [] + rag.llm_usage = { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3, + "prompt_unit_price": "0", + "completion_unit_price": "0", + "prompt_price_unit": "0.001", + "completion_price_unit": "0.001", + "prompt_price": "0", + "completion_price": "0", + "total_price": "0", + "currency": "USD", + "latency": 1, + } + mock_rag_cls.return_value = rag + + InnerKnowledgeRetrievalService().retrieve(request) + + rag_request = rag.knowledge_retrieval.call_args.kwargs["request"] + assert rag_request.retrieval_mode == "single" + assert rag_request.model_provider == "openai" + assert rag_request.model_name == "gpt-4o-mini" + assert rag_request.model_mode == "chat" + assert rag_request.completion_params == {"temperature": 0} + assert rag_request.metadata_filtering_mode == "automatic" + assert rag_request.metadata_model_config is not None + assert rag_request.metadata_model_config.provider == "openai" + + @patch("services.knowledge_retrieval_inner_service.db") + def test_retrieve_raises_when_app_missing(self, mock_db): + mock_db.session.scalar.return_value = None + + with pytest.raises(InnerKnowledgeRetrieveAppNotFoundError): + InnerKnowledgeRetrievalService().retrieve(_build_request()) + + @patch("services.knowledge_retrieval_inner_service.db") + def test_retrieve_raises_when_app_belongs_to_other_tenant(self, mock_db): + mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-2") + + with pytest.raises(InnerKnowledgeRetrieveAppTenantMismatchError): + InnerKnowledgeRetrievalService().retrieve(_build_request()) + + @patch("services.knowledge_retrieval_inner_service.db") + def test_retrieve_raises_when_dataset_missing(self, mock_db): + mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") + mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] + + with pytest.raises(InnerKnowledgeRetrieveDatasetNotFoundError): + InnerKnowledgeRetrievalService().retrieve(_build_request()) + + @patch("services.knowledge_retrieval_inner_service.db") + def test_retrieve_raises_when_dataset_belongs_to_other_tenant(self, mock_db): + mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") + mock_db.session.scalars.return_value.all.return_value = [ + MagicMock(id="dataset-1", tenant_id="tenant-1"), + MagicMock(id="dataset-2", tenant_id="tenant-2"), + ] + + with pytest.raises(InnerKnowledgeRetrieveDatasetTenantMismatchError): + InnerKnowledgeRetrievalService().retrieve(_build_request()) diff --git a/dify-agent/src/dify_agent/layers/execution_context/__init__.py b/dify-agent/src/dify_agent/layers/execution_context/__init__.py index f1534bceff..aaf031501b 100644 --- a/dify-agent/src/dify_agent/layers/execution_context/__init__.py +++ b/dify-agent/src/dify_agent/layers/execution_context/__init__.py @@ -2,7 +2,8 @@ Implementation layers live in sibling modules and require server-side runtime dependencies. Keep this package root import-safe for client code that only -needs to build run requests. +needs to build run requests. Knowledge layers read ``user_from`` from the same +DTO, but that runtime implementation still lives in sibling modules. """ from dify_agent.layers.execution_context.configs import ( diff --git a/dify-agent/src/dify_agent/layers/execution_context/configs.py b/dify-agent/src/dify_agent/layers/execution_context/configs.py index 2b042add7b..21abe9b948 100644 --- a/dify-agent/src/dify_agent/layers/execution_context/configs.py +++ b/dify-agent/src/dify_agent/layers/execution_context/configs.py @@ -4,9 +4,11 @@ This layer carries both Dify product execution context (tenant, user, workflow, invoke source) and Agent backend runtime mode. The product-facing fields are used by trusted server-side boundaries such as the Agent Stub when they need to reconstruct Dify API file-access scope without granting the sandbox any -direct inner-API credentials. Server-only plugin-daemon settings are injected -by the runtime provider factory and therefore do not appear in this public -schema. +direct inner-API credentials. Knowledge-base layers also read ``user_from`` from +this shared config so the inner Dify API can distinguish platform-user and +end-user searches without making that caller identity model-controlled. +Server-only plugin-daemon settings are injected by the runtime provider factory +and therefore do not appear in this public schema. """ from typing import ClassVar, Final, Literal, TypeAlias @@ -42,7 +44,7 @@ class DifyExecutionContextLayerConfig(LayerConfig): tenant_id: str user_id: str | None = None - user_from: DifyExecutionContextUserFrom + user_from: DifyExecutionContextUserFrom | None = None app_id: str | None = None workflow_id: str | None = None workflow_run_id: str | None = None diff --git a/dify-agent/src/dify_agent/layers/execution_context/layer.py b/dify-agent/src/dify_agent/layers/execution_context/layer.py index 06ef41ecf4..55a1920f53 100644 --- a/dify-agent/src/dify_agent/layers/execution_context/layer.py +++ b/dify-agent/src/dify_agent/layers/execution_context/layer.py @@ -1,7 +1,8 @@ """Runtime Dify execution-context layer. The public config carries Dify-owned execution identifiers plus the tenant/user -daemon context needed by plugin-backed business layers. Server-only daemon URL +daemon context needed by plugin-backed business layers and the caller identity +needed by knowledge-base layers. Server-only daemon URL and API key are injected by the provider factory. The layer is intentionally config/settings-only under Agenton's state-only core: it does not open, cache, close, or snapshot HTTP clients, and its lifecycle hooks remain the inherited @@ -29,7 +30,7 @@ from dify_agent.layers.execution_context.configs import ( class DifyExecutionContextLayer(PlainLayer[NoLayerDeps, DifyExecutionContextLayerConfig, EmptyRuntimeState]): """Layer that carries Dify execution context without owning live resources.""" - type_id: ClassVar[str] = DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID + type_id: ClassVar[str | None] = DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID config: DifyExecutionContextLayerConfig daemon_url: str diff --git a/dify-agent/src/dify_agent/layers/knowledge/__init__.py b/dify-agent/src/dify_agent/layers/knowledge/__init__.py new file mode 100644 index 0000000000..569512d800 --- /dev/null +++ b/dify-agent/src/dify_agent/layers/knowledge/__init__.py @@ -0,0 +1,27 @@ +"""Client-safe exports for Dify knowledge-base layer DTOs and type ids. + +Implementation layers and HTTP clients live in sibling modules so this package +root stays import-safe for callers that only need to construct run requests. +""" + +from dify_agent.layers.knowledge.configs import ( + DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID, + DifyKnowledgeBaseLayerConfig, + DifyKnowledgeMetadataCondition, + DifyKnowledgeMetadataConditions, + DifyKnowledgeMetadataFilteringConfig, + DifyKnowledgeModelConfig, + DifyKnowledgeRerankingModelConfig, + DifyKnowledgeRetrievalConfig, +) + +__all__ = [ + "DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID", + "DifyKnowledgeBaseLayerConfig", + "DifyKnowledgeMetadataCondition", + "DifyKnowledgeMetadataConditions", + "DifyKnowledgeMetadataFilteringConfig", + "DifyKnowledgeModelConfig", + "DifyKnowledgeRerankingModelConfig", + "DifyKnowledgeRetrievalConfig", +] diff --git a/dify-agent/src/dify_agent/layers/knowledge/client.py b/dify-agent/src/dify_agent/layers/knowledge/client.py new file mode 100644 index 0000000000..b80e363190 --- /dev/null +++ b/dify-agent/src/dify_agent/layers/knowledge/client.py @@ -0,0 +1,214 @@ +"""Async client for the Dify API inner knowledge retrieval endpoint. + +This wrapper owns only request/response mapping and error normalization for +``POST /inner/api/knowledge/retrieve``. The shared ``httpx.AsyncClient`` is +supplied by the FastAPI lifespan/runtime and must stay open for the caller. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import ClassVar + +import httpx +from pydantic import BaseModel, ConfigDict, Field, JsonValue, ValidationError + +from dify_agent.layers.knowledge.configs import ( + DifyKnowledgeMetadataFilteringConfig, + DifyKnowledgeRetrievalConfig, +) + + +class DifyKnowledgeBaseClientError(RuntimeError): + """Raised when the inner knowledge retrieval HTTP boundary fails.""" + + status_code: int | None + error_code: str | None + retryable: bool + + def __init__( + self, + message: str, + *, + status_code: int | None = None, + error_code: str | None = None, + retryable: bool, + ) -> None: + self.status_code = status_code + self.error_code = error_code + self.retryable = retryable + super().__init__(message) + + +class _DifyKnowledgeCaller(BaseModel): + tenant_id: str + user_id: str + app_id: str + user_from: str + invoke_from: str + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class _DifyKnowledgeRetrieveRequest(BaseModel): + caller: _DifyKnowledgeCaller + dataset_ids: list[str] + query: str + retrieval: dict[str, JsonValue] + metadata_filtering: dict[str, JsonValue] + attachment_ids: list[str] = Field(default_factory=list) + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class DifyKnowledgeResultMetadata(BaseModel): + source: str | None = Field(default=None, alias="_source") + dataset_id: str | None = None + dataset_name: str | None = None + document_id: str | None = None + document_name: str | None = None + score: float | int | str | None = None + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow", populate_by_name=True) + + +class DifyKnowledgeResult(BaseModel): + metadata: DifyKnowledgeResultMetadata + title: str | None = None + files: list[JsonValue] | None = None + content: str | None = None + summary: str | None = None + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class DifyKnowledgeRetrieveResponse(BaseModel): + results: list[DifyKnowledgeResult] + usage: dict[str, JsonValue] = Field(default_factory=dict) + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +@dataclass(slots=True) +class DifyKnowledgeBaseClient: + """Boundary client for the Dify API inner knowledge retrieval endpoint.""" + + base_url: str + api_key: str = field(repr=False) + http_client: httpx.AsyncClient = field(repr=False) + + def __post_init__(self) -> None: + self.base_url = self.base_url.rstrip("/") + + async def retrieve( + self, + *, + tenant_id: str, + user_id: str, + app_id: str, + user_from: str, + invoke_from: str, + dataset_ids: list[str], + query: str, + retrieval: DifyKnowledgeRetrievalConfig, + metadata_filtering: DifyKnowledgeMetadataFilteringConfig, + ) -> DifyKnowledgeRetrieveResponse: + """Call the inner API and return parsed retrieval results. + + Raises: + DifyKnowledgeBaseClientError: For HTTP, transport, or response-shape + failures. Only ``429``, ``502``, and transport/timeout failures + are marked retryable because the model may continue gracefully in + those temporary-unavailable cases. + """ + request_payload = _DifyKnowledgeRetrieveRequest( + caller=_DifyKnowledgeCaller( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from=user_from, + invoke_from=invoke_from, + ), + dataset_ids=dataset_ids, + query=query, + retrieval=retrieval.to_request_payload(), + metadata_filtering=metadata_filtering.to_request_payload(), + ) + + try: + response = await self.http_client.post( + f"{self.base_url}/inner/api/knowledge/retrieve", + headers={ + "X-Inner-Api-Key": self.api_key, + "Content-Type": "application/json", + }, + json=request_payload.model_dump(mode="json", by_alias=True), + ) + except (httpx.InvalidURL, httpx.UnsupportedProtocol) as exc: + raise DifyKnowledgeBaseClientError( + f"Knowledge base search is misconfigured: {exc}", + retryable=False, + ) from exc + except httpx.TimeoutException as exc: + raise DifyKnowledgeBaseClientError( + "Knowledge base search timed out.", + retryable=True, + ) from exc + except httpx.RequestError as exc: + raise DifyKnowledgeBaseClientError( + f"Knowledge base search request failed: {exc}", + retryable=True, + ) from exc + + if response.status_code >= 400: + raise _build_http_error(response) + + try: + return DifyKnowledgeRetrieveResponse.model_validate_json(response.text) + except ValidationError as exc: + raise DifyKnowledgeBaseClientError( + "Invalid knowledge retrieval response from Dify API.", + status_code=response.status_code, + error_code="invalid_response", + retryable=False, + ) from exc + + +def _build_http_error(response: httpx.Response) -> DifyKnowledgeBaseClientError: + detail = _decode_error_detail(response) + retryable = response.status_code in {429, 502} + message = detail["message"] or f"HTTP {response.status_code}" + return DifyKnowledgeBaseClientError( + message, + status_code=response.status_code, + error_code=detail["error_code"], + retryable=retryable, + ) + + +def _decode_error_detail(response: httpx.Response) -> dict[str, str | None]: + raw_body = response.text + try: + payload = response.json() + except json.JSONDecodeError: + payload = None + + if isinstance(payload, dict): + error_code = payload.get("code") + message = payload.get("message") + return { + "error_code": error_code if isinstance(error_code, str) else None, + "message": message if isinstance(message, str) and message else raw_body or f"HTTP {response.status_code}", + } + + return {"error_code": None, "message": raw_body or f"HTTP {response.status_code}"} + + +__all__ = [ + "DifyKnowledgeBaseClient", + "DifyKnowledgeBaseClientError", + "DifyKnowledgeResult", + "DifyKnowledgeResultMetadata", + "DifyKnowledgeRetrieveResponse", +] diff --git a/dify-agent/src/dify_agent/layers/knowledge/configs.py b/dify-agent/src/dify_agent/layers/knowledge/configs.py new file mode 100644 index 0000000000..9ada075d1c --- /dev/null +++ b/dify-agent/src/dify_agent/layers/knowledge/configs.py @@ -0,0 +1,200 @@ +"""Client-safe DTOs for the Dify knowledge-base Agenton layer. + +The public layer config exposes only static retrieval controls: dataset ids, +retrieval strategy, metadata filtering, and observation-size limits. The agent +model itself should only ever see a single ``query`` tool argument; tenant/ +app/user context comes from the execution-context layer and the actual +retrieval is delegated to the Dify API inner endpoint. Tool naming is not +caller-configurable: the runtime always exposes the same stable knowledge-base +search tool. +""" + +from __future__ import annotations + +from typing import ClassVar, Final, Literal + +from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator, model_validator + +from agenton.layers import LayerConfig + +type DifyKnowledgeMetadataComparisonOperator = Literal[ + "contains", + "not contains", + "start with", + "end with", + "is", + "is not", + "empty", + "not empty", + "in", + "not in", + "=", + "≠", + ">", + "<", + "≥", + "≤", + "before", + "after", +] + +DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID: Final[str] = "dify.knowledge_base" + + +class DifyKnowledgeModelConfig(BaseModel): + """Static model configuration forwarded to the inner retrieval API.""" + + provider: str + name: str + mode: str + completion_params: dict[str, JsonValue] = Field(default_factory=dict) + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class DifyKnowledgeRerankingModelConfig(BaseModel): + """Reranking model settings for multiple-mode retrieval.""" + + provider: str + model: str + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class DifyKnowledgeRetrievalConfig(BaseModel): + """Static retrieval controls mirrored into the inner API request.""" + + mode: Literal["multiple", "single"] + top_k: int | None = Field(default=None, ge=1) + score_threshold: float = 0.0 + reranking_mode: str = "reranking_model" + reranking_enable: bool = True + reranking_model: DifyKnowledgeRerankingModelConfig | None = None + weights: dict[str, JsonValue] | None = None + model: DifyKnowledgeModelConfig | None = None + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + @model_validator(mode="after") + def validate_mode_specific_fields(self) -> DifyKnowledgeRetrievalConfig: + if self.mode == "multiple" and self.top_k is None: + raise ValueError("retrieval.top_k is required for multiple mode") + if self.mode == "single" and self.model is None: + raise ValueError("retrieval.model is required for single mode") + return self + + def to_request_payload(self) -> dict[str, JsonValue]: + """Serialize the retrieval config into the inner API request shape.""" + payload: dict[str, JsonValue] = { + "mode": self.mode, + "score_threshold": self.score_threshold, + "reranking_mode": self.reranking_mode, + "reranking_enable": self.reranking_enable, + } + if self.mode == "multiple": + payload["top_k"] = self.top_k + payload["reranking_model"] = ( + self.reranking_model.model_dump(mode="json") if self.reranking_model is not None else None + ) + payload["weights"] = self.weights + else: + payload["model"] = self.model.model_dump(mode="json") if self.model is not None else None + return payload + + +class DifyKnowledgeMetadataCondition(BaseModel): + """One manual metadata filter clause.""" + + name: str + comparison_operator: DifyKnowledgeMetadataComparisonOperator + value: str | list[str] | int | float | None = None + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class DifyKnowledgeMetadataConditions(BaseModel): + """Boolean composition for manual metadata filtering.""" + + logical_operator: Literal["and", "or"] = "and" + conditions: list[DifyKnowledgeMetadataCondition] + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + +class DifyKnowledgeMetadataFilteringConfig(BaseModel): + """Static metadata filtering controls for the inner API request.""" + + mode: Literal["disabled", "automatic", "manual"] = "disabled" + metadata_model_config: DifyKnowledgeModelConfig | None = Field(default=None, alias="model_config") + conditions: DifyKnowledgeMetadataConditions | None = None + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", populate_by_name=True) + + @model_validator(mode="after") + def validate_mode_specific_fields(self) -> DifyKnowledgeMetadataFilteringConfig: + if self.mode == "automatic" and self.metadata_model_config is None: + raise ValueError("metadata_filtering.model_config is required for automatic mode") + if self.mode == "manual" and (self.conditions is None or not self.conditions.conditions): + raise ValueError("metadata_filtering.conditions is required for manual mode") + return self + + def to_request_payload(self) -> dict[str, JsonValue]: + """Serialize metadata filtering using the inner API request field names.""" + if self.mode == "disabled": + return {"mode": self.mode} + + payload: dict[str, JsonValue] = {"mode": self.mode} + if self.metadata_model_config is not None: + payload["model_config"] = self.metadata_model_config.model_dump(mode="json") + if self.conditions is not None: + payload["conditions"] = self.conditions.model_dump(mode="json") + return payload + + +class DifyKnowledgeBaseLayerConfig(LayerConfig): + """Public config for one model-visible knowledge search tool. + + The model only gets to choose whether to call the tool and what ``query`` + to send. Dataset ids, retrieval settings, metadata filtering, and caller + context remain config/runtime concerns outside the model-visible tool + schema. The tool name and description are fixed by the layer runtime and do + not appear in the public config DTO. + """ + + dataset_ids: list[str] + retrieval: DifyKnowledgeRetrievalConfig + metadata_filtering: DifyKnowledgeMetadataFilteringConfig = Field( + default_factory=DifyKnowledgeMetadataFilteringConfig + ) + max_result_content_chars: int = Field(default=2000, ge=1) + max_observation_chars: int = Field(default=12000, ge=1) + + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") + + @field_validator("dataset_ids") + @classmethod + def validate_dataset_ids(cls, value: list[str]) -> list[str]: + if not value: + raise ValueError("dataset_ids must contain at least one item") + normalized_ids = [item.strip() for item in value] + if any(not item for item in normalized_ids): + raise ValueError("dataset_ids must not contain blank items") + return normalized_ids + + @model_validator(mode="after") + def validate_observation_limits(self) -> DifyKnowledgeBaseLayerConfig: + if self.max_observation_chars < self.max_result_content_chars: + raise ValueError("max_observation_chars must be greater than or equal to max_result_content_chars") + return self + + +__all__ = [ + "DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID", + "DifyKnowledgeBaseLayerConfig", + "DifyKnowledgeMetadataCondition", + "DifyKnowledgeMetadataConditions", + "DifyKnowledgeMetadataFilteringConfig", + "DifyKnowledgeModelConfig", + "DifyKnowledgeRerankingModelConfig", + "DifyKnowledgeRetrievalConfig", +] diff --git a/dify-agent/src/dify_agent/layers/knowledge/layer.py b/dify-agent/src/dify_agent/layers/knowledge/layer.py new file mode 100644 index 0000000000..16605a9ceb --- /dev/null +++ b/dify-agent/src/dify_agent/layers/knowledge/layer.py @@ -0,0 +1,285 @@ +"""Dify knowledge-base layer exposing one model-visible search tool. + +The layer depends on ``DifyExecutionContextLayer`` for tenant/app/user/invoke +identity, keeps retrieval controls in config only, and borrows a lifespan-owned +HTTP client for each tool invocation. It never owns live clients or stores +retrieved source content in layer state. Tool identity is intentionally fixed at +runtime: callers cannot rename the knowledge tool or override its description +through public layer config because the model-visible surface must stay stable +across API-side Agent Soul mappings. +""" + +from __future__ import annotations + +from dataclasses import dataclass +import logging +from typing import ClassVar, cast + +import httpx +from pydantic_ai import RunContext, Tool +from pydantic_ai.tools import ToolDefinition +from typing_extensions import Self, override + +from agenton.layers import LayerDeps, PlainLayer +from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer +from dify_agent.layers.knowledge.client import ( + DifyKnowledgeBaseClient, + DifyKnowledgeBaseClientError, + DifyKnowledgeRetrieveResponse, +) +from dify_agent.layers.knowledge.configs import DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID, DifyKnowledgeBaseLayerConfig + +logger = logging.getLogger(__name__) + +# Fixed model-visible tool identity. These stay module-private on purpose so the +# public DTO cannot grow a parallel naming contract that diverges from the +# runtime knowledge-search surface. +_KNOWLEDGE_BASE_TOOL_NAME = "knowledge_base_search" +_KNOWLEDGE_BASE_TOOL_DESCRIPTION = "Search configured knowledge bases for information relevant to the query." +BLANK_QUERY_OBSERVATION = "knowledge base search requires a non-empty query" +NO_RESULTS_OBSERVATION = "No relevant knowledge base results were found." +TEMPORARY_UNAVAILABLE_OBSERVATION = ( + "Knowledge base search is temporarily unavailable. Please continue without it if possible." +) +QUERY_TOOL_SCHEMA = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query for the configured knowledge bases.", + } + }, + "required": ["query"], + "additionalProperties": False, +} + + +class DifyKnowledgeBaseDeps(LayerDeps): + """Dependencies required by ``DifyKnowledgeBaseLayer``.""" + + execution_context: DifyExecutionContextLayer # pyright: ignore[reportUninitializedInstanceVariable] + + +@dataclass(slots=True) +class DifyKnowledgeBaseLayer(PlainLayer[DifyKnowledgeBaseDeps, DifyKnowledgeBaseLayerConfig]): + """Layer that resolves one config-scoped knowledge search tool.""" + + type_id: ClassVar[str | None] = DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID + + config: DifyKnowledgeBaseLayerConfig + dify_api_inner_url: str + dify_api_inner_api_key: str + + @classmethod + @override + def from_config(cls, config: DifyKnowledgeBaseLayerConfig) -> Self: + """Reject construction without server-injected Dify API settings.""" + del config + raise TypeError( + "DifyKnowledgeBaseLayer requires server-side Dify API settings and must use a provider factory." + ) + + @classmethod + def from_config_with_settings( + cls, + config: DifyKnowledgeBaseLayerConfig, + *, + dify_api_inner_url: str, + dify_api_inner_api_key: str, + ) -> Self: + """Create the layer from public config plus server-only API settings.""" + return cls( + config=DifyKnowledgeBaseLayerConfig.model_validate(config), + dify_api_inner_url=dify_api_inner_url, + dify_api_inner_api_key=dify_api_inner_api_key, + ) + + async def get_tools(self, *, http_client: httpx.AsyncClient) -> list[Tool[object]]: + """Build one Pydantic AI tool that exposes only ``query`` to the model. + + Knowledge tools depend on execution-context identity that is optional for + other run types but mandatory here: ``tenant_id``, ``user_id``, + ``user_from``, ``app_id``, and ``invoke_from`` must all be present before + any HTTP request is attempted. Tool execution then follows a strict + observation policy: + + - blank ``query`` returns a local validation observation; + - retryable client failures (timeouts, connection failures, HTTP + ``429``/``502``) become a temporary-unavailable observation; + - non-retryable client failures are raised so the run fails fast. + """ + if http_client.is_closed: + raise RuntimeError("DifyKnowledgeBaseLayer.get_tools() requires an open shared HTTP client.") + + execution_context = self.deps.execution_context.config + caller = _build_caller_context(execution_context) + client = DifyKnowledgeBaseClient( + base_url=self.dify_api_inner_url, + api_key=self.dify_api_inner_api_key, + http_client=http_client, + ) + + async def knowledge_base_search(_ctx: RunContext[object], query: str) -> str: + normalized_query = query.strip() + if not normalized_query: + return BLANK_QUERY_OBSERVATION + try: + response = await client.retrieve( + tenant_id=caller["tenant_id"], + user_id=caller["user_id"], + app_id=caller["app_id"], + user_from=caller["user_from"], + invoke_from=caller["invoke_from"], + dataset_ids=list(self.config.dataset_ids), + query=normalized_query, + retrieval=self.config.retrieval, + metadata_filtering=self.config.metadata_filtering, + ) + except DifyKnowledgeBaseClientError as exc: + if exc.retryable: + logger.warning( + "knowledge base search temporarily unavailable", + extra={ + "tenant_id": caller["tenant_id"], + "app_id": caller["app_id"], + "invoke_from": caller["invoke_from"], + "error_code": exc.error_code, + "status_code": exc.status_code, + }, + ) + return TEMPORARY_UNAVAILABLE_OBSERVATION + logger.error( + "knowledge base search failed", + extra={ + "tenant_id": caller["tenant_id"], + "app_id": caller["app_id"], + "invoke_from": caller["invoke_from"], + "error_code": exc.error_code, + "status_code": exc.status_code, + }, + ) + raise + return _format_observation(response, self.config) + + async def prepare_tool_definition(_ctx: RunContext[object], tool_def: ToolDefinition) -> ToolDefinition: + return ToolDefinition( + name=tool_def.name, + description=tool_def.description, + parameters_json_schema=QUERY_TOOL_SCHEMA, + strict=tool_def.strict, + sequential=tool_def.sequential, + metadata=tool_def.metadata, + timeout=tool_def.timeout, + defer_loading=tool_def.defer_loading, + kind=tool_def.kind, + return_schema=tool_def.return_schema, + include_return_schema=tool_def.include_return_schema, + ) + + return [ + Tool( + knowledge_base_search, + takes_ctx=True, + name=_KNOWLEDGE_BASE_TOOL_NAME, + description=_KNOWLEDGE_BASE_TOOL_DESCRIPTION, + prepare=prepare_tool_definition, + ) + ] + + +def _build_caller_context(execution_context: object) -> dict[str, str]: + """Extract the inner-API caller identity from execution-context config. + + The public execution-context DTO keeps several fields optional for general + runs, but knowledge retrieval requires all of ``tenant_id``, ``user_id``, + ``user_from``, ``app_id``, and ``invoke_from``. Missing or blank values are + rejected here so misconfigured runs fail before transport rather than being + softened into tool observations. + """ + tenant_id = getattr(execution_context, "tenant_id", None) + user_id = getattr(execution_context, "user_id", None) + user_from = getattr(execution_context, "user_from", None) + app_id = getattr(execution_context, "app_id", None) + invoke_from = getattr(execution_context, "invoke_from", None) + + missing_fields = [ + field_name + for field_name, value in ( + ("tenant_id", tenant_id), + ("user_id", user_id), + ("user_from", user_from), + ("app_id", app_id), + ("invoke_from", invoke_from), + ) + if not isinstance(value, str) or not value.strip() + ] + if missing_fields: + joined_fields = ", ".join(missing_fields) + raise ValueError(f"Dify knowledge base layer requires execution context fields: {joined_fields}") + + normalized_tenant_id = cast(str, tenant_id).strip() + normalized_user_id = cast(str, user_id).strip() + normalized_user_from = cast(str, user_from).strip() + normalized_app_id = cast(str, app_id).strip() + normalized_invoke_from = cast(str, invoke_from).strip() + + return { + "tenant_id": normalized_tenant_id, + "user_id": normalized_user_id, + "user_from": normalized_user_from, + "app_id": normalized_app_id, + "invoke_from": normalized_invoke_from, + } + + +def _format_observation(response: DifyKnowledgeRetrieveResponse, config: DifyKnowledgeBaseLayerConfig) -> str: + """Render inner-API retrieval results into the model-visible tool response. + + The formatting contract is intentionally simple and stable for the model: + + - empty ``results`` returns ``NO_RESULTS_OBSERVATION``; + - non-empty results become a numbered list headed by + ``"Knowledge base search results:"``; + - each item includes title plus dataset/document/score metadata when those + fields are present; + - each content snippet is truncated by ``max_result_content_chars``; + - the final observation is truncated by ``max_observation_chars``. + """ + if not response.results: + return NO_RESULTS_OBSERVATION + + lines = ["Knowledge base search results:"] + for index, result in enumerate(response.results, start=1): + metadata = result.metadata + title = result.title or metadata.document_name or "Untitled" + lines.append(f"{index}. Title: {title}") + if metadata.dataset_name: + lines.append(f" Dataset: {metadata.dataset_name}") + if metadata.document_name: + lines.append(f" Document: {metadata.document_name}") + if metadata.score is not None: + lines.append(f" Score: {metadata.score}") + content = _truncate_text(result.content or result.summary or "", config.max_result_content_chars) + if content: + lines.append(f" Content: {content}") + lines.append("") + + return _truncate_text("\n".join(lines).rstrip(), config.max_observation_chars) + + +def _truncate_text(text: str, max_chars: int) -> str: + if len(text) <= max_chars: + return text + if max_chars <= 3: + return text[:max_chars] + return f"{text[: max_chars - 3]}..." + + +__all__ = [ + "BLANK_QUERY_OBSERVATION", + "DifyKnowledgeBaseDeps", + "DifyKnowledgeBaseLayer", + "NO_RESULTS_OBSERVATION", + "QUERY_TOOL_SCHEMA", + "TEMPORARY_UNAVAILABLE_OBSERVATION", +] diff --git a/dify-agent/src/dify_agent/runtime/compositor_factory.py b/dify-agent/src/dify_agent/runtime/compositor_factory.py index 81bfcd48e2..a637dce776 100644 --- a/dify-agent/src/dify_agent/runtime/compositor_factory.py +++ b/dify-agent/src/dify_agent/runtime/compositor_factory.py @@ -4,23 +4,23 @@ Only explicitly allowed provider type ids are constructible here. The default provider set contains prompt layers, the optional pydantic-ai history layer, the state-free Dify structured output layer, the optional Dify ask-human layer, the Dify execution-context layer, the stateful Dify shell layer, and the Dify -plugin business-layer family: +plugin/knowledge business-layer family: - ``dify.drive`` for the inert Skills & Files drive declaration, - ``dify.execution_context`` for shared tenant/user/run daemon context, - ``dify.shell`` for shellctl-backed shell job control, -- ``dify.plugin.llm`` for plugin-backed model selection, and -- ``dify.plugin.tools`` for prepared plugin tool exposure. +- ``dify.plugin.llm`` for plugin-backed model selection, +- ``dify.plugin.tools`` for prepared plugin tool exposure, and +- ``dify.knowledge_base`` for inner-API-backed knowledge search tools. Public DTOs provide Dify context plus plugin/model/tool data, while server-only -plugin daemon settings are injected through the provider factory for -``DifyExecutionContextLayer`` and the optional shellctl entrypoint/auth token plus -client factory plus optional Agent Stub URL/token issuer are injected for -``DifyShellLayer``. The resulting ``Compositor`` -remains Agenton state-only at the snapshot boundary: live resources such as -HTTP clients are injected by runtime-owned providers, may be held on active -layer instances inside ``resource_context()``, and never enter session -snapshots. +plugin daemon settings and Dify API inner settings are injected through provider +factories. Optional shellctl entrypoint/auth token, client factory, and Agent +Stub URL/token issuer are injected for ``DifyShellLayer``. The resulting +``Compositor`` remains Agenton state-only at the snapshot boundary: live +resources such as HTTP clients are injected by runtime-owned providers, may be +held on active layer instances inside ``resource_context()``, and never enter +session snapshots. """ from collections.abc import Mapping, Sequence @@ -41,6 +41,8 @@ from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer from dify_agent.layers.drive.layer import DifyDriveLayer from dify_agent.layers.execution_context.configs import DifyExecutionContextLayerConfig from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer +from dify_agent.layers.knowledge.configs import DifyKnowledgeBaseLayerConfig +from dify_agent.layers.knowledge.layer import DifyKnowledgeBaseLayer from dify_agent.layers.output.output_layer import DifyOutputLayer from dify_agent.layers.shell.configs import DifyShellLayerConfig from dify_agent.layers.shell.layer import DifyShellLayer, create_shellctl_client_factory @@ -52,6 +54,8 @@ def create_default_layer_providers( *, plugin_daemon_url: str = "http://localhost:5002", plugin_daemon_api_key: str = "", + dify_api_inner_url: str = "http://localhost:5001", + dify_api_inner_api_key: str = "", shellctl_entrypoint: str | None = None, shellctl_auth_token: str | None = None, agent_stub_url: str | None = None, @@ -109,6 +113,14 @@ def create_default_layer_providers( ), LayerProvider.from_layer_type(DifyPluginLLMLayer), LayerProvider.from_layer_type(DifyPluginToolsLayer), + LayerProvider.from_factory( + layer_type=DifyKnowledgeBaseLayer, + create=lambda config: DifyKnowledgeBaseLayer.from_config_with_settings( + DifyKnowledgeBaseLayerConfig.model_validate(config), + dify_api_inner_url=dify_api_inner_url, + dify_api_inner_api_key=dify_api_inner_api_key, + ), + ), ) diff --git a/dify-agent/src/dify_agent/runtime/run_scheduler.py b/dify-agent/src/dify_agent/runtime/run_scheduler.py index 9dfc93b846..4186b6afd7 100644 --- a/dify-agent/src/dify_agent/runtime/run_scheduler.py +++ b/dify-agent/src/dify_agent/runtime/run_scheduler.py @@ -69,6 +69,7 @@ class RunScheduler: runner_factory: RunRunnerFactory layer_providers: tuple[LayerProviderInput, ...] plugin_daemon_http_client: httpx.AsyncClient + dify_api_http_client: httpx.AsyncClient _lifecycle_lock: asyncio.Lock def __init__( @@ -76,6 +77,7 @@ class RunScheduler: *, store: RunStore, plugin_daemon_http_client: httpx.AsyncClient, + dify_api_http_client: httpx.AsyncClient, shutdown_grace_seconds: float = 30, layer_providers: tuple[LayerProviderInput, ...] | None = None, runner_factory: RunRunnerFactory | None = None, @@ -85,6 +87,7 @@ class RunScheduler: self.active_tasks = {} self.stopping = False self.plugin_daemon_http_client = plugin_daemon_http_client + self.dify_api_http_client = dify_api_http_client self.layer_providers = layer_providers if layer_providers is not None else create_default_layer_providers() self.runner_factory = runner_factory or self._default_runner_factory self._lifecycle_lock = asyncio.Lock() @@ -141,6 +144,7 @@ class RunScheduler: request=request, run_id=record.run_id, plugin_daemon_http_client=self.plugin_daemon_http_client, + dify_api_http_client=self.dify_api_http_client, layer_providers=self.layer_providers, ) diff --git a/dify-agent/src/dify_agent/runtime/runner.py b/dify-agent/src/dify_agent/runtime/runner.py index 8a6d7b9bd9..d10b1843e9 100644 --- a/dify-agent/src/dify_agent/runtime/runner.py +++ b/dify-agent/src/dify_agent/runtime/runner.py @@ -19,7 +19,9 @@ publishes that deferred request through the normal ``run_succeeded`` event as ``deferred_tool_call`` instead of a final ``output``. Invalid structured outputs or invalid deferred-tool behavior still trigger normal retries/failures before Dify Agent emits success. Layers still never own the FastAPI lifespan-owned -plugin daemon HTTP client. +plugin daemon or Dify API inner HTTP clients. Successful terminal events contain +both the JSON-safe final output or deferred tool call and the session snapshot; +there are no separate output or snapshot events to correlate. """ from collections.abc import AsyncIterable @@ -38,6 +40,7 @@ from agenton.layers.types import PydanticAITool from dify_agent.layers.ask_human.layer import get_ask_human_layer, validate_ask_human_layer_composition from dify_agent.layers.dify_plugin.llm_layer import DifyPluginLLMLayer from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer +from dify_agent.layers.knowledge.layer import DifyKnowledgeBaseLayer from dify_agent.protocol.schemas import ( CreateRunRequest, DIFY_AGENT_MODEL_LAYER_ID, @@ -91,6 +94,7 @@ class AgentRunRunner: run_id: str layer_providers: tuple[LayerProviderInput, ...] plugin_daemon_http_client: httpx.AsyncClient + dify_api_http_client: httpx.AsyncClient def __init__( self, @@ -99,12 +103,14 @@ class AgentRunRunner: request: CreateRunRequest, run_id: str, plugin_daemon_http_client: httpx.AsyncClient, + dify_api_http_client: httpx.AsyncClient, layer_providers: tuple[LayerProviderInput, ...] | None = None, ) -> None: self.sink = sink self.request = request self.run_id = run_id self.plugin_daemon_http_client = plugin_daemon_http_client + self.dify_api_http_client = dify_api_http_client self.layer_providers = layer_providers if layer_providers is not None else create_default_layer_providers() async def run(self) -> None: @@ -187,7 +193,11 @@ class AgentRunRunner: ask_human_layer = get_ask_human_layer(run) llm_layer = run.get_layer(DIFY_AGENT_MODEL_LAYER_ID, DifyPluginLLMLayer) model = llm_layer.get_model(http_client=self.plugin_daemon_http_client) - tools = await _resolve_run_tools(run, http_client=self.plugin_daemon_http_client) + tools = await _resolve_run_tools( + run, + plugin_daemon_http_client=self.plugin_daemon_http_client, + dify_api_http_client=self.dify_api_http_client, + ) except (KeyError, TypeError, RuntimeError, ValueError) as exc: raise AgentRunValidationError(str(exc)) from exc @@ -266,14 +276,17 @@ def _resolve_deferred_tool_results(request: CreateRunRequest) -> DeferredToolRes async def _resolve_run_tools( run: Any, *, - http_client: httpx.AsyncClient, + plugin_daemon_http_client: httpx.AsyncClient, + dify_api_http_client: httpx.AsyncClient, ) -> list[PydanticAITool[object]]: - """Return the static compositor tools plus any Dify plugin runtime tools.""" + """Return the static compositor tools plus any Dify runtime tools.""" resolved_tools = list(cast(list[PydanticAITool[object]], run.tools)) for slot in run.slots.values(): layer = slot.layer if isinstance(layer, DifyPluginToolsLayer): - resolved_tools.extend(await layer.get_tools(http_client=http_client)) + resolved_tools.extend(await layer.get_tools(http_client=plugin_daemon_http_client)) + if isinstance(layer, DifyKnowledgeBaseLayer): + resolved_tools.extend(await layer.get_tools(http_client=dify_api_http_client)) _validate_unique_tool_names(resolved_tools) return resolved_tools diff --git a/dify-agent/src/dify_agent/server/app.py b/dify-agent/src/dify_agent/server/app.py index f4eab601a2..dab6caaca0 100644 --- a/dify-agent/src/dify_agent/server/app.py +++ b/dify-agent/src/dify_agent/server/app.py @@ -1,15 +1,16 @@ """FastAPI application factory for the Dify Agent run server. -The HTTP process owns Redis clients, one shared plugin daemon ``httpx.AsyncClient``, -route wiring, and a process-local scheduler. Run execution happens in background -``asyncio`` tasks rather than request handlers, so client disconnects do not -cancel the agent runtime. Redis persists run records and per-run event streams -with configured retention only; it is not used as a job queue. Agenton layers and -providers stay state-only: they borrow the lifespan-owned plugin daemon client -through the runner and receive shell-layer server settings through provider -construction rather than reading environment variables themselves. The standard -server always mounts the HTTP Agent Stub router and additionally starts the -optional grpclib Agent Stub server when ``DIFY_AGENT_STUB_URL`` uses ``grpc://``. +The HTTP process owns Redis clients plus separate shared ``httpx.AsyncClient`` +instances for plugin-daemon and Dify API inner calls, route wiring, and a +process-local scheduler. Run execution happens in background ``asyncio`` tasks +rather than request handlers, so client disconnects do not cancel the agent +runtime. Redis persists run records and per-run event streams with configured +retention only; it is not used as a job queue. Agenton layers and providers +stay state-only: they borrow the lifespan-owned clients through the runner and +receive shell-layer server settings through provider construction rather than +reading environment variables themselves. The standard server always mounts the +HTTP Agent Stub router and additionally starts the optional grpclib Agent Stub +server when ``DIFY_AGENT_STUB_URL`` uses ``grpc://``. """ from collections.abc import AsyncGenerator @@ -39,6 +40,8 @@ def create_app(settings: ServerSettings | None = None) -> FastAPI: layer_providers = create_default_layer_providers( plugin_daemon_url=resolved_settings.plugin_daemon_url, plugin_daemon_api_key=resolved_settings.plugin_daemon_api_key, + dify_api_inner_url=resolved_settings.dify_api_inner_url, + dify_api_inner_api_key=resolved_settings.dify_api_inner_api_key or "", shellctl_entrypoint=resolved_settings.shellctl_entrypoint, shellctl_auth_token=resolved_settings.shellctl_auth_token, agent_stub_url=resolved_settings.agent_stub_url, @@ -53,6 +56,7 @@ def create_app(settings: ServerSettings | None = None) -> FastAPI: async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: redis = Redis.from_url(resolved_settings.redis_url) plugin_daemon_http_client = create_plugin_daemon_http_client(resolved_settings) + dify_api_inner_http_client = create_dify_api_inner_http_client(resolved_settings) store = RedisRunStore( redis, prefix=resolved_settings.redis_prefix, @@ -61,6 +65,7 @@ def create_app(settings: ServerSettings | None = None) -> FastAPI: scheduler = RunScheduler( store=store, plugin_daemon_http_client=plugin_daemon_http_client, + dify_api_http_client=dify_api_inner_http_client, shutdown_grace_seconds=resolved_settings.shutdown_grace_seconds, layer_providers=layer_providers, ) @@ -83,6 +88,7 @@ def create_app(settings: ServerSettings | None = None) -> FastAPI: if grpc_server is not None: await grpc_server.aclose() await scheduler.shutdown() + await dify_api_inner_http_client.aclose() await plugin_daemon_http_client.aclose() await redis.aclose() @@ -112,17 +118,33 @@ def create_plugin_daemon_http_client(settings: ServerSettings) -> httpx.AsyncCli process and must be closed by the app lifespan after the scheduler has stopped using it. """ + return _create_shared_http_client(settings) + + +def create_dify_api_inner_http_client(settings: ServerSettings) -> httpx.AsyncClient: + """Create the lifespan-owned Dify API inner HTTP client. + + The Dify API inner client intentionally shares the generic outbound HTTP + timeout and connection-pool settings with the plugin daemon client so + operational tuning stays in one place while endpoint URL/API keys remain + distinct server settings. + """ + return _create_shared_http_client(settings) + + +def _create_shared_http_client(settings: ServerSettings) -> httpx.AsyncClient: + """Build one shared HTTP client using generic outbound timeout/pool settings.""" return httpx.AsyncClient( timeout=httpx.Timeout( - connect=settings.plugin_daemon_connect_timeout, - read=settings.plugin_daemon_read_timeout, - write=settings.plugin_daemon_write_timeout, - pool=settings.plugin_daemon_pool_timeout, + connect=settings.outbound_http_connect_timeout, + read=settings.outbound_http_read_timeout, + write=settings.outbound_http_write_timeout, + pool=settings.outbound_http_pool_timeout, ), limits=httpx.Limits( - max_connections=settings.plugin_daemon_max_connections, - max_keepalive_connections=settings.plugin_daemon_max_keepalive_connections, - keepalive_expiry=settings.plugin_daemon_keepalive_expiry, + max_connections=settings.outbound_http_max_connections, + max_keepalive_connections=settings.outbound_http_max_keepalive_connections, + keepalive_expiry=settings.outbound_http_keepalive_expiry, ), trust_env=False, ) @@ -131,4 +153,4 @@ def create_plugin_daemon_http_client(settings: ServerSettings) -> httpx.AsyncCli app = create_app() -__all__ = ["app", "create_app", "create_plugin_daemon_http_client"] +__all__ = ["app", "create_app", "create_dify_api_inner_http_client", "create_plugin_daemon_http_client"] diff --git a/dify-agent/src/dify_agent/server/settings.py b/dify-agent/src/dify_agent/server/settings.py index 7c24fbb9f9..2b4aff62e5 100644 --- a/dify-agent/src/dify_agent/server/settings.py +++ b/dify-agent/src/dify_agent/server/settings.py @@ -1,12 +1,13 @@ """Configuration for the FastAPI run server. -Plugin daemon HTTP client settings describe the single FastAPI lifespan-owned -``httpx.AsyncClient`` shared by local run tasks. Layers and Agenton providers do -not own that client, so these settings are process resource limits rather than -per-run lifecycle knobs. The Agent Stub now also uses this main server settings -model directly: the public Agent Stub URL, server secret, optional gRPC bind -override, and optional Dify inner API file-request settings all live here under -the longstanding ``DIFY_AGENT_...`` environment-variable namespace. +Outbound HTTP client settings describe the FastAPI lifespan-owned +``httpx.AsyncClient`` instances shared by local run tasks for plugin-daemon and +Dify API inner calls. Layers and Agenton providers do not own those clients, so +these settings are process resource limits rather than per-run lifecycle knobs. +Endpoint URLs and API keys stay service-specific. The Agent Stub also uses this +settings model directly: the public Agent Stub URL, server secret, optional gRPC +bind override, and optional Dify inner API file-request settings all live here +under the longstanding ``DIFY_AGENT_...`` environment-variable namespace. """ from typing import ClassVar @@ -23,7 +24,7 @@ DEFAULT_RUN_RETENTION_SECONDS = 3 * 24 * 60 * 60 class ServerSettings(BaseSettings): - """Environment-backed settings for Redis, scheduling, plugin, and shell access.""" + """Environment-backed settings for Redis, scheduling, outbound HTTP, and shell access.""" redis_url: str = "redis://localhost:6379/0" redis_prefix: str = "dify-agent" @@ -31,6 +32,7 @@ class ServerSettings(BaseSettings): run_retention_seconds: int = Field(default=DEFAULT_RUN_RETENTION_SECONDS, ge=1) plugin_daemon_url: str = "http://localhost:5002" plugin_daemon_api_key: str = "" + dify_api_inner_url: str = "http://localhost:5001" dify_api_base_url: str | None = None dify_api_inner_api_key: str | None = None shellctl_entrypoint: str | None = None @@ -38,13 +40,13 @@ class ServerSettings(BaseSettings): agent_stub_url: str | None = Field(default=None, validation_alias="DIFY_AGENT_STUB_URL") agent_stub_grpc_bind_address: str | None = Field(default=None, validation_alias="DIFY_AGENT_STUB_GRPC_BIND_ADDRESS") server_secret_key: str | None = None - plugin_daemon_connect_timeout: float = Field(default=10.0, ge=0) - plugin_daemon_read_timeout: float = Field(default=600.0, ge=0) - plugin_daemon_write_timeout: float = Field(default=30.0, ge=0) - plugin_daemon_pool_timeout: float = Field(default=10.0, ge=0) - plugin_daemon_max_connections: int = Field(default=100, ge=1) - plugin_daemon_max_keepalive_connections: int = Field(default=20, ge=0) - plugin_daemon_keepalive_expiry: float = Field(default=30.0, ge=0) + outbound_http_connect_timeout: float = Field(default=10.0, ge=0) + outbound_http_read_timeout: float = Field(default=600.0, ge=0) + outbound_http_write_timeout: float = Field(default=30.0, ge=0) + outbound_http_pool_timeout: float = Field(default=10.0, ge=0) + outbound_http_max_connections: int = Field(default=100, ge=1) + outbound_http_max_keepalive_connections: int = Field(default=20, ge=0) + outbound_http_keepalive_expiry: float = Field(default=30.0, ge=0) model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict( env_prefix="DIFY_AGENT_", @@ -116,7 +118,7 @@ class ServerSettings(BaseSettings): @model_validator(mode="after") def validate_agent_stub_requirements(self) -> "ServerSettings": - """Require the server secret and Dify API file settings in valid pairs.""" + """Require Agent Stub settings while allowing knowledge-only inner API keys.""" if self.agent_stub_url is not None and self.server_secret_key is None: raise ValueError("DIFY_AGENT_SERVER_SECRET_KEY is required when DIFY_AGENT_STUB_URL is set.") if self.agent_stub_grpc_bind_address is not None: @@ -124,8 +126,8 @@ class ServerSettings(BaseSettings): raise ValueError("DIFY_AGENT_STUB_URL is required when DIFY_AGENT_STUB_GRPC_BIND_ADDRESS is set.") if not parse_agent_stub_endpoint(self.agent_stub_url).is_grpc: raise ValueError("DIFY_AGENT_STUB_GRPC_BIND_ADDRESS requires a grpc:// DIFY_AGENT_STUB_URL.") - if (self.dify_api_base_url is None) != (self.dify_api_inner_api_key is None): - raise ValueError("DIFY_AGENT_DIFY_API_BASE_URL and DIFY_AGENT_DIFY_API_INNER_API_KEY must be set together.") + if self.dify_api_base_url is not None and self.dify_api_inner_api_key is None: + raise ValueError("DIFY_AGENT_DIFY_API_INNER_API_KEY is required when DIFY_AGENT_DIFY_API_BASE_URL is set.") return self def create_agent_stub_token_codec(self) -> AgentStubTokenCodec | None: diff --git a/dify-agent/tests/local/dify_agent/layers/knowledge/__init__.py b/dify-agent/tests/local/dify_agent/layers/knowledge/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dify-agent/tests/local/dify_agent/layers/knowledge/test_client.py b/dify-agent/tests/local/dify_agent/layers/knowledge/test_client.py new file mode 100644 index 0000000000..9e2ca5462f --- /dev/null +++ b/dify-agent/tests/local/dify_agent/layers/knowledge/test_client.py @@ -0,0 +1,248 @@ +import json +from unittest.mock import AsyncMock + +import httpx +import pytest + +from dify_agent.layers.knowledge.client import DifyKnowledgeBaseClient, DifyKnowledgeBaseClientError +from dify_agent.layers.knowledge.configs import ( + DifyKnowledgeMetadataFilteringConfig, + DifyKnowledgeRetrievalConfig, +) + + +def _retrieval_config() -> DifyKnowledgeRetrievalConfig: + return DifyKnowledgeRetrievalConfig(mode="multiple", top_k=4, score_threshold=0.2) + + +def _metadata_filtering() -> DifyKnowledgeMetadataFilteringConfig: + return DifyKnowledgeMetadataFilteringConfig(mode="disabled") + + +def test_knowledge_client_posts_inner_api_request_with_static_controls() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert str(request.url) == "http://dify-api/inner/api/knowledge/retrieve" + assert request.headers["X-Inner-Api-Key"] == "inner-secret" + payload = json.loads(request.content.decode("utf-8")) + assert payload == { + "caller": { + "tenant_id": "tenant-1", + "user_id": "user-1", + "app_id": "app-1", + "user_from": "account", + "invoke_from": "agent_app", + }, + "dataset_ids": ["dataset-1"], + "query": "reset password", + "retrieval": { + "mode": "multiple", + "top_k": 4, + "score_threshold": 0.2, + "reranking_mode": "reranking_model", + "reranking_enable": True, + "reranking_model": None, + "weights": None, + }, + "metadata_filtering": {"mode": "disabled"}, + "attachment_ids": [], + } + return httpx.Response( + 200, + json={ + "results": [ + { + "metadata": { + "_source": "knowledge", + "dataset_name": "Docs", + "document_name": "FAQ.md", + "score": 0.9, + }, + "title": "FAQ", + "files": [], + "content": "Use the reset link.", + "summary": None, + } + ], + "usage": {}, + }, + ) + + async def scenario() -> None: + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http_client: + client = DifyKnowledgeBaseClient( + base_url="http://dify-api", + api_key="inner-secret", + http_client=http_client, + ) + response = await client.retrieve( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="account", + invoke_from="agent_app", + dataset_ids=["dataset-1"], + query="reset password", + retrieval=_retrieval_config(), + metadata_filtering=_metadata_filtering(), + ) + assert response.results[0].metadata.dataset_name == "Docs" + + import asyncio + + asyncio.run(scenario()) + + +def test_knowledge_client_marks_retryable_http_failures() -> None: + async def scenario() -> None: + async with httpx.AsyncClient( + transport=httpx.MockTransport( + lambda _request: httpx.Response( + 502, json={"code": "external_knowledge_failed", "message": "bad gateway"} + ) + ) + ) as http_client: + client = DifyKnowledgeBaseClient( + base_url="http://dify-api", api_key="inner-secret", http_client=http_client + ) + with pytest.raises(DifyKnowledgeBaseClientError) as exc_info: + _ = await client.retrieve( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="account", + invoke_from="agent_app", + dataset_ids=["dataset-1"], + query="reset password", + retrieval=_retrieval_config(), + metadata_filtering=_metadata_filtering(), + ) + assert exc_info.value.status_code == 502 + assert exc_info.value.error_code == "external_knowledge_failed" + assert exc_info.value.retryable is True + + import asyncio + + asyncio.run(scenario()) + + +def test_knowledge_client_marks_non_retryable_http_failures() -> None: + async def scenario() -> None: + async with httpx.AsyncClient( + transport=httpx.MockTransport( + lambda _request: httpx.Response( + 403, + json={"code": "dataset_tenant_mismatch", "message": "forbidden"}, + ) + ) + ) as http_client: + client = DifyKnowledgeBaseClient( + base_url="http://dify-api", api_key="inner-secret", http_client=http_client + ) + with pytest.raises(DifyKnowledgeBaseClientError) as exc_info: + _ = await client.retrieve( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="account", + invoke_from="agent_app", + dataset_ids=["dataset-1"], + query="reset password", + retrieval=_retrieval_config(), + metadata_filtering=_metadata_filtering(), + ) + assert exc_info.value.status_code == 403 + assert exc_info.value.error_code == "dataset_tenant_mismatch" + assert exc_info.value.retryable is False + + import asyncio + + asyncio.run(scenario()) + + +def test_knowledge_client_rejects_malformed_success_response() -> None: + async def scenario() -> None: + async with httpx.AsyncClient( + transport=httpx.MockTransport(lambda _request: httpx.Response(200, json={"bad": []})) + ) as http_client: + client = DifyKnowledgeBaseClient( + base_url="http://dify-api", api_key="inner-secret", http_client=http_client + ) + with pytest.raises(DifyKnowledgeBaseClientError) as exc_info: + _ = await client.retrieve( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="account", + invoke_from="agent_app", + dataset_ids=["dataset-1"], + query="reset password", + retrieval=_retrieval_config(), + metadata_filtering=_metadata_filtering(), + ) + assert exc_info.value.error_code == "invalid_response" + assert exc_info.value.retryable is False + + import asyncio + + asyncio.run(scenario()) + + +@pytest.mark.parametrize( + "error_factory", + [ + lambda request: httpx.ReadTimeout("timed out", request=request), + lambda request: httpx.ConnectError("connection failed", request=request), + ], +) +def test_knowledge_client_marks_transport_failures_retryable(error_factory) -> None: + def handler(request: httpx.Request) -> httpx.Response: + raise error_factory(request) + + async def scenario() -> None: + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http_client: + client = DifyKnowledgeBaseClient( + base_url="http://dify-api", api_key="inner-secret", http_client=http_client + ) + with pytest.raises(DifyKnowledgeBaseClientError) as exc_info: + _ = await client.retrieve( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="account", + invoke_from="agent_app", + dataset_ids=["dataset-1"], + query="reset password", + retrieval=_retrieval_config(), + metadata_filtering=_metadata_filtering(), + ) + assert exc_info.value.retryable is True + + import asyncio + + asyncio.run(scenario()) + + +def test_knowledge_client_treats_invalid_url_errors_as_non_retryable_configuration_error() -> None: + async def scenario() -> None: + async with httpx.AsyncClient() as http_client: + http_client.post = AsyncMock(side_effect=httpx.UnsupportedProtocol("unsupported protocol")) + client = DifyKnowledgeBaseClient( + base_url="http://dify-api", api_key="inner-secret", http_client=http_client + ) + with pytest.raises(DifyKnowledgeBaseClientError) as exc_info: + _ = await client.retrieve( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="account", + invoke_from="agent_app", + dataset_ids=["dataset-1"], + query="reset password", + retrieval=_retrieval_config(), + metadata_filtering=_metadata_filtering(), + ) + assert exc_info.value.retryable is False + + import asyncio + + asyncio.run(scenario()) diff --git a/dify-agent/tests/local/dify_agent/layers/knowledge/test_configs.py b/dify-agent/tests/local/dify_agent/layers/knowledge/test_configs.py new file mode 100644 index 0000000000..f28939e329 --- /dev/null +++ b/dify-agent/tests/local/dify_agent/layers/knowledge/test_configs.py @@ -0,0 +1,65 @@ +import pytest +from pydantic import ValidationError + +from dify_agent.layers.knowledge import DifyKnowledgeBaseLayerConfig + + +def _valid_config() -> dict[str, object]: + return { + "dataset_ids": ["dataset-1"], + "retrieval": { + "mode": "multiple", + "top_k": 4, + }, + } + + +def test_knowledge_base_config_accepts_valid_multiple_mode() -> None: + config = DifyKnowledgeBaseLayerConfig.model_validate(_valid_config()) + + assert config.dataset_ids == ["dataset-1"] + assert config.retrieval.top_k == 4 + assert config.metadata_filtering.mode == "disabled" + + +@pytest.mark.parametrize( + "payload, expected_message", + [ + ({"dataset_ids": [], "retrieval": {"mode": "multiple", "top_k": 4}}, "dataset_ids"), + ({"tool_name": "knowledge_base_search", **_valid_config()}, "Extra inputs are not permitted"), + ({"tool_description": "Search knowledge", **_valid_config()}, "Extra inputs are not permitted"), + ({"dataset_ids": ["dataset-1"], "retrieval": {"mode": "multiple"}}, "top_k"), + ({"dataset_ids": ["dataset-1"], "retrieval": {"mode": "single"}}, "retrieval.model"), + ( + { + "dataset_ids": ["dataset-1"], + "retrieval": {"mode": "multiple", "top_k": 4}, + "metadata_filtering": {"mode": "automatic"}, + }, + "metadata_filtering.model_config", + ), + ( + { + "dataset_ids": ["dataset-1"], + "retrieval": {"mode": "multiple", "top_k": 4}, + "metadata_filtering": {"mode": "manual"}, + }, + "metadata_filtering.conditions", + ), + ], +) +def test_knowledge_base_config_rejects_invalid_inputs(payload: dict[str, object], expected_message: str) -> None: + with pytest.raises(ValidationError, match=expected_message): + _ = DifyKnowledgeBaseLayerConfig.model_validate(payload) + + +def test_knowledge_base_config_rejects_observation_limit_smaller_than_result_limit() -> None: + with pytest.raises(ValidationError, match="max_observation_chars"): + _ = DifyKnowledgeBaseLayerConfig.model_validate( + { + "dataset_ids": ["dataset-1"], + "retrieval": {"mode": "multiple", "top_k": 4}, + "max_result_content_chars": 50, + "max_observation_chars": 20, + } + ) diff --git a/dify-agent/tests/local/dify_agent/layers/knowledge/test_layer.py b/dify-agent/tests/local/dify_agent/layers/knowledge/test_layer.py new file mode 100644 index 0000000000..5db74d6f45 --- /dev/null +++ b/dify-agent/tests/local/dify_agent/layers/knowledge/test_layer.py @@ -0,0 +1,417 @@ +import asyncio +import json + +import httpx +import pytest +from pydantic_ai import Tool + +from agenton.compositor import Compositor, LayerNode, LayerProvider +from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig +from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer +from dify_agent.layers.knowledge.client import DifyKnowledgeBaseClientError +from dify_agent.layers.knowledge.configs import DifyKnowledgeBaseLayerConfig +from dify_agent.layers.knowledge.layer import ( + BLANK_QUERY_OBSERVATION, + DifyKnowledgeBaseLayer, + NO_RESULTS_OBSERVATION, + TEMPORARY_UNAVAILABLE_OBSERVATION, +) + + +def _execution_context_config(**overrides: object) -> DifyExecutionContextLayerConfig: + payload: dict[str, object] = { + "tenant_id": "tenant-1", + "user_id": "user-1", + "user_from": "account", + "app_id": "app-1", + "agent_mode": "agent_app", + "invoke_from": "web-app", + } + payload.update(overrides) + return DifyExecutionContextLayerConfig.model_validate(payload) + + +def _knowledge_config(**overrides: object) -> DifyKnowledgeBaseLayerConfig: + payload: dict[str, object] = { + "dataset_ids": ["dataset-1"], + "retrieval": {"mode": "multiple", "top_k": 4}, + } + payload.update(overrides) + return DifyKnowledgeBaseLayerConfig.model_validate(payload) + + +def _execution_context_provider() -> LayerProvider[DifyExecutionContextLayer]: + return LayerProvider.from_factory( + layer_type=DifyExecutionContextLayer, + create=lambda config: DifyExecutionContextLayer.from_config_with_settings( + DifyExecutionContextLayerConfig.model_validate(config), + daemon_url="http://plugin-daemon", + daemon_api_key="daemon-secret", + ), + ) + + +def _knowledge_provider() -> LayerProvider[DifyKnowledgeBaseLayer]: + return LayerProvider.from_factory( + layer_type=DifyKnowledgeBaseLayer, + create=lambda config: DifyKnowledgeBaseLayer.from_config_with_settings( + DifyKnowledgeBaseLayerConfig.model_validate(config), + dify_api_inner_url="http://dify-api", + dify_api_inner_api_key="inner-secret", + ), + ) + + +def test_knowledge_layer_exposes_one_query_only_tool_definition() -> None: + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient() as http_client: + async with compositor.enter( + configs={ + "execution_context": _execution_context_config(), + "knowledge": _knowledge_config(), + } + ) as run: + knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer) + tool = (await knowledge_layer.get_tools(http_client=http_client))[0] + tool_def = await tool.prepare_tool_def(None) # pyright: ignore[reportArgumentType] + assert isinstance(tool, Tool) + assert tool.name == "knowledge_base_search" + assert tool.description == "Search configured knowledge bases for information relevant to the query." + assert tool_def is not None + assert ( + tool_def.description == "Search configured knowledge bases for information relevant to the query." + ) + assert tool_def.parameters_json_schema == { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query for the configured knowledge bases.", + } + }, + "required": ["query"], + "additionalProperties": False, + } + + asyncio.run(scenario()) + + +def test_knowledge_layer_rejects_blank_query_locally() -> None: + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient() as http_client: + async with compositor.enter( + configs={ + "execution_context": _execution_context_config(), + "knowledge": _knowledge_config(), + } + ) as run: + knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer) + tool = (await knowledge_layer.get_tools(http_client=http_client))[0] + result = await tool.function_schema.call({"query": " "}, None) # pyright: ignore[reportArgumentType] + assert result == BLANK_QUERY_OBSERVATION + + asyncio.run(scenario()) + + +@pytest.mark.parametrize( + ("field_name", "field_value"), + [ + ("user_id", None), + ("user_from", None), + ("app_id", None), + ], +) +def test_knowledge_layer_fails_fast_when_execution_context_is_missing_required_fields( + field_name: str, + field_value: object, +) -> None: + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient() as http_client: + async with compositor.enter( + configs={ + "execution_context": _execution_context_config(), + "knowledge": _knowledge_config(), + } + ) as run: + execution_context_layer = run.get_layer("execution_context", DifyExecutionContextLayer) + setattr(execution_context_layer.config, field_name, field_value) + knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer) + with pytest.raises(ValueError, match=field_name): + _ = await knowledge_layer.get_tools(http_client=http_client) + + asyncio.run(scenario()) + + +def test_knowledge_layer_formats_results_and_truncates_observation() -> None: + def handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={ + "results": [ + { + "metadata": { + "_source": "knowledge", + "dataset_name": "Docs", + "document_name": "Guide.md", + "score": 0.9, + }, + "title": "Guide", + "files": [], + "content": "ABCDEFGHIJKL", + "summary": None, + } + ], + "usage": {}, + }, + ) + + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http_client: + async with compositor.enter( + configs={ + "execution_context": _execution_context_config(), + "knowledge": _knowledge_config(max_result_content_chars=8, max_observation_chars=160), + } + ) as run: + knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer) + tool = (await knowledge_layer.get_tools(http_client=http_client))[0] + result = await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType] + assert result.startswith("Knowledge base search results:\n1. Title: Guide") + assert "Dataset: Docs" in result + assert "Document: Guide.md" in result + assert "Score: 0.9" in result + assert "Content: ABCDE..." in result + assert len(result) <= 160 + + asyncio.run(scenario()) + + +def test_knowledge_layer_returns_no_results_observation() -> None: + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient( + transport=httpx.MockTransport(lambda _request: httpx.Response(200, json={"results": [], "usage": {}})) + ) as http_client: + async with compositor.enter( + configs={ + "execution_context": _execution_context_config(), + "knowledge": _knowledge_config(), + } + ) as run: + knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer) + tool = (await knowledge_layer.get_tools(http_client=http_client))[0] + result = await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType] + assert result == NO_RESULTS_OBSERVATION + + asyncio.run(scenario()) + + +def test_knowledge_layer_converts_retryable_failures_into_observation() -> None: + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient( + transport=httpx.MockTransport( + lambda _request: httpx.Response(429, json={"code": "knowledge_rate_limited", "message": "slow down"}) + ) + ) as http_client: + async with compositor.enter( + configs={ + "execution_context": _execution_context_config(), + "knowledge": _knowledge_config(), + } + ) as run: + knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer) + tool = (await knowledge_layer.get_tools(http_client=http_client))[0] + result = await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType] + assert result == TEMPORARY_UNAVAILABLE_OBSERVATION + + asyncio.run(scenario()) + + +@pytest.mark.parametrize( + "transport_error", + [ + lambda request: httpx.ReadTimeout("timed out", request=request), + lambda request: httpx.ConnectError("connection failed", request=request), + ], +) +def test_knowledge_layer_converts_retryable_transport_failures_into_observation(transport_error) -> None: + def handler(request: httpx.Request) -> httpx.Response: + raise transport_error(request) + + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http_client: + async with compositor.enter( + configs={ + "execution_context": _execution_context_config(), + "knowledge": _knowledge_config(), + } + ) as run: + knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer) + tool = (await knowledge_layer.get_tools(http_client=http_client))[0] + result = await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType] + assert result == TEMPORARY_UNAVAILABLE_OBSERVATION + + asyncio.run(scenario()) + + +def test_knowledge_layer_raises_non_retryable_client_errors() -> None: + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient( + transport=httpx.MockTransport( + lambda _request: httpx.Response(403, json={"code": "dataset_tenant_mismatch", "message": "forbidden"}) + ) + ) as http_client: + async with compositor.enter( + configs={ + "execution_context": _execution_context_config(), + "knowledge": _knowledge_config(), + } + ) as run: + knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer) + tool = (await knowledge_layer.get_tools(http_client=http_client))[0] + with pytest.raises(DifyKnowledgeBaseClientError) as exc_info: + await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType] + assert exc_info.value.status_code == 403 + + asyncio.run(scenario()) + + +def test_knowledge_layer_raises_for_malformed_success_responses() -> None: + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient( + transport=httpx.MockTransport(lambda _request: httpx.Response(200, json={"bad": []})) + ) as http_client: + async with compositor.enter( + configs={ + "execution_context": _execution_context_config(), + "knowledge": _knowledge_config(), + } + ) as run: + knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer) + tool = (await knowledge_layer.get_tools(http_client=http_client))[0] + with pytest.raises(DifyKnowledgeBaseClientError) as exc_info: + await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType] + assert exc_info.value.error_code == "invalid_response" + assert exc_info.value.retryable is False + + asyncio.run(scenario()) + + +def test_knowledge_layer_sends_execution_context_and_static_config_to_inner_api() -> None: + def handler(request: httpx.Request) -> httpx.Response: + payload = json.loads(request.content.decode("utf-8")) + assert request.headers["X-Inner-Api-Key"] == "inner-secret" + assert payload["caller"] == { + "tenant_id": "tenant-1", + "user_id": "user-1", + "app_id": "app-1", + "user_from": "account", + "invoke_from": "web-app", + } + assert payload["dataset_ids"] == ["dataset-1", "dataset-2"] + assert payload["query"] == "reset" + assert payload["retrieval"]["top_k"] == 2 + assert payload["metadata_filtering"] == { + "mode": "manual", + "conditions": { + "logical_operator": "and", + "conditions": [ + { + "name": "category", + "comparison_operator": "contains", + "value": "auth", + } + ], + }, + } + return httpx.Response(200, json={"results": [], "usage": {}}) + + async def scenario() -> None: + compositor = Compositor( + [ + LayerNode("execution_context", _execution_context_provider()), + LayerNode("knowledge", _knowledge_provider(), deps={"execution_context": "execution_context"}), + ] + ) + async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http_client: + async with compositor.enter( + configs={ + "execution_context": _execution_context_config(), + "knowledge": _knowledge_config( + dataset_ids=["dataset-1", "dataset-2"], + retrieval={"mode": "multiple", "top_k": 2}, + metadata_filtering={ + "mode": "manual", + "conditions": { + "logical_operator": "and", + "conditions": [ + { + "name": "category", + "comparison_operator": "contains", + "value": "auth", + } + ], + }, + }, + ), + } + ) as run: + knowledge_layer = run.get_layer("knowledge", DifyKnowledgeBaseLayer) + tool = (await knowledge_layer.get_tools(http_client=http_client))[0] + result = await tool.function_schema.call({"query": "reset"}, None) # pyright: ignore[reportArgumentType] + assert result == NO_RESULTS_OBSERVATION + + asyncio.run(scenario()) diff --git a/dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py b/dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py index a4a5ad8429..e1124560ac 100644 --- a/dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py +++ b/dify-agent/tests/local/dify_agent/runtime/test_run_scheduler.py @@ -120,6 +120,7 @@ def test_create_run_starts_background_task_and_returns_running() -> None: scheduler = RunScheduler( store=store, plugin_daemon_http_client=client, + dify_api_http_client=client, runner_factory=lambda _record, _request: ControlledRunner(started=started, release=release), ) @@ -144,6 +145,7 @@ def test_shutdown_marks_unfinished_runs_failed_and_appends_event() -> None: scheduler = RunScheduler( store=store, plugin_daemon_http_client=client, + dify_api_http_client=client, shutdown_grace_seconds=0, runner_factory=lambda _record, _request: ControlledRunner(started=started, release=asyncio.Event()), ) @@ -165,7 +167,7 @@ def test_create_run_accepts_blank_prompt_and_runner_fails_asynchronously() -> No async def scenario() -> None: store = FakeStore() async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) + scheduler = RunScheduler(store=store, plugin_daemon_http_client=client, dify_api_http_client=client) record = await scheduler.create_run(_request(["", " "])) await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1) @@ -182,7 +184,7 @@ def test_create_run_accepts_invalid_output_schema_and_runner_fails_asynchronousl async def scenario() -> None: store = FakeStore() async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) + scheduler = RunScheduler(store=store, plugin_daemon_http_client=client, dify_api_http_client=client) record = await scheduler.create_run( _request( @@ -205,7 +207,12 @@ def test_create_run_honors_explicit_empty_layer_providers_by_failing_after_persi async def scenario() -> None: store = FakeStore() async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=store, plugin_daemon_http_client=client, layer_providers=()) + scheduler = RunScheduler( + store=store, + plugin_daemon_http_client=client, + dify_api_http_client=client, + layer_providers=(), + ) record = await scheduler.create_run(_request()) await asyncio.wait_for(scheduler.active_tasks[record.run_id], timeout=1) @@ -222,7 +229,7 @@ def test_create_run_accepts_closed_session_snapshot_and_runner_fails_asynchronou async def scenario() -> None: store = FakeStore() async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) + scheduler = RunScheduler(store=store, plugin_daemon_http_client=client, dify_api_http_client=client) request = _request() request.session_snapshot = CompositorSessionSnapshot( layers=[ @@ -248,7 +255,7 @@ def test_create_run_accepts_closed_session_snapshot_and_runner_fails_asynchronou def test_create_run_rejects_after_shutdown_starts() -> None: async def scenario() -> None: async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=FakeStore(), plugin_daemon_http_client=client) + scheduler = RunScheduler(store=FakeStore(), plugin_daemon_http_client=client, dify_api_http_client=client) await scheduler.shutdown() with pytest.raises(SchedulerStoppingError): @@ -261,7 +268,7 @@ def test_create_run_rejects_invalid_request_after_shutdown_without_persisting() async def scenario() -> None: store = FakeStore() async with httpx.AsyncClient() as client: - scheduler = RunScheduler(store=store, plugin_daemon_http_client=client) + scheduler = RunScheduler(store=store, plugin_daemon_http_client=client, dify_api_http_client=client) await scheduler.shutdown() with pytest.raises(SchedulerStoppingError): @@ -282,6 +289,7 @@ def test_shutdown_waits_for_in_flight_create_to_register_before_cancelling() -> scheduler = RunScheduler( store=store, plugin_daemon_http_client=client, + dify_api_http_client=client, shutdown_grace_seconds=0, runner_factory=lambda _record, _request: ControlledRunner( started=runner_started, release=asyncio.Event() diff --git a/dify-agent/tests/local/dify_agent/runtime/test_runner.py b/dify-agent/tests/local/dify_agent/runtime/test_runner.py index c910b7c3dd..f5ddeb7236 100644 --- a/dify-agent/tests/local/dify_agent/runtime/test_runner.py +++ b/dify-agent/tests/local/dify_agent/runtime/test_runner.py @@ -41,6 +41,8 @@ from dify_agent.layers.dify_plugin.configs import ( ) from dify_agent.layers.dify_plugin.llm_layer import DifyPluginLLMLayer from dify_agent.layers.dify_plugin.tools_layer import DifyPluginToolsLayer +from dify_agent.layers.knowledge.configs import DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID, DifyKnowledgeBaseLayerConfig +from dify_agent.layers.knowledge.layer import DifyKnowledgeBaseLayer from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig from dify_agent.protocol import DIFY_AGENT_HISTORY_LAYER_ID, DIFY_AGENT_MODEL_LAYER_ID, DIFY_AGENT_OUTPUT_LAYER_ID from dify_agent.protocol.schemas import ( @@ -357,6 +359,7 @@ def test_runner_emits_terminal_success_and_snapshot(monkeypatch: pytest.MonkeyPa request=request, run_id="run-1", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() assert seen_clients == [client] assert client.is_closed is False @@ -406,6 +409,7 @@ def test_runner_preserves_explicit_json_null_output(monkeypatch: pytest.MonkeyPa request=request, run_id="run-null-output", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -462,6 +466,7 @@ def test_runner_emits_deferred_tool_call_and_persists_pending_history(monkeypatc request=request, run_id="run-ask-human", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -558,6 +563,7 @@ def test_runner_resumes_with_deferred_tool_results_and_no_user_prompt(monkeypatc request=request, run_id="run-ask-human-initial", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() initial_terminal = sink.events["run-ask-human-initial"][-1] @@ -582,6 +588,7 @@ def test_runner_resumes_with_deferred_tool_results_and_no_user_prompt(monkeypatc request=resumed_request, run_id="run-ask-human-resume", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -657,6 +664,7 @@ def test_runner_can_emit_second_deferred_tool_call_after_resume(monkeypatch: pyt request=request, run_id="run-ask-human-turn-1", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() first_terminal = sink.events["run-ask-human-turn-1"][-1] @@ -681,6 +689,7 @@ def test_runner_can_emit_second_deferred_tool_call_after_resume(monkeypatch: pyt request=resumed_request, run_id="run-ask-human-turn-2", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -736,6 +745,7 @@ def test_runner_rejects_deferred_tool_call_without_history_layer(monkeypatch: py request=request, run_id="run-ask-human-no-history", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -785,6 +795,7 @@ def test_runner_rejects_resume_with_deferred_tool_results_without_history_layer( request=request, run_id="run-ask-human-resume-no-history", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -823,6 +834,7 @@ def test_runner_rejects_multiple_deferred_tool_calls(monkeypatch: pytest.MonkeyP request=request, run_id="run-ask-human-multi", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -861,6 +873,7 @@ def test_runner_rejects_deferred_approval_requests(monkeypatch: pytest.MonkeyPat request=request, run_id="run-ask-human-approval", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -960,6 +973,7 @@ def test_runner_passes_dynamic_dify_plugin_tools_to_agent(monkeypatch: pytest.Mo request=request, run_id="run-tools", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -970,6 +984,105 @@ def test_runner_passes_dynamic_dify_plugin_tools_to_agent(monkeypatch: pytest.Mo assert terminal.data.output == "done" +def test_runner_passes_dynamic_dify_knowledge_tools_to_agent(monkeypatch: pytest.MonkeyPatch) -> None: + seen_tools: list[Tool[object]] = [] + + async def knowledge_tool() -> str: + return "knowledge" + + def fake_get_model(_self: DifyPluginLLMLayer, *, http_client: httpx.AsyncClient): + assert http_client.is_closed is False + return TestModel(custom_output_text="done") # pyright: ignore[reportReturnType] + + async def fake_get_tools(self: DifyKnowledgeBaseLayer, *, http_client: httpx.AsyncClient) -> list[Tool[object]]: + assert self.config.dataset_ids == ["dataset-1"] + assert http_client.headers.get("X-Test-Client") == "dify-api" + return [Tool(knowledge_tool, name="knowledge_base_search")] + + class FakeResult: + output: str = "done" + + def new_messages(self) -> list[ModelMessage]: + return [] + + class FakeAgent: + async def run(self, *_args: object, **_kwargs: object) -> FakeResult: + return FakeResult() + + def fake_create_agent(model: object, *, tools: list[Tool[object]], output_type: object) -> FakeAgent: + del model, output_type + seen_tools.extend(tools) + return FakeAgent() + + monkeypatch.setattr(DifyPluginLLMLayer, "get_model", fake_get_model) + monkeypatch.setattr(DifyKnowledgeBaseLayer, "get_tools", fake_get_tools) + monkeypatch.setattr("dify_agent.runtime.runner.create_agent", fake_create_agent) + + request = CreateRunRequest( + composition=RunComposition( + layers=[ + RunLayerSpec( + name="prompt", + type="plain.prompt", + config=PromptLayerConfig(prefix="system", user="hello"), + ), + RunLayerSpec( + name="execution_context", + type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID, + config=DifyExecutionContextLayerConfig( + tenant_id="tenant-1", + user_id="user-1", + user_from="account", + app_id="app-1", + agent_mode="workflow_run", + invoke_from="service-api", + ), + ), + RunLayerSpec( + name=DIFY_AGENT_MODEL_LAYER_ID, + type="dify.plugin.llm", + deps={"execution_context": "execution_context"}, + config=DifyPluginLLMLayerConfig( + plugin_id="langgenius/openai", + model_provider="openai", + model="demo-model", + credentials={"api_key": "secret"}, + ), + ), + RunLayerSpec( + name="knowledge", + type=DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID, + deps={"execution_context": "execution_context"}, + config=DifyKnowledgeBaseLayerConfig.model_validate( + { + "dataset_ids": ["dataset-1"], + "retrieval": {"mode": "multiple", "top_k": 4}, + } + ), + ), + ] + ) + ) + sink = InMemoryRunEventSink() + + async def scenario() -> None: + async with ( + httpx.AsyncClient() as plugin_client, + httpx.AsyncClient(headers={"X-Test-Client": "dify-api"}) as dify_api_client, + ): + await AgentRunRunner( + sink=sink, + request=request, + run_id="run-knowledge-tools", + plugin_daemon_http_client=plugin_client, + dify_api_http_client=dify_api_client, + ).run() + + asyncio.run(scenario()) + + assert [tool.name for tool in seen_tools] == ["knowledge_base_search"] + + def test_runner_rejects_duplicate_tool_names_across_dynamic_tool_layers( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -1075,6 +1188,7 @@ def test_runner_rejects_duplicate_tool_names_across_dynamic_tool_layers( request=request, run_id="run-duplicate-tools", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -1182,6 +1296,7 @@ def test_runner_rejects_duplicate_tool_names_between_static_and_dynamic_tools( request=request, run_id="run-static-dynamic-duplicate-tools", plugin_daemon_http_client=client, + dify_api_http_client=client, layer_providers=layer_providers, ).run() @@ -1297,6 +1412,7 @@ def test_runner_rejects_duplicate_tool_names_between_shell_and_other_layers( request=request, run_id="run-shell-duplicate-tools", plugin_daemon_http_client=client, + dify_api_http_client=client, layer_providers=layer_providers, ).run() @@ -1325,6 +1441,7 @@ def test_runner_passes_temporary_system_prompt_prefix_without_history_layer(monk request=_request("current user"), run_id="run-no-history", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -1368,6 +1485,7 @@ def test_runner_prepends_current_system_prompt_to_stored_history_and_appends_onl request=request, run_id="run-history", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -1418,6 +1536,7 @@ def test_runner_with_empty_history_layer_still_sends_system_prompt_and_saves_onl request=request, run_id="run-empty-history", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -1468,6 +1587,7 @@ def test_runner_failure_with_history_layer_emits_failed_terminal_event_without_s request=request, run_id="run-history-failure", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -1499,6 +1619,7 @@ def test_runner_applies_on_exit_overrides_to_success_snapshot(monkeypatch: pytes request=request, run_id="run-exit", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -1559,6 +1680,7 @@ def test_runner_passes_output_layer_spec_to_agent_and_serializes_structured_resu request=request, run_id="run-structured-output", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() first_terminal = sink.events["run-structured-output"][-1] @@ -1572,6 +1694,7 @@ def test_runner_passes_output_layer_spec_to_agent_and_serializes_structured_resu request=resumed_request, run_id="run-structured-output-resume", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -1645,6 +1768,7 @@ def test_runner_retries_invalid_structured_output_and_eventually_succeeds(monkey request=request, run_id="run-output-retry-success", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -1698,6 +1822,7 @@ def test_runner_fails_when_invalid_structured_output_exhausts_retries(monkeypatc request=request, run_id="run-output-retry-failed", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -1734,6 +1859,7 @@ def test_runner_rejects_invalid_output_layer_before_model_resolution(monkeypatch request=request, run_id="run-invalid-output", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -1808,6 +1934,7 @@ def test_runner_rejects_misnamed_output_layer_before_model_resolution(monkeypatc request=request, run_id="run-misnamed-output", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -1894,6 +2021,7 @@ def test_runner_rejects_multiple_output_layers_before_model_resolution(monkeypat request=request, run_id="run-duplicate-output", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -1965,6 +2093,7 @@ def test_runner_rejects_reserved_output_name_with_wrong_layer_type_before_model_ request=request, run_id="run-wrong-output-type", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -2009,6 +2138,7 @@ def test_runner_rejects_misnamed_output_layer_before_provider_checks() -> None: request=request, run_id="run-misnamed-output-before-providers", plugin_daemon_http_client=client, + dify_api_http_client=client, layer_providers=(), ).run() @@ -2033,6 +2163,7 @@ def test_runner_rejects_unknown_on_exit_layer_id() -> None: request=request, run_id="run-unknown-signal", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -2053,6 +2184,7 @@ def test_runner_honors_explicit_empty_layer_providers() -> None: request=request, run_id="run-empty-providers", plugin_daemon_http_client=client, + dify_api_http_client=client, layer_providers=(), ).run() @@ -2074,6 +2206,7 @@ def test_runner_fails_empty_user_prompts() -> None: request=request, run_id="run-2", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -2094,6 +2227,7 @@ def test_runner_fails_blank_string_user_prompt_list() -> None: request=request, run_id="run-3", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -2114,6 +2248,7 @@ def test_runner_requires_llm_layer_id() -> None: request=request, run_id="run-4", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -2153,6 +2288,7 @@ def test_runner_rejects_closed_session_snapshot_as_validation_error() -> None: request=request, run_id="run-closed-snapshot", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -2205,6 +2341,7 @@ def test_runner_treats_missing_shell_entrypoint_as_validation_error() -> None: request=request, run_id="run-missing-shell-entrypoint", plugin_daemon_http_client=client, + dify_api_http_client=client, ).run() asyncio.run(scenario()) @@ -2282,6 +2419,7 @@ def test_runner_treats_invalid_shell_snapshot_offsets_as_validation_error() -> N request=request, run_id="run-invalid-shell-offset", plugin_daemon_http_client=client, + dify_api_http_client=client, layer_providers=create_default_layer_providers(shellctl_entrypoint="http://shellctl"), ).run() diff --git a/dify-agent/tests/local/dify_agent/server/test_app.py b/dify-agent/tests/local/dify_agent/server/test_app.py index 534b42e764..b12a636381 100644 --- a/dify-agent/tests/local/dify_agent/server/test_app.py +++ b/dify-agent/tests/local/dify_agent/server/test_app.py @@ -13,10 +13,12 @@ from shell_session_manager.shellctl.client import ShellctlClient import dify_agent.server.app as app_module from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig from dify_agent.layers.execution_context.layer import DifyExecutionContextLayer +from dify_agent.layers.knowledge.configs import DifyKnowledgeBaseLayerConfig +from dify_agent.layers.knowledge.layer import DifyKnowledgeBaseLayer from dify_agent.layers.shell import DifyShellLayerConfig from dify_agent.layers.shell.layer import DifyShellLayer from dify_agent.runtime.compositor_factory import DifyAgentLayerProvider -from dify_agent.server.app import create_app, create_plugin_daemon_http_client +from dify_agent.server.app import create_app, create_dify_api_inner_http_client, create_plugin_daemon_http_client from dify_agent.server.settings import ServerSettings from dify_agent.storage.redis_run_store import RedisRunStore @@ -67,6 +69,7 @@ class FakeRunScheduler: shutdown_grace_seconds: float layer_providers: tuple[DifyAgentLayerProvider, ...] plugin_daemon_http_client: FakePluginDaemonHttpClient + dify_api_http_client: FakePluginDaemonHttpClient shutdown_called: bool def __init__( @@ -74,6 +77,7 @@ class FakeRunScheduler: *, store: object, plugin_daemon_http_client: FakePluginDaemonHttpClient, + dify_api_http_client: FakePluginDaemonHttpClient, shutdown_grace_seconds: float, layer_providers: tuple[DifyAgentLayerProvider, ...], ) -> None: @@ -81,6 +85,7 @@ class FakeRunScheduler: self.shutdown_grace_seconds = shutdown_grace_seconds self.layer_providers = layer_providers self.plugin_daemon_http_client = plugin_daemon_http_client + self.dify_api_http_client = dify_api_http_client self.shutdown_called = False self.created.append(self) @@ -160,7 +165,22 @@ class FakeHttpxModule: def test_create_app_creates_scheduler_and_closes_after_shutdown(monkeypatch: pytest.MonkeyPatch) -> None: - fake_redis, fake_http_client = _patch_app_lifecycle(monkeypatch) + fake_redis = FakeRedis() + fake_http_client = FakePluginDaemonHttpClient() + fake_dify_api_http_client = FakePluginDaemonHttpClient() + FakeRunScheduler.created.clear() + FakeRedisModule.fake_redis = fake_redis + monkeypatch.setattr(app_module, "Redis", FakeRedisModule) + monkeypatch.setattr(app_module, "RunScheduler", FakeRunScheduler) + + def fake_create_plugin_daemon_http_client(_settings: ServerSettings) -> FakePluginDaemonHttpClient: + return fake_http_client + + def fake_create_dify_api_inner_http_client(_settings: ServerSettings) -> FakePluginDaemonHttpClient: + return fake_dify_api_http_client + + monkeypatch.setattr(app_module, "create_plugin_daemon_http_client", fake_create_plugin_daemon_http_client) + monkeypatch.setattr(app_module, "create_dify_api_inner_http_client", fake_create_dify_api_inner_http_client) settings = ServerSettings( redis_url="redis://example.invalid/0", @@ -169,19 +189,20 @@ def test_create_app_creates_scheduler_and_closes_after_shutdown(monkeypatch: pyt run_retention_seconds=7, plugin_daemon_url="http://plugin-daemon", plugin_daemon_api_key="daemon-secret", + dify_api_inner_url="http://dify-api", shellctl_entrypoint="http://shellctl", shellctl_auth_token="shell-secret", agent_stub_url="https://agent.example.com/agent-stub", server_secret_key=_base64url_secret(b"1" * 32), dify_api_base_url="https://api.example.com", dify_api_inner_api_key="inner-secret", - plugin_daemon_connect_timeout=1, - plugin_daemon_read_timeout=2, - plugin_daemon_write_timeout=3, - plugin_daemon_pool_timeout=4, - plugin_daemon_max_connections=5, - plugin_daemon_max_keepalive_connections=3, - plugin_daemon_keepalive_expiry=6, + outbound_http_connect_timeout=1, + outbound_http_read_timeout=2, + outbound_http_write_timeout=3, + outbound_http_pool_timeout=4, + outbound_http_max_connections=5, + outbound_http_max_keepalive_connections=3, + outbound_http_keepalive_expiry=6, ) with TestClient(create_app(settings)): @@ -207,6 +228,18 @@ def test_create_app_creates_scheduler_and_closes_after_shutdown(monkeypatch: pyt assert isinstance(shell_layer, DifyShellLayer) assert execution_context_layer.daemon_url == "http://plugin-daemon" assert execution_context_layer.daemon_api_key == "daemon-secret" + knowledge_provider = next(provider for provider in layer_providers if provider.type_id == "dify.knowledge_base") + knowledge_layer = knowledge_provider.create_layer( + DifyKnowledgeBaseLayerConfig.model_validate( + { + "dataset_ids": ["dataset-1"], + "retrieval": {"mode": "multiple", "top_k": 2}, + } + ) + ) + assert isinstance(knowledge_layer, DifyKnowledgeBaseLayer) + assert knowledge_layer.dify_api_inner_url == "http://dify-api" + assert knowledge_layer.dify_api_inner_api_key == "inner-secret" assert shell_layer.shellctl_entrypoint == "http://shellctl" assert shell_layer.agent_stub_url == "https://agent.example.com/agent-stub" shellctl_client = shell_layer.shellctl_client_factory("http://shellctl") @@ -216,6 +249,8 @@ def test_create_app_creates_scheduler_and_closes_after_shutdown(monkeypatch: pyt http_client = scheduler.plugin_daemon_http_client assert http_client is fake_http_client assert http_client.is_closed is False + assert scheduler.dify_api_http_client is fake_dify_api_http_client + assert scheduler.dify_api_http_client.is_closed is False store = scheduler.store assert isinstance(store, RedisRunStore) assert store.run_retention_seconds == 7 @@ -229,6 +264,7 @@ def test_create_app_creates_scheduler_and_closes_after_shutdown(monkeypatch: pyt ) assert FakeRunScheduler.created[0].shutdown_called is True + assert FakeRunScheduler.created[0].dify_api_http_client.is_closed is True assert FakeRunScheduler.created[0].plugin_daemon_http_client.is_closed is True assert fake_redis.closed is True @@ -326,21 +362,75 @@ def test_create_app_starts_and_stops_agent_stub_grpc_server_for_grpc_url(monkeyp assert fake_redis.closed is True -def test_create_plugin_daemon_http_client_uses_configured_httpx_construction_args( +def test_create_plugin_daemon_http_client_uses_generic_outbound_httpx_construction_args( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setattr(app_module, "httpx", FakeHttpxModule) - client = create_plugin_daemon_http_client(ServerSettings()) + client = create_plugin_daemon_http_client( + ServerSettings( + outbound_http_connect_timeout=1, + outbound_http_read_timeout=2, + outbound_http_write_timeout=3, + outbound_http_pool_timeout=4, + outbound_http_max_connections=5, + outbound_http_max_keepalive_connections=3, + outbound_http_keepalive_expiry=6, + ) + ) assert isinstance(client, FakePluginDaemonHttpClient) assert isinstance(client.timeout, FakeTimeout) - assert client.timeout.connect == 10 - assert client.timeout.read == 600 - assert client.timeout.write == 30 - assert client.timeout.pool == 10 + assert client.timeout.connect == 1 + assert client.timeout.read == 2 + assert client.timeout.write == 3 + assert client.timeout.pool == 4 assert isinstance(client.limits, FakeLimits) - assert client.limits.max_connections == 100 - assert client.limits.max_keepalive_connections == 20 - assert client.limits.keepalive_expiry == 30 + assert client.limits.max_connections == 5 + assert client.limits.max_keepalive_connections == 3 + assert client.limits.keepalive_expiry == 6 assert client.trust_env is False + + +def test_create_dify_api_inner_http_client_uses_generic_outbound_httpx_construction_args( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(app_module, "httpx", FakeHttpxModule) + + client = create_dify_api_inner_http_client( + ServerSettings( + outbound_http_connect_timeout=1, + outbound_http_read_timeout=2, + outbound_http_write_timeout=3, + outbound_http_pool_timeout=4, + outbound_http_max_connections=5, + outbound_http_max_keepalive_connections=3, + outbound_http_keepalive_expiry=6, + ) + ) + + assert isinstance(client, FakePluginDaemonHttpClient) + assert isinstance(client.timeout, FakeTimeout) + assert client.timeout.connect == 1 + assert client.timeout.read == 2 + assert client.timeout.write == 3 + assert client.timeout.pool == 4 + assert isinstance(client.limits, FakeLimits) + assert client.limits.max_connections == 5 + assert client.limits.max_keepalive_connections == 3 + assert client.limits.keepalive_expiry == 6 + assert client.trust_env is False + + +def test_server_settings_use_generic_outbound_http_args_for_shared_clients() -> None: + model_fields = ServerSettings.model_fields + + assert "dify_api_inner_url" in model_fields + assert "dify_api_inner_api_key" in model_fields + assert "outbound_http_connect_timeout" in model_fields + assert "outbound_http_read_timeout" in model_fields + assert "outbound_http_write_timeout" in model_fields + assert "outbound_http_pool_timeout" in model_fields + assert "outbound_http_max_connections" in model_fields + assert "outbound_http_max_keepalive_connections" in model_fields + assert "outbound_http_keepalive_expiry" in model_fields diff --git a/dify-agent/tests/local/dify_agent/server/test_settings.py b/dify-agent/tests/local/dify_agent/server/test_settings.py index 07b8e09f53..fb444f840c 100644 --- a/dify-agent/tests/local/dify_agent/server/test_settings.py +++ b/dify-agent/tests/local/dify_agent/server/test_settings.py @@ -129,12 +129,13 @@ def test_server_settings_normalizes_dify_api_base_url_from_env(monkeypatch: pyte assert settings.dify_api_inner_api_key == "inner-secret" -def test_server_settings_requires_dify_api_base_url_and_key_together() -> None: - with pytest.raises(ValidationError, match="DIFY_AGENT_DIFY_API_BASE_URL"): +def test_server_settings_requires_inner_api_key_when_dify_api_base_url_is_set() -> None: + with pytest.raises(ValidationError, match="DIFY_AGENT_DIFY_API_INNER_API_KEY"): _ = ServerSettings(dify_api_base_url="https://api.example.com") - with pytest.raises(ValidationError, match="DIFY_AGENT_DIFY_API_BASE_URL"): - _ = ServerSettings(dify_api_inner_api_key="inner-secret") + settings = ServerSettings(dify_api_inner_api_key="inner-secret") + assert settings.dify_api_inner_api_key == "inner-secret" + assert settings.dify_api_base_url is None def test_server_settings_rejects_dify_api_base_url_with_query_or_fragment() -> None: diff --git a/dify-agent/tests/local/dify_agent/test_import_boundaries.py b/dify-agent/tests/local/dify_agent/test_import_boundaries.py index 0ac1d77615..ecc2d54857 100644 --- a/dify-agent/tests/local/dify_agent/test_import_boundaries.py +++ b/dify-agent/tests/local/dify_agent/test_import_boundaries.py @@ -84,6 +84,8 @@ def test_protocol_and_dify_plugin_exports_do_not_import_server_only_modules() -> "dify_agent.layers.ask_human.layer", "dify_agent.layers.dify_plugin.llm_layer", "dify_agent.layers.dify_plugin.tools_layer", + "dify_agent.layers.knowledge.client", + "dify_agent.layers.knowledge.layer", "dify_agent.layers.output.output_layer", "dify_agent.layers.shell.layer", "dify_agent.runtime", @@ -103,6 +105,7 @@ def test_protocol_and_dify_plugin_exports_do_not_import_server_only_modules() -> "dify_agent.layers.execution_context", "dify_agent.layers.ask_human", "dify_agent.layers.dify_plugin", + "dify_agent.layers.knowledge", "dify_agent.layers.output", "dify_agent.layers.shell", ], @@ -112,6 +115,7 @@ def test_protocol_and_dify_plugin_exports_do_not_import_server_only_modules() -> "assert dify_agent_layers_execution_context.__all__ == ['DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID', 'DifyExecutionContextAgentMode', 'DifyExecutionContextInvokeFrom', 'DifyExecutionContextLayerConfig', 'DifyExecutionContextUserFrom']", "assert dify_agent_layers_ask_human.__all__ == ['AskHumanAction', 'AskHumanActionStyle', 'AskHumanField', 'AskHumanFieldType', 'AskHumanFileField', 'AskHumanFileListField', 'AskHumanParagraphField', 'AskHumanResultStatus', 'AskHumanSelectField', 'AskHumanSelectOption', 'AskHumanSelectedAction', 'AskHumanToolArgs', 'AskHumanToolResult', 'AskHumanUrgency', 'DEFAULT_ASK_HUMAN_TOOL_DESCRIPTION', 'DIFY_ASK_HUMAN_LAYER_TYPE_ID', 'DifyAskHumanLayerConfig']", "assert dify_agent_layers_dify_plugin.__all__ == ['DIFY_PLUGIN_LLM_LAYER_TYPE_ID', 'DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID', 'DifyPluginCredentialValue', 'DifyPluginLLMLayerConfig', 'DifyPluginToolCredentialType', 'DifyPluginToolConfig', 'DifyPluginToolOption', 'DifyPluginToolParameter', 'DifyPluginToolParameterForm', 'DifyPluginToolParameterType', 'DifyPluginToolsLayerConfig', 'DifyPluginToolValue']", + "assert dify_agent_layers_knowledge.__all__ == ['DIFY_KNOWLEDGE_BASE_LAYER_TYPE_ID', 'DifyKnowledgeBaseLayerConfig', 'DifyKnowledgeMetadataCondition', 'DifyKnowledgeMetadataConditions', 'DifyKnowledgeMetadataFilteringConfig', 'DifyKnowledgeModelConfig', 'DifyKnowledgeRerankingModelConfig', 'DifyKnowledgeRetrievalConfig']", "assert dify_agent_layers_output.__all__ == ['DIFY_OUTPUT_LAYER_TYPE_ID', 'DifyOutputLayerConfig']", "assert dify_agent_layers_shell.__all__ == ['DIFY_SHELL_LAYER_TYPE_ID', 'DifyShellCliToolConfig', 'DifyShellEnvVarConfig', 'DifyShellLayerConfig', 'DifyShellSandboxConfig', 'DifyShellSecretRefConfig']", ],